
144 lines
4.9 KiB

Classes to represent the definitions of aggregate functions.
from django.core.exceptions import FieldError
from django.db.models.expressions import Func, Star
from django.db.models.fields import FloatField, IntegerField
__all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
class Aggregate(Func):
contains_aggregate = True
name = None
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
# Aggregates are not allowed in UPDATE queries, so ignore for_save
c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize)
if not summarize:
expressions = c.get_source_expressions()
for index, expr in enumerate(expressions):
if expr.contains_aggregate:
before_resolved = self.get_source_expressions()[index]
name = if hasattr(before_resolved, 'name') else repr(before_resolved)
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (, name, name))
return c
def default_alias(self):
expressions = self.get_source_expressions()
if len(expressions) == 1 and hasattr(expressions[0], 'name'):
return '%s__%s' % (expressions[0].name,
raise TypeError("Complex expressions require an alias")
def get_group_by_cols(self):
return []
class Avg(Aggregate):
function = 'AVG'
name = 'Avg'
def __init__(self, expression, **extra):
output_field = extra.pop('output_field', FloatField())
super(Avg, self).__init__(expression, output_field=output_field, **extra)
def as_oracle(self, compiler, connection):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
from import IntervalToSeconds, SecondsToInterval
return compiler.compile(
return super(Avg, self).as_sql(compiler, connection)
class Count(Aggregate):
function = 'COUNT'
name = 'Count'
template = '%(function)s(%(distinct)s%(expressions)s)'
def __init__(self, expression, distinct=False, **extra):
if expression == '*':
expression = Star()
super(Count, self).__init__(
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
def __repr__(self):
return "{}({}, distinct={})".format(
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.extra['distinct'] == '' else 'True',
def convert_value(self, value, expression, connection, context):
if value is None:
return 0
return int(value)
class Max(Aggregate):
function = 'MAX'
name = 'Max'
class Min(Aggregate):
function = 'MIN'
name = 'Min'
class StdDev(Aggregate):
name = 'StdDev'
def __init__(self, expression, sample=False, **extra):
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
def __repr__(self):
return "{}({}, sample={})".format(
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.function == 'STDDEV_POP' else 'True',
def convert_value(self, value, expression, connection, context):
if value is None:
return value
return float(value)
class Sum(Aggregate):
function = 'SUM'
name = 'Sum'
def as_oracle(self, compiler, connection):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
from import IntervalToSeconds, SecondsToInterval
return compiler.compile(
return super(Sum, self).as_sql(compiler, connection)
class Variance(Aggregate):
name = 'Variance'
def __init__(self, expression, sample=False, **extra):
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
def __repr__(self):
return "{}({}, sample={})".format(
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.function == 'VAR_POP' else 'True',
def convert_value(self, value, expression, connection, context):
if value is None:
return value
return float(value)