From b78d100fa62cd4fbbc70f2bae77c192cb36c1ccd Mon Sep 17 00:00:00 2001 From: Tom Date: Sat, 22 Apr 2017 16:44:51 +0100 Subject: [PATCH] Fixed #27849 -- Added filtering support to aggregates. --- .../contrib/postgres/aggregates/statistics.py | 12 +-- django/db/backends/base/features.py | 4 + django/db/backends/postgresql/features.py | 4 + django/db/models/aggregates.py | 69 ++++++++++++++-- docs/ref/contrib/postgres/aggregates.txt | 36 ++++----- docs/ref/models/conditional-expressions.txt | 43 ++++++---- docs/ref/models/expressions.txt | 10 ++- docs/ref/models/querysets.txt | 25 ++++-- docs/releases/2.0.txt | 4 + docs/topics/db/aggregation.txt | 33 ++++++++ tests/aggregation/test_filter_argument.py | 81 +++++++++++++++++++ tests/expressions/tests.py | 15 +++- tests/expressions_case/tests.py | 9 +++ 13 files changed, 290 insertions(+), 55 deletions(-) create mode 100644 tests/aggregation/test_filter_argument.py diff --git a/django/contrib/postgres/aggregates/statistics.py b/django/contrib/postgres/aggregates/statistics.py index b9a8ba07c5..19f26ec53c 100644 --- a/django/contrib/postgres/aggregates/statistics.py +++ b/django/contrib/postgres/aggregates/statistics.py @@ -8,10 +8,10 @@ __all__ = [ class StatAggregate(Aggregate): - def __init__(self, y, x, output_field=FloatField()): + def __init__(self, y, x, output_field=FloatField(), filter=None): if not x or not y: raise ValueError('Both y and x must be provided.') - super().__init__(y, x, output_field=output_field) + super().__init__(y, x, output_field=output_field, filter=filter) def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): return super().resolve_expression(query, allow_joins, reuse, summarize) @@ -22,9 +22,9 @@ class Corr(StatAggregate): class CovarPop(StatAggregate): - def __init__(self, y, x, sample=False): + def __init__(self, y, x, sample=False, filter=None): self.function = 'COVAR_SAMP' if sample else 'COVAR_POP' - super().__init__(y, x) + super().__init__(y, x, filter=filter) class RegrAvgX(StatAggregate): @@ -38,8 +38,8 @@ class RegrAvgY(StatAggregate): class RegrCount(StatAggregate): function = 'REGR_COUNT' - def __init__(self, y, x): - super().__init__(y=y, x=x, output_field=IntegerField()) + def __init__(self, y, x, filter=None): + super().__init__(y=y, x=x, output_field=IntegerField(), filter=filter) def convert_value(self, value, expression, connection): if value is None: diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 7626595741..22c2990e77 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -229,6 +229,10 @@ class BaseDatabaseFeatures: supports_select_difference = True supports_slicing_ordering_in_compound = False + # Does the database support SQL 2003 FILTER (WHERE ...) in aggregate + # expressions? + supports_aggregate_filter_clause = False + # Does the backend support indexing a TextField? supports_index_on_text_field = True diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 647fb9dc7f..3c7a7af80f 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -50,6 +50,10 @@ class DatabaseFeatures(BaseDatabaseFeatures): END; $$ LANGUAGE plpgsql;""" + @cached_property + def supports_aggregate_filter_clause(self): + return self.connection.pg_version >= 90400 + @cached_property def has_select_for_update_skip_locked(self): return self.connection.pg_version >= 90500 diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 2472a24663..b70ea03838 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -2,8 +2,9 @@ Classes to represent the definitions of aggregate functions. """ from django.core.exceptions import FieldError -from django.db.models.expressions import Func, Star +from django.db.models.expressions import Case, Func, Star, When from django.db.models.fields import DecimalField, FloatField, IntegerField +from django.db.models.query_utils import Q __all__ = [ 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', @@ -13,12 +14,36 @@ __all__ = [ class Aggregate(Func): contains_aggregate = True name = None + filter_template = '%s FILTER (WHERE %%(filter)s)' + + def __init__(self, *args, filter=None, **kwargs): + self.filter = filter + super().__init__(*args, **kwargs) + + def get_source_fields(self): + # Don't return the filter expression since it's not a source field. + return [e._output_field_or_none for e in super().get_source_expressions()] + + def get_source_expressions(self): + source_expressions = super().get_source_expressions() + if self.filter: + source_expressions += [self.filter] + return source_expressions + + def set_source_expressions(self, exprs): + if self.filter: + self.filter = exprs.pop() + return super().set_source_expressions(exprs) def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): # Aggregates are not allowed in UPDATE queries, so ignore for_save c = super().resolve_expression(query, allow_joins, reuse, summarize) + if c.filter: + c.filter = c.filter.resolve_expression(query, allow_joins, reuse, summarize) if not summarize: - expressions = c.get_source_expressions() + # Call Aggregate.get_source_expressions() to avoid + # returning self.filter and including that in this loop. + expressions = super(Aggregate, c).get_source_expressions() for index, expr in enumerate(expressions): if expr.contains_aggregate: before_resolved = self.get_source_expressions()[index] @@ -36,6 +61,29 @@ class Aggregate(Func): def get_group_by_cols(self): return [] + def as_sql(self, compiler, connection, **extra_context): + if self.filter: + if connection.features.supports_aggregate_filter_clause: + filter_sql, filter_params = self.filter.as_sql(compiler, connection) + template = self.filter_template % extra_context.get('template', self.template) + sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql) + return sql, params + filter_params + else: + copy = self.copy() + copy.filter = None + condition = When(Q()) + source_expressions = copy.get_source_expressions() + condition.set_source_expressions([self.filter, source_expressions[0]]) + copy.set_source_expressions([Case(condition)] + source_expressions[1:]) + return super(Aggregate, copy).as_sql(compiler, connection, **extra_context) + return super().as_sql(compiler, connection, **extra_context) + + def _get_repr_options(self): + options = super()._get_repr_options() + if self.filter: + options.update({'filter': self.filter}) + return options + class Avg(Aggregate): function = 'AVG' @@ -52,7 +100,7 @@ class Avg(Aggregate): expression = self.get_source_expressions()[0] from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval return compiler.compile( - SecondsToInterval(Avg(IntervalToSeconds(expression))) + SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter)) ) return super().as_sql(compiler, connection) @@ -62,16 +110,19 @@ class Count(Aggregate): name = 'Count' template = '%(function)s(%(distinct)s%(expressions)s)' - def __init__(self, expression, distinct=False, **extra): + def __init__(self, expression, distinct=False, filter=None, **extra): if expression == '*': expression = Star() + if isinstance(expression, Star) and filter is not None: + raise ValueError('Star cannot be used with filter. Please specify a field.') super().__init__( expression, distinct='DISTINCT ' if distinct else '', - output_field=IntegerField(), **extra + output_field=IntegerField(), filter=filter, **extra ) def _get_repr_options(self): - return {'distinct': self.extra['distinct'] != ''} + options = super()._get_repr_options() + return dict(options, distinct=self.extra['distinct'] != '') def convert_value(self, value, expression, connection): if value is None: @@ -97,7 +148,8 @@ class StdDev(Aggregate): super().__init__(expression, output_field=FloatField(), **extra) def _get_repr_options(self): - return {'sample': self.function == 'STDDEV_SAMP'} + options = super()._get_repr_options() + return dict(options, sample=self.function == 'STDDEV_SAMP') def convert_value(self, value, expression, connection): if value is None: @@ -127,7 +179,8 @@ class Variance(Aggregate): super().__init__(expression, output_field=FloatField(), **extra) def _get_repr_options(self): - return {'sample': self.function == 'VAR_SAMP'} + options = super()._get_repr_options() + return dict(options, sample=self.function == 'VAR_SAMP') def convert_value(self, value, expression, connection): if value is None: diff --git a/docs/ref/contrib/postgres/aggregates.txt b/docs/ref/contrib/postgres/aggregates.txt index 1c738c35ae..a51249b6c4 100644 --- a/docs/ref/contrib/postgres/aggregates.txt +++ b/docs/ref/contrib/postgres/aggregates.txt @@ -22,7 +22,7 @@ General-purpose aggregation functions ``ArrayAgg`` ------------ -.. class:: ArrayAgg(expression, distinct=False, **extra) +.. class:: ArrayAgg(expression, distinct=False, filter=None, **extra) Returns a list of values, including nulls, concatenated into an array. @@ -36,7 +36,7 @@ General-purpose aggregation functions ``BitAnd`` ---------- -.. class:: BitAnd(expression, **extra) +.. class:: BitAnd(expression, filter=None, **extra) Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or ``None`` if all values are null. @@ -44,7 +44,7 @@ General-purpose aggregation functions ``BitOr`` --------- -.. class:: BitOr(expression, **extra) +.. class:: BitOr(expression, filter=None, **extra) Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or ``None`` if all values are null. @@ -52,7 +52,7 @@ General-purpose aggregation functions ``BoolAnd`` ----------- -.. class:: BoolAnd(expression, **extra) +.. class:: BoolAnd(expression, filter=None, **extra) Returns ``True``, if all input values are true, ``None`` if all values are null or if there are no values, otherwise ``False`` . @@ -60,7 +60,7 @@ General-purpose aggregation functions ``BoolOr`` ---------- -.. class:: BoolOr(expression, **extra) +.. class:: BoolOr(expression, filter=None, **extra) Returns ``True`` if at least one input value is true, ``None`` if all values are null or if there are no values, otherwise ``False``. @@ -68,7 +68,7 @@ General-purpose aggregation functions ``JSONBAgg`` ------------ -.. class:: JSONBAgg(expressions, **extra) +.. class:: JSONBAgg(expressions, filter=None, **extra) .. versionadded:: 1.11 @@ -77,7 +77,7 @@ General-purpose aggregation functions ``StringAgg`` ------------- -.. class:: StringAgg(expression, delimiter, distinct=False) +.. class:: StringAgg(expression, delimiter, distinct=False, filter=None) Returns the input values concatenated into a string, separated by the ``delimiter`` string. @@ -105,7 +105,7 @@ field or an expression returning a numeric data. Both are required. ``Corr`` -------- -.. class:: Corr(y, x) +.. class:: Corr(y, x, filter=None) Returns the correlation coefficient as a ``float``, or ``None`` if there aren't any matching rows. @@ -113,7 +113,7 @@ field or an expression returning a numeric data. Both are required. ``CovarPop`` ------------ -.. class:: CovarPop(y, x, sample=False) +.. class:: CovarPop(y, x, sample=False, filter=None) Returns the population covariance as a ``float``, or ``None`` if there aren't any matching rows. @@ -129,7 +129,7 @@ field or an expression returning a numeric data. Both are required. ``RegrAvgX`` ------------ -.. class:: RegrAvgX(y, x) +.. class:: RegrAvgX(y, x, filter=None) Returns the average of the independent variable (``sum(x)/N``) as a ``float``, or ``None`` if there aren't any matching rows. @@ -137,7 +137,7 @@ field or an expression returning a numeric data. Both are required. ``RegrAvgY`` ------------ -.. class:: RegrAvgY(y, x) +.. class:: RegrAvgY(y, x, filter=None) Returns the average of the dependent variable (``sum(y)/N``) as a ``float``, or ``None`` if there aren't any matching rows. @@ -145,7 +145,7 @@ field or an expression returning a numeric data. Both are required. ``RegrCount`` ------------- -.. class:: RegrCount(y, x) +.. class:: RegrCount(y, x, filter=None) Returns an ``int`` of the number of input rows in which both expressions are not null. @@ -153,7 +153,7 @@ field or an expression returning a numeric data. Both are required. ``RegrIntercept`` ----------------- -.. class:: RegrIntercept(y, x) +.. class:: RegrIntercept(y, x, filter=None) Returns the y-intercept of the least-squares-fit linear equation determined by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any @@ -162,7 +162,7 @@ field or an expression returning a numeric data. Both are required. ``RegrR2`` ---------- -.. class:: RegrR2(y, x) +.. class:: RegrR2(y, x, filter=None) Returns the square of the correlation coefficient as a ``float``, or ``None`` if there aren't any matching rows. @@ -170,7 +170,7 @@ field or an expression returning a numeric data. Both are required. ``RegrSlope`` ------------- -.. class:: RegrSlope(y, x) +.. class:: RegrSlope(y, x, filter=None) Returns the slope of the least-squares-fit linear equation determined by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any @@ -179,7 +179,7 @@ field or an expression returning a numeric data. Both are required. ``RegrSXX`` ----------- -.. class:: RegrSXX(y, x) +.. class:: RegrSXX(y, x, filter=None) Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent variable) as a ``float``, or ``None`` if there aren't any matching rows. @@ -187,7 +187,7 @@ field or an expression returning a numeric data. Both are required. ``RegrSXY`` ----------- -.. class:: RegrSXY(y, x) +.. class:: RegrSXY(y, x, filter=None) Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent times dependent variable) as a ``float``, or ``None`` if there aren't any @@ -196,7 +196,7 @@ field or an expression returning a numeric data. Both are required. ``RegrSYY`` ----------- -.. class:: RegrSYY(y, x) +.. class:: RegrSYY(y, x, filter=None) Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent variable) as a ``float``, or ``None`` if there aren't any matching rows. diff --git a/docs/ref/models/conditional-expressions.txt b/docs/ref/models/conditional-expressions.txt index af134d4561..35f0c808f1 100644 --- a/docs/ref/models/conditional-expressions.txt +++ b/docs/ref/models/conditional-expressions.txt @@ -184,12 +184,14 @@ their registration dates. We can do this using a conditional expression and the >>> Client.objects.values_list('name', 'account_type') +.. _conditional-aggregation: + Conditional aggregation ----------------------- What if we want to find out how many clients there are for each -``account_type``? We can nest conditional expression within -:ref:`aggregate functions ` to achieve this:: +``account_type``? We can use the ``filter`` argument of :ref:`aggregate +functions ` to achieve this:: >>> # Create some more Clients first so we can have something to count >>> Client.objects.create( @@ -207,17 +209,30 @@ What if we want to find out how many clients there are for each >>> # Get counts for each value of account_type >>> from django.db.models import IntegerField, Sum >>> Client.objects.aggregate( - ... regular=Sum( - ... Case(When(account_type=Client.REGULAR, then=1), - ... output_field=IntegerField()) - ... ), - ... gold=Sum( - ... Case(When(account_type=Client.GOLD, then=1), - ... output_field=IntegerField()) - ... ), - ... platinum=Sum( - ... Case(When(account_type=Client.PLATINUM, then=1), - ... output_field=IntegerField()) - ... ) + ... regular=Count('pk', filter=Q(account_type=Client.REGULAR)), + ... gold=Count('pk', filter=Q(account_type=Client.GOLD)), + ... platinum=Count('pk', filter=Q(account_type=Client.PLATINUM)), ... ) {'regular': 2, 'gold': 1, 'platinum': 3} + +This aggregate produces a query with the SQL 2003 ``FILTER WHERE`` syntax +on databases that support it: + +.. code-block:: sql + + SELECT count('id') FILTER (WHERE account_type=1) as regular, + count('id') FILTER (WHERE account_type=2) as gold, + count('id') FILTER (WHERE account_type=3) as platinum + FROM clients; + +On other databases, this is emulated using a ``CASE`` statement: + +.. code-block:: sql + + SELECT count(CASE WHEN account_type=1 THEN id ELSE null) as regular, + count(CASE WHEN account_type=2 THEN id ELSE null) as gold, + count(CASE WHEN account_type=3 THEN id ELSE null) as platinum + FROM clients; + +The two SQL statements are functionally equivalent but the more explicit +``FILTER`` may perform better. diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 361b170412..0974a9dd51 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -339,7 +339,7 @@ some complex computations:: The ``Aggregate`` API is as follows: -.. class:: Aggregate(expression, output_field=None, **extra) +.. class:: Aggregate(expression, output_field=None, filter=None, **extra) .. attribute:: template @@ -370,9 +370,17 @@ should define the desired ``output_field``. For example, adding an ``IntegerField()`` and a ``FloatField()`` together should probably have ``output_field=FloatField()`` defined. +The ``filter`` argument takes a :class:`Q object ` that's +used to filter the rows that are aggregated. See :ref:`conditional-aggregation` +and :ref:`filtering-on-annotations` for example usage. + The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated into the ``template`` attribute. +.. versionchanged:: 2.0 + + The ``filter`` argument was added. + Creating your own Aggregate Functions ------------------------------------- diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 74f83ab8c5..9c329d48ee 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -3085,6 +3085,17 @@ of the return value ``output_field`` if all fields are of the same type. Otherwise, you must provide the ``output_field`` yourself. +``filter`` +~~~~~~~~~~ + +.. versionadded:: 2.0 + +An optional :class:`Q object ` that's used to filter the +rows that are aggregated. + +See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for +example usage. + ``**extra`` ~~~~~~~~~~~ @@ -3094,7 +3105,7 @@ by the aggregate. ``Avg`` ~~~~~~~ -.. class:: Avg(expression, output_field=FloatField(), **extra) +.. class:: Avg(expression, output_field=FloatField(), filter=None, **extra) Returns the mean value of the given expression, which must be numeric unless you specify a different ``output_field``. @@ -3106,7 +3117,7 @@ by the aggregate. ``Count`` ~~~~~~~~~ -.. class:: Count(expression, distinct=False, **extra) +.. class:: Count(expression, distinct=False, filter=None, **extra) Returns the number of objects that are related through the provided expression. @@ -3125,7 +3136,7 @@ by the aggregate. ``Max`` ~~~~~~~ -.. class:: Max(expression, output_field=None, **extra) +.. class:: Max(expression, output_field=None, filter=None, **extra) Returns the maximum value of the given expression. @@ -3135,7 +3146,7 @@ by the aggregate. ``Min`` ~~~~~~~ -.. class:: Min(expression, output_field=None, **extra) +.. class:: Min(expression, output_field=None, filter=None, **extra) Returns the minimum value of the given expression. @@ -3145,7 +3156,7 @@ by the aggregate. ``StdDev`` ~~~~~~~~~~ -.. class:: StdDev(expression, sample=False, **extra) +.. class:: StdDev(expression, sample=False, filter=None, **extra) Returns the standard deviation of the data in the provided expression. @@ -3169,7 +3180,7 @@ by the aggregate. ``Sum`` ~~~~~~~ -.. class:: Sum(expression, output_field=None, **extra) +.. class:: Sum(expression, output_field=None, filter=None, **extra) Computes the sum of all values of the given expression. @@ -3179,7 +3190,7 @@ by the aggregate. ``Variance`` ~~~~~~~~~~~~ -.. class:: Variance(expression, sample=False, **extra) +.. class:: Variance(expression, sample=False, filter=None, **extra) Returns the variance of the data in the provided expression. diff --git a/docs/releases/2.0.txt b/docs/releases/2.0.txt index 1584b87293..36aea3aefc 100644 --- a/docs/releases/2.0.txt +++ b/docs/releases/2.0.txt @@ -273,6 +273,10 @@ Models parameters, if the backend supports this feature. Of Django's built-in backends, only Oracle supports it. +* The new ``filter`` argument for built-in aggregates allows :ref:`adding + different conditionals ` to multiple aggregations + over the same fields or relations. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/topics/db/aggregation.txt b/docs/topics/db/aggregation.txt index 1f59c02b4d..523f6e0aaa 100644 --- a/docs/topics/db/aggregation.txt +++ b/docs/topics/db/aggregation.txt @@ -84,6 +84,16 @@ In a hurry? Here's how to do common aggregate queries, assuming the models above >>> pubs[0].num_books 73 + # Each publisher, with a separate count of books with a rating above and below 5 + >>> from django.db.models import Q + >>> above_5 = Count('book', filter=Q(book__rating__gt=5)) + >>> below_5 = Count('book', filter=Q(book__rating__lte=5)) + >>> pubs = Publisher.objects.annotate(below_5=below_5).annotate(above_5=above_5) + >>> pubs[0].above_5 + 23 + >>> pubs[0].below_5 + 12 + # The top 5 publishers, in order by number of books. >>> pubs = Publisher.objects.annotate(num_books=Count('book')).order_by('-num_books')[:5] >>> pubs[0].num_books @@ -324,6 +334,8 @@ title that starts with "Django" using the query:: >>> Book.objects.filter(name__startswith="Django").aggregate(Avg('price')) +.. _filtering-on-annotations: + Filtering on annotations ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -339,6 +351,27 @@ you can issue the query:: This query generates an annotated result set, and then generates a filter based upon that annotation. +If you need two annotations with two separate filters you can use the +``filter`` argument with any aggregate. For example, to generate a list of +authors with a count of highly rated books:: + + >>> highly_rated = Count('books', filter=Q(books__rating__gte=7)) + >>> Author.objects.annotate(num_books=Count('books'), highly_rated_books=highly_rated) + +Each ``Author`` in the result set will have the ``num_books`` and +``highly_rated_books`` attributes. + +.. admonition:: Choosing between ``filter`` and ``QuerySet.filter()`` + + Avoid using the ``filter`` argument with a single annotation or + aggregation. It's more efficient to use ``QuerySet.filter()`` to exclude + rows. The aggregation ``filter`` argument is only useful when using two or + more aggregations over the same relations with different conditionals. + +.. versionchanged:: 2.0 + + The ``filter`` argument was added to aggregates. + Order of ``annotate()`` and ``filter()`` clauses ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/aggregation/test_filter_argument.py b/tests/aggregation/test_filter_argument.py new file mode 100644 index 0000000000..54836178c4 --- /dev/null +++ b/tests/aggregation/test_filter_argument.py @@ -0,0 +1,81 @@ +import datetime +from decimal import Decimal + +from django.db.models import Case, Count, F, Q, Sum, When +from django.test import TestCase + +from .models import Author, Book, Publisher + + +class FilteredAggregateTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.a1 = Author.objects.create(name='test', age=40) + cls.a2 = Author.objects.create(name='test2', age=60) + cls.a3 = Author.objects.create(name='test3', age=100) + cls.p1 = Publisher.objects.create(name='Apress', num_awards=3, duration=datetime.timedelta(days=1)) + cls.b1 = Book.objects.create( + isbn='159059725', name='The Definitive Guide to Django: Web Development Done Right', + pages=447, rating=4.5, price=Decimal('30.00'), contact=cls.a1, publisher=cls.p1, + pubdate=datetime.date(2007, 12, 6), + ) + cls.b2 = Book.objects.create( + isbn='067232959', name='Sams Teach Yourself Django in 24 Hours', + pages=528, rating=3.0, price=Decimal('23.09'), contact=cls.a2, publisher=cls.p1, + pubdate=datetime.date(2008, 3, 3), + ) + cls.b3 = Book.objects.create( + isbn='159059996', name='Practical Django Projects', + pages=600, rating=4.5, price=Decimal('29.69'), contact=cls.a3, publisher=cls.p1, + pubdate=datetime.date(2008, 6, 23), + ) + cls.a1.friends.add(cls.a2) + cls.a1.friends.add(cls.a3) + cls.b1.authors.add(cls.a1) + cls.b1.authors.add(cls.a3) + cls.b2.authors.add(cls.a2) + cls.b3.authors.add(cls.a3) + + def test_filtered_aggregates(self): + agg = Sum('age', filter=Q(name__startswith='test')) + self.assertEqual(Author.objects.aggregate(age=agg)['age'], 200) + + def test_double_filtered_aggregates(self): + agg = Sum('age', filter=Q(Q(name='test2') & ~Q(name='test'))) + self.assertEqual(Author.objects.aggregate(age=agg)['age'], 60) + + def test_excluded_aggregates(self): + agg = Sum('age', filter=~Q(name='test2')) + self.assertEqual(Author.objects.aggregate(age=agg)['age'], 140) + + def test_related_aggregates_m2m(self): + agg = Sum('friends__age', filter=~Q(friends__name='test')) + self.assertEqual(Author.objects.filter(name='test').aggregate(age=agg)['age'], 160) + + def test_related_aggregates_m2m_and_fk(self): + q = Q(friends__book__publisher__name='Apress') & ~Q(friends__name='test3') + agg = Sum('friends__book__pages', filter=q) + self.assertEqual(Author.objects.filter(name='test').aggregate(pages=agg)['pages'], 528) + + def test_plain_annotate(self): + agg = Sum('book__pages', filter=Q(book__rating__gt=3)) + qs = Author.objects.annotate(pages=agg).order_by('pk') + self.assertSequenceEqual([a.pages for a in qs], [447, None, 1047]) + + def test_filtered_aggregate_on_annotate(self): + pages_annotate = Sum('book__pages', filter=Q(book__rating__gt=3)) + age_agg = Sum('age', filter=Q(total_pages__gte=400)) + aggregated = Author.objects.annotate(total_pages=pages_annotate).aggregate(summed_age=age_agg) + self.assertEqual(aggregated, {'summed_age': 140}) + + def test_case_aggregate(self): + agg = Sum( + Case(When(friends__age=40, then=F('friends__age'))), + filter=Q(friends__name__startswith='test'), + ) + self.assertEqual(Author.objects.aggregate(age=agg)['age'], 80) + + def test_sum_star_exception(self): + msg = 'Star cannot be used with filter. Please specify a field.' + with self.assertRaisesMessage(ValueError, msg): + Count('*', filter=Q(age=40)) diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 8952045002..7865492ed5 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -5,7 +5,7 @@ from copy import deepcopy from django.core.exceptions import FieldError from django.db import DatabaseError, connection, models, transaction -from django.db.models import CharField, TimeField, UUIDField +from django.db.models import CharField, Q, TimeField, UUIDField from django.db.models.aggregates import ( Avg, Count, Max, Min, StdDev, Sum, Variance, ) @@ -1369,3 +1369,16 @@ class ReprTests(TestCase): self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)") self.assertEqual(repr(Sum('a')), "Sum(F(a))") self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)") + + def test_filtered_aggregates(self): + filter = Q(a=1) + self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))") + self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), distinct=False, filter=(AND: ('a', 1)))") + self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))") + self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))") + self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)") + self.assertEqual(repr(Sum('a', filter=filter)), "Sum(F(a), filter=(AND: ('a', 1)))") + self.assertEqual( + repr(Variance('a', sample=True, filter=filter)), + "Variance(F(a), filter=(AND: ('a', 1)), sample=True)" + ) diff --git a/tests/expressions_case/tests.py b/tests/expressions_case/tests.py index 20d1e801ec..090607e8b8 100644 --- a/tests/expressions_case/tests.py +++ b/tests/expressions_case/tests.py @@ -1253,6 +1253,15 @@ class CaseDocumentationExamples(TestCase): account_type=Client.PLATINUM, registered_on=date.today(), ) + self.assertEqual( + Client.objects.aggregate( + regular=models.Count('pk', filter=Q(account_type=Client.REGULAR)), + gold=models.Count('pk', filter=Q(account_type=Client.GOLD)), + platinum=models.Count('pk', filter=Q(account_type=Client.PLATINUM)), + ), + {'regular': 2, 'gold': 1, 'platinum': 3} + ) + # This was the example before the filter argument was added. self.assertEqual( Client.objects.aggregate( regular=models.Sum(Case(