From c6350d594c359151ee17b0c4f354bb44f28ff69e Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Wed, 28 Sep 2022 10:51:06 -0400 Subject: [PATCH] Refs #30158 -- Removed alias argument for Expression.get_group_by_cols(). Recent refactors allowed GROUP BY aliasing allowed for aliasing to be entirely handled by the sql.Query.set_group_by and compiler layers. --- django/db/models/aggregates.py | 2 +- django/db/models/expressions.py | 38 ++++++++++++++---------------- django/db/models/functions/math.py | 2 +- django/db/models/lookups.py | 2 +- django/db/models/sql/query.py | 19 ++++++++------- django/db/models/sql/where.py | 2 +- docs/ref/models/expressions.txt | 9 ++++--- docs/releases/4.2.txt | 2 ++ tests/expressions/tests.py | 4 ++-- 9 files changed, 43 insertions(+), 37 deletions(-) diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 2ffed7cd2c..ec21e5fd11 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -97,7 +97,7 @@ class Aggregate(Func): return "%s__%s" % (expressions[0].name, self.name.lower()) raise TypeError("Complex expressions require an alias") - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): return [] def as_sql(self, compiler, connection, **extra_context): diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 387a7aa357..6c3bc5c4de 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -417,10 +417,8 @@ class BaseExpression: ) return clone - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): if not self.contains_aggregate: - if alias: - return [Ref(alias, self)] return [self] cols = [] for source in self.get_source_expressions(): @@ -858,7 +856,7 @@ class ResolvedOuterRef(F): def relabeled_clone(self, relabels): return self - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): return [] @@ -1021,7 +1019,7 @@ class Value(SQLiteNumericMixin, Expression): c.for_save = for_save return c - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): return [] def _resolve_output_field(self): @@ -1066,7 +1064,7 @@ class RawSQL(Expression): def as_sql(self, compiler, connection): return "(%s)" % self.sql, self.params - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): return [self] def resolve_expression( @@ -1124,7 +1122,7 @@ class Col(Expression): relabels.get(self.alias, self.alias), self.target, self.output_field ) - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): return [self] def get_db_converters(self, connection): @@ -1167,7 +1165,7 @@ class Ref(Expression): def as_sql(self, compiler, connection): return connection.ops.quote_name(self.refs), [] - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): return [self] @@ -1238,14 +1236,14 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression): def get_source_expressions(self): return [self.expression] - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): if isinstance(self.expression, Expression): expression = self.expression.copy() expression.output_field = self.output_field - return expression.get_group_by_cols(alias=alias) + return expression.get_group_by_cols() # For non-expressions e.g. an SQL WHERE clause, the entire # `expression` must be included in the GROUP BY clause. - return super().get_group_by_cols(alias=alias) + return super().get_group_by_cols() def as_sql(self, compiler, connection): return compiler.compile(self.expression) @@ -1330,7 +1328,7 @@ class When(Expression): *result_params, ) - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): # This is not a complete expression and cannot be used in GROUP BY. cols = [] for source in self.get_source_expressions(): @@ -1426,10 +1424,10 @@ class Case(SQLiteNumericMixin, Expression): sql = connection.ops.unification_cast_sql(self.output_field) % sql return sql, sql_params - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): if not self.cases: - return self.default.get_group_by_cols(alias) - return super().get_group_by_cols(alias) + return self.default.get_group_by_cols() + return super().get_group_by_cols() class Subquery(BaseExpression, Combinable): @@ -1480,8 +1478,8 @@ class Subquery(BaseExpression, Combinable): sql = template % template_params return sql, sql_params - def get_group_by_cols(self, alias=None): - return self.query.get_group_by_cols(alias=alias, wrapper=self) + def get_group_by_cols(self): + return self.query.get_group_by_cols(wrapper=self) class Exists(Subquery): @@ -1602,7 +1600,7 @@ class OrderBy(Expression): return copy.as_sql(compiler, connection) return self.as_sql(compiler, connection) - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): cols = [] for source in self.get_source_expressions(): cols.extend(source.get_group_by_cols()) @@ -1732,7 +1730,7 @@ class Window(SQLiteNumericMixin, Expression): def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self) - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): group_by_cols = [] if self.partition_by: group_by_cols.extend(self.partition_by.get_group_by_cols()) @@ -1780,7 +1778,7 @@ class WindowFrame(Expression): def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self) - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): return [] def __str__(self): diff --git a/django/db/models/functions/math.py b/django/db/models/functions/math.py index 8b5fd79c3a..460143ba5a 100644 --- a/django/db/models/functions/math.py +++ b/django/db/models/functions/math.py @@ -169,7 +169,7 @@ class Random(NumericOutputFieldMixin, Func): def as_sqlite(self, compiler, connection, **extra_context): return super().as_sql(compiler, connection, function="RAND", **extra_context) - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): return [] diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 866e38df83..9b4bdb9bd6 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -128,7 +128,7 @@ class Lookup(Expression): def rhs_is_direct_value(self): return not hasattr(self.rhs, "as_sql") - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): cols = [] for source in self.get_source_expressions(): cols.extend(source.get_group_by_cols()) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 9d4ca0ffb6..c2a71ff589 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1147,13 +1147,11 @@ class Query(BaseExpression): if col.alias in self.external_aliases ] - def get_group_by_cols(self, alias=None, wrapper=None): + def get_group_by_cols(self, wrapper=None): # If wrapper is referenced by an alias for an explicit GROUP BY through # values() a reference to this expression and not the self must be # returned to ensure external column references are not grouped against # as well. - if alias: - return [Ref(alias, wrapper or self)] external_cols = self.get_external_cols() if any(col.possibly_multivalued for col in external_cols): return [wrapper or self] @@ -2247,11 +2245,16 @@ class Query(BaseExpression): self.select = tuple(values_select.values()) self.values_select = tuple(values_select) group_by = list(self.select) - if self.annotation_select: - for alias, annotation in self.annotation_select.items(): - if not allow_aliases or alias in column_names: - alias = None - group_by_cols = annotation.get_group_by_cols(alias=alias) + for alias, annotation in self.annotation_select.items(): + if not (group_by_cols := annotation.get_group_by_cols()): + continue + if ( + allow_aliases + and alias not in column_names + and not annotation.contains_aggregate + ): + group_by.append(Ref(alias, annotation)) + else: group_by.extend(group_by_cols) self.group_by = tuple(group_by) diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 5ce7e4e55c..63fdf58d9d 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -178,7 +178,7 @@ class WhereNode(tree.Node): sql_string = "(%s)" % sql_string return sql_string, result_params - def get_group_by_cols(self, alias=None): + def get_group_by_cols(self): cols = [] for child in self.children: cols.extend(child.get_group_by_cols()) diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 95f093e2a3..ccd18670b1 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -1036,13 +1036,16 @@ calling the appropriate methods on the wrapped expression. ``expression`` is the same as ``self``. - .. method:: get_group_by_cols(alias=None) + .. method:: get_group_by_cols() Responsible for returning the list of columns references by this expression. ``get_group_by_cols()`` should be called on any nested expressions. ``F()`` objects, in particular, hold a reference - to a column. The ``alias`` parameter will be ``None`` unless the - expression has been annotated and is used for grouping. + to a column. + + .. versionchanged:: 4.2 + + The ``alias=None`` keyword argument was removed. .. method:: asc(nulls_first=None, nulls_last=None) diff --git a/docs/releases/4.2.txt b/docs/releases/4.2.txt index 0c5492e125..5a849cbbe5 100644 --- a/docs/releases/4.2.txt +++ b/docs/releases/4.2.txt @@ -329,6 +329,8 @@ Miscellaneous * The :option:`makemigrations --check` option no longer creates missing migration files. +* The ``alias`` argument for :meth:`.Expression.get_group_by_cols` is removed. + .. _deprecated-features-4.2: Features deprecated in 4.2 diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index d63415ae72..7a759feda5 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -2525,13 +2525,13 @@ class CombinedExpressionTests(SimpleTestCase): class ExpressionWrapperTests(SimpleTestCase): def test_empty_group_by(self): expr = ExpressionWrapper(Value(3), output_field=IntegerField()) - self.assertEqual(expr.get_group_by_cols(alias=None), []) + self.assertEqual(expr.get_group_by_cols(), []) def test_non_empty_group_by(self): value = Value("f") value.output_field = None expr = ExpressionWrapper(Lower(value), output_field=IntegerField()) - group_by_cols = expr.get_group_by_cols(alias=None) + group_by_cols = expr.get_group_by_cols() self.assertEqual(group_by_cols, [expr.expression]) self.assertEqual(group_by_cols[0].output_field, expr.output_field)