Fixed #26067 -- Added ordering support to ArrayAgg and StringAgg.

This commit is contained in:
Floris den Hengst 2016-07-05 11:47:24 +02:00 committed by Tim Graham
parent 2a0116266c
commit 96199e562d
5 changed files with 146 additions and 12 deletions

View File

@ -1,14 +1,16 @@
from django.contrib.postgres.fields import ArrayField, JSONField from django.contrib.postgres.fields import ArrayField, JSONField
from django.db.models.aggregates import Aggregate from django.db.models.aggregates import Aggregate
from .mixins import OrderableAggMixin
__all__ = [ __all__ = [
'ArrayAgg', 'BitAnd', 'BitOr', 'BoolAnd', 'BoolOr', 'JSONBAgg', 'StringAgg', 'ArrayAgg', 'BitAnd', 'BitOr', 'BoolAnd', 'BoolOr', 'JSONBAgg', 'StringAgg',
] ]
class ArrayAgg(Aggregate): class ArrayAgg(OrderableAggMixin, Aggregate):
function = 'ARRAY_AGG' function = 'ARRAY_AGG'
template = '%(function)s(%(distinct)s%(expressions)s)' template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)'
@property @property
def output_field(self): def output_field(self):
@ -49,9 +51,9 @@ class JSONBAgg(Aggregate):
return value return value
class StringAgg(Aggregate): class StringAgg(OrderableAggMixin, Aggregate):
function = 'STRING_AGG' function = 'STRING_AGG'
template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s')" template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)"
def __init__(self, expression, delimiter, distinct=False, **extra): def __init__(self, expression, delimiter, distinct=False, **extra):
distinct = 'DISTINCT ' if distinct else '' distinct = 'DISTINCT ' if distinct else ''

View File

@ -0,0 +1,47 @@
from django.db.models.expressions import F, OrderBy
class OrderableAggMixin:
def __init__(self, expression, ordering=(), **extra):
if not isinstance(ordering, (list, tuple)):
ordering = [ordering]
ordering = ordering or []
# Transform minus sign prefixed strings into an OrderBy() expression.
ordering = (
(OrderBy(F(o[1:]), descending=True) if isinstance(o, str) and o[0] == '-' else o)
for o in ordering
)
super().__init__(expression, **extra)
self.ordering = self._parse_expressions(*ordering)
def resolve_expression(self, *args, **kwargs):
self.ordering = [expr.resolve_expression(*args, **kwargs) for expr in self.ordering]
return super().resolve_expression(*args, **kwargs)
def as_sql(self, compiler, connection):
if self.ordering:
self.extra['ordering'] = 'ORDER BY ' + ', '.join((
ordering_element.as_sql(compiler, connection)[0]
for ordering_element in self.ordering
))
else:
self.extra['ordering'] = ''
return super().as_sql(compiler, connection)
def get_source_expressions(self):
return self.source_expressions + self.ordering
def get_source_fields(self):
# Filter out fields contributed by the ordering expressions as
# these should not be used to determine which the return type of the
# expression.
return [
e._output_field_or_none
for e in self.get_source_expressions()[:self._get_ordering_expressions_index()]
]
def _get_ordering_expressions_index(self):
"""Return the index at which the ordering expressions start."""
source_expressions = self.get_source_expressions()
return len(source_expressions) - len(self.ordering)

View File

@ -22,7 +22,7 @@ General-purpose aggregation functions
``ArrayAgg`` ``ArrayAgg``
------------ ------------
.. class:: ArrayAgg(expression, distinct=False, filter=None, **extra) .. class:: ArrayAgg(expression, distinct=False, filter=None, ordering=(), **extra)
Returns a list of values, including nulls, concatenated into an array. Returns a list of values, including nulls, concatenated into an array.
@ -31,6 +31,22 @@ General-purpose aggregation functions
An optional boolean argument that determines if array values An optional boolean argument that determines if array values
will be distinct. Defaults to ``False``. will be distinct. Defaults to ``False``.
.. attribute:: ordering
.. versionadded:: 2.2
An optional string of a field name (with an optional ``"-"`` prefix
which indicates descending order) or an expression (or a tuple or list
of strings and/or expressions) that specifies the ordering of the
elements in the result list.
Examples::
'some_field'
'-some_field'
from django.db.models import F
F('some_field').desc()
``BitAnd`` ``BitAnd``
---------- ----------
@ -73,7 +89,7 @@ General-purpose aggregation functions
``StringAgg`` ``StringAgg``
------------- -------------
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None) .. class:: StringAgg(expression, delimiter, distinct=False, filter=None, ordering=())
Returns the input values concatenated into a string, separated by Returns the input values concatenated into a string, separated by
the ``delimiter`` string. the ``delimiter`` string.
@ -87,6 +103,17 @@ General-purpose aggregation functions
An optional boolean argument that determines if concatenated values An optional boolean argument that determines if concatenated values
will be distinct. Defaults to ``False``. will be distinct. Defaults to ``False``.
.. attribute:: ordering
.. versionadded:: 2.2
An optional string of a field name (with an optional ``"-"`` prefix
which indicates descending order) or an expression (or a tuple or list
of strings and/or expressions) that specifies the ordering of the
elements in the result string.
Examples are the same as for :attr:`ArrayAgg.ordering`.
Aggregate functions for statistics Aggregate functions for statistics
================================== ==================================

