Fixed #27849 -- Added filtering support to aggregates.

This commit is contained in:
Tom 2017-04-22 16:44:51 +01:00 committed by Tim Graham
parent 489421b015
commit b78d100fa6
13 changed files with 290 additions and 55 deletions

View File

@ -8,10 +8,10 @@ __all__ = [
class StatAggregate(Aggregate):
def __init__(self, y, x, output_field=FloatField()):
def __init__(self, y, x, output_field=FloatField(), filter=None):
if not x or not y:
raise ValueError('Both y and x must be provided.')
super().__init__(y, x, output_field=output_field)
super().__init__(y, x, output_field=output_field, filter=filter)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
return super().resolve_expression(query, allow_joins, reuse, summarize)
@ -22,9 +22,9 @@ class Corr(StatAggregate):
class CovarPop(StatAggregate):
def __init__(self, y, x, sample=False):
def __init__(self, y, x, sample=False, filter=None):
self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
super().__init__(y, x)
super().__init__(y, x, filter=filter)
class RegrAvgX(StatAggregate):
@ -38,8 +38,8 @@ class RegrAvgY(StatAggregate):
class RegrCount(StatAggregate):
function = 'REGR_COUNT'
def __init__(self, y, x):
super().__init__(y=y, x=x, output_field=IntegerField())
def __init__(self, y, x, filter=None):
super().__init__(y=y, x=x, output_field=IntegerField(), filter=filter)
def convert_value(self, value, expression, connection):
if value is None:

View File

@ -229,6 +229,10 @@ class BaseDatabaseFeatures:
supports_select_difference = True
supports_slicing_ordering_in_compound = False
# Does the database support SQL 2003 FILTER (WHERE ...) in aggregate
# expressions?
supports_aggregate_filter_clause = False
# Does the backend support indexing a TextField?
supports_index_on_text_field = True

View File

@ -50,6 +50,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
END;
$$ LANGUAGE plpgsql;"""
@cached_property
def supports_aggregate_filter_clause(self):
return self.connection.pg_version >= 90400
@cached_property
def has_select_for_update_skip_locked(self):
return self.connection.pg_version >= 90500

View File

@ -2,8 +2,9 @@
Classes to represent the definitions of aggregate functions.
"""
from django.core.exceptions import FieldError
from django.db.models.expressions import Func, Star
from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import DecimalField, FloatField, IntegerField
from django.db.models.query_utils import Q
__all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
@ -13,12 +14,36 @@ __all__ = [
class Aggregate(Func):
contains_aggregate = True
name = None
filter_template = '%s FILTER (WHERE %%(filter)s)'
def __init__(self, *args, filter=None, **kwargs):
self.filter = filter
super().__init__(*args, **kwargs)
def get_source_fields(self):
# Don't return the filter expression since it's not a source field.
return [e._output_field_or_none for e in super().get_source_expressions()]
def get_source_expressions(self):
source_expressions = super().get_source_expressions()
if self.filter:
source_expressions += [self.filter]
return source_expressions
def set_source_expressions(self, exprs):
if self.filter:
self.filter = exprs.pop()
return super().set_source_expressions(exprs)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
# Aggregates are not allowed in UPDATE queries, so ignore for_save
c = super().resolve_expression(query, allow_joins, reuse, summarize)
if c.filter:
c.filter = c.filter.resolve_expression(query, allow_joins, reuse, summarize)
if not summarize:
expressions = c.get_source_expressions()
# Call Aggregate.get_source_expressions() to avoid
# returning self.filter and including that in this loop.
expressions = super(Aggregate, c).get_source_expressions()
for index, expr in enumerate(expressions):
if expr.contains_aggregate:
before_resolved = self.get_source_expressions()[index]
@ -36,6 +61,29 @@ class Aggregate(Func):
def get_group_by_cols(self):
return []
def as_sql(self, compiler, connection, **extra_context):
if self.filter:
if connection.features.supports_aggregate_filter_clause:
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
template = self.filter_template % extra_context.get('template', self.template)
sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql)
return sql, params + filter_params
else:
copy = self.copy()
copy.filter = None
condition = When(Q())
source_expressions = copy.get_source_expressions()
condition.set_source_expressions([self.filter, source_expressions[0]])
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
return super().as_sql(compiler, connection, **extra_context)
def _get_repr_options(self):
options = super()._get_repr_options()
if self.filter:
options.update({'filter': self.filter})
return options
class Avg(Aggregate):
function = 'AVG'
@ -52,7 +100,7 @@ class Avg(Aggregate):
expression = self.get_source_expressions()[0]
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile(
SecondsToInterval(Avg(IntervalToSeconds(expression)))
SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
)
return super().as_sql(compiler, connection)
@ -62,16 +110,19 @@ class Count(Aggregate):
name = 'Count'
template = '%(function)s(%(distinct)s%(expressions)s)'
def __init__(self, expression, distinct=False, **extra):
def __init__(self, expression, distinct=False, filter=None, **extra):
if expression == '*':
expression = Star()
if isinstance(expression, Star) and filter is not None:
raise ValueError('Star cannot be used with filter. Please specify a field.')
super().__init__(
expression, distinct='DISTINCT ' if distinct else '',
output_field=IntegerField(), **extra
output_field=IntegerField(), filter=filter, **extra
)
def _get_repr_options(self):
return {'distinct': self.extra['distinct'] != ''}
options = super()._get_repr_options()
return dict(options, distinct=self.extra['distinct'] != '')
def convert_value(self, value, expression, connection):
if value is None:
@ -97,7 +148,8 @@ class StdDev(Aggregate):
super().__init__(expression, output_field=FloatField(), **extra)
def _get_repr_options(self):
return {'sample': self.function == 'STDDEV_SAMP'}
options = super()._get_repr_options()
return dict(options, sample=self.function == 'STDDEV_SAMP')
def convert_value(self, value, expression, connection):
if value is None:
@ -127,7 +179,8 @@ class Variance(Aggregate):
super().__init__(expression, output_field=FloatField(), **extra)
def _get_repr_options(self):
return {'sample': self.function == 'VAR_SAMP'}
options = super()._get_repr_options()
return dict(options, sample=self.function == 'VAR_SAMP')
def convert_value(self, value, expression, connection):
if value is None:

View File

@ -22,7 +22,7 @@ General-purpose aggregation functions
``ArrayAgg``
------------
.. class:: ArrayAgg(expression, distinct=False, **extra)
.. class:: ArrayAgg(expression, distinct=False, filter=None, **extra)
Returns a list of values, including nulls, concatenated into an array.
@ -36,7 +36,7 @@ General-purpose aggregation functions
``BitAnd``
----------
.. class:: BitAnd(expression, **extra)
.. class:: BitAnd(expression, filter=None, **extra)
Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or
``None`` if all values are null.
@ -44,7 +44,7 @@ General-purpose aggregation functions
``BitOr``
---------
.. class:: BitOr(expression, **extra)
.. class:: BitOr(expression, filter=None, **extra)
Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or
``None`` if all values are null.
@ -52,7 +52,7 @@ General-purpose aggregation functions
``BoolAnd``
-----------
.. class:: BoolAnd(expression, **extra)
.. class:: BoolAnd(expression, filter=None, **extra)
Returns ``True``, if all input values are true, ``None`` if all values are
null or if there are no values, otherwise ``False`` .
@ -60,7 +60,7 @@ General-purpose aggregation functions
``BoolOr``
----------
.. class:: BoolOr(expression, **extra)
.. class:: BoolOr(expression, filter=None, **extra)
Returns ``True`` if at least one input value is true, ``None`` if all
values are null or if there are no values, otherwise ``False``.
@ -68,7 +68,7 @@ General-purpose aggregation functions
``JSONBAgg``
------------
.. class:: JSONBAgg(expressions, **extra)
.. class:: JSONBAgg(expressions, filter=None, **extra)
.. versionadded:: 1.11
@ -77,7 +77,7 @@ General-purpose aggregation functions
``StringAgg``
-------------
.. class:: StringAgg(expression, delimiter, distinct=False)
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None)
Returns the input values concatenated into a string, separated by
the ``delimiter`` string.
@ -105,7 +105,7 @@ field or an expression returning a numeric data. Both are required.
``Corr``
--------
.. class:: Corr(y, x)
.. class:: Corr(y, x, filter=None)
Returns the correlation coefficient as a ``float``, or ``None`` if there
aren't any matching rows.
@ -113,7 +113,7 @@ field or an expression returning a numeric data. Both are required.
``CovarPop``
------------
.. class:: CovarPop(y, x, sample=False)
.. class:: CovarPop(y, x, sample=False, filter=None)
Returns the population covariance as a ``float``, or ``None`` if there
aren't any matching rows.
@ -129,7 +129,7 @@ field or an expression returning a numeric data. Both are required.
``RegrAvgX``
------------
.. class:: RegrAvgX(y, x)
.. class:: RegrAvgX(y, x, filter=None)
Returns the average of the independent variable (``sum(x)/N``) as a
``float``, or ``None`` if there aren't any matching rows.
@ -137,7 +137,7 @@ field or an expression returning a numeric data. Both are required.
``RegrAvgY``
------------
.. class:: RegrAvgY(y, x)
.. class:: RegrAvgY(y, x, filter=None)
Returns the average of the dependent variable (``sum(y)/N``) as a
``float``, or ``None`` if there aren't any matching rows.
@ -145,7 +145,7 @@ field or an expression returning a numeric data. Both are required.
``RegrCount``
-------------
.. class:: RegrCount(y, x)
.. class:: RegrCount(y, x, filter=None)
Returns an ``int`` of the number of input rows in which both expressions
are not null.
@ -153,7 +153,7 @@ field or an expression returning a numeric data. Both are required.
``RegrIntercept``
-----------------
.. class:: RegrIntercept(y, x)
.. class:: RegrIntercept(y, x, filter=None)
Returns the y-intercept of the least-squares-fit linear equation determined
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
@ -162,7 +162,7 @@ field or an expression returning a numeric data. Both are required.
``RegrR2``
----------
.. class:: RegrR2(y, x)
.. class:: RegrR2(y, x, filter=None)
Returns the square of the correlation coefficient as a ``float``, or
``None`` if there aren't any matching rows.
@ -170,7 +170,7 @@ field or an expression returning a numeric data. Both are required.
``RegrSlope``
-------------
.. class:: RegrSlope(y, x)
.. class:: RegrSlope(y, x, filter=None)
Returns the slope of the least-squares-fit linear equation determined
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
@ -179,7 +179,7 @@ field or an expression returning a numeric data. Both are required.
``RegrSXX``
-----------
.. class:: RegrSXX(y, x)
.. class:: RegrSXX(y, x, filter=None)
Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent
variable) as a ``float``, or ``None`` if there aren't any matching rows.
@ -187,7 +187,7 @@ field or an expression returning a numeric data. Both are required.
``RegrSXY``
-----------
.. class:: RegrSXY(y, x)
.. class:: RegrSXY(y, x, filter=None)
Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent
times dependent variable) as a ``float``, or ``None`` if there aren't any
@ -196,7 +196,7 @@ field or an expression returning a numeric data. Both are required.
``RegrSYY``
-----------
.. class:: RegrSYY(y, x)
.. class:: RegrSYY(y, x, filter=None)
Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent
variable) as a ``float``, or ``None`` if there aren't any matching rows.

