Refs #35339 -- Updated Aggregate class to return consistent source expressions.

Refactored the filter and order_by expressions in the Aggregate class to
return a list of Expression (or None) values, ensuring that the list
item is always available and represents the filter expression.
For the PostgreSQL OrderableAggMixin, the returned list will always
include the filter and the order_by value as the last two elements.

Lastly, emtpy Q objects passed directly into aggregate objects using
Aggregate.filter in admin facets are filtered out when resolving the
expression to avoid errors in get_refs().

Thanks Simon Charette for the review.
This commit is contained in:
Chris Muthig 2024-04-03 16:06:39 -06:00 committed by nessita
parent ec8552417d
commit 42b567ab4c
3 changed files with 12 additions and 13 deletions

View File

@ -17,13 +17,10 @@ class OrderableAggMixin:
return super().resolve_expression(*args, **kwargs)
def get_source_expressions(self):
if self.order_by is not None:
return super().get_source_expressions() + [self.order_by]
return super().get_source_expressions()
return super().get_source_expressions() + [self.order_by]
def set_source_expressions(self, exprs):
if isinstance(exprs[-1], OrderByList):
*exprs, self.order_by = exprs
*exprs, self.order_by = exprs
return super().set_source_expressions(exprs)
def as_sql(self, compiler, connection):

View File

@ -50,12 +50,10 @@ class Aggregate(Func):
def get_source_expressions(self):
source_expressions = super().get_source_expressions()
if self.filter:
return source_expressions + [self.filter]
return source_expressions
return source_expressions + [self.filter]
def set_source_expressions(self, exprs):
self.filter = self.filter and exprs.pop()
*exprs, self.filter = exprs
return super().set_source_expressions(exprs)
def resolve_expression(
@ -63,8 +61,10 @@ class Aggregate(Func):
):
# Aggregates are not allowed in UPDATE queries, so ignore for_save
c = super().resolve_expression(query, allow_joins, reuse, summarize)
c.filter = c.filter and c.filter.resolve_expression(
query, allow_joins, reuse, summarize
c.filter = (
c.filter.resolve_expression(query, allow_joins, reuse, summarize)
if c.filter
else None
)
if summarize:
# Summarized aggregates cannot refer to summarized aggregates.
@ -104,7 +104,9 @@ class Aggregate(Func):
@property
def default_alias(self):
expressions = self.get_source_expressions()
expressions = [
expr for expr in self.get_source_expressions() if expr is not None
]
if len(expressions) == 1 and hasattr(expressions[0], "name"):
return "%s__%s" % (expressions[0].name, self.name.lower())
raise TypeError("Complex expressions require an alias")

View File

@ -1291,7 +1291,7 @@ class AggregateTestCase(TestCase):
def as_sql(self, compiler, connection):
copy = self.copy()
copy.set_source_expressions(copy.get_source_expressions()[0:1])
copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None])
return super(MyMax, copy).as_sql(compiler, connection)
with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):