Fixed #26067 -- Added ordering support to ArrayAgg and StringAgg.
This commit is contained in:
parent
2a0116266c
commit
96199e562d
|
@ -1,14 +1,16 @@
|
|||
from django.contrib.postgres.fields import ArrayField, JSONField
|
||||
from django.db.models.aggregates import Aggregate
|
||||
|
||||
from .mixins import OrderableAggMixin
|
||||
|
||||
__all__ = [
|
||||
'ArrayAgg', 'BitAnd', 'BitOr', 'BoolAnd', 'BoolOr', 'JSONBAgg', 'StringAgg',
|
||||
]
|
||||
|
||||
|
||||
class ArrayAgg(Aggregate):
|
||||
class ArrayAgg(OrderableAggMixin, Aggregate):
|
||||
function = 'ARRAY_AGG'
|
||||
template = '%(function)s(%(distinct)s%(expressions)s)'
|
||||
template = '%(function)s(%(distinct)s%(expressions)s %(ordering)s)'
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
|
@ -49,9 +51,9 @@ class JSONBAgg(Aggregate):
|
|||
return value
|
||||
|
||||
|
||||
class StringAgg(Aggregate):
|
||||
class StringAgg(OrderableAggMixin, Aggregate):
|
||||
function = 'STRING_AGG'
|
||||
template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s')"
|
||||
template = "%(function)s(%(distinct)s%(expressions)s, '%(delimiter)s'%(ordering)s)"
|
||||
|
||||
def __init__(self, expression, delimiter, distinct=False, **extra):
|
||||
distinct = 'DISTINCT ' if distinct else ''
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
from django.db.models.expressions import F, OrderBy
|
||||
|
||||
|
||||
class OrderableAggMixin:
|
||||
|
||||
def __init__(self, expression, ordering=(), **extra):
|
||||
if not isinstance(ordering, (list, tuple)):
|
||||
ordering = [ordering]
|
||||
ordering = ordering or []
|
||||
# Transform minus sign prefixed strings into an OrderBy() expression.
|
||||
ordering = (
|
||||
(OrderBy(F(o[1:]), descending=True) if isinstance(o, str) and o[0] == '-' else o)
|
||||
for o in ordering
|
||||
)
|
||||
super().__init__(expression, **extra)
|
||||
self.ordering = self._parse_expressions(*ordering)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
self.ordering = [expr.resolve_expression(*args, **kwargs) for expr in self.ordering]
|
||||
return super().resolve_expression(*args, **kwargs)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if self.ordering:
|
||||
self.extra['ordering'] = 'ORDER BY ' + ', '.join((
|
||||
ordering_element.as_sql(compiler, connection)[0]
|
||||
for ordering_element in self.ordering
|
||||
))
|
||||
else:
|
||||
self.extra['ordering'] = ''
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
def get_source_expressions(self):
|
||||
return self.source_expressions + self.ordering
|
||||
|
||||
def get_source_fields(self):
|
||||
# Filter out fields contributed by the ordering expressions as
|
||||
# these should not be used to determine which the return type of the
|
||||
# expression.
|
||||
return [
|
||||
e._output_field_or_none
|
||||
for e in self.get_source_expressions()[:self._get_ordering_expressions_index()]
|
||||
]
|
||||
|
||||
def _get_ordering_expressions_index(self):
|
||||
"""Return the index at which the ordering expressions start."""
|
||||
source_expressions = self.get_source_expressions()
|
||||
return len(source_expressions) - len(self.ordering)
|
|
@ -22,7 +22,7 @@ General-purpose aggregation functions
|
|||
``ArrayAgg``
|
||||
------------
|
||||
|
||||
.. class:: ArrayAgg(expression, distinct=False, filter=None, **extra)
|
||||
.. class:: ArrayAgg(expression, distinct=False, filter=None, ordering=(), **extra)
|
||||
|
||||
Returns a list of values, including nulls, concatenated into an array.
|
||||
|
||||
|
@ -31,6 +31,22 @@ General-purpose aggregation functions
|
|||
An optional boolean argument that determines if array values
|
||||
will be distinct. Defaults to ``False``.
|
||||
|
||||
.. attribute:: ordering
|
||||
|
||||
.. versionadded:: 2.2
|
||||
|
||||
An optional string of a field name (with an optional ``"-"`` prefix
|
||||
which indicates descending order) or an expression (or a tuple or list
|
||||
of strings and/or expressions) that specifies the ordering of the
|
||||
elements in the result list.
|
||||
|
||||
Examples::
|
||||
|
||||
'some_field'
|
||||
'-some_field'
|
||||
from django.db.models import F
|
||||
F('some_field').desc()
|
||||
|
||||
``BitAnd``
|
||||
----------
|
||||
|
||||
|
@ -73,7 +89,7 @@ General-purpose aggregation functions
|
|||
``StringAgg``
|
||||
-------------
|
||||
|
||||
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None)
|
||||
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, ordering=())
|
||||
|
||||
Returns the input values concatenated into a string, separated by
|
||||
the ``delimiter`` string.
|
||||
|
@ -87,6 +103,17 @@ General-purpose aggregation functions
|
|||
An optional boolean argument that determines if concatenated values
|
||||
will be distinct. Defaults to ``False``.
|
||||
|
||||
.. attribute:: ordering
|
||||
|
||||
.. versionadded:: 2.2
|
||||
|
||||
An optional string of a field name (with an optional ``"-"`` prefix
|
||||
which indicates descending order) or an expression (or a tuple or list
|
||||
of strings and/or expressions) that specifies the ordering of the
|
||||
elements in the result string.
|
||||
|
||||
Examples are the same as for :attr:`ArrayAgg.ordering`.
|
||||
|
||||
Aggregate functions for statistics
|
||||
==================================
|
||||
|
||||
|
|
|
@ -70,7 +70,10 @@ Minor features
|
|||
:mod:`django.contrib.postgres`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
* ...
|
||||
* The new ``ordering`` argument for
|
||||
:class:`~django.contrib.postgres.aggregates.ArrayAgg` and
|
||||
:class:`~django.contrib.postgres.aggregates.StringAgg` determines the
|
||||
ordering of the aggregated elements.
|
||||
|
||||
:mod:`django.contrib.redirects`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
@ -22,21 +22,57 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
|||
def setUpTestData(cls):
|
||||
AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0)
|
||||
AggregateTestModel.objects.create(boolean_field=False, char_field='Foo2', integer_field=1)
|
||||
AggregateTestModel.objects.create(boolean_field=False, char_field='Foo3', integer_field=2)
|
||||
AggregateTestModel.objects.create(boolean_field=True, char_field='Foo4', integer_field=0)
|
||||
AggregateTestModel.objects.create(boolean_field=False, char_field='Foo4', integer_field=2)
|
||||
AggregateTestModel.objects.create(boolean_field=True, char_field='Foo3', integer_field=0)
|
||||
|
||||
def test_array_agg_charfield(self):
|
||||
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
|
||||
self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo3', 'Foo4']})
|
||||
self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
|
||||
|
||||
def test_array_agg_charfield_ordering(self):
|
||||
ordering_test_cases = (
|
||||
(F('char_field').desc(), ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
|
||||
(F('char_field').asc(), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
|
||||
(F('char_field'), ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
|
||||
([F('boolean_field'), F('char_field').desc()], ['Foo4', 'Foo2', 'Foo3', 'Foo1']),
|
||||
((F('boolean_field'), F('char_field').desc()), ['Foo4', 'Foo2', 'Foo3', 'Foo1']),
|
||||
('char_field', ['Foo1', 'Foo2', 'Foo3', 'Foo4']),
|
||||
('-char_field', ['Foo4', 'Foo3', 'Foo2', 'Foo1']),
|
||||
)
|
||||
for ordering, expected_output in ordering_test_cases:
|
||||
with self.subTest(ordering=ordering, expected_output=expected_output):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
arrayagg=ArrayAgg('char_field', ordering=ordering)
|
||||
)
|
||||
self.assertEqual(values, {'arrayagg': expected_output})
|
||||
|
||||
def test_array_agg_integerfield(self):
|
||||
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('integer_field'))
|
||||
self.assertEqual(values, {'arrayagg': [0, 1, 2, 0]})
|
||||
|
||||
def test_array_agg_integerfield_ordering(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
arrayagg=ArrayAgg('integer_field', ordering=F('integer_field').desc())
|
||||
)
|
||||
self.assertEqual(values, {'arrayagg': [2, 1, 0, 0]})
|
||||
|
||||
def test_array_agg_booleanfield(self):
|
||||
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field'))
|
||||
self.assertEqual(values, {'arrayagg': [True, False, False, True]})
|
||||
|
||||
def test_array_agg_booleanfield_ordering(self):
|
||||
ordering_test_cases = (
|
||||
(F('boolean_field').asc(), [False, False, True, True]),
|
||||
(F('boolean_field').desc(), [True, True, False, False]),
|
||||
(F('boolean_field'), [False, False, True, True]),
|
||||
)
|
||||
for ordering, expected_output in ordering_test_cases:
|
||||
with self.subTest(ordering=ordering, expected_output=expected_output):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
arrayagg=ArrayAgg('boolean_field', ordering=ordering)
|
||||
)
|
||||
self.assertEqual(values, {'arrayagg': expected_output})
|
||||
|
||||
def test_array_agg_empty_result(self):
|
||||
AggregateTestModel.objects.all().delete()
|
||||
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
|
||||
|
@ -122,17 +158,36 @@ class TestGeneralAggregate(PostgreSQLTestCase):
|
|||
|
||||
def test_string_agg_charfield(self):
|
||||
values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';'))
|
||||
self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo3;Foo4'})
|
||||
self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo4;Foo3'})
|
||||
|
||||
def test_string_agg_charfield_ordering(self):
|
||||
ordering_test_cases = (
|
||||
(F('char_field').desc(), 'Foo4;Foo3;Foo2;Foo1'),
|
||||
(F('char_field').asc(), 'Foo1;Foo2;Foo3;Foo4'),
|
||||
(F('char_field'), 'Foo1;Foo2;Foo3;Foo4'),
|
||||
)
|
||||
for ordering, expected_output in ordering_test_cases:
|
||||
with self.subTest(ordering=ordering, expected_output=expected_output):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
stringagg=StringAgg('char_field', delimiter=';', ordering=ordering)
|
||||
)
|
||||
self.assertEqual(values, {'stringagg': expected_output})
|
||||
|
||||
def test_string_agg_empty_result(self):
|
||||
AggregateTestModel.objects.all().delete()
|
||||
values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';'))
|
||||
self.assertEqual(values, {'stringagg': ''})
|
||||
|
||||
def test_orderable_agg_alternative_fields(self):
|
||||
values = AggregateTestModel.objects.aggregate(
|
||||
arrayagg=ArrayAgg('integer_field', ordering=F('char_field').asc())
|
||||
)
|
||||
self.assertEqual(values, {'arrayagg': [0, 1, 0, 2]})
|
||||
|
||||
@skipUnlessDBFeature('has_jsonb_agg')
|
||||
def test_json_agg(self):
|
||||
values = AggregateTestModel.objects.aggregate(jsonagg=JSONBAgg('char_field'))
|
||||
self.assertEqual(values, {'jsonagg': ['Foo1', 'Foo2', 'Foo3', 'Foo4']})
|
||||
self.assertEqual(values, {'jsonagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
|
||||
|
||||
@skipUnlessDBFeature('has_jsonb_agg')
|
||||
def test_json_agg_empty(self):
|
||||
|
|
Loading…
Reference in New Issue