Refs #33304 -- Enclosed aggregate ordering logic in an expression.

This greatly simplifies the implementation of contrib.postgres'
OrderableAggMixin and allows for reuse in Window expressions.
This commit is contained in:
Simon Charette 2021-11-23 00:34:48 -05:00 committed by Mariusz Felisiak
parent a17becf4c7
commit e06dc4571e
2 changed files with 36 additions and 37 deletions

View File

@ -1,48 +1,27 @@
from django.db.models import F, OrderBy from django.db.models.expressions import OrderByList
class OrderableAggMixin: class OrderableAggMixin:
def __init__(self, *expressions, ordering=(), **extra): def __init__(self, *expressions, ordering=(), **extra):
if not isinstance(ordering, (list, tuple)): if isinstance(ordering, (list, tuple)):
ordering = [ordering] self.order_by = OrderByList(*ordering)
ordering = ordering or [] else:
# Transform minus sign prefixed strings into an OrderBy() expression. self.order_by = OrderByList(ordering)
ordering = (
(OrderBy(F(o[1:]), descending=True) if isinstance(o, str) and o[0] == '-' else o)
for o in ordering
)
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
self.ordering = self._parse_expressions(*ordering)
def resolve_expression(self, *args, **kwargs): def resolve_expression(self, *args, **kwargs):
self.ordering = [expr.resolve_expression(*args, **kwargs) for expr in self.ordering] self.order_by = self.order_by.resolve_expression(*args, **kwargs)
return super().resolve_expression(*args, **kwargs) return super().resolve_expression(*args, **kwargs)
def as_sql(self, compiler, connection): def get_source_expressions(self):
if self.ordering: return super().get_source_expressions() + [self.order_by]
ordering_params = []
ordering_expr_sql = []
for expr in self.ordering:
expr_sql, expr_params = compiler.compile(expr)
ordering_expr_sql.append(expr_sql)
ordering_params.extend(expr_params)
sql, sql_params = super().as_sql(compiler, connection, ordering=(
'ORDER BY ' + ', '.join(ordering_expr_sql)
))
return sql, (*sql_params, *ordering_params)
return super().as_sql(compiler, connection, ordering='')
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
# Extract the ordering expressions because ORDER BY clause is handled *exprs, self.order_by = exprs
# in a custom way. return super().set_source_expressions(exprs)
self.ordering = exprs[self._get_ordering_expressions_index():]
return super().set_source_expressions(exprs[:self._get_ordering_expressions_index()])
def get_source_expressions(self): def as_sql(self, compiler, connection):
return super().get_source_expressions() + self.ordering order_by_sql, order_by_params = compiler.compile(self.order_by)
sql, sql_params = super().as_sql(compiler, connection, ordering=order_by_sql)
def _get_ordering_expressions_index(self): return sql, (*sql_params, *order_by_params)
"""Return the index at which the ordering expressions start."""
source_expressions = self.get_source_expressions()
return len(source_expressions) - len(self.ordering)

View File

@ -915,8 +915,8 @@ class Ref(Expression):
class ExpressionList(Func): class ExpressionList(Func):
""" """
An expression containing multiple expressions. Can be used to provide a An expression containing multiple expressions. Can be used to provide a
list of expressions as an argument to another expression, like an list of expressions as an argument to another expression, like a partition
ordering clause. clause.
""" """
template = '%(expressions)s' template = '%(expressions)s'
@ -933,6 +933,26 @@ class ExpressionList(Func):
return self.as_sql(compiler, connection, **extra_context) return self.as_sql(compiler, connection, **extra_context)
class OrderByList(Func):
template = 'ORDER BY %(expressions)s'
def __init__(self, *expressions, **extra):
expressions = (
(
OrderBy(F(expr[1:]), descending=True)
if isinstance(expr, str) and expr[0] == '-'
else expr
)
for expr in expressions
)
super().__init__(*expressions, **extra)
def as_sql(self, *args, **kwargs):
if not self.source_expressions:
return '', ()
return super().as_sql(*args, **kwargs)
class ExpressionWrapper(SQLiteNumericMixin, Expression): class ExpressionWrapper(SQLiteNumericMixin, Expression):
""" """
An expression that can wrap another expression so that it can provide An expression that can wrap another expression so that it can provide