Fixed #28658 -- Added DISTINCT handling to the Aggregate class.

This commit is contained in:
Simon Charette 2019-01-09 17:52:36 -05:00 committed by Tim Graham
parent 222caab68a
commit bc05547cd8
9 changed files with 83 additions and 24 deletions

View File

@ -11,14 +11,12 @@ __all__ = [
class ArrayAgg(OrderableAggMixin, Aggregate): class ArrayAgg(OrderableAggMixin, Aggregate):
function = 'ARRAY_AGG' function = 'ARRAY_AGG'
template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)' template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)'
allow_distinct = True
@property @property
def output_field(self): def output_field(self):
return ArrayField(self.source_expressions[0].output_field) return ArrayField(self.source_expressions[0].output_field)
def __init__(self, expression, distinct=False, **extra):
super().__init__(expression, distinct='DISTINCT ' if distinct else '', **extra)
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if not value: if not value:
return [] return []
@ -54,10 +52,10 @@ class JSONBAgg(Aggregate):
class StringAgg(OrderableAggMixin, Aggregate): class StringAgg(OrderableAggMixin, Aggregate):
function = 'STRING_AGG' function = 'STRING_AGG'
template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)" template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)"
allow_distinct = True
def __init__(self, expression, delimiter, distinct=False, **extra): def __init__(self, expression, delimiter, **extra):
distinct = 'DISTINCT ' if distinct else '' super().__init__(expression, delimiter=delimiter, **extra)
super().__init__(expression, delimiter=delimiter, distinct=distinct, **extra)
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if not value: if not value:

View File

@ -57,6 +57,11 @@ class DatabaseOperations(BaseDatabaseOperations):
'aggregations on date/time fields in sqlite3 ' 'aggregations on date/time fields in sqlite3 '
'since date/time is saved as text.' 'since date/time is saved as text.'
) )
if isinstance(expression, aggregates.Aggregate) and len(expression.source_expressions) > 1:
raise utils.NotSupportedError(
"SQLite doesn't support DISTINCT on aggregate functions "
"accepting multiple arguments."
)
def date_extract_sql(self, lookup_type, field_name): def date_extract_sql(self, lookup_type, field_name):
""" """

View File