View File

@ -184,12 +184,14 @@ their registration dates. We can do this using a conditional expression and the
>>> Client.objects.values_list('name', 'account_type')
<QuerySet [('Jane Doe', 'G'), ('James Smith', 'R'), ('Jack Black', 'P')]>
.. _conditional-aggregation:
Conditional aggregation
-----------------------
What if we want to find out how many clients there are for each
``account_type``? We can nest conditional expression within
:ref:`aggregate functions <aggregation-functions>` to achieve this::
``account_type``? We can use the ``filter`` argument of :ref:`aggregate
functions <aggregation-functions>` to achieve this::
>>> # Create some more Clients first so we can have something to count
>>> Client.objects.create(
@ -207,17 +209,30 @@ What if we want to find out how many clients there are for each
>>> # Get counts for each value of account_type
>>> from django.db.models import IntegerField, Sum
>>> Client.objects.aggregate(
... regular=Sum(
... Case(When(account_type=Client.REGULAR, then=1),
... output_field=IntegerField())
... ),
... gold=Sum(
... Case(When(account_type=Client.GOLD, then=1),
... output_field=IntegerField())
... ),
... platinum=Sum(
... Case(When(account_type=Client.PLATINUM, then=1),
... output_field=IntegerField())
... )
... regular=Count('pk', filter=Q(account_type=Client.REGULAR)),
... gold=Count('pk', filter=Q(account_type=Client.GOLD)),
... platinum=Count('pk', filter=Q(account_type=Client.PLATINUM)),
... )
{'regular': 2, 'gold': 1, 'platinum': 3}
This aggregate produces a query with the SQL 2003 ``FILTER WHERE`` syntax
on databases that support it:
.. code-block:: sql
SELECT count('id') FILTER (WHERE account_type=1) as regular,
count('id') FILTER (WHERE account_type=2) as gold,
count('id') FILTER (WHERE account_type=3) as platinum
FROM clients;
On other databases, this is emulated using a ``CASE`` statement:
.. code-block:: sql
SELECT count(CASE WHEN account_type=1 THEN id ELSE null) as regular,
count(CASE WHEN account_type=2 THEN id ELSE null) as gold,
count(CASE WHEN account_type=3 THEN id ELSE null) as platinum
FROM clients;
The two SQL statements are functionally equivalent but the more explicit
``FILTER`` may perform better.

View File

@ -339,7 +339,7 @@ some complex computations::
The ``Aggregate`` API is as follows:
.. class:: Aggregate(expression, output_field=None, **extra)
.. class:: Aggregate(expression, output_field=None, filter=None, **extra)
.. attribute:: template
@ -370,9 +370,17 @@ should define the desired ``output_field``. For example, adding an
``IntegerField()`` and a ``FloatField()`` together should probably have
``output_field=FloatField()`` defined.
The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
and :ref:`filtering-on-annotations` for example usage.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute.
.. versionchanged:: 2.0
The ``filter`` argument was added.
Creating your own Aggregate Functions
-------------------------------------

View File

@ -3085,6 +3085,17 @@ of the return value
``output_field`` if all fields are of the same type. Otherwise, you
must provide the ``output_field`` yourself.
``filter``
~~~~~~~~~~
.. versionadded:: 2.0
An optional :class:`Q object <django.db.models.Q>` that's used to filter the
rows that are aggregated.
See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for
example usage.
``**extra``
~~~~~~~~~~~
@ -3094,7 +3105,7 @@ by the aggregate.
``Avg``
~~~~~~~
.. class:: Avg(expression, output_field=FloatField(), **extra)
.. class:: Avg(expression, output_field=FloatField(), filter=None, **extra)
Returns the mean value of the given expression, which must be numeric
unless you specify a different ``output_field``.
@ -3106,7 +3117,7 @@ by the aggregate.
``Count``
~~~~~~~~~
.. class:: Count(expression, distinct=False, **extra)
.. class:: Count(expression, distinct=False, filter=None, **extra)
Returns the number of objects that are related through the provided
expression.
@ -3125,7 +3136,7 @@ by the aggregate.
``Max``
~~~~~~~
.. class:: Max(expression, output_field=None, **extra)
.. class:: Max(expression, output_field=None, filter=None, **extra)
Returns the maximum value of the given expression.
@ -3135,7 +3146,7 @@ by the aggregate.
``Min``
~~~~~~~
.. class:: Min(expression, output_field=None, **extra)
.. class:: Min(expression, output_field=None, filter=None, **extra)
Returns the minimum value of the given expression.
@ -3145,7 +3156,7 @@ by the aggregate.
``StdDev``
~~~~~~~~~~
.. class:: StdDev(expression, sample=False, **extra)
.. class:: StdDev(expression, sample=False, filter=None, **extra)
Returns the standard deviation of the data in the provided expression.
@ -3169,7 +3180,7 @@ by the aggregate.
``Sum``
~~~~~~~
.. class:: Sum(expression, output_field=None, **extra)
.. class:: Sum(expression, output_field=None, filter=None, **extra)
Computes the sum of all values of the given expression.
@ -3179,7 +3190,7 @@ by the aggregate.
``Variance``
~~~~~~~~~~~~
.. class:: Variance(expression, sample=False, **extra)
.. class:: Variance(expression, sample=False, filter=None, **extra)
Returns the variance of the data in the provided expression.

View File

@ -273,6 +273,10 @@ Models
parameters, if the backend supports this feature. Of Django's built-in
backends, only Oracle supports it.
* The new ``filter`` argument for built-in aggregates allows :ref:`adding
different conditionals <conditional-aggregation>` to multiple aggregations
over the same fields or relations.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -84,6 +84,16 @@ In a hurry? Here's how to do common aggregate queries, assuming the models above
>>> pubs[0].num_books
73
# Each publisher, with a separate count of books with a rating above and below 5
>>> from django.db.models import Q
>>> above_5 = Count('book', filter=Q(book__rating__gt=5))
>>> below_5 = Count('book', filter=Q(book__rating__lte=5))
>>> pubs = Publisher.objects.annotate(below_5=below_5).annotate(above_5=above_5)
>>> pubs[0].above_5
23
>>> pubs[0].below_5
12
# The top 5 publishers, in order by number of books.
>>> pubs = Publisher.objects.annotate(num_books=Count('book')).order_by('-num_books')[:5]
>>> pubs[0].num_books
@ -324,6 +334,8 @@ title that starts with "Django" using the query::
>>> Book.objects.filter(name__startswith="Django").aggregate(Avg('price'))
.. _filtering-on-annotations:
Filtering on annotations
~~~~~~~~~~~~~~~~~~~~~~~~
@ -339,6 +351,27 @@ you can issue the query::
This query generates an annotated result set, and then generates a filter
based upon that annotation.
If you need two annotations with two separate filters you can use the
``filter`` argument with any aggregate. For example, to generate a list of
authors with a count of highly rated books::
>>> highly_rated = Count('books', filter=Q(books__rating__gte=7))
>>> Author.objects.annotate(num_books=Count('books'), highly_rated_books=highly_rated)
Each ``Author`` in the result set will have the ``num_books`` and
``highly_rated_books`` attributes.
.. admonition:: Choosing between ``filter`` and ``QuerySet.filter()``
Avoid using the ``filter`` argument with a single annotation or
aggregation. It's more efficient to use ``QuerySet.filter()`` to exclude
rows. The aggregation ``filter`` argument is only useful when using two or
more aggregations over the same relations with different conditionals.
.. versionchanged:: 2.0
The ``filter`` argument was added to aggregates.
Order of ``annotate()`` and ``filter()`` clauses
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -0,0 +1,81 @@
import datetime
from decimal import Decimal
from django.db.models import Case, Count, F, Q, Sum, When
from django.test import TestCase
from .models import Author, Book, Publisher
class FilteredAggregateTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.a1 = Author.objects.create(name='test', age=40)
cls.a2 = Author.objects.create(name='test2', age=60)
cls.a3 = Author.objects.create(name='test3', age=100)
cls.p1 = Publisher.objects.create(name='Apress', num_awards=3, duration=datetime.timedelta(days=1))
cls.b1 = Book.objects.create(
isbn='159059725', name='The Definitive Guide to Django: Web Development Done Right',
pages=447, rating=4.5, price=Decimal('30.00'), contact=cls.a1, publisher=cls.p1,
pubdate=datetime.date(2007, 12, 6),
)
cls.b2 = Book.objects.create(
isbn='067232959', name='Sams Teach Yourself Django in 24 Hours',
pages=528, rating=3.0, price=Decimal('23.09'), contact=cls.a2, publisher=cls.p1,
pubdate=datetime.date(2008, 3, 3),
)
cls.b3 = Book.objects.create(
isbn='159059996', name='Practical Django Projects',
pages=600, rating=4.5, price=Decimal('29.69'), contact=cls.a3, publisher=cls.p1,
pubdate=datetime.date(2008, 6, 23),
)
cls.a1.friends.add(cls.a2)
cls.a1.friends.add(cls.a3)
cls.b1.authors.add(cls.a1)
cls.b1.authors.add(cls.a3)
cls.b2.authors.add(cls.a2)
cls.b3.authors.add(cls.a3)
def test_filtered_aggregates(self):
agg = Sum('age', filter=Q(name__startswith='test'))
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 200)
def test_double_filtered_aggregates(self):
agg = Sum('age', filter=Q(Q(name='test2') & ~Q(name='test')))
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 60)
def test_excluded_aggregates(self):
agg = Sum('age', filter=~Q(name='test2'))
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 140)
def test_related_aggregates_m2m(self):
agg = Sum('friends__age', filter=~Q(friends__name='test'))
self.assertEqual(Author.objects.filter(name='test').aggregate(age=agg)['age'], 160)
def test_related_aggregates_m2m_and_fk(self):
q = Q(friends__book__publisher__name='Apress') & ~Q(friends__name='test3')
agg = Sum('friends__book__pages', filter=q)
self.assertEqual(Author.objects.filter(name='test').aggregate(pages=agg)['pages'], 528)
def test_plain_annotate(self):
agg = Sum('book__pages', filter=Q(book__rating__gt=3))
qs = Author.objects.annotate(pages=agg).order_by('pk')
self.assertSequenceEqual([a.pages for a in qs], [447, None, 1047])
def test_filtered_aggregate_on_annotate(self):
pages_annotate = Sum('book__pages', filter=Q(book__rating__gt=3))
age_agg = Sum('age', filter=Q(total_pages__gte=400))
aggregated = Author.objects.annotate(total_pages=pages_annotate).aggregate(summed_age=age_agg)
self.assertEqual(aggregated, {'summed_age': 140})
def test_case_aggregate(self):
agg = Sum(
Case(When(friends__age=40, then=F('friends__age'))),
filter=Q(friends__name__startswith='test'),
)
self.assertEqual(Author.objects.aggregate(age=agg)['age'], 80)
def test_sum_star_exception(self):
msg = 'Star cannot be used with filter. Please specify a field.'
with self.assertRaisesMessage(ValueError, msg):
Count('*', filter=Q(age=40))

