2009-01-15 19:06:34 +08:00
|
|
|
"""
|
|
|
|
Classes to represent the definitions of aggregate functions.
|
|
|
|
"""
|
2013-12-25 21:13:18 +08:00
|
|
|
from django.core.exceptions import FieldError
|
2017-04-22 23:44:51 +08:00
|
|
|
from django.db.models.expressions import Case, Func, Star, When
|
2016-04-06 11:48:08 +08:00
|
|
|
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
Refactored qs.add_q() and utils/tree.py
The sql/query.py add_q method did a lot of where/having tree hacking to
get complex queries to work correctly. The logic was refactored so that
it should be simpler to understand. The new logic should also produce
leaner WHERE conditions.
The changes cascade somewhat, as some other parts of Django (like
add_filter() and WhereNode) expect boolean trees in certain format or
they fail to work. So to fix the add_q() one must fix utils/tree.py,
some things in add_filter(), WhereNode and so on.
This commit also fixed add_filter to see negate clauses up the path.
A query like .exclude(Q(reversefk__in=a_list)) didn't work similarly to
.filter(~Q(reversefk__in=a_list)). The reason for this is that only
the immediate parent negate clauses were seen by add_filter, and thus a
tree like AND: (NOT AND: (AND: condition)) will not be handled
correctly, as there is one intermediary AND node in the tree. The
example tree is generated by .exclude(~Q(reversefk__in=a_list)).
Still, aggregation lost connectors in OR cases, and F() objects and
aggregates in same filter clause caused GROUP BY problems on some
databases.
Fixed #17600, fixed #13198, fixed #17025, fixed #17000, fixed #11293.
2012-05-25 05:27:24 +08:00
|
|
|
|
2013-10-18 19:25:30 +08:00
|
|
|
__all__ = [
|
|
|
|
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
|
|
|
]
|
|
|
|
|
2013-07-08 08:39:54 +08:00
|
|
|
|
2013-12-25 21:13:18 +08:00
|
|
|
class Aggregate(Func):
|
|
|
|
contains_aggregate = True
|
|
|
|
name = None
|
2017-04-22 23:44:51 +08:00
|
|
|
filter_template = '%s FILTER (WHERE %%(filter)s)'
|
2017-09-18 21:42:29 +08:00
|
|
|
window_compatible = True
|
2017-04-22 23:44:51 +08:00
|
|
|
|
|
|
|
def __init__(self, *args, filter=None, **kwargs):
|
|
|
|
self.filter = filter
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
def get_source_fields(self):
|
|
|
|
# Don't return the filter expression since it's not a source field.
|
|
|
|
return [e._output_field_or_none for e in super().get_source_expressions()]
|
|
|
|
|
|
|
|
def get_source_expressions(self):
|
|
|
|
source_expressions = super().get_source_expressions()
|
|
|
|
if self.filter:
|
|
|
|
source_expressions += [self.filter]
|
|
|
|
return source_expressions
|
|
|
|
|
|
|
|
def set_source_expressions(self, exprs):
|
2018-01-04 07:52:12 +08:00
|
|
|
self.filter = self.filter and exprs.pop()
|
2017-04-22 23:44:51 +08:00
|
|
|
return super().set_source_expressions(exprs)
|
2013-12-25 21:13:18 +08:00
|
|
|
|
2015-01-02 09:39:31 +08:00
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
|
|
# Aggregates are not allowed in UPDATE queries, so ignore for_save
|
2017-01-21 21:13:44 +08:00
|
|
|
c = super().resolve_expression(query, allow_joins, reuse, summarize)
|
2018-01-04 07:52:12 +08:00
|
|
|
c.filter = c.filter and c.filter.resolve_expression(query, allow_joins, reuse, summarize)
|
2015-06-04 23:25:38 +08:00
|
|
|
if not summarize:
|
2017-04-22 23:44:51 +08:00
|
|
|
# Call Aggregate.get_source_expressions() to avoid
|
|
|
|
# returning self.filter and including that in this loop.
|
|
|
|
expressions = super(Aggregate, c).get_source_expressions()
|
2015-06-04 23:25:38 +08:00
|
|
|
for index, expr in enumerate(expressions):
|
|
|
|
if expr.contains_aggregate:
|
|
|
|
before_resolved = self.get_source_expressions()[index]
|
|
|
|
name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
|
|
|
|
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
|
2013-12-25 21:13:18 +08:00
|
|
|
return c
|
|
|
|
|
|
|
|
@property
|
|
|
|
def default_alias(self):
|
2015-06-04 23:25:38 +08:00
|
|
|
expressions = self.get_source_expressions()
|
|
|
|
if len(expressions) == 1 and hasattr(expressions[0], 'name'):
|
|
|
|
return '%s__%s' % (expressions[0].name, self.name.lower())
|
2013-12-25 21:13:18 +08:00
|
|
|
raise TypeError("Complex expressions require an alias")
|
|
|
|
|
|
|
|
def get_group_by_cols(self):
|
|
|
|
return []
|
|
|
|
|
2017-04-22 23:44:51 +08:00
|
|
|
def as_sql(self, compiler, connection, **extra_context):
|
|
|
|
if self.filter:
|
|
|
|
if connection.features.supports_aggregate_filter_clause:
|
|
|
|
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
|
|
|
template = self.filter_template % extra_context.get('template', self.template)
|
|
|
|
sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql)
|
|
|
|
return sql, params + filter_params
|
|
|
|
else:
|
|
|
|
copy = self.copy()
|
|
|
|
copy.filter = None
|
|
|
|
source_expressions = copy.get_source_expressions()
|
2017-12-08 23:59:49 +08:00
|
|
|
condition = When(self.filter, then=source_expressions[0])
|
2017-04-22 23:44:51 +08:00
|
|
|
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
|
|
|
|
return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
|
|
|
|
return super().as_sql(compiler, connection, **extra_context)
|
|
|
|
|
|
|
|
def _get_repr_options(self):
|
|
|
|
options = super()._get_repr_options()
|
|
|
|
if self.filter:
|
|
|
|
options.update({'filter': self.filter})
|
|
|
|
return options
|
|
|
|
|
2013-07-08 08:39:54 +08:00
|
|
|
|
2009-01-15 19:06:34 +08:00
|
|
|
class Avg(Aggregate):
|
2013-12-25 21:13:18 +08:00
|
|
|
function = 'AVG'
|
2009-01-15 19:06:34 +08:00
|
|
|
name = 'Avg'
|
|
|
|
|
2016-04-06 11:48:08 +08:00
|
|
|
def _resolve_output_field(self):
|
|
|
|
source_field = self.get_source_fields()[0]
|
|
|
|
if isinstance(source_field, (IntegerField, DecimalField)):
|
2017-07-07 12:04:37 +08:00
|
|
|
return FloatField()
|
|
|
|
return super()._resolve_output_field()
|
2013-12-25 21:13:18 +08:00
|
|
|
|
2017-12-28 01:23:08 +08:00
|
|
|
def as_mysql(self, compiler, connection):
|
|
|
|
sql, params = super().as_sql(compiler, connection)
|
|
|
|
if self.output_field.get_internal_type() == 'DurationField':
|
|
|
|
sql = 'CAST(%s as SIGNED)' % sql
|
|
|
|
return sql, params
|
|
|
|
|
2015-05-23 16:12:09 +08:00
|
|
|
def as_oracle(self, compiler, connection):
|
|
|
|
if self.output_field.get_internal_type() == 'DurationField':
|
|
|
|
expression = self.get_source_expressions()[0]
|
|
|
|
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
|
|
|
|
return compiler.compile(
|
2017-04-22 23:44:51 +08:00
|
|
|
SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
|
2015-05-23 16:12:09 +08:00
|
|
|
)
|
2017-01-21 21:13:44 +08:00
|
|
|
return super().as_sql(compiler, connection)
|
2015-05-23 16:12:09 +08:00
|
|
|
|
2013-07-08 08:39:54 +08:00
|
|
|
|
2009-01-15 19:06:34 +08:00
|
|
|
class Count(Aggregate):
|
2013-12-25 21:13:18 +08:00
|
|
|
function = 'COUNT'
|
2009-01-15 19:06:34 +08:00
|
|
|
name = 'Count'
|
2013-12-25 21:13:18 +08:00
|
|
|
template = '%(function)s(%(distinct)s%(expressions)s)'
|
2017-09-29 12:37:49 +08:00
|
|
|
output_field = IntegerField()
|
2013-12-25 21:13:18 +08:00
|
|
|
|
2017-04-22 23:44:51 +08:00
|
|
|
def __init__(self, expression, distinct=False, filter=None, **extra):
|
2013-12-25 21:13:18 +08:00
|
|
|
if expression == '*':
|
2015-09-11 00:07:09 +08:00
|
|
|
expression = Star()
|
2017-04-22 23:44:51 +08:00
|
|
|
if isinstance(expression, Star) and filter is not None:
|
|
|
|
raise ValueError('Star cannot be used with filter. Please specify a field.')
|
2017-01-21 21:13:44 +08:00
|
|
|
super().__init__(
|
|
|
|
expression, distinct='DISTINCT ' if distinct else '',
|
2017-09-29 12:37:49 +08:00
|
|
|
filter=filter, **extra
|
2017-01-21 21:13:44 +08:00
|
|
|
)
|
2013-12-25 21:13:18 +08:00
|
|
|
|
2017-07-17 22:07:19 +08:00
|
|
|
def _get_repr_options(self):
|
2017-12-11 20:08:45 +08:00
|
|
|
return {**super()._get_repr_options(), 'distinct': self.extra['distinct'] != ''}
|
2015-01-27 10:40:32 +08:00
|
|
|
|
2017-07-07 01:18:05 +08:00
|
|
|
def convert_value(self, value, expression, connection):
|
2017-09-16 17:24:59 +08:00
|
|
|
return 0 if value is None else value
|
2009-01-15 19:06:34 +08:00
|
|
|
|
2013-07-08 08:39:54 +08:00
|
|
|
|
2009-01-15 19:06:34 +08:00
|
|
|
class Max(Aggregate):
|
2013-12-25 21:13:18 +08:00
|
|
|
function = 'MAX'
|
2009-01-15 19:06:34 +08:00
|
|
|
name = 'Max'
|
|
|
|
|
2013-07-08 08:39:54 +08:00
|
|
|
|
2009-01-15 19:06:34 +08:00
|
|
|
class Min(Aggregate):
|
2013-12-25 21:13:18 +08:00
|
|
|
function = 'MIN'
|
2009-01-15 19:06:34 +08:00
|
|
|
name = 'Min'
|
|
|
|
|
2013-07-08 08:39:54 +08:00
|
|
|
|
2009-01-15 19:06:34 +08:00
|
|
|
class StdDev(Aggregate):
|
|
|
|
name = 'StdDev'
|
2017-09-29 12:37:49 +08:00
|
|
|
output_field = FloatField()
|
2009-01-15 19:06:34 +08:00
|
|
|
|
2013-12-25 21:13:18 +08:00
|
|
|
def __init__(self, expression, sample=False, **extra):
|
|
|
|
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
|
2017-09-29 12:37:49 +08:00
|
|
|
super().__init__(expression, **extra)
|
2013-12-25 21:13:18 +08:00
|
|
|
|
2017-07-17 22:07:19 +08:00
|
|
|
def _get_repr_options(self):
|
2017-12-11 20:08:45 +08:00
|
|
|
return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'}
|
2015-01-27 10:40:32 +08:00
|
|
|
|
2013-07-08 08:39:54 +08:00
|
|
|
|
2009-01-15 19:06:34 +08:00
|
|
|
class Sum(Aggregate):
|
2013-12-25 21:13:18 +08:00
|
|
|
function = 'SUM'
|
2009-01-15 19:06:34 +08:00
|
|
|
name = 'Sum'
|
|
|
|
|
2017-12-28 01:23:08 +08:00
|
|
|
def as_mysql(self, compiler, connection):
|
|
|
|
sql, params = super().as_sql(compiler, connection)
|
|
|
|
if self.output_field.get_internal_type() == 'DurationField':
|
|
|
|
sql = 'CAST(%s as SIGNED)' % sql
|
|
|
|
return sql, params
|
|
|
|
|
2015-05-23 16:12:09 +08:00
|
|
|
def as_oracle(self, compiler, connection):
|
|
|
|
if self.output_field.get_internal_type() == 'DurationField':
|
|
|
|
expression = self.get_source_expressions()[0]
|
|
|
|
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
|
|
|
|
return compiler.compile(
|
|
|
|
SecondsToInterval(Sum(IntervalToSeconds(expression)))
|
|
|
|
)
|
2017-01-21 21:13:44 +08:00
|
|
|
return super().as_sql(compiler, connection)
|
2015-05-23 16:12:09 +08:00
|
|
|
|
2013-07-08 08:39:54 +08:00
|
|
|
|
2009-01-15 19:06:34 +08:00
|
|
|
class Variance(Aggregate):
|
|
|
|
name = 'Variance'
|
2017-09-29 12:37:49 +08:00
|
|
|
output_field = FloatField()
|
2013-12-25 21:13:18 +08:00
|
|
|
|
|
|
|
def __init__(self, expression, sample=False, **extra):
|
|
|
|
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
|
2017-09-29 12:37:49 +08:00
|
|
|
super().__init__(expression, **extra)
|
2013-12-25 21:13:18 +08:00
|
|
|
|
2017-07-17 22:07:19 +08:00
|
|
|
def _get_repr_options(self):
|
2017-12-11 20:08:45 +08:00
|
|
|
return {**super()._get_repr_options(), 'sample': self.function == 'VAR_SAMP'}
|