From bc05547cd8c1dd511c6b6a6c873a1bc63417b111 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Wed, 9 Jan 2019 17:52:36 -0500 Subject: [PATCH] Fixed #28658 -- Added DISTINCT handling to the Aggregate class. --- django/contrib/postgres/aggregates/general.py | 10 +++---- django/db/backends/sqlite3/operations.py | 5 ++++ django/db/models/aggregates.py | 26 ++++++++++--------- docs/ref/models/expressions.txt | 19 +++++++++++++- docs/releases/2.2.txt | 7 +++++ tests/aggregation/tests.py | 4 +-- tests/aggregation_regress/tests.py | 11 ++++++++ tests/backends/sqlite/tests.py | 12 +++++++++ tests/expressions/tests.py | 13 +++++++--- 9 files changed, 83 insertions(+), 24 deletions(-) diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py index 4b2da0b101..918373e926 100644 --- a/django/contrib/postgres/aggregates/general.py +++ b/django/contrib/postgres/aggregates/general.py @@ -11,14 +11,12 @@ __all__ = [ class ArrayAgg(OrderableAggMixin, Aggregate): function = 'ARRAY_AGG' template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)' + allow_distinct = True @property def output_field(self): return ArrayField(self.source_expressions[0].output_field) - def __init__(self, expression, distinct=False, **extra): - super().__init__(expression, distinct='DISTINCT ' if distinct else '', **extra) - def convert_value(self, value, expression, connection): if not value: return [] @@ -54,10 +52,10 @@ class JSONBAgg(Aggregate): class StringAgg(OrderableAggMixin, Aggregate): function = 'STRING_AGG' template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)" + allow_distinct = True - def __init__(self, expression, delimiter, distinct=False, **extra): - distinct = 'DISTINCT ' if distinct else '' - super().__init__(expression, delimiter=delimiter, distinct=distinct, **extra) + def __init__(self, expression, delimiter, **extra): + super().__init__(expression, delimiter=delimiter, **extra) def convert_value(self, value, expression, connection): if not value: diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 6ec4859f0e..c4b02e5c60 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -57,6 +57,11 @@ class DatabaseOperations(BaseDatabaseOperations): 'aggregations on date/time fields in sqlite3 ' 'since date/time is saved as text.' ) + if isinstance(expression, aggregates.Aggregate) and len(expression.source_expressions) > 1: + raise utils.NotSupportedError( + "SQLite doesn't support DISTINCT on aggregate functions " + "accepting multiple arguments." + ) def date_extract_sql(self, lookup_type, field_name): """ diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index b270640ea5..a7dc55ee98 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -11,14 +11,19 @@ __all__ = [ class Aggregate(Func): + template = '%(function)s(%(distinct)s%(expressions)s)' contains_aggregate = True name = None filter_template = '%s FILTER (WHERE %%(filter)s)' window_compatible = True + allow_distinct = False - def __init__(self, *args, filter=None, **kwargs): + def __init__(self, *expressions, distinct=False, filter=None, **extra): + if distinct and not self.allow_distinct: + raise TypeError("%s does not allow distinct." % self.__class__.__name__) + self.distinct = distinct self.filter = filter - super().__init__(*args, **kwargs) + super().__init__(*expressions, **extra) def get_source_fields(self): # Don't return the filter expression since it's not a source field. @@ -60,6 +65,7 @@ class Aggregate(Func): return [] def as_sql(self, compiler, connection, **extra_context): + extra_context['distinct'] = 'DISTINCT' if self.distinct else '' if self.filter: if connection.features.supports_aggregate_filter_clause: filter_sql, filter_params = self.filter.as_sql(compiler, connection) @@ -80,8 +86,10 @@ class Aggregate(Func): def _get_repr_options(self): options = super()._get_repr_options() + if self.distinct: + options['distinct'] = self.distinct if self.filter: - options.update({'filter': self.filter}) + options['filter'] = self.filter return options @@ -114,21 +122,15 @@ class Avg(Aggregate): class Count(Aggregate): function = 'COUNT' name = 'Count' - template = '%(function)s(%(distinct)s%(expressions)s)' output_field = IntegerField() + allow_distinct = True - def __init__(self, expression, distinct=False, filter=None, **extra): + def __init__(self, expression, filter=None, **extra): if expression == '*': expression = Star() if isinstance(expression, Star) and filter is not None: raise ValueError('Star cannot be used with filter. Please specify a field.') - super().__init__( - expression, distinct='DISTINCT ' if distinct else '', - filter=filter, **extra - ) - - def _get_repr_options(self): - return {**super()._get_repr_options(), 'distinct': self.extra['distinct'] != ''} + super().__init__(expression, filter=filter, **extra) def convert_value(self, value, expression, connection): return 0 if value is None else value diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 7a358a5ce8..2413952228 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -373,7 +373,7 @@ some complex computations:: The ``Aggregate`` API is as follows: -.. class:: Aggregate(*expressions, output_field=None, filter=None, **extra) +.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, **extra) .. attribute:: template @@ -392,6 +392,14 @@ The ``Aggregate`` API is as follows: Defaults to ``True`` since most aggregate functions can be used as the source expression in :class:`~django.db.models.expressions.Window`. + .. attribute:: allow_distinct + + .. versionadded:: 2.2 + + A class attribute determining whether or not this aggregate function + allows passing a ``distinct`` keyword argument. If set to ``False`` + (default), ``TypeError`` is raised if ``distinct=True`` is passed. + The ``expressions`` positional arguments can include expressions or the names of model fields. They will be converted to a string and used as the ``expressions`` placeholder within the ``template``. @@ -409,6 +417,11 @@ should define the desired ``output_field``. For example, adding an ``IntegerField()`` and a ``FloatField()`` together should probably have ``output_field=FloatField()`` defined. +The ``distinct`` argument determines whether or not the aggregate function +should be invoked for each distinct value of ``expressions`` (or set of +values, for multiple ``expressions``). The argument is only supported on +aggregates that have :attr:`~Aggregate.allow_distinct` set to ``True``. + The ``filter`` argument takes a :class:`Q object ` that's used to filter the rows that are aggregated. See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for example usage. @@ -416,6 +429,10 @@ and :ref:`filtering-on-annotations` for example usage. The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated into the ``template`` attribute. +.. versionadded:: 2.2 + + The ``allow_distinct`` attribute and ``distinct`` argument were added. + Creating your own Aggregate Functions ------------------------------------- diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt index c8515d5ba6..150fe413db 100644 --- a/docs/releases/2.2.txt +++ b/docs/releases/2.2.txt @@ -239,6 +239,13 @@ Models * Added SQLite support for the :class:`~django.db.models.StdDev` and :class:`~django.db.models.Variance` functions. +* The handling of ``DISTINCT`` aggregation is added to the + :class:`~django.db.models.Aggregate` class. Adding :attr:`allow_distinct = + True ` as a class attribute on + ``Aggregate`` subclasses allows a ``distinct`` keyword argument to be + specified on initialization to ensure that the aggregate function is only + called for each distinct value of ``expressions``. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index a55ccfbfa2..75d2ecb1c5 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -1026,7 +1026,7 @@ class AggregateTestCase(TestCase): # test completely changing how the output is rendered def lower_case_function_override(self, compiler, connection): sql, params = compiler.compile(self.source_expressions[0]) - substitutions = {'function': self.function.lower(), 'expressions': sql} + substitutions = {'function': self.function.lower(), 'expressions': sql, 'distinct': ''} substitutions.update(self.extra) return self.template % substitutions, params setattr(MySum, 'as_' + connection.vendor, lower_case_function_override) @@ -1053,7 +1053,7 @@ class AggregateTestCase(TestCase): # test overriding all parts of the template def be_evil(self, compiler, connection): - substitutions = {'function': 'MAX', 'expressions': '2'} + substitutions = {'function': 'MAX', 'expressions': '2', 'distinct': ''} substitutions.update(self.extra) return self.template % substitutions, () setattr(MySum, 'as_' + connection.vendor, be_evil) diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py index 29b32c4987..2b3948a0b4 100644 --- a/tests/aggregation_regress/tests.py +++ b/tests/aggregation_regress/tests.py @@ -11,6 +11,7 @@ from django.db.models import ( Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum, Value, Variance, When, ) +from django.db.models.aggregates import Aggregate from django.test import ( TestCase, ignore_warnings, skipUnlessAnyDBFeature, skipUnlessDBFeature, ) @@ -1496,6 +1497,16 @@ class AggregationTests(TestCase): qs = Author.objects.values_list('age', flat=True).annotate(age_count=Count('age')).filter(age_count__gt=1) self.assertSequenceEqual(qs, [29]) + def test_allow_distinct(self): + class MyAggregate(Aggregate): + pass + with self.assertRaisesMessage(TypeError, 'MyAggregate does not allow distinct'): + MyAggregate('foo', distinct=True) + + class DistinctAggregate(Aggregate): + allow_distinct = True + DistinctAggregate('foo', distinct=True) + class JoinPromotionTests(TestCase): def test_ticket_21150(self): diff --git a/tests/backends/sqlite/tests.py b/tests/backends/sqlite/tests.py index bddaf8620f..c681d39775 100644 --- a/tests/backends/sqlite/tests.py +++ b/tests/backends/sqlite/tests.py @@ -4,6 +4,7 @@ import unittest from django.db import connection, transaction from django.db.models import Avg, StdDev, Sum, Variance +from django.db.models.aggregates import Aggregate from django.db.models.fields import CharField from django.db.utils import NotSupportedError from django.test import ( @@ -34,6 +35,17 @@ class Tests(TestCase): **{'complex': aggregate('last_modified') + aggregate('last_modified')} ) + def test_distinct_aggregation(self): + class DistinctAggregate(Aggregate): + allow_distinct = True + aggregate = DistinctAggregate('first', 'second', distinct=True) + msg = ( + "SQLite doesn't support DISTINCT on aggregate functions accepting " + "multiple arguments." + ) + with self.assertRaisesMessage(NotSupportedError, msg): + connection.ops.check_expression_support(aggregate) + def test_memory_db_test_name(self): """A named in-memory db should be allowed where supported.""" from django.db.backends.sqlite3.base import DatabaseWrapper diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 2ed928915a..ee3676e64a 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -1481,18 +1481,22 @@ class ReprTests(SimpleTestCase): def test_aggregates(self): self.assertEqual(repr(Avg('a')), "Avg(F(a))") - self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)") - self.assertEqual(repr(Count('*')), "Count('*', distinct=False)") + self.assertEqual(repr(Count('a')), "Count(F(a))") + self.assertEqual(repr(Count('*')), "Count('*')") self.assertEqual(repr(Max('a')), "Max(F(a))") self.assertEqual(repr(Min('a')), "Min(F(a))") self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)") self.assertEqual(repr(Sum('a')), "Sum(F(a))") self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)") + def test_distinct_aggregates(self): + self.assertEqual(repr(Count('a', distinct=True)), "Count(F(a), distinct=True)") + self.assertEqual(repr(Count('*', distinct=True)), "Count('*', distinct=True)") + def test_filtered_aggregates(self): filter = Q(a=1) self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))") - self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), distinct=False, filter=(AND: ('a', 1)))") + self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), filter=(AND: ('a', 1)))") self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))") self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))") self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)") @@ -1501,6 +1505,9 @@ class ReprTests(SimpleTestCase): repr(Variance('a', sample=True, filter=filter)), "Variance(F(a), filter=(AND: ('a', 1)), sample=True)" ) + self.assertEqual( + repr(Count('a', filter=filter, distinct=True)), "Count(F(a), distinct=True, filter=(AND: ('a', 1)))" + ) class CombinableTests(SimpleTestCase):