@ -11,14 +11,19 @@ __all__ = [
class Aggregate(Func): class Aggregate(Func):
template = '%(function)s(%(distinct)s%(expressions)s)'
contains_aggregate = True contains_aggregate = True
name = None name = None
filter_template = '%s FILTER (WHERE %%(filter)s)' filter_template = '%s FILTER (WHERE %%(filter)s)'
window_compatible = True window_compatible = True
allow_distinct = False
def __init__(self, *args, filter=None, **kwargs): def __init__(self, *expressions, distinct=False, filter=None, **extra):
if distinct and not self.allow_distinct:
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
self.distinct = distinct
self.filter = filter self.filter = filter
super().__init__(*args, **kwargs) super().__init__(*expressions, **extra)
def get_source_fields(self): def get_source_fields(self):
# Don't return the filter expression since it's not a source field. # Don't return the filter expression since it's not a source field.
@ -60,6 +65,7 @@ class Aggregate(Func):
return [] return []
def as_sql(self, compiler, connection, **extra_context): def as_sql(self, compiler, connection, **extra_context):
extra_context['distinct'] = 'DISTINCT' if self.distinct else ''
if self.filter: if self.filter:
if connection.features.supports_aggregate_filter_clause: if connection.features.supports_aggregate_filter_clause:
filter_sql, filter_params = self.filter.as_sql(compiler, connection) filter_sql, filter_params = self.filter.as_sql(compiler, connection)
@ -80,8 +86,10 @@ class Aggregate(Func):
def _get_repr_options(self): def _get_repr_options(self):
options = super()._get_repr_options() options = super()._get_repr_options()
if self.distinct:
options['distinct'] = self.distinct
if self.filter: if self.filter:
options.update({'filter': self.filter}) options['filter'] = self.filter
return options return options
@ -114,21 +122,15 @@ class Avg(Aggregate):
class Count(Aggregate): class Count(Aggregate):
function = 'COUNT' function = 'COUNT'
name = 'Count' name = 'Count'
template = '%(function)s(%(distinct)s%(expressions)s)'
output_field = IntegerField() output_field = IntegerField()
allow_distinct = True
def __init__(self, expression, distinct=False, filter=None, **extra): def __init__(self, expression, filter=None, **extra):
if expression == '*': if expression == '*':
expression = Star() expression = Star()
if isinstance(expression, Star) and filter is not None: if isinstance(expression, Star) and filter is not None:
raise ValueError('Star cannot be used with filter. Please specify a field.') raise ValueError('Star cannot be used with filter. Please specify a field.')
super().__init__( super().__init__(expression, filter=filter, **extra)
expression, distinct='DISTINCT ' if distinct else '',
filter=filter, **extra
)
def _get_repr_options(self):
return {**super()._get_repr_options(), 'distinct': self.extra['distinct'] != ''}
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
return 0 if value is None else value return 0 if value is None else value

View File

@ -373,7 +373,7 @@ some complex computations::
The ``Aggregate`` API is as follows: The ``Aggregate`` API is as follows:
.. class:: Aggregate(*expressions, output_field=None, filter=None, **extra) .. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, **extra)
.. attribute:: template .. attribute:: template
@ -392,6 +392,14 @@ The ``Aggregate`` API is as follows:
Defaults to ``True`` since most aggregate functions can be used as the Defaults to ``True`` since most aggregate functions can be used as the
source expression in :class:`~django.db.models.expressions.Window`. source expression in :class:`~django.db.models.expressions.Window`.
.. attribute:: allow_distinct
.. versionadded:: 2.2
A class attribute determining whether or not this aggregate function
allows passing a ``distinct`` keyword argument. If set to ``False``
(default), ``TypeError`` is raised if ``distinct=True`` is passed.
The ``expressions`` positional arguments can include expressions or the names The ``expressions`` positional arguments can include expressions or the names
of model fields. They will be converted to a string and used as the of model fields. They will be converted to a string and used as the
``expressions`` placeholder within the ``template``. ``expressions`` placeholder within the ``template``.
@ -409,6 +417,11 @@ should define the desired ``output_field``. For example, adding an
``IntegerField()`` and a ``FloatField()`` together should probably have ``IntegerField()`` and a ``FloatField()`` together should probably have
``output_field=FloatField()`` defined. ``output_field=FloatField()`` defined.
The ``distinct`` argument determines whether or not the aggregate function
should be invoked for each distinct value of ``expressions`` (or set of
values, for multiple ``expressions``). The argument is only supported on
aggregates that have :attr:`~Aggregate.allow_distinct` set to ``True``.
The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
used to filter the rows that are aggregated. See :ref:`conditional-aggregation` used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
and :ref:`filtering-on-annotations` for example usage. and :ref:`filtering-on-annotations` for example usage.
@ -416,6 +429,10 @@ and :ref:`filtering-on-annotations` for example usage.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. into the ``template`` attribute.
.. versionadded:: 2.2
The ``allow_distinct`` attribute and ``distinct`` argument were added.
Creating your own Aggregate Functions Creating your own Aggregate Functions
------------------------------------- -------------------------------------

View File

@ -239,6 +239,13 @@ Models
* Added SQLite support for the :class:`~django.db.models.StdDev` and * Added SQLite support for the :class:`~django.db.models.StdDev` and
:class:`~django.db.models.Variance` functions. :class:`~django.db.models.Variance` functions.
* The handling of ``DISTINCT`` aggregation is added to the
:class:`~django.db.models.Aggregate` class. Adding :attr:`allow_distinct =
True <django.db.models.Aggregate.allow_distinct>` as a class attribute on
``Aggregate`` subclasses allows a ``distinct`` keyword argument to be
specified on initialization to ensure that the aggregate function is only
called for each distinct value of ``expressions``.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1026,7 +1026,7 @@ class AggregateTestCase(TestCase):
# test completely changing how the output is rendered # test completely changing how the output is rendered
def lower_case_function_override(self, compiler, connection): def lower_case_function_override(self, compiler, connection):
sql, params = compiler.compile(self.source_expressions[0]) sql, params = compiler.compile(self.source_expressions[0])
substitutions = {'function': self.function.lower(), 'expressions': sql} substitutions = {'function': self.function.lower(), 'expressions': sql, 'distinct': ''}
substitutions.update(self.extra) substitutions.update(self.extra)
return self.template % substitutions, params return self.template % substitutions, params
setattr(MySum, 'as_' + connection.vendor, lower_case_function_override) setattr(MySum, 'as_' + connection.vendor, lower_case_function_override)
@ -1053,7 +1053,7 @@ class AggregateTestCase(TestCase):
# test overriding all parts of the template # test overriding all parts of the template
def be_evil(self, compiler, connection): def be_evil(self, compiler, connection):
substitutions = {'function': 'MAX', 'expressions': '2'} substitutions = {'function': 'MAX', 'expressions': '2', 'distinct': ''}
substitutions.update(self.extra) substitutions.update(self.extra)
return self.template % substitutions, () return self.template % substitutions, ()
setattr(MySum, 'as_' + connection.vendor, be_evil) setattr(MySum, 'as_' + connection.vendor, be_evil)

View File

