Fixed #28658 -- Added DISTINCT handling to the Aggregate class.
This commit is contained in:
parent
222caab68a
commit
bc05547cd8
|
@ -11,14 +11,12 @@ __all__ = [
|
||||||
class ArrayAgg(OrderableAggMixin, Aggregate):
|
class ArrayAgg(OrderableAggMixin, Aggregate):
|
||||||
function = 'ARRAY_AGG'
|
function = 'ARRAY_AGG'
|
||||||
template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)'
|
template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)'
|
||||||
|
allow_distinct = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_field(self):
|
def output_field(self):
|
||||||
return ArrayField(self.source_expressions[0].output_field)
|
return ArrayField(self.source_expressions[0].output_field)
|
||||||
|
|
||||||
def __init__(self, expression, distinct=False, **extra):
|
|
||||||
super().__init__(expression, distinct='DISTINCT ' if distinct else '', **extra)
|
|
||||||
|
|
||||||
def convert_value(self, value, expression, connection):
|
def convert_value(self, value, expression, connection):
|
||||||
if not value:
|
if not value:
|
||||||
return []
|
return []
|
||||||
|
@ -54,10 +52,10 @@ class JSONBAgg(Aggregate):
|
||||||
class StringAgg(OrderableAggMixin, Aggregate):
|
class StringAgg(OrderableAggMixin, Aggregate):
|
||||||
function = 'STRING_AGG'
|
function = 'STRING_AGG'
|
||||||
template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)"
|
template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)"
|
||||||
|
allow_distinct = True
|
||||||
|
|
||||||
def __init__(self, expression, delimiter, distinct=False, **extra):
|
def __init__(self, expression, delimiter, **extra):
|
||||||
distinct = 'DISTINCT ' if distinct else ''
|
super().__init__(expression, delimiter=delimiter, **extra)
|
||||||
super().__init__(expression, delimiter=delimiter, distinct=distinct, **extra)
|
|
||||||
|
|
||||||
def convert_value(self, value, expression, connection):
|
def convert_value(self, value, expression, connection):
|
||||||
if not value:
|
if not value:
|
||||||
|
|
|
@ -57,6 +57,11 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
'aggregations on date/time fields in sqlite3 '
|
'aggregations on date/time fields in sqlite3 '
|
||||||
'since date/time is saved as text.'
|
'since date/time is saved as text.'
|
||||||
)
|
)
|
||||||
|
if isinstance(expression, aggregates.Aggregate) and len(expression.source_expressions) > 1:
|
||||||
|
raise utils.NotSupportedError(
|
||||||
|
"SQLite doesn't support DISTINCT on aggregate functions "
|
||||||
|
"accepting multiple arguments."
|
||||||
|
)
|
||||||
|
|
||||||
def date_extract_sql(self, lookup_type, field_name):
|
def date_extract_sql(self, lookup_type, field_name):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -11,14 +11,19 @@ __all__ = [
|
||||||
|
|
||||||
|
|
||||||
class Aggregate(Func):
|
class Aggregate(Func):
|
||||||
|
template = '%(function)s(%(distinct)s%(expressions)s)'
|
||||||
contains_aggregate = True
|
contains_aggregate = True
|
||||||
name = None
|
name = None
|
||||||
filter_template = '%s FILTER (WHERE %%(filter)s)'
|
filter_template = '%s FILTER (WHERE %%(filter)s)'
|
||||||
window_compatible = True
|
window_compatible = True
|
||||||
|
allow_distinct = False
|
||||||
|
|
||||||
def __init__(self, *args, filter=None, **kwargs):
|
def __init__(self, *expressions, distinct=False, filter=None, **extra):
|
||||||
|
if distinct and not self.allow_distinct:
|
||||||
|
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
|
||||||
|
self.distinct = distinct
|
||||||
self.filter = filter
|
self.filter = filter
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
def get_source_fields(self):
|
def get_source_fields(self):
|
||||||
# Don't return the filter expression since it's not a source field.
|
# Don't return the filter expression since it's not a source field.
|
||||||
|
@ -60,6 +65,7 @@ class Aggregate(Func):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, **extra_context):
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
|
extra_context['distinct'] = 'DISTINCT' if self.distinct else ''
|
||||||
if self.filter:
|
if self.filter:
|
||||||
if connection.features.supports_aggregate_filter_clause:
|
if connection.features.supports_aggregate_filter_clause:
|
||||||
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
||||||
|
@ -80,8 +86,10 @@ class Aggregate(Func):
|
||||||
|
|
||||||
def _get_repr_options(self):
|
def _get_repr_options(self):
|
||||||
options = super()._get_repr_options()
|
options = super()._get_repr_options()
|
||||||
|
if self.distinct:
|
||||||
|
options['distinct'] = self.distinct
|
||||||
if self.filter:
|
if self.filter:
|
||||||
options.update({'filter': self.filter})
|
options['filter'] = self.filter
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,21 +122,15 @@ class Avg(Aggregate):
|
||||||
class Count(Aggregate):
|
class Count(Aggregate):
|
||||||
function = 'COUNT'
|
function = 'COUNT'
|
||||||
name = 'Count'
|
name = 'Count'
|
||||||
template = '%(function)s(%(distinct)s%(expressions)s)'
|
|
||||||
output_field = IntegerField()
|
output_field = IntegerField()
|
||||||
|
allow_distinct = True
|
||||||
|
|
||||||
def __init__(self, expression, distinct=False, filter=None, **extra):
|
def __init__(self, expression, filter=None, **extra):
|
||||||
if expression == '*':
|
if expression == '*':
|
||||||
expression = Star()
|
expression = Star()
|
||||||
if isinstance(expression, Star) and filter is not None:
|
if isinstance(expression, Star) and filter is not None:
|
||||||
raise ValueError('Star cannot be used with filter. Please specify a field.')
|
raise ValueError('Star cannot be used with filter. Please specify a field.')
|
||||||
super().__init__(
|
super().__init__(expression, filter=filter, **extra)
|
||||||
expression, distinct='DISTINCT ' if distinct else '',
|
|
||||||
filter=filter, **extra
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_repr_options(self):
|
|
||||||
return {**super()._get_repr_options(), 'distinct': self.extra['distinct'] != ''}
|
|
||||||
|
|
||||||
def convert_value(self, value, expression, connection):
|
def convert_value(self, value, expression, connection):
|
||||||
return 0 if value is None else value
|
return 0 if value is None else value
|
||||||
|
|
|
@ -373,7 +373,7 @@ some complex computations::
|
||||||
|
|
||||||
The ``Aggregate`` API is as follows:
|
The ``Aggregate`` API is as follows:
|
||||||
|
|
||||||
.. class:: Aggregate(*expressions, output_field=None, filter=None, **extra)
|
.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, **extra)
|
||||||
|
|
||||||
.. attribute:: template
|
.. attribute:: template
|
||||||
|
|
||||||
|
@ -392,6 +392,14 @@ The ``Aggregate`` API is as follows:
|
||||||
Defaults to ``True`` since most aggregate functions can be used as the
|
Defaults to ``True`` since most aggregate functions can be used as the
|
||||||
source expression in :class:`~django.db.models.expressions.Window`.
|
source expression in :class:`~django.db.models.expressions.Window`.
|
||||||
|
|
||||||
|
.. attribute:: allow_distinct
|
||||||
|
|
||||||
|
.. versionadded:: 2.2
|
||||||
|
|
||||||
|
A class attribute determining whether or not this aggregate function
|
||||||
|
allows passing a ``distinct`` keyword argument. If set to ``False``
|
||||||
|
(default), ``TypeError`` is raised if ``distinct=True`` is passed.
|
||||||
|
|
||||||
The ``expressions`` positional arguments can include expressions or the names
|
The ``expressions`` positional arguments can include expressions or the names
|
||||||
of model fields. They will be converted to a string and used as the
|
of model fields. They will be converted to a string and used as the
|
||||||
``expressions`` placeholder within the ``template``.
|
``expressions`` placeholder within the ``template``.
|
||||||
|
@ -409,6 +417,11 @@ should define the desired ``output_field``. For example, adding an
|
||||||
``IntegerField()`` and a ``FloatField()`` together should probably have
|
``IntegerField()`` and a ``FloatField()`` together should probably have
|
||||||
``output_field=FloatField()`` defined.
|
``output_field=FloatField()`` defined.
|
||||||
|
|
||||||
|
The ``distinct`` argument determines whether or not the aggregate function
|
||||||
|
should be invoked for each distinct value of ``expressions`` (or set of
|
||||||
|
values, for multiple ``expressions``). The argument is only supported on
|
||||||
|
aggregates that have :attr:`~Aggregate.allow_distinct` set to ``True``.
|
||||||
|
|
||||||
The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
|
The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
|
||||||
used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
|
used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
|
||||||
and :ref:`filtering-on-annotations` for example usage.
|
and :ref:`filtering-on-annotations` for example usage.
|
||||||
|
@ -416,6 +429,10 @@ and :ref:`filtering-on-annotations` for example usage.
|
||||||
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
|
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
|
||||||
into the ``template`` attribute.
|
into the ``template`` attribute.
|
||||||
|
|
||||||
|
.. versionadded:: 2.2
|
||||||
|
|
||||||
|
The ``allow_distinct`` attribute and ``distinct`` argument were added.
|
||||||
|
|
||||||
Creating your own Aggregate Functions
|
Creating your own Aggregate Functions
|
||||||
-------------------------------------
|
-------------------------------------
|
||||||
|
|
||||||
|
|
|
@ -239,6 +239,13 @@ Models
|
||||||
* Added SQLite support for the :class:`~django.db.models.StdDev` and
|
* Added SQLite support for the :class:`~django.db.models.StdDev` and
|
||||||
:class:`~django.db.models.Variance` functions.
|
:class:`~django.db.models.Variance` functions.
|
||||||
|
|
||||||
|
* The handling of ``DISTINCT`` aggregation is added to the
|
||||||
|
:class:`~django.db.models.Aggregate` class. Adding :attr:`allow_distinct =
|
||||||
|
True <django.db.models.Aggregate.allow_distinct>` as a class attribute on
|
||||||
|
``Aggregate`` subclasses allows a ``distinct`` keyword argument to be
|
||||||
|
specified on initialization to ensure that the aggregate function is only
|
||||||
|
called for each distinct value of ``expressions``.
|
||||||
|
|
||||||
Requests and Responses
|
Requests and Responses
|
||||||
~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -1026,7 +1026,7 @@ class AggregateTestCase(TestCase):
|
||||||
# test completely changing how the output is rendered
|
# test completely changing how the output is rendered
|
||||||
def lower_case_function_override(self, compiler, connection):
|
def lower_case_function_override(self, compiler, connection):
|
||||||
sql, params = compiler.compile(self.source_expressions[0])
|
sql, params = compiler.compile(self.source_expressions[0])
|
||||||
substitutions = {'function': self.function.lower(), 'expressions': sql}
|
substitutions = {'function': self.function.lower(), 'expressions': sql, 'distinct': ''}
|
||||||
substitutions.update(self.extra)
|
substitutions.update(self.extra)
|
||||||
return self.template % substitutions, params
|
return self.template % substitutions, params
|
||||||
setattr(MySum, 'as_' + connection.vendor, lower_case_function_override)
|
setattr(MySum, 'as_' + connection.vendor, lower_case_function_override)
|
||||||
|
@ -1053,7 +1053,7 @@ class AggregateTestCase(TestCase):
|
||||||
|
|
||||||
# test overriding all parts of the template
|
# test overriding all parts of the template
|
||||||
def be_evil(self, compiler, connection):
|
def be_evil(self, compiler, connection):
|
||||||
substitutions = {'function': 'MAX', 'expressions': '2'}
|
substitutions = {'function': 'MAX', 'expressions': '2', 'distinct': ''}
|
||||||
substitutions.update(self.extra)
|
substitutions.update(self.extra)
|
||||||
return self.template % substitutions, ()
|
return self.template % substitutions, ()
|
||||||
setattr(MySum, 'as_' + connection.vendor, be_evil)
|
setattr(MySum, 'as_' + connection.vendor, be_evil)
|
||||||
|
|
|
@ -11,6 +11,7 @@ from django.db.models import (
|
||||||
Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum,
|
Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum,
|
||||||
Value, Variance, When,
|
Value, Variance, When,
|
||||||
)
|
)
|
||||||
|
from django.db.models.aggregates import Aggregate
|
||||||
from django.test import (
|
from django.test import (
|
||||||
TestCase, ignore_warnings, skipUnlessAnyDBFeature, skipUnlessDBFeature,
|
TestCase, ignore_warnings, skipUnlessAnyDBFeature, skipUnlessDBFeature,
|
||||||
)
|
)
|
||||||
|
@ -1496,6 +1497,16 @@ class AggregationTests(TestCase):
|
||||||
qs = Author.objects.values_list('age', flat=True).annotate(age_count=Count('age')).filter(age_count__gt=1)
|
qs = Author.objects.values_list('age', flat=True).annotate(age_count=Count('age')).filter(age_count__gt=1)
|
||||||
self.assertSequenceEqual(qs, [29])
|
self.assertSequenceEqual(qs, [29])
|
||||||
|
|
||||||
|
def test_allow_distinct(self):
|
||||||
|
class MyAggregate(Aggregate):
|
||||||
|
pass
|
||||||
|
with self.assertRaisesMessage(TypeError, 'MyAggregate does not allow distinct'):
|
||||||
|
MyAggregate('foo', distinct=True)
|
||||||
|
|
||||||
|
class DistinctAggregate(Aggregate):
|
||||||
|
allow_distinct = True
|
||||||
|
DistinctAggregate('foo', distinct=True)
|
||||||
|
|
||||||
|
|
||||||
class JoinPromotionTests(TestCase):
|
class JoinPromotionTests(TestCase):
|
||||||
def test_ticket_21150(self):
|
def test_ticket_21150(self):
|
||||||
|
|
|
@ -4,6 +4,7 @@ import unittest
|
||||||
|
|
||||||
from django.db import connection, transaction
|
from django.db import connection, transaction
|
||||||
from django.db.models import Avg, StdDev, Sum, Variance
|
from django.db.models import Avg, StdDev, Sum, Variance
|
||||||
|
from django.db.models.aggregates import Aggregate
|
||||||
from django.db.models.fields import CharField
|
from django.db.models.fields import CharField
|
||||||
from django.db.utils import NotSupportedError
|
from django.db.utils import NotSupportedError
|
||||||
from django.test import (
|
from django.test import (
|
||||||
|
@ -34,6 +35,17 @@ class Tests(TestCase):
|
||||||
**{'complex': aggregate('last_modified') + aggregate('last_modified')}
|
**{'complex': aggregate('last_modified') + aggregate('last_modified')}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_distinct_aggregation(self):
|
||||||
|
class DistinctAggregate(Aggregate):
|
||||||
|
allow_distinct = True
|
||||||
|
aggregate = DistinctAggregate('first', 'second', distinct=True)
|
||||||
|
msg = (
|
||||||
|
"SQLite doesn't support DISTINCT on aggregate functions accepting "
|
||||||
|
"multiple arguments."
|
||||||
|
)
|
||||||
|
with self.assertRaisesMessage(NotSupportedError, msg):
|
||||||
|
connection.ops.check_expression_support(aggregate)
|
||||||
|
|
||||||
def test_memory_db_test_name(self):
|
def test_memory_db_test_name(self):
|
||||||
"""A named in-memory db should be allowed where supported."""
|
"""A named in-memory db should be allowed where supported."""
|
||||||
from django.db.backends.sqlite3.base import DatabaseWrapper
|
from django.db.backends.sqlite3.base import DatabaseWrapper
|
||||||
|
|
|
@ -1481,18 +1481,22 @@ class ReprTests(SimpleTestCase):
|
||||||
|
|
||||||
def test_aggregates(self):
|
def test_aggregates(self):
|
||||||
self.assertEqual(repr(Avg('a')), "Avg(F(a))")
|
self.assertEqual(repr(Avg('a')), "Avg(F(a))")
|
||||||
self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)")
|
self.assertEqual(repr(Count('a')), "Count(F(a))")
|
||||||
self.assertEqual(repr(Count('*')), "Count('*', distinct=False)")
|
self.assertEqual(repr(Count('*')), "Count('*')")
|
||||||
self.assertEqual(repr(Max('a')), "Max(F(a))")
|
self.assertEqual(repr(Max('a')), "Max(F(a))")
|
||||||
self.assertEqual(repr(Min('a')), "Min(F(a))")
|
self.assertEqual(repr(Min('a')), "Min(F(a))")
|
||||||
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
|
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
|
||||||
self.assertEqual(repr(Sum('a')), "Sum(F(a))")
|
self.assertEqual(repr(Sum('a')), "Sum(F(a))")
|
||||||
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
|
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
|
||||||
|
|
||||||
|
def test_distinct_aggregates(self):
|
||||||
|
self.assertEqual(repr(Count('a', distinct=True)), "Count(F(a), distinct=True)")
|
||||||
|
self.assertEqual(repr(Count('*', distinct=True)), "Count('*', distinct=True)")
|
||||||
|
|
||||||
def test_filtered_aggregates(self):
|
def test_filtered_aggregates(self):
|
||||||
filter = Q(a=1)
|
filter = Q(a=1)
|
||||||
self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))")
|
self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))")
|
||||||
self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), distinct=False, filter=(AND: ('a', 1)))")
|
self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), filter=(AND: ('a', 1)))")
|
||||||
self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))")
|
self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))")
|
||||||
self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))")
|
self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))")
|
||||||
self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)")
|
self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)")
|
||||||
|
@ -1501,6 +1505,9 @@ class ReprTests(SimpleTestCase):
|
||||||
repr(Variance('a', sample=True, filter=filter)),
|
repr(Variance('a', sample=True, filter=filter)),
|
||||||
"Variance(F(a), filter=(AND: ('a', 1)), sample=True)"
|
"Variance(F(a), filter=(AND: ('a', 1)), sample=True)"
|
||||||
)
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
repr(Count('a', filter=filter, distinct=True)), "Count(F(a), distinct=True, filter=(AND: ('a', 1)))"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CombinableTests(SimpleTestCase):
|
class CombinableTests(SimpleTestCase):
|
||||||
|
|
Loading…
Reference in New Issue