Fixed #32662 -- Refactored a generator out of SQLCompiler.get_order_by().

This also renames the `asc` variable to `default_order`, markes the
`desc` variable as unused, fixes a typo in SQLCompiler.get_order_by()
docstring, and reorders some blocks in SQLCompiler._order_by_pairs().
This commit is contained in:
Chris Jerdonek 2021-04-21 00:04:30 -07:00 committed by Mariusz Felisiak
parent 6d0cbe42c3
commit 0461b7a6b6
1 changed files with 46 additions and 37 deletions

View File

@ -269,15 +269,7 @@ class SQLCompiler:
ret.append((col, (sql, params), alias)) ret.append((col, (sql, params), alias))
return ret, klass_info, annotations return ret, klass_info, annotations
def get_order_by(self): def _order_by_pairs(self):
"""
Return a list of 2-tuples of form (expr, (sql, params, is_ref)) for the
ORDER BY clause.
The order_by clause can alter the select clause (for example it
can add aliases to clauses that do not yet have one, or it can
add totally new select clauses).
"""
if self.query.extra_order_by: if self.query.extra_order_by:
ordering = self.query.extra_order_by ordering = self.query.extra_order_by
elif not self.query.default_ordering: elif not self.query.default_ordering:
@ -290,11 +282,10 @@ class SQLCompiler:
else: else:
ordering = [] ordering = []
if self.query.standard_ordering: if self.query.standard_ordering:
asc, desc = ORDER_DIR['ASC'] default_order, _ = ORDER_DIR['ASC']
else: else:
asc, desc = ORDER_DIR['DESC'] default_order, _ = ORDER_DIR['DESC']
order_by = []
for field in ordering: for field in ordering:
if hasattr(field, 'resolve_expression'): if hasattr(field, 'resolve_expression'):
if isinstance(field, Value): if isinstance(field, Value):
@ -305,20 +296,24 @@ class SQLCompiler:
if not self.query.standard_ordering: if not self.query.standard_ordering:
field = field.copy() field = field.copy()
field.reverse_ordering() field.reverse_ordering()
order_by.append((field, False)) yield field, False
continue continue
if field == '?': # random if field == '?': # random
order_by.append((OrderBy(Random()), False)) yield OrderBy(Random()), False
continue continue
col, order = get_order_dir(field, asc) col, order = get_order_dir(field, default_order)
descending = order == 'DESC' descending = order == 'DESC'
if col in self.query.annotation_select: if col in self.query.annotation_select:
# Reference to expression in SELECT clause # Reference to expression in SELECT clause
order_by.append(( yield (
OrderBy(Ref(col, self.query.annotation_select[col]), descending=descending), OrderBy(
True)) Ref(col, self.query.annotation_select[col]),
descending=descending,
),
True,
)
continue continue
if col in self.query.annotations: if col in self.query.annotations:
# References to an expression which is masked out of the SELECT # References to an expression which is masked out of the SELECT
@ -332,44 +327,58 @@ class SQLCompiler:
if isinstance(expr, Value): if isinstance(expr, Value):
# output_field must be resolved for constants. # output_field must be resolved for constants.
expr = Cast(expr, expr.output_field) expr = Cast(expr, expr.output_field)
order_by.append((OrderBy(expr, descending=descending), False)) yield OrderBy(expr, descending=descending), False
continue continue
if '.' in field: if '.' in field:
# This came in through an extra(order_by=...) addition. Pass it # This came in through an extra(order_by=...) addition. Pass it
# on verbatim. # on verbatim.
table, col = col.split('.', 1) table, col = col.split('.', 1)
order_by.append(( yield (
OrderBy( OrderBy(
RawSQL('%s.%s' % (self.quote_name_unless_alias(table), col), []), RawSQL('%s.%s' % (self.quote_name_unless_alias(table), col), []),
descending=descending descending=descending,
), False)) ),
False,
)
continue continue
if not self.query.extra or col not in self.query.extra: if self.query.extra and col in self.query.extra:
if col in self.query.extra_select:
yield (
OrderBy(Ref(col, RawSQL(*self.query.extra[col])), descending=descending),
True,
)
else:
yield (
OrderBy(RawSQL(*self.query.extra[col]), descending=descending),
False,
)
else:
if self.query.combinator and self.select: if self.query.combinator and self.select:
# Don't use the first model's field because other # Don't use the first model's field because other
# combinated queries might define it differently. # combinated queries might define it differently.
order_by.append((OrderBy(F(col), descending=descending), False)) yield OrderBy(F(col), descending=descending), False
else: else:
# 'col' is of the form 'field' or 'field1__field2' or # 'col' is of the form 'field' or 'field1__field2' or
# '-field1__field2__field', etc. # '-field1__field2__field', etc.
order_by.extend(self.find_ordering_name( yield from self.find_ordering_name(
field, self.query.get_meta(), default_order=asc, field, self.query.get_meta(), default_order=default_order,
)) )
else:
if col not in self.query.extra_select: def get_order_by(self):
order_by.append(( """
OrderBy(RawSQL(*self.query.extra[col]), descending=descending), Return a list of 2-tuples of the form (expr, (sql, params, is_ref)) for
False)) the ORDER BY clause.
else:
order_by.append(( The order_by clause can alter the select clause (for example it can add
OrderBy(Ref(col, RawSQL(*self.query.extra[col])), descending=descending), aliases to clauses that do not yet have one, or it can add totally new
True)) select clauses).
"""
result = [] result = []
seen = set() seen = set()
for expr, is_ref in order_by: for expr, is_ref in self._order_by_pairs():
resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None) resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None)
if self.query.combinator and self.select: if self.query.combinator and self.select:
src = resolved.get_source_expressions()[0] src = resolved.get_source_expressions()[0]