Fixed #35339 -- Fixed PostgreSQL aggregate's filter and order_by params order.

Updated OrderableAggMixin.as_sql() to separate the order_by parameters
from the filter parameters. Previously, the parameters and SQL were
calculated by the Aggregate parent class, resulting in a mixture of
order_by and filter parameters.

Thanks Simon Charette for the review.
This commit is contained in:
Chris Muthig 2024-04-03 16:09:44 -06:00 committed by nessita
parent 42b567ab4c
commit c8df2f9941
2 changed files with 32 additions and 7 deletions

View File

@ -1,3 +1,4 @@
from django.core.exceptions import FullResultSet
from django.db.models.expressions import OrderByList
@ -24,9 +25,23 @@ class OrderableAggMixin:
return super().set_source_expressions(exprs)
def as_sql(self, compiler, connection):
if self.order_by is not None:
order_by_sql, order_by_params = compiler.compile(self.order_by)
else:
order_by_sql, order_by_params = "", ()
sql, sql_params = super().as_sql(compiler, connection, ordering=order_by_sql)
return sql, (*sql_params, *order_by_params)
*source_exprs, filtering_expr, ordering_expr = self.get_source_expressions()
order_by_sql = ""
order_by_params = []
if ordering_expr is not None:
order_by_sql, order_by_params = compiler.compile(ordering_expr)
filter_params = []
if filtering_expr is not None:
try:
_, filter_params = compiler.compile(filtering_expr)
except FullResultSet:
pass
source_params = []
for source_expr in source_exprs:
source_params += compiler.compile(source_expr)[1]
sql, _ = super().as_sql(compiler, connection, ordering=order_by_sql)
return sql, (*source_params, *order_by_params, *filter_params)

View File

@ -12,7 +12,7 @@ from django.db.models import (
Window,
)
from django.db.models.fields.json import KeyTextTransform, KeyTransform
from django.db.models.functions import Cast, Concat, Substr
from django.db.models.functions import Cast, Concat, LPad, Substr
from django.test import skipUnlessDBFeature
from django.test.utils import Approximate
from django.utils import timezone
@ -238,6 +238,16 @@ class TestGeneralAggregate(PostgreSQLTestCase):
)
self.assertEqual(values, {"arrayagg": ["en", "pl"]})
def test_array_agg_filter_and_ordering_params(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg(
"char_field",
filter=Q(json_field__has_key="lang"),
ordering=LPad(Cast("integer_field", CharField()), 2, Value("0")),
)
)
self.assertEqual(values, {"arrayagg": ["Foo2", "Foo4"]})
def test_array_agg_filter(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg("integer_field", filter=Q(integer_field__gt=0)),