Fixed #27849 -- Added filtering support to aggregates.
This commit is contained in:
parent
489421b015
commit
b78d100fa6
|
@ -8,10 +8,10 @@ __all__ = [
|
||||||
|
|
||||||
|
|
||||||
class StatAggregate(Aggregate):
|
class StatAggregate(Aggregate):
|
||||||
def __init__(self, y, x, output_field=FloatField()):
|
def __init__(self, y, x, output_field=FloatField(), filter=None):
|
||||||
if not x or not y:
|
if not x or not y:
|
||||||
raise ValueError('Both y and x must be provided.')
|
raise ValueError('Both y and x must be provided.')
|
||||||
super().__init__(y, x, output_field=output_field)
|
super().__init__(y, x, output_field=output_field, filter=filter)
|
||||||
|
|
||||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||||
return super().resolve_expression(query, allow_joins, reuse, summarize)
|
return super().resolve_expression(query, allow_joins, reuse, summarize)
|
||||||
|
@ -22,9 +22,9 @@ class Corr(StatAggregate):
|
||||||
|
|
||||||
|
|
||||||
class CovarPop(StatAggregate):
|
class CovarPop(StatAggregate):
|
||||||
def __init__(self, y, x, sample=False):
|
def __init__(self, y, x, sample=False, filter=None):
|
||||||
self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
|
self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
|
||||||
super().__init__(y, x)
|
super().__init__(y, x, filter=filter)
|
||||||
|
|
||||||
|
|
||||||
class RegrAvgX(StatAggregate):
|
class RegrAvgX(StatAggregate):
|
||||||
|
@ -38,8 +38,8 @@ class RegrAvgY(StatAggregate):
|
||||||
class RegrCount(StatAggregate):
|
class RegrCount(StatAggregate):
|
||||||
function = 'REGR_COUNT'
|
function = 'REGR_COUNT'
|
||||||
|
|
||||||
def __init__(self, y, x):
|
def __init__(self, y, x, filter=None):
|
||||||
super().__init__(y=y, x=x, output_field=IntegerField())
|
super().__init__(y=y, x=x, output_field=IntegerField(), filter=filter)
|
||||||
|
|
||||||
def convert_value(self, value, expression, connection):
|
def convert_value(self, value, expression, connection):
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|
|
@ -229,6 +229,10 @@ class BaseDatabaseFeatures:
|
||||||
supports_select_difference = True
|
supports_select_difference = True
|
||||||
supports_slicing_ordering_in_compound = False
|
supports_slicing_ordering_in_compound = False
|
||||||
|
|
||||||
|
# Does the database support SQL 2003 FILTER (WHERE ...) in aggregate
|
||||||
|
# expressions?
|
||||||
|
supports_aggregate_filter_clause = False
|
||||||
|
|
||||||
# Does the backend support indexing a TextField?
|
# Does the backend support indexing a TextField?
|
||||||
supports_index_on_text_field = True
|
supports_index_on_text_field = True
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
END;
|
END;
|
||||||
$$ LANGUAGE plpgsql;"""
|
$$ LANGUAGE plpgsql;"""
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def supports_aggregate_filter_clause(self):
|
||||||
|
return self.connection.pg_version >= 90400
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def has_select_for_update_skip_locked(self):
|
def has_select_for_update_skip_locked(self):
|
||||||
return self.connection.pg_version >= 90500
|
return self.connection.pg_version >= 90500
|
||||||
|
|
|
@ -2,8 +2,9 @@
|
||||||
Classes to represent the definitions of aggregate functions.
|
Classes to represent the definitions of aggregate functions.
|
||||||
"""
|
"""
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db.models.expressions import Func, Star
|
from django.db.models.expressions import Case, Func, Star, When
|
||||||
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
||||||
|
from django.db.models.query_utils import Q
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
||||||
|
@ -13,12 +14,36 @@ __all__ = [
|
||||||
class Aggregate(Func):
|
class Aggregate(Func):
|
||||||
contains_aggregate = True
|
contains_aggregate = True
|
||||||
name = None
|
name = None
|
||||||
|
filter_template = '%s FILTER (WHERE %%(filter)s)'
|
||||||
|
|
||||||
|
def __init__(self, *args, filter=None, **kwargs):
|
||||||
|
self.filter = filter
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_source_fields(self):
|
||||||
|
# Don't return the filter expression since it's not a source field.
|
||||||
|
return [e._output_field_or_none for e in super().get_source_expressions()]
|
||||||
|
|
||||||
|
def get_source_expressions(self):
|
||||||
|
source_expressions = super().get_source_expressions()
|
||||||
|
if self.filter:
|
||||||
|
source_expressions += [self.filter]
|
||||||
|
return source_expressions
|
||||||
|
|
||||||
|
def set_source_expressions(self, exprs):
|
||||||
|
if self.filter:
|
||||||
|
self.filter = exprs.pop()
|
||||||
|
return super().set_source_expressions(exprs)
|
||||||
|
|
||||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||||
# Aggregates are not allowed in UPDATE queries, so ignore for_save
|
# Aggregates are not allowed in UPDATE queries, so ignore for_save
|
||||||
c = super().resolve_expression(query, allow_joins, reuse, summarize)
|
c = super().resolve_expression(query, allow_joins, reuse, summarize)
|
||||||
|
if c.filter:
|
||||||
|
c.filter = c.filter.resolve_expression(query, allow_joins, reuse, summarize)
|
||||||
if not summarize:
|
if not summarize:
|
||||||
expressions = c.get_source_expressions()
|
# Call Aggregate.get_source_expressions() to avoid
|
||||||
|
# returning self.filter and including that in this loop.
|
||||||
|
expressions = super(Aggregate, c).get_source_expressions()
|
||||||
for index, expr in enumerate(expressions):
|
for index, expr in enumerate(expressions):
|
||||||
if expr.contains_aggregate:
|
if expr.contains_aggregate:
|
||||||
before_resolved = self.get_source_expressions()[index]
|
before_resolved = self.get_source_expressions()[index]
|
||||||
|
@ -36,6 +61,29 @@ class Aggregate(Func):
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
|
if self.filter:
|
||||||
|
if connection.features.supports_aggregate_filter_clause:
|
||||||
|
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
||||||
|
template = self.filter_template % extra_context.get('template', self.template)
|
||||||
|
sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql)
|
||||||
|
return sql, params + filter_params
|
||||||
|
else:
|
||||||
|
copy = self.copy()
|
||||||
|
copy.filter = None
|
||||||
|
condition = When(Q())
|
||||||
|
source_expressions = copy.get_source_expressions()
|
||||||
|
condition.set_source_expressions([self.filter, source_expressions[0]])
|
||||||
|
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
|
||||||
|
return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
|
||||||
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
def _get_repr_options(self):
|
||||||
|
options = super()._get_repr_options()
|
||||||
|
if self.filter:
|
||||||
|
options.update({'filter': self.filter})
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
class Avg(Aggregate):
|
class Avg(Aggregate):
|
||||||
function = 'AVG'
|
function = 'AVG'
|
||||||
|
@ -52,7 +100,7 @@ class Avg(Aggregate):
|
||||||
expression = self.get_source_expressions()[0]
|
expression = self.get_source_expressions()[0]
|
||||||
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
|
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
|
||||||
return compiler.compile(
|
return compiler.compile(
|
||||||
SecondsToInterval(Avg(IntervalToSeconds(expression)))
|
SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
|
||||||
)
|
)
|
||||||
return super().as_sql(compiler, connection)
|
return super().as_sql(compiler, connection)
|
||||||
|
|
||||||
|
@ -62,16 +110,19 @@ class Count(Aggregate):
|
||||||
name = 'Count'
|
name = 'Count'
|
||||||
template = '%(function)s(%(distinct)s%(expressions)s)'
|
template = '%(function)s(%(distinct)s%(expressions)s)'
|
||||||
|
|
||||||
def __init__(self, expression, distinct=False, **extra):
|
def __init__(self, expression, distinct=False, filter=None, **extra):
|
||||||
if expression == '*':
|
if expression == '*':
|
||||||
expression = Star()
|
expression = Star()
|
||||||
|
if isinstance(expression, Star) and filter is not None:
|
||||||
|
raise ValueError('Star cannot be used with filter. Please specify a field.')
|
||||||
super().__init__(
|
super().__init__(
|
||||||
expression, distinct='DISTINCT ' if distinct else '',
|
expression, distinct='DISTINCT ' if distinct else '',
|
||||||
output_field=IntegerField(), **extra
|
output_field=IntegerField(), filter=filter, **extra
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_repr_options(self):
|
def _get_repr_options(self):
|
||||||
return {'distinct': self.extra['distinct'] != ''}
|
options = super()._get_repr_options()
|
||||||
|
return dict(options, distinct=self.extra['distinct'] != '')
|
||||||
|
|
||||||
def convert_value(self, value, expression, connection):
|
def convert_value(self, value, expression, connection):
|
||||||
if value is None:
|
if value is None:
|
||||||
|
@ -97,7 +148,8 @@ class StdDev(Aggregate):
|
||||||
super().__init__(expression, output_field=FloatField(), **extra)
|
super().__init__(expression, output_field=FloatField(), **extra)
|
||||||
|
|
||||||
def _get_repr_options(self):
|
def _get_repr_options(self):
|
||||||
return {'sample': self.function == 'STDDEV_SAMP'}
|
options = super()._get_repr_options()
|
||||||
|
return dict(options, sample=self.function == 'STDDEV_SAMP')
|
||||||
|
|
||||||
def convert_value(self, value, expression, connection):
|
def convert_value(self, value, expression, connection):
|
||||||
if value is None:
|
if value is None:
|
||||||
|
@ -127,7 +179,8 @@ class Variance(Aggregate):
|
||||||
super().__init__(expression, output_field=FloatField(), **extra)
|
super().__init__(expression, output_field=FloatField(), **extra)
|
||||||
|
|
||||||
def _get_repr_options(self):
|
def _get_repr_options(self):
|
||||||
return {'sample': self.function == 'VAR_SAMP'}
|
options = super()._get_repr_options()
|
||||||
|
return dict(options, sample=self.function == 'VAR_SAMP')
|
||||||
|
|
||||||
def convert_value(self, value, expression, connection):
|
def convert_value(self, value, expression, connection):
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|
|
@ -22,7 +22,7 @@ General-purpose aggregation functions
|
||||||
``ArrayAgg``
|
``ArrayAgg``
|
||||||
------------
|
------------
|
||||||
|
|
||||||
.. class:: ArrayAgg(expression, distinct=False, **extra)
|
.. class:: ArrayAgg(expression, distinct=False, filter=None, **extra)
|
||||||
|
|
||||||
Returns a list of values, including nulls, concatenated into an array.
|
Returns a list of values, including nulls, concatenated into an array.
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ General-purpose aggregation functions
|
||||||
``BitAnd``
|
``BitAnd``
|
||||||
----------
|
----------
|
||||||
|
|
||||||
.. class:: BitAnd(expression, **extra)
|
.. class:: BitAnd(expression, filter=None, **extra)
|
||||||
|
|
||||||
Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or
|
Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or
|
||||||
``None`` if all values are null.
|
``None`` if all values are null.
|
||||||
|
@ -44,7 +44,7 @@ General-purpose aggregation functions
|
||||||
``BitOr``
|
``BitOr``
|
||||||
---------
|
---------
|
||||||
|
|
||||||
.. class:: BitOr(expression, **extra)
|
.. class:: BitOr(expression, filter=None, **extra)
|
||||||
|
|
||||||
Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or
|
Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or
|
||||||
``None`` if all values are null.
|
``None`` if all values are null.
|
||||||
|
@ -52,7 +52,7 @@ General-purpose aggregation functions
|
||||||
``BoolAnd``
|
``BoolAnd``
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
.. class:: BoolAnd(expression, **extra)
|
.. class:: BoolAnd(expression, filter=None, **extra)
|
||||||
|
|
||||||
Returns ``True``, if all input values are true, ``None`` if all values are
|
Returns ``True``, if all input values are true, ``None`` if all values are
|
||||||
null or if there are no values, otherwise ``False`` .
|
null or if there are no values, otherwise ``False`` .
|
||||||
|
@ -60,7 +60,7 @@ General-purpose aggregation functions
|
||||||
``BoolOr``
|
``BoolOr``
|
||||||
----------
|
----------
|
||||||
|
|
||||||
.. class:: BoolOr(expression, **extra)
|
.. class:: BoolOr(expression, filter=None, **extra)
|
||||||
|
|
||||||
Returns ``True`` if at least one input value is true, ``None`` if all
|
Returns ``True`` if at least one input value is true, ``None`` if all
|
||||||
values are null or if there are no values, otherwise ``False``.
|
values are null or if there are no values, otherwise ``False``.
|
||||||
|
@ -68,7 +68,7 @@ General-purpose aggregation functions
|
||||||
``JSONBAgg``
|
``JSONBAgg``
|
||||||
------------
|
------------
|
||||||
|
|
||||||
.. class:: JSONBAgg(expressions, **extra)
|
.. class:: JSONBAgg(expressions, filter=None, **extra)
|
||||||
|
|
||||||
.. versionadded:: 1.11
|
.. versionadded:: 1.11
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ General-purpose aggregation functions
|
||||||
``StringAgg``
|
``StringAgg``
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
.. class:: StringAgg(expression, delimiter, distinct=False)
|
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None)
|
||||||
|
|
||||||
Returns the input values concatenated into a string, separated by
|
Returns the input values concatenated into a string, separated by
|
||||||
the ``delimiter`` string.
|
the ``delimiter`` string.
|
||||||
|
@ -105,7 +105,7 @@ field or an expression returning a numeric data. Both are required.
|
||||||
``Corr``
|
``Corr``
|
||||||
--------
|
--------
|
||||||
|
|
||||||
.. class:: Corr(y, x)
|
.. class:: Corr(y, x, filter=None)
|
||||||
|
|
||||||
Returns the correlation coefficient as a ``float``, or ``None`` if there
|
Returns the correlation coefficient as a ``float``, or ``None`` if there
|
||||||
aren't any matching rows.
|
aren't any matching rows.
|
||||||
|
@ -113,7 +113,7 @@ field or an expression returning a numeric data. Both are required.
|
||||||
``CovarPop``
|
``CovarPop``
|
||||||
------------
|
------------
|
||||||
|
|
||||||
.. class:: CovarPop(y, x, sample=False)
|
.. class:: CovarPop(y, x, sample=False, filter=None)
|
||||||
|
|
||||||
Returns the population covariance as a ``float``, or ``None`` if there
|
Returns the population covariance as a ``float``, or ``None`` if there
|
||||||
aren't any matching rows.
|
aren't any matching rows.
|
||||||
|
@ -129,7 +129,7 @@ field or an expression returning a numeric data. Both are required.
|
||||||
``RegrAvgX``
|
``RegrAvgX``
|
||||||
------------
|
------------
|
||||||
|
|
||||||
.. class:: RegrAvgX(y, x)
|
.. class:: RegrAvgX(y, x, filter=None)
|
||||||
|
|
||||||
Returns the average of the independent variable (``sum(x)/N``) as a
|
Returns the average of the independent variable (``sum(x)/N``) as a
|
||||||
``float``, or ``None`` if there aren't any matching rows.
|
``float``, or ``None`` if there aren't any matching rows.
|
||||||
|
@ -137,7 +137,7 @@ field or an expression returning a numeric data. Both are required.
|
||||||
``RegrAvgY``
|
``RegrAvgY``
|
||||||
------------
|
------------
|
||||||
|
|
||||||
.. class:: RegrAvgY(y, x)
|
.. class:: RegrAvgY(y, x, filter=None)
|
||||||
|
|
||||||
Returns the average of the dependent variable (``sum(y)/N``) as a
|
Returns the average of the dependent variable (``sum(y)/N``) as a
|
||||||
``float``, or ``None`` if there aren't any matching rows.
|
``float``, or ``None`` if there aren't any matching rows.
|
||||||
|
@ -145,7 +145,7 @@ field or an expression returning a numeric data. Both are required.
|
||||||
``RegrCount``
|
``RegrCount``
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
.. class:: RegrCount(y, x)
|
.. class:: RegrCount(y, x, filter=None)
|
||||||
|
|
||||||
Returns an ``int`` of the number of input rows in which both expressions
|
Returns an ``int`` of the number of input rows in which both expressions
|
||||||
are not null.
|
are not null.
|
||||||
|
@ -153,7 +153,7 @@ field or an expression returning a numeric data. Both are required.
|
||||||
``RegrIntercept``
|
``RegrIntercept``
|
||||||
-----------------
|
-----------------
|
||||||
|
|
||||||
.. class:: RegrIntercept(y, x)
|
.. class:: RegrIntercept(y, x, filter=None)
|
||||||
|
|
||||||
Returns the y-intercept of the least-squares-fit linear equation determined
|
Returns the y-intercept of the least-squares-fit linear equation determined
|
||||||
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
|
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
|
||||||
|
@ -162,7 +162,7 @@ field or an expression returning a numeric data. Both are required.
|
||||||
``RegrR2``
|
``RegrR2``
|
||||||
----------
|
----------
|
||||||
|
|
||||||
.. class:: RegrR2(y, x)
|
.. class:: RegrR2(y, x, filter=None)
|
||||||
|
|
||||||
Returns the square of the correlation coefficient as a ``float``, or
|
Returns the square of the correlation coefficient as a ``float``, or
|
||||||
``None`` if there aren't any matching rows.
|
``None`` if there aren't any matching rows.
|
||||||
|
@ -170,7 +170,7 @@ field or an expression returning a numeric data. Both are required.
|
||||||
``RegrSlope``
|
``RegrSlope``
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
.. class:: RegrSlope(y, x)
|
.. class:: RegrSlope(y, x, filter=None)
|
||||||
|
|
||||||
Returns the slope of the least-squares-fit linear equation determined
|
Returns the slope of the least-squares-fit linear equation determined
|
||||||
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
|
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
|
||||||
|
@ -179,7 +179,7 @@ field or an expression returning a numeric data. Both are required.
|
||||||
``RegrSXX``
|
``RegrSXX``
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
.. class:: RegrSXX(y, x)
|
.. class:: RegrSXX(y, x, filter=None)
|
||||||
|
|
||||||
Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent
|
Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent
|
||||||
variable) as a ``float``, or ``None`` if there aren't any matching rows.
|
variable) as a ``float``, or ``None`` if there aren't any matching rows.
|
||||||
|
@ -187,7 +187,7 @@ field or an expression returning a numeric data. Both are required.
|
||||||
``RegrSXY``
|
``RegrSXY``
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
.. class:: RegrSXY(y, x)
|
.. class:: RegrSXY(y, x, filter=None)
|
||||||
|
|
||||||
Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent
|
Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent
|
||||||
times dependent variable) as a ``float``, or ``None`` if there aren't any
|
times dependent variable) as a ``float``, or ``None`` if there aren't any
|
||||||
|
@ -196,7 +196,7 @@ field or an expression returning a numeric data. Both are required.
|
||||||
``RegrSYY``
|
``RegrSYY``
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
.. class:: RegrSYY(y, x)
|
.. class:: RegrSYY(y, x, filter=None)
|
||||||
|
|
||||||
Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent
|
Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent
|
||||||
variable) as a ``float``, or ``None`` if there aren't any matching rows.
|
variable) as a ``float``, or ``None`` if there aren't any matching rows.
|
||||||
|
|
|
@ -184,12 +184,14 @@ their registration dates. We can do this using a conditional expression and the
|
||||||
>>> Client.objects.values_list('name', 'account_type')
|
>>> Client.objects.values_list('name', 'account_type')
|
||||||
<QuerySet [('Jane Doe', 'G'), ('James Smith', 'R'), ('Jack Black', 'P')]>
|
<QuerySet [('Jane Doe', 'G'), ('James Smith', 'R'), ('Jack Black', 'P')]>
|
||||||
|
|
||||||
|
.. _conditional-aggregation:
|
||||||
|
|
||||||
Conditional aggregation
|
Conditional aggregation
|
||||||
-----------------------
|
-----------------------
|
||||||
|
|
||||||
What if we want to find out how many clients there are for each
|
What if we want to find out how many clients there are for each
|
||||||
``account_type``? We can nest conditional expression within
|
``account_type``? We can use the ``filter`` argument of :ref:`aggregate
|
||||||
:ref:`aggregate functions <aggregation-functions>` to achieve this::
|
functions <aggregation-functions>` to achieve this::
|
||||||
|
|
||||||
>>> # Create some more Clients first so we can have something to count
|
>>> # Create some more Clients first so we can have something to count
|
||||||
>>> Client.objects.create(
|
>>> Client.objects.create(
|
||||||
|
@ -207,17 +209,30 @@ What if we want to find out how many clients there are for each
|
||||||
>>> # Get counts for each value of account_type
|
>>> # Get counts for each value of account_type
|
||||||
>>> from django.db.models import IntegerField, Sum
|
>>> from django.db.models import IntegerField, Sum
|
||||||
>>> Client.objects.aggregate(
|
>>> Client.objects.aggregate(
|
||||||
... regular=Sum(
|
... regular=Count('pk', filter=Q(account_type=Client.REGULAR)),
|
||||||
... Case(When(account_type=Client.REGULAR, then=1),
|
... gold=Count('pk', filter=Q(account_type=Client.GOLD)),
|
||||||
... output_field=IntegerField())
|
... platinum=Count('pk', filter=Q(account_type=Client.PLATINUM)),
|
||||||
... ),
|
|
||||||
... gold=Sum(
|
|
||||||
... Case(When(account_type=Client.GOLD, then=1),
|
|
||||||
... output_field=IntegerField())
|
|
||||||
... ),
|
|
||||||
... platinum=Sum(
|
|
||||||
... Case(When(account_type=Client.PLATINUM, then=1),
|
|
||||||
... output_field=IntegerField())
|
|
||||||
... )
|
|
||||||
... )
|
... )
|
||||||
{'regular': 2, 'gold': 1, 'platinum': 3}
|
{'regular': 2, 'gold': 1, 'platinum': 3}
|
||||||
|
|
||||||
|
This aggregate produces a query with the SQL 2003 ``FILTER WHERE`` syntax
|
||||||
|
on databases that support it:
|
||||||
|
|
||||||
|
.. code-block:: sql
|
||||||
|
|
||||||
|
SELECT count('id') FILTER (WHERE account_type=1) as regular,
|
||||||
|
count('id') FILTER (WHERE account_type=2) as gold,
|
||||||
|
count('id') FILTER (WHERE account_type=3) as platinum
|
||||||
|
FROM clients;
|
||||||
|
|
||||||
|
On other databases, this is emulated using a ``CASE`` statement:
|
||||||
|
|
||||||
|
.. code-block:: sql
|
||||||
|
|
||||||
|
SELECT count(CASE WHEN account_type=1 THEN id ELSE null) as regular,
|
||||||
|
count(CASE WHEN account_type=2 THEN id ELSE null) as gold,
|
||||||
|
count(CASE WHEN account_type=3 THEN id ELSE null) as platinum
|
||||||
|
FROM clients;
|
||||||
|
|
||||||
|
The two SQL statements are functionally equivalent but the more explicit
|
||||||
|
``FILTER`` may perform better.
|
||||||
|
|
|
@ -339,7 +339,7 @@ some complex computations::
|
||||||
|
|
||||||
The ``Aggregate`` API is as follows:
|
The ``Aggregate`` API is as follows:
|
||||||
|
|
||||||
.. class:: Aggregate(expression, output_field=None, **extra)
|
.. class:: Aggregate(expression, output_field=None, filter=None, **extra)
|
||||||
|
|
||||||
.. attribute:: template
|
.. attribute:: template
|
||||||
|
|
||||||
|
@ -370,9 +370,17 @@ should define the desired ``output_field``. For example, adding an
|
||||||
``IntegerField()`` and a ``FloatField()`` together should probably have
|
``IntegerField()`` and a ``FloatField()`` together should probably have
|
||||||
``output_field=FloatField()`` defined.
|
``output_field=FloatField()`` defined.
|
||||||
|
|
||||||
|
The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
|
||||||
|
used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
|
||||||
|
and :ref:`filtering-on-annotations` for example usage.
|
||||||
|
|
||||||
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
|
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
|
||||||
into the ``template`` attribute.
|
into the ``template`` attribute.
|
||||||
|
|
||||||
|
.. versionchanged:: 2.0
|
||||||
|
|
||||||
|
The ``filter`` argument was added.
|
||||||
|
|
||||||
Creating your own Aggregate Functions
|
Creating your own Aggregate Functions
|
||||||
-------------------------------------
|
-------------------------------------
|
||||||
|
|
||||||
|
|
|
@ -3085,6 +3085,17 @@ of the return value
|
||||||
``output_field`` if all fields are of the same type. Otherwise, you
|
``output_field`` if all fields are of the same type. Otherwise, you
|
||||||
must provide the ``output_field`` yourself.
|
must provide the ``output_field`` yourself.
|
||||||
|
|
||||||
|
``filter``
|
||||||
|
~~~~~~~~~~
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
|
||||||
|
An optional :class:`Q object <django.db.models.Q>` that's used to filter the
|
||||||
|
rows that are aggregated.
|
||||||
|
|
||||||
|
See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for
|
||||||
|
example usage.
|
||||||
|
|
||||||
``**extra``
|
``**extra``
|
||||||
~~~~~~~~~~~
|
~~~~~~~~~~~
|
||||||
|
|
||||||
|
@ -3094,7 +3105,7 @@ by the aggregate.
|
||||||
``Avg``
|
``Avg``
|
||||||
~~~~~~~
|
~~~~~~~
|
||||||
|
|
||||||
.. class:: Avg(expression, output_field=FloatField(), **extra)
|
.. class:: Avg(expression, output_field=FloatField(), filter=None, **extra)
|
||||||
|
|
||||||
Returns the mean value of the given expression, which must be numeric
|
Returns the mean value of the given expression, which must be numeric
|
||||||
unless you specify a different ``output_field``.
|
unless you specify a different ``output_field``.
|
||||||
|
@ -3106,7 +3117,7 @@ by the aggregate.
|
||||||
``Count``
|
``Count``
|
||||||
~~~~~~~~~
|
~~~~~~~~~
|
||||||
|
|
||||||
.. class:: Count(expression, distinct=False, **extra)
|
.. class:: Count(expression, distinct=False, filter=None, **extra)
|
||||||
|
|
||||||
Returns the number of objects that are related through the provided
|
Returns the number of objects that are related through the provided
|
||||||
expression.
|
expression.
|
||||||
|
@ -3125,7 +3136,7 @@ by the aggregate.
|
||||||
``Max``
|
``Max``
|
||||||
~~~~~~~
|
~~~~~~~
|
||||||
|
|
||||||
.. class:: Max(expression, output_field=None, **extra)
|
.. class:: Max(expression, output_field=None, filter=None, **extra)
|
||||||
|
|
||||||
Returns the maximum value of the given expression.
|
Returns the maximum value of the given expression.
|
||||||
|
|
||||||
|
@ -3135,7 +3146,7 @@ by the aggregate.
|
||||||
``Min``
|
``Min``
|
||||||
~~~~~~~
|
~~~~~~~
|
||||||
|
|
||||||
.. class:: Min(expression, output_field=None, **extra)
|
.. class:: Min(expression, output_field=None, filter=None, **extra)
|
||||||
|
|
||||||
Returns the minimum value of the given expression.
|
Returns the minimum value of the given expression.
|
||||||
|
|
||||||
|
@ -3145,7 +3156,7 @@ by the aggregate.
|
||||||
``StdDev``
|
``StdDev``
|
||||||
~~~~~~~~~~
|
~~~~~~~~~~
|
||||||
|
|
||||||
.. class:: StdDev(expression, sample=False, **extra)
|
.. class:: StdDev(expression, sample=False, filter=None, **extra)
|
||||||
|
|
||||||
Returns the standard deviation of the data in the provided expression.
|
Returns the standard deviation of the data in the provided expression.
|
||||||
|
|
||||||
|
@ -3169,7 +3180,7 @@ by the aggregate.
|
||||||
``Sum``
|
``Sum``
|
||||||
~~~~~~~
|
~~~~~~~
|
||||||
|
|
||||||
.. class:: Sum(expression, output_field=None, **extra)
|
.. class:: Sum(expression, output_field=None, filter=None, **extra)
|
||||||
|
|
||||||
Computes the sum of all values of the given expression.
|
Computes the sum of all values of the given expression.
|
||||||
|
|
||||||
|
@ -3179,7 +3190,7 @@ by the aggregate.
|
||||||
``Variance``
|
``Variance``
|
||||||
~~~~~~~~~~~~
|
~~~~~~~~~~~~
|
||||||
|
|
||||||
.. class:: Variance(expression, sample=False, **extra)
|
.. class:: Variance(expression, sample=False, filter=None, **extra)
|
||||||
|
|
||||||
Returns the variance of the data in the provided expression.
|
Returns the variance of the data in the provided expression.
|
||||||
|
|
||||||
|
|
|
@ -273,6 +273,10 @@ Models
|
||||||
parameters, if the backend supports this feature. Of Django's built-in
|
parameters, if the backend supports this feature. Of Django's built-in
|
||||||
backends, only Oracle supports it.
|
backends, only Oracle supports it.
|
||||||
|
|
||||||
|
* The new ``filter`` argument for built-in aggregates allows :ref:`adding
|
||||||
|
different conditionals <conditional-aggregation>` to multiple aggregations
|
||||||
|
over the same fields or relations.
|
||||||
|
|
||||||
Requests and Responses
|
Requests and Responses
|
||||||
~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,16 @@ In a hurry? Here's how to do common aggregate queries, assuming the models above
|
||||||
>>> pubs[0].num_books
|
>>> pubs[0].num_books
|
||||||
73
|
73
|
||||||
|
|
||||||
|
# Each publisher, with a separate count of books with a rating above and below 5
|
||||||
|
>>> from django.db.models import Q
|
||||||
|
>>> above_5 = Count('book', filter=Q(book__rating__gt=5))
|
||||||
|
>>> below_5 = Count('book', filter=Q(book__rating__lte=5))
|
||||||
|
>>> pubs = Publisher.objects.annotate(below_5=below_5).annotate(above_5=above_5)
|
||||||
|
>>> pubs[0].above_5
|
||||||
|
23
|
||||||
|
>>> pubs[0].below_5
|
||||||
|
12
|
||||||
|
|
||||||
# The top 5 publishers, in order by number of books.
|
# The top 5 publishers, in order by number of books.
|
||||||
>>> pubs = Publisher.objects.annotate(num_books=Count('book')).order_by('-num_books')[:5]
|
>>> pubs = Publisher.objects.annotate(num_books=Count('book')).order_by('-num_books')[:5]
|
||||||
>>> pubs[0].num_books
|
>>> pubs[0].num_books
|
||||||
|
@ -324,6 +334,8 @@ title that starts with "Django" using the query::
|
||||||
|
|
||||||
>>> Book.objects.filter(name__startswith="Django").aggregate(Avg('price'))
|
>>> Book.objects.filter(name__startswith="Django").aggregate(Avg('price'))
|
||||||
|
|
||||||
|
.. _filtering-on-annotations:
|
||||||
|
|
||||||
Filtering on annotations
|
Filtering on annotations
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
@ -339,6 +351,27 @@ you can issue the query::
|
||||||
This query generates an annotated result set, and then generates a filter
|
This query generates an annotated result set, and then generates a filter
|
||||||
based upon that annotation.
|
based upon that annotation.
|
||||||
|
|
||||||
|
If you need two annotations with two separate filters you can use the
|
||||||
|
``filter`` argument with any aggregate. For example, to generate a list of
|
||||||
|
authors with a count of highly rated books::
|
||||||
|
|
||||||
|
>>> highly_rated = Count('books', filter=Q(books__rating__gte=7))
|
||||||
|
>>> Author.objects.annotate(num_books=Count('books'), highly_rated_books=highly_rated)
|
||||||
|
|
||||||
|
Each ``Author`` in the result set will have the ``num_books`` and
|
||||||
|
``highly_rated_books`` attributes.
|
||||||
|
|
||||||
|
.. admonition:: Choosing between ``filter`` and ``QuerySet.filter()``
|
||||||
|
|
||||||
|
Avoid using the ``filter`` argument with a single annotation or
|
||||||
|
aggregation. It's more efficient to use ``QuerySet.filter()`` to exclude
|
||||||
|
rows. The aggregation ``filter`` argument is only useful when using two or
|
||||||
|
more aggregations over the same relations with different conditionals.
|
||||||
|
|
||||||
|
.. versionchanged:: 2.0
|
||||||
|
|
||||||
|
The ``filter`` argument was added to aggregates.
|
||||||
|
|
||||||
Order of ``annotate()`` and ``filter()`` clauses
|
Order of ``annotate()`` and ``filter()`` clauses
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from django.db.models import Case, Count, F, Q, Sum, When
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from .models import Author, Book, Publisher
|
||||||
|
|
||||||
|
|
||||||
|
class FilteredAggregateTests(TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpTestData(cls):
|
||||||
|
cls.a1 = Author.objects.create(name='test', age=40)
|
||||||
|
cls.a2 = Author.objects.create(name='test2', age=60)
|
||||||
|
cls.a3 = Author.objects.create(name='test3', age=100)
|
||||||
|
cls.p1 = Publisher.objects.create(name='Apress', num_awards=3, duration=datetime.timedelta(days=1))
|
||||||
|
cls.b1 = Book.objects.create(
|
||||||
|
isbn='159059725', name='The Definitive Guide to Django: Web Development Done Right',
|
||||||
|
pages=447, rating=4.5, price=Decimal('30.00'), contact=cls.a1, publisher=cls.p1,
|
||||||
|
pubdate=datetime.date(2007, 12, 6),
|
||||||
|
)
|
||||||
|
cls.b2 = Book.objects.create(
|
||||||
|
isbn='067232959', name='Sams Teach Yourself Django in 24 Hours',
|
||||||
|
pages=528, rating=3.0, price=Decimal('23.09'), contact=cls.a2, publisher=cls.p1,
|
||||||
|
pubdate=datetime.date(2008, 3, 3),
|
||||||
|
)
|
||||||
|
cls.b3 = Book.objects.create(
|
||||||
|
isbn='159059996', name='Practical Django Projects',
|
||||||
|
pages=600, rating=4.5, price=Decimal('29.69'), contact=cls.a3, publisher=cls.p1,
|
||||||
|
pubdate=datetime.date(2008, 6, 23),
|
||||||
|
)
|
||||||
|
cls.a1.friends.add(cls.a2)
|
||||||
|
cls.a1.friends.add(cls.a3)
|
||||||
|
cls.b1.authors.add(cls.a1)
|
||||||
|
cls.b1.authors.add(cls.a3)
|
||||||
|
cls.b2.authors.add(cls.a2)
|
||||||
|
cls.b3.authors.add(cls.a3)
|
||||||
|
|
||||||
|
def test_filtered_aggregates(self):
|
||||||
|
agg = Sum('age', filter=Q(name__startswith='test'))
|
||||||
|
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 200)
|
||||||
|
|
||||||
|
def test_double_filtered_aggregates(self):
|
||||||
|
agg = Sum('age', filter=Q(Q(name='test2') & ~Q(name='test')))
|
||||||
|
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 60)
|
||||||
|
|
||||||
|
def test_excluded_aggregates(self):
|
||||||
|
agg = Sum('age', filter=~Q(name='test2'))
|
||||||
|
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 140)
|
||||||
|
|
||||||
|
def test_related_aggregates_m2m(self):
|
||||||
|
agg = Sum('friends__age', filter=~Q(friends__name='test'))
|
||||||
|
self.assertEqual(Author.objects.filter(name='test').aggregate(age=agg)['age'], 160)
|
||||||
|
|
||||||
|
def test_related_aggregates_m2m_and_fk(self):
|
||||||
|
q = Q(friends__book__publisher__name='Apress') & ~Q(friends__name='test3')
|
||||||
|
agg = Sum('friends__book__pages', filter=q)
|
||||||
|
self.assertEqual(Author.objects.filter(name='test').aggregate(pages=agg)['pages'], 528)
|
||||||
|
|
||||||
|
def test_plain_annotate(self):
|
||||||
|
agg = Sum('book__pages', filter=Q(book__rating__gt=3))
|
||||||
|
qs = Author.objects.annotate(pages=agg).order_by('pk')
|
||||||
|
self.assertSequenceEqual([a.pages for a in qs], [447, None, 1047])
|
||||||
|
|
||||||
|
def test_filtered_aggregate_on_annotate(self):
|
||||||
|
pages_annotate = Sum('book__pages', filter=Q(book__rating__gt=3))
|
||||||
|
age_agg = Sum('age', filter=Q(total_pages__gte=400))
|
||||||
|
aggregated = Author.objects.annotate(total_pages=pages_annotate).aggregate(summed_age=age_agg)
|
||||||
|
self.assertEqual(aggregated, {'summed_age': 140})
|
||||||
|
|
||||||
|
def test_case_aggregate(self):
|
||||||
|
agg = Sum(
|
||||||
|
Case(When(friends__age=40, then=F('friends__age'))),
|
||||||
|
filter=Q(friends__name__startswith='test'),
|
||||||
|
)
|
||||||
|
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 80)
|
||||||
|
|
||||||
|
def test_sum_star_exception(self):
|
||||||
|
msg = 'Star cannot be used with filter. Please specify a field.'
|
||||||
|
with self.assertRaisesMessage(ValueError, msg):
|
||||||
|
Count('*', filter=Q(age=40))
|
|
@ -5,7 +5,7 @@ from copy import deepcopy
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import DatabaseError, connection, models, transaction
|
from django.db import DatabaseError, connection, models, transaction
|
||||||
from django.db.models import CharField, TimeField, UUIDField
|
from django.db.models import CharField, Q, TimeField, UUIDField
|
||||||
from django.db.models.aggregates import (
|
from django.db.models.aggregates import (
|
||||||
Avg, Count, Max, Min, StdDev, Sum, Variance,
|
Avg, Count, Max, Min, StdDev, Sum, Variance,
|
||||||
)
|
)
|
||||||
|
@ -1369,3 +1369,16 @@ class ReprTests(TestCase):
|
||||||
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
|
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
|
||||||
self.assertEqual(repr(Sum('a')), "Sum(F(a))")
|
self.assertEqual(repr(Sum('a')), "Sum(F(a))")
|
||||||
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
|
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
|
||||||
|
|
||||||
|
def test_filtered_aggregates(self):
|
||||||
|
filter = Q(a=1)
|
||||||
|
self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))")
|
||||||
|
self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), distinct=False, filter=(AND: ('a', 1)))")
|
||||||
|
self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))")
|
||||||
|
self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))")
|
||||||
|
self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)")
|
||||||
|
self.assertEqual(repr(Sum('a', filter=filter)), "Sum(F(a), filter=(AND: ('a', 1)))")
|
||||||
|
self.assertEqual(
|
||||||
|
repr(Variance('a', sample=True, filter=filter)),
|
||||||
|
"Variance(F(a), filter=(AND: ('a', 1)), sample=True)"
|
||||||
|
)
|
||||||
|
|
|
@ -1253,6 +1253,15 @@ class CaseDocumentationExamples(TestCase):
|
||||||
account_type=Client.PLATINUM,
|
account_type=Client.PLATINUM,
|
||||||
registered_on=date.today(),
|
registered_on=date.today(),
|
||||||
)
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
Client.objects.aggregate(
|
||||||
|
regular=models.Count('pk', filter=Q(account_type=Client.REGULAR)),
|
||||||
|
gold=models.Count('pk', filter=Q(account_type=Client.GOLD)),
|
||||||
|
platinum=models.Count('pk', filter=Q(account_type=Client.PLATINUM)),
|
||||||
|
),
|
||||||
|
{'regular': 2, 'gold': 1, 'platinum': 3}
|
||||||
|
)
|
||||||
|
# This was the example before the filter argument was added.
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
Client.objects.aggregate(
|
Client.objects.aggregate(
|
||||||
regular=models.Sum(Case(
|
regular=models.Sum(Case(
|
||||||
|
|
Loading…
Reference in New Issue