View File

@ -70,7 +70,10 @@ Minor features
:mod:`django.contrib.postgres` :mod:`django.contrib.postgres`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* ... * The new ``ordering`` argument for
:class:`~django.contrib.postgres.aggregates.ArrayAgg` and
:class:`~django.contrib.postgres.aggregates.StringAgg` determines the
ordering of the aggregated elements.
:mod:`django.contrib.redirects` :mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -22,21 +22,57 @@ class TestGeneralAggregate(PostgreSQLTestCase):
def setUpTestData(cls): def setUpTestData(cls):
AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0) AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0)
AggregateTestModel.objects.create(boolean_field=False, char_field='Foo2', integer_field=1) AggregateTestModel.objects.create(boolean_field=False, char_field='Foo2', integer_field=1)
AggregateTestModel.objects.create(boolean_field=False, char_field='Foo3', integer_field=2) AggregateTestModel.objects.create(boolean_field=False, char_field='Foo4', integer_field=2)
AggregateTestModel.objects.create(boolean_field=True, char_field='Foo4', integer_field=0) AggregateTestModel.objects.create(boolean_field=True, char_field='Foo3', integer_field=0)
def test_array_agg_charfield(self): def test_array_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field')) values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo3', 'Foo4']}) self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
def test_array_agg_charfield_ordering(self):
ordering_test_cases = (
(F('char_field').desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
(F('char_field').asc(), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
(F('char_field'), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
([F('boolean_field'), F('char_field').desc()], ['Foo4', 'Foo2', 'Foo3', 'Foo1']),
((F('boolean_field'), F('char_field').desc()), ['Foo4', 'Foo2', 'Foo3', 'Foo1']),
('char_field', ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
('-char_field', ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
)
for ordering, expected_output in ordering_test_cases:
with self.subTest(ordering=ordering, expected_output=expected_output):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg('char_field', ordering=ordering)
)
self.assertEqual(values, {'arrayagg': expected_output})
def test_array_agg_integerfield(self): def test_array_agg_integerfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('integer_field')) values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('integer_field'))
self.assertEqual(values, {'arrayagg': [0, 1, 2, 0]}) self.assertEqual(values, {'arrayagg': [0, 1, 2, 0]})
def test_array_agg_integerfield_ordering(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg('integer_field', ordering=F('integer_field').desc())
)
self.assertEqual(values, {'arrayagg': [2, 1, 0, 0]})
def test_array_agg_booleanfield(self): def test_array_agg_booleanfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field')) values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field'))
self.assertEqual(values, {'arrayagg': [True, False, False, True]}) self.assertEqual(values, {'arrayagg': [True, False, False, True]})
def test_array_agg_booleanfield_ordering(self):
ordering_test_cases = (
(F('boolean_field').asc(), [False, False, True, True]),
(F('boolean_field').desc(), [True, True, False, False]),
(F('boolean_field'), [False, False, True, True]),
)
for ordering, expected_output in ordering_test_cases:
with self.subTest(ordering=ordering, expected_output=expected_output):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg('boolean_field', ordering=ordering)
)
self.assertEqual(values, {'arrayagg': expected_output})
def test_array_agg_empty_result(self): def test_array_agg_empty_result(self):
AggregateTestModel.objects.all().delete() AggregateTestModel.objects.all().delete()
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field')) values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
@ -122,17 +158,36 @@ class TestGeneralAggregate(PostgreSQLTestCase):
def test_string_agg_charfield(self): def test_string_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';')) values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';'))
self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo3;Foo4'}) self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo4;Foo3'})
def test_string_agg_charfield_ordering(self):
ordering_test_cases = (
(F('char_field').desc(), 'Foo4;Foo3;Foo2;Foo1'),
(F('char_field').asc(), 'Foo1;Foo2;Foo3;Foo4'),
(F('char_field'), 'Foo1;Foo2;Foo3;Foo4'),
)
for ordering, expected_output in ordering_test_cases:
with self.subTest(ordering=ordering, expected_output=expected_output):
values = AggregateTestModel.objects.aggregate(
stringagg=StringAgg('char_field', delimiter=';', ordering=ordering)
)
self.assertEqual(values, {'stringagg': expected_output})
def test_string_agg_empty_result(self): def test_string_agg_empty_result(self):
AggregateTestModel.objects.all().delete() AggregateTestModel.objects.all().delete()
values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';')) values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';'))
self.assertEqual(values, {'stringagg': ''}) self.assertEqual(values, {'stringagg': ''})
def test_orderable_agg_alternative_fields(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg('integer_field', ordering=F('char_field').asc())
)
self.assertEqual(values, {'arrayagg': [0, 1, 0, 2]})
@skipUnlessDBFeature('has_jsonb_agg') @skipUnlessDBFeature('has_jsonb_agg')
def test_json_agg(self): def test_json_agg(self):
values = AggregateTestModel.objects.aggregate(jsonagg=JSONBAgg('char_field')) values = AggregateTestModel.objects.aggregate(jsonagg=JSONBAgg('char_field'))
self.assertEqual(values, {'jsonagg': ['Foo1', 'Foo2', 'Foo3', 'Foo4']}) self.assertEqual(values, {'jsonagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
@skipUnlessDBFeature('has_jsonb_agg') @skipUnlessDBFeature('has_jsonb_agg')
def test_json_agg_empty(self): def test_json_agg_empty(self):