From 96199e562dcc409ab4bdc2b2146fa7cf73c7c5fe Mon Sep 17 00:00:00 2001 From: Floris den Hengst Date: Tue, 5 Jul 2016 11:47:24 +0200 Subject: [PATCH] Fixed #26067 -- Added ordering support to ArrayAgg and StringAgg. --- django/contrib/postgres/aggregates/general.py | 10 +-- django/contrib/postgres/aggregates/mixins.py | 47 ++++++++++++++ docs/ref/contrib/postgres/aggregates.txt | 31 ++++++++- docs/releases/2.2.txt | 5 +- tests/postgres_tests/test_aggregates.py | 65 +++++++++++++++++-- 5 files changed, 146 insertions(+), 12 deletions(-) create mode 100644 django/contrib/postgres/aggregates/mixins.py diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py index 806ecd1b78..4b2da0b101 100644 --- a/django/contrib/postgres/aggregates/general.py +++ b/django/contrib/postgres/aggregates/general.py @@ -1,14 +1,16 @@ from django.contrib.postgres.fields import ArrayField, JSONField from django.db.models.aggregates import Aggregate +from .mixins import OrderableAggMixin + __all__ = [ 'ArrayAgg', 'BitAnd', 'BitOr', 'BoolAnd', 'BoolOr', 'JSONBAgg', 'StringAgg', ] -class ArrayAgg(Aggregate): +class ArrayAgg(OrderableAggMixin, Aggregate): function = 'ARRAY_AGG' - template = '%(function)s(%(distinct)s%(expressions)s)' + template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)' @property def output_field(self): @@ -49,9 +51,9 @@ class JSONBAgg(Aggregate): return value -class StringAgg(Aggregate): +class StringAgg(OrderableAggMixin, Aggregate): function = 'STRING_AGG' - template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s')" + template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)" def __init__(self, expression, delimiter, distinct=False, **extra): distinct = 'DISTINCT ' if distinct else '' diff --git a/django/contrib/postgres/aggregates/mixins.py b/django/contrib/postgres/aggregates/mixins.py new file mode 100644 index 0000000000..b270a5b653 --- /dev/null +++ b/django/contrib/postgres/aggregates/mixins.py @@ -0,0 +1,47 @@ +from django.db.models.expressions import F, OrderBy + + +class OrderableAggMixin: + + def __init__(self, expression, ordering=(), **extra): + if not isinstance(ordering, (list, tuple)): + ordering = [ordering] + ordering = ordering or [] + # Transform minus sign prefixed strings into an OrderBy() expression. + ordering = ( + (OrderBy(F(o[1:]), descending=True) if isinstance(o, str) and o[0] == '-' else o) + for o in ordering + ) + super().__init__(expression, **extra) + self.ordering = self._parse_expressions(*ordering) + + def resolve_expression(self, *args, **kwargs): + self.ordering = [expr.resolve_expression(*args, **kwargs) for expr in self.ordering] + return super().resolve_expression(*args, **kwargs) + + def as_sql(self, compiler, connection): + if self.ordering: + self.extra['ordering'] = 'ORDER BY ' + ', '.join(( + ordering_element.as_sql(compiler, connection)[0] + for ordering_element in self.ordering + )) + else: + self.extra['ordering'] = '' + return super().as_sql(compiler, connection) + + def get_source_expressions(self): + return self.source_expressions + self.ordering + + def get_source_fields(self): + # Filter out fields contributed by the ordering expressions as + # these should not be used to determine which the return type of the + # expression. + return [ + e._output_field_or_none + for e in self.get_source_expressions()[:self._get_ordering_expressions_index()] + ] + + def _get_ordering_expressions_index(self): + """Return the index at which the ordering expressions start.""" + source_expressions = self.get_source_expressions() + return len(source_expressions) - len(self.ordering) diff --git a/docs/ref/contrib/postgres/aggregates.txt b/docs/ref/contrib/postgres/aggregates.txt index 480c230c40..a605bc831c 100644 --- a/docs/ref/contrib/postgres/aggregates.txt +++ b/docs/ref/contrib/postgres/aggregates.txt @@ -22,7 +22,7 @@ General-purpose aggregation functions ``ArrayAgg`` ------------ -.. class:: ArrayAgg(expression, distinct=False, filter=None, **extra) +.. class:: ArrayAgg(expression, distinct=False, filter=None, ordering=(), **extra) Returns a list of values, including nulls, concatenated into an array. @@ -31,6 +31,22 @@ General-purpose aggregation functions An optional boolean argument that determines if array values will be distinct. Defaults to ``False``. + .. attribute:: ordering + + .. versionadded:: 2.2 + + An optional string of a field name (with an optional ``"-"`` prefix + which indicates descending order) or an expression (or a tuple or list + of strings and/or expressions) that specifies the ordering of the + elements in the result list. + + Examples:: + + 'some_field' + '-some_field' + from django.db.models import F + F('some_field').desc() + ``BitAnd`` ---------- @@ -73,7 +89,7 @@ General-purpose aggregation functions ``StringAgg`` ------------- -.. class:: StringAgg(expression, delimiter, distinct=False, filter=None) +.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, ordering=()) Returns the input values concatenated into a string, separated by the ``delimiter`` string. @@ -87,6 +103,17 @@ General-purpose aggregation functions An optional boolean argument that determines if concatenated values will be distinct. Defaults to ``False``. + .. attribute:: ordering + + .. versionadded:: 2.2 + + An optional string of a field name (with an optional ``"-"`` prefix + which indicates descending order) or an expression (or a tuple or list + of strings and/or expressions) that specifies the ordering of the + elements in the result string. + + Examples are the same as for :attr:`ArrayAgg.ordering`. + Aggregate functions for statistics ================================== diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt index 840d4b4d0d..742f4893be 100644 --- a/docs/releases/2.2.txt +++ b/docs/releases/2.2.txt @@ -70,7 +70,10 @@ Minor features :mod:`django.contrib.postgres` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -* ... +* The new ``ordering`` argument for + :class:`~django.contrib.postgres.aggregates.ArrayAgg` and + :class:`~django.contrib.postgres.aggregates.StringAgg` determines the + ordering of the aggregated elements. :mod:`django.contrib.redirects` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/postgres_tests/test_aggregates.py b/tests/postgres_tests/test_aggregates.py index d4a01ff027..85d6f45fd1 100644 --- a/tests/postgres_tests/test_aggregates.py +++ b/tests/postgres_tests/test_aggregates.py @@ -22,21 +22,57 @@ class TestGeneralAggregate(PostgreSQLTestCase): def setUpTestData(cls): AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0) AggregateTestModel.objects.create(boolean_field=False, char_field='Foo2', integer_field=1) - AggregateTestModel.objects.create(boolean_field=False, char_field='Foo3', integer_field=2) - AggregateTestModel.objects.create(boolean_field=True, char_field='Foo4', integer_field=0) + AggregateTestModel.objects.create(boolean_field=False, char_field='Foo4', integer_field=2) + AggregateTestModel.objects.create(boolean_field=True, char_field='Foo3', integer_field=0) def test_array_agg_charfield(self): values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field')) - self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo3', 'Foo4']}) + self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']}) + + def test_array_agg_charfield_ordering(self): + ordering_test_cases = ( + (F('char_field').desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']), + (F('char_field').asc(), ['Foo1', 'Foo2', 'Foo3', 'Foo4']), + (F('char_field'), ['Foo1', 'Foo2', 'Foo3', 'Foo4']), + ([F('boolean_field'), F('char_field').desc()], ['Foo4', 'Foo2', 'Foo3', 'Foo1']), + ((F('boolean_field'), F('char_field').desc()), ['Foo4', 'Foo2', 'Foo3', 'Foo1']), + ('char_field', ['Foo1', 'Foo2', 'Foo3', 'Foo4']), + ('-char_field', ['Foo4', 'Foo3', 'Foo2', 'Foo1']), + ) + for ordering, expected_output in ordering_test_cases: + with self.subTest(ordering=ordering, expected_output=expected_output): + values = AggregateTestModel.objects.aggregate( + arrayagg=ArrayAgg('char_field', ordering=ordering) + ) + self.assertEqual(values, {'arrayagg': expected_output}) def test_array_agg_integerfield(self): values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('integer_field')) self.assertEqual(values, {'arrayagg': [0, 1, 2, 0]}) + def test_array_agg_integerfield_ordering(self): + values = AggregateTestModel.objects.aggregate( + arrayagg=ArrayAgg('integer_field', ordering=F('integer_field').desc()) + ) + self.assertEqual(values, {'arrayagg': [2, 1, 0, 0]}) + def test_array_agg_booleanfield(self): values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field')) self.assertEqual(values, {'arrayagg': [True, False, False, True]}) + def test_array_agg_booleanfield_ordering(self): + ordering_test_cases = ( + (F('boolean_field').asc(), [False, False, True, True]), + (F('boolean_field').desc(), [True, True, False, False]), + (F('boolean_field'), [False, False, True, True]), + ) + for ordering, expected_output in ordering_test_cases: + with self.subTest(ordering=ordering, expected_output=expected_output): + values = AggregateTestModel.objects.aggregate( + arrayagg=ArrayAgg('boolean_field', ordering=ordering) + ) + self.assertEqual(values, {'arrayagg': expected_output}) + def test_array_agg_empty_result(self): AggregateTestModel.objects.all().delete() values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field')) @@ -122,17 +158,36 @@ class TestGeneralAggregate(PostgreSQLTestCase): def test_string_agg_charfield(self): values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';')) - self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo3;Foo4'}) + self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo4;Foo3'}) + + def test_string_agg_charfield_ordering(self): + ordering_test_cases = ( + (F('char_field').desc(), 'Foo4;Foo3;Foo2;Foo1'), + (F('char_field').asc(), 'Foo1;Foo2;Foo3;Foo4'), + (F('char_field'), 'Foo1;Foo2;Foo3;Foo4'), + ) + for ordering, expected_output in ordering_test_cases: + with self.subTest(ordering=ordering, expected_output=expected_output): + values = AggregateTestModel.objects.aggregate( + stringagg=StringAgg('char_field', delimiter=';', ordering=ordering) + ) + self.assertEqual(values, {'stringagg': expected_output}) def test_string_agg_empty_result(self): AggregateTestModel.objects.all().delete() values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';')) self.assertEqual(values, {'stringagg': ''}) + def test_orderable_agg_alternative_fields(self): + values = AggregateTestModel.objects.aggregate( + arrayagg=ArrayAgg('integer_field', ordering=F('char_field').asc()) + ) + self.assertEqual(values, {'arrayagg': [0, 1, 0, 2]}) + @skipUnlessDBFeature('has_jsonb_agg') def test_json_agg(self): values = AggregateTestModel.objects.aggregate(jsonagg=JSONBAgg('char_field')) - self.assertEqual(values, {'jsonagg': ['Foo1', 'Foo2', 'Foo3', 'Foo4']}) + self.assertEqual(values, {'jsonagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']}) @skipUnlessDBFeature('has_jsonb_agg') def test_json_agg_empty(self):