From b181cae2e3697b2e53b5b67ac67e59f3b05a6f0d Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Sat, 12 Nov 2022 10:06:00 -0500 Subject: [PATCH] Refs #25307 -- Replaced SQLQuery.rewrite_cols() by replace_expressions(). The latter offers a more generic interface that doesn't require specialized expression types handling. --- django/db/models/sql/query.py | 87 ++++++++++------------------------- 1 file changed, 25 insertions(+), 62 deletions(-) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index fa72c124c72..2d150ed6d8a 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -381,60 +381,6 @@ class Query(BaseExpression): alias = None return target.get_col(alias, field) - def rewrite_cols(self, annotation, col_cnt): - # We must make sure the inner query has the referred columns in it. - # If we are aggregating over an annotation, then Django uses Ref() - # instances to note this. However, if we are annotating over a column - # of a related model, then it might be that column isn't part of the - # SELECT clause of the inner query, and we must manually make sure - # the column is selected. An example case is: - # .aggregate(Sum('author__awards')) - # Resolving this expression results in a join to author, but there - # is no guarantee the awards column of author is in the select clause - # of the query. Thus we must manually add the column to the inner - # query. - orig_exprs = annotation.get_source_expressions() - new_exprs = [] - for expr in orig_exprs: - # FIXME: These conditions are fairly arbitrary. Identify a better - # method of having expressions decide which code path they should - # take. - if isinstance(expr, Ref): - # Its already a Ref to subquery (see resolve_ref() for - # details) - new_exprs.append(expr) - elif isinstance(expr, (WhereNode, Lookup)): - # Decompose the subexpressions further. The code here is - # copied from the else clause, but this condition must appear - # before the contains_aggregate/is_summary condition below. - new_expr, col_cnt = self.rewrite_cols(expr, col_cnt) - new_exprs.append(new_expr) - else: - # Reuse aliases of expressions already selected in subquery. - for col_alias, selected_annotation in self.annotation_select.items(): - if selected_annotation is expr: - new_expr = Ref(col_alias, expr) - break - else: - # An expression that is not selected the subquery. - if isinstance(expr, Col) or ( - expr.contains_aggregate and not expr.is_summary - ): - # Reference column or another aggregate. Select it - # under a non-conflicting alias. - col_cnt += 1 - col_alias = "__col%d" % col_cnt - self.annotations[col_alias] = expr - self.append_annotation_mask([col_alias]) - new_expr = Ref(col_alias, expr) - else: - # Some other expression not referencing database values - # directly. Its subexpression might contain Cols. - new_expr, col_cnt = self.rewrite_cols(expr, col_cnt) - new_exprs.append(new_expr) - annotation.set_source_expressions(new_exprs) - return annotation, col_cnt - def get_aggregation(self, using, added_aggregate_names): """ Return the dictionary with the values of the existing aggregations. @@ -508,17 +454,31 @@ class Query(BaseExpression): annotation_mask |= inner_query.annotations[name].get_refs() inner_query.set_annotation_mask(annotation_mask) - relabels = {t: "subquery" for t in inner_query.alias_map} - relabels[None] = "subquery" - # Remove any aggregates marked for reduction from the subquery - # and move them to the outer AggregateQuery. - col_cnt = 0 + # Remove any aggregates marked for reduction from the subquery and + # move them to the outer AggregateQuery. This requires making sure + # all columns referenced by the aggregates are selected in the + # subquery. It is achieved by retrieving all column references from + # the aggregates, explicitly selecting them if they are not + # already, and making sure the aggregates are repointed to + # referenced to them. + col_refs = {} for alias, expression in list(inner_query.annotation_select.items()): if not expression.is_summary: continue annotation_select_mask = inner_query.annotation_select_mask - expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt) - outer_query.annotations[alias] = expression.relabeled_clone(relabels) + replacements = {} + for col in self._gen_cols([expression], resolve_refs=False): + if not (col_ref := col_refs.get(col)): + index = len(col_refs) + 1 + col_alias = f"__col{index}" + col_ref = Ref(col_alias, col) + col_refs[col] = col_ref + inner_query.annotations[col_alias] = col + inner_query.append_annotation_mask([col_alias]) + replacements[col] = col_ref + outer_query.annotations[alias] = expression.replace_expressions( + replacements + ) del inner_query.annotations[alias] annotation_select_mask.remove(alias) # Make sure the annotation_select wont use cached results. @@ -1923,7 +1883,7 @@ class Query(BaseExpression): return targets, joins[-1], joins @classmethod - def _gen_cols(cls, exprs, include_external=False): + def _gen_cols(cls, exprs, include_external=False, resolve_refs=True): for expr in exprs: if isinstance(expr, Col): yield expr @@ -1932,9 +1892,12 @@ class Query(BaseExpression): ): yield from expr.get_external_cols() elif hasattr(expr, "get_source_expressions"): + if not resolve_refs and isinstance(expr, Ref): + continue yield from cls._gen_cols( expr.get_source_expressions(), include_external=include_external, + resolve_refs=resolve_refs, ) @classmethod