2019-08-20 15:54:41 +08:00
|
|
|
from django.db.models import F, OrderBy
|
2016-07-05 17:47:24 +08:00
|
|
|
|
|
|
|
|
|
|
|
class OrderableAggMixin:
|
|
|
|
|
2020-01-01 01:46:06 +08:00
|
|
|
def __init__(self, *expressions, ordering=(), **extra):
|
2016-07-05 17:47:24 +08:00
|
|
|
if not isinstance(ordering, (list, tuple)):
|
|
|
|
ordering = [ordering]
|
|
|
|
ordering = ordering or []
|
|
|
|
# Transform minus sign prefixed strings into an OrderBy() expression.
|
|
|
|
ordering = (
|
|
|
|
(OrderBy(F(o[1:]), descending=True) if isinstance(o, str) and o[0] == '-' else o)
|
|
|
|
for o in ordering
|
|
|
|
)
|
2020-01-01 01:46:06 +08:00
|
|
|
super().__init__(*expressions, **extra)
|
2016-07-05 17:47:24 +08:00
|
|
|
self.ordering = self._parse_expressions(*ordering)
|
|
|
|
|
|
|
|
def resolve_expression(self, *args, **kwargs):
|
|
|
|
self.ordering = [expr.resolve_expression(*args, **kwargs) for expr in self.ordering]
|
|
|
|
return super().resolve_expression(*args, **kwargs)
|
|
|
|
|
|
|
|
def as_sql(self, compiler, connection):
|
|
|
|
if self.ordering:
|
2019-04-06 19:45:22 +08:00
|
|
|
ordering_params = []
|
|
|
|
ordering_expr_sql = []
|
|
|
|
for expr in self.ordering:
|
2020-10-12 19:02:12 +08:00
|
|
|
expr_sql, expr_params = compiler.compile(expr)
|
2019-04-06 19:45:22 +08:00
|
|
|
ordering_expr_sql.append(expr_sql)
|
|
|
|
ordering_params.extend(expr_params)
|
|
|
|
sql, sql_params = super().as_sql(compiler, connection, ordering=(
|
|
|
|
'ORDER BY ' + ', '.join(ordering_expr_sql)
|
2016-07-05 17:47:24 +08:00
|
|
|
))
|
2019-04-06 19:45:22 +08:00
|
|
|
return sql, sql_params + ordering_params
|
|
|
|
return super().as_sql(compiler, connection, ordering='')
|
2016-07-05 17:47:24 +08:00
|
|
|
|
2019-05-25 21:19:32 +08:00
|
|
|
def set_source_expressions(self, exprs):
|
|
|
|
# Extract the ordering expressions because ORDER BY clause is handled
|
|
|
|
# in a custom way.
|
|
|
|
self.ordering = exprs[self._get_ordering_expressions_index():]
|
|
|
|
return super().set_source_expressions(exprs[:self._get_ordering_expressions_index()])
|
|
|
|
|
2016-07-05 17:47:24 +08:00
|
|
|
def get_source_expressions(self):
|
2019-12-29 05:42:46 +08:00
|
|
|
return super().get_source_expressions() + self.ordering
|
2016-07-05 17:47:24 +08:00
|
|
|
|
|
|
|
def _get_ordering_expressions_index(self):
|
|
|
|
"""Return the index at which the ordering expressions start."""
|
|
|
|
source_expressions = self.get_source_expressions()
|
|
|
|
return len(source_expressions) - len(self.ordering)
|