@ -11,6 +11,7 @@ from django.db.models import (
Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum, Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum,
Value, Variance, When, Value, Variance, When,
) )
from django.db.models.aggregates import Aggregate
from django.test import ( from django.test import (
TestCase, ignore_warnings, skipUnlessAnyDBFeature, skipUnlessDBFeature, TestCase, ignore_warnings, skipUnlessAnyDBFeature, skipUnlessDBFeature,
) )
@ -1496,6 +1497,16 @@ class AggregationTests(TestCase):
qs = Author.objects.values_list('age', flat=True).annotate(age_count=Count('age')).filter(age_count__gt=1) qs = Author.objects.values_list('age', flat=True).annotate(age_count=Count('age')).filter(age_count__gt=1)
self.assertSequenceEqual(qs, [29]) self.assertSequenceEqual(qs, [29])
def test_allow_distinct(self):
class MyAggregate(Aggregate):
pass
with self.assertRaisesMessage(TypeError, 'MyAggregate does not allow distinct'):
MyAggregate('foo', distinct=True)
class DistinctAggregate(Aggregate):
allow_distinct = True
DistinctAggregate('foo', distinct=True)
class JoinPromotionTests(TestCase): class JoinPromotionTests(TestCase):
def test_ticket_21150(self): def test_ticket_21150(self):

View File

@ -4,6 +4,7 @@ import unittest
from django.db import connection, transaction from django.db import connection, transaction
from django.db.models import Avg, StdDev, Sum, Variance from django.db.models import Avg, StdDev, Sum, Variance
from django.db.models.aggregates import Aggregate
from django.db.models.fields import CharField from django.db.models.fields import CharField
from django.db.utils import NotSupportedError from django.db.utils import NotSupportedError
from django.test import ( from django.test import (
@ -34,6 +35,17 @@ class Tests(TestCase):
**{'complex': aggregate('last_modified') + aggregate('last_modified')} **{'complex': aggregate('last_modified') + aggregate('last_modified')}
) )
def test_distinct_aggregation(self):
class DistinctAggregate(Aggregate):
allow_distinct = True
aggregate = DistinctAggregate('first', 'second', distinct=True)
msg = (
"SQLite doesn't support DISTINCT on aggregate functions accepting "
"multiple arguments."
)
with self.assertRaisesMessage(NotSupportedError, msg):
connection.ops.check_expression_support(aggregate)
def test_memory_db_test_name(self): def test_memory_db_test_name(self):
"""A named in-memory db should be allowed where supported.""" """A named in-memory db should be allowed where supported."""
from django.db.backends.sqlite3.base import DatabaseWrapper from django.db.backends.sqlite3.base import DatabaseWrapper

View File

@ -1481,18 +1481,22 @@ class ReprTests(SimpleTestCase):
def test_aggregates(self): def test_aggregates(self):
self.assertEqual(repr(Avg('a')), "Avg(F(a))") self.assertEqual(repr(Avg('a')), "Avg(F(a))")
self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)") self.assertEqual(repr(Count('a')), "Count(F(a))")
self.assertEqual(repr(Count('*')), "Count('*', distinct=False)") self.assertEqual(repr(Count('*')), "Count('*')")
self.assertEqual(repr(Max('a')), "Max(F(a))") self.assertEqual(repr(Max('a')), "Max(F(a))")
self.assertEqual(repr(Min('a')), "Min(F(a))") self.assertEqual(repr(Min('a')), "Min(F(a))")
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)") self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
self.assertEqual(repr(Sum('a')), "Sum(F(a))") self.assertEqual(repr(Sum('a')), "Sum(F(a))")
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)") self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
def test_distinct_aggregates(self):
self.assertEqual(repr(Count('a', distinct=True)), "Count(F(a), distinct=True)")
self.assertEqual(repr(Count('*', distinct=True)), "Count('*', distinct=True)")
def test_filtered_aggregates(self): def test_filtered_aggregates(self):
filter = Q(a=1) filter = Q(a=1)
self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))") self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), distinct=False, filter=(AND: ('a', 1)))") self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))") self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))") self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))")
self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)") self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)")
@ -1501,6 +1505,9 @@ class ReprTests(SimpleTestCase):
repr(Variance('a', sample=True, filter=filter)), repr(Variance('a', sample=True, filter=filter)),
"Variance(F(a), filter=(AND: ('a', 1)), sample=True)" "Variance(F(a), filter=(AND: ('a', 1)), sample=True)"
) )
self.assertEqual(
repr(Count('a', filter=filter, distinct=True)), "Count(F(a), distinct=True, filter=(AND: ('a', 1)))"
)
class CombinableTests(SimpleTestCase): class CombinableTests(SimpleTestCase):