View File

@ -5,7 +5,7 @@ from copy import deepcopy
from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, models, transaction
from django.db.models import CharField, TimeField, UUIDField
from django.db.models import CharField, Q, TimeField, UUIDField
from django.db.models.aggregates import (
Avg, Count, Max, Min, StdDev, Sum, Variance,
)
@ -1369,3 +1369,16 @@ class ReprTests(TestCase):
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
self.assertEqual(repr(Sum('a')), "Sum(F(a))")
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
def test_filtered_aggregates(self):
filter = Q(a=1)
self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), distinct=False, filter=(AND: ('a', 1)))")
self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)")
self.assertEqual(repr(Sum('a', filter=filter)), "Sum(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(
repr(Variance('a', sample=True, filter=filter)),
"Variance(F(a), filter=(AND: ('a', 1)), sample=True)"
)

View File

@ -1253,6 +1253,15 @@ class CaseDocumentationExamples(TestCase):
account_type=Client.PLATINUM,
registered_on=date.today(),
)
self.assertEqual(
Client.objects.aggregate(
regular=models.Count('pk', filter=Q(account_type=Client.REGULAR)),
gold=models.Count('pk', filter=Q(account_type=Client.GOLD)),
platinum=models.Count('pk', filter=Q(account_type=Client.PLATINUM)),
),
{'regular': 2, 'gold': 1, 'platinum': 3}
)
# This was the example before the filter argument was added.
self.assertEqual(
Client.objects.aggregate(
regular=models.Sum(Case(