diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index ea88c54b0d..da4ff928aa 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -64,7 +64,7 @@ class Aggregate(Func): return '%s__%s' % (expressions[0].name, self.name.lower()) raise TypeError("Complex expressions require an alias") - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): return [] def as_sql(self, compiler, connection, **extra_context): diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index ccb67876e2..30c2f465f0 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -332,7 +332,7 @@ class BaseExpression: def copy(self): return copy.copy(self) - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): if not self.contains_aggregate: return [self] cols = [] @@ -669,7 +669,7 @@ class Value(Expression): c.for_save = for_save return c - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): return [] @@ -694,7 +694,7 @@ class RawSQL(Expression): def as_sql(self, compiler, connection): return '(%s)' % self.sql, self.params - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): return [self] @@ -737,7 +737,7 @@ class Col(Expression): def relabeled_clone(self, relabels): return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field) - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): return [self] def get_db_converters(self, connection): @@ -769,7 +769,7 @@ class SimpleCol(Expression): qn = compiler.quote_name_unless_alias return qn(self.target.column), [] - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): return [self] def get_db_converters(self, connection): @@ -810,7 +810,7 @@ class Ref(Expression): def as_sql(self, compiler, connection): return connection.ops.quote_name(self.refs), [] - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): return [self] @@ -905,7 +905,7 @@ class When(Expression): template = template or self.template return template % template_params, sql_params - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): # This is not a complete expression and cannot be used in GROUP BY. cols = [] for source in self.get_source_expressions(): @@ -1171,7 +1171,7 @@ class OrderBy(BaseExpression): template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s ' return self.as_sql(compiler, connection, template=template) - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): cols = [] for source in self.get_source_expressions(): cols.extend(source.get_group_by_cols()) @@ -1281,7 +1281,7 @@ class Window(Expression): def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self) - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): return [] @@ -1317,7 +1317,7 @@ class WindowFrame(Expression): def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self) - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): return [] def __str__(self): diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index fa4561dabf..e78ffdf390 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -104,7 +104,7 @@ class Lookup: new.rhs = new.rhs.relabeled_clone(relabels) return new - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): cols = self.lhs.get_group_by_cols() if hasattr(self.rhs, 'get_group_by_cols'): cols.extend(self.rhs.get_group_by_cols()) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 2e4c6c35af..ba4baca2b8 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -8,6 +8,8 @@ all about the internals of models in order to get the information it needs. """ import difflib import functools +import inspect +import warnings from collections import Counter, namedtuple from collections.abc import Iterator, Mapping from itertools import chain, count, product @@ -35,6 +37,7 @@ from django.db.models.sql.datastructures import ( from django.db.models.sql.where import ( AND, OR, ExtraWhere, NothingNode, WhereNode, ) +from django.utils.deprecation import RemovedInDjango40Warning from django.utils.functional import cached_property from django.utils.tree import Node @@ -1818,9 +1821,20 @@ class Query: """ group_by = list(self.select) if self.annotation_select: - for annotation in self.annotation_select.values(): - for col in annotation.get_group_by_cols(): - group_by.append(col) + for alias, annotation in self.annotation_select.items(): + try: + inspect.getcallargs(annotation.get_group_by_cols, alias=alias) + except TypeError: + annotation_class = annotation.__class__ + msg = ( + '`alias=None` must be added to the signature of ' + '%s.%s.get_group_by_cols().' + ) % (annotation_class.__module__, annotation_class.__qualname__) + warnings.warn(msg, category=RemovedInDjango40Warning) + group_by_cols = annotation.get_group_by_cols() + else: + group_by_cols = annotation.get_group_by_cols(alias=alias) + group_by.extend(group_by_cols) self.group_by = tuple(group_by) def add_select_related(self, fields): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 879de0474a..9d3d6a9366 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -114,7 +114,7 @@ class WhereNode(tree.Node): sql_string = '(%s)' % sql_string return sql_string, result_params - def get_group_by_cols(self): + def get_group_by_cols(self, alias=None): cols = [] for child in self.children: cols.extend(child.get_group_by_cols()) diff --git a/docs/internals/deprecation.txt b/docs/internals/deprecation.txt index 02935c01a2..93b6a04594 100644 --- a/docs/internals/deprecation.txt +++ b/docs/internals/deprecation.txt @@ -27,6 +27,9 @@ details on these changes. * ``django.views.i18n.set_language()`` will no longer set the user language in ``request.session`` (key ``django.utils.translation.LANGUAGE_SESSION_KEY``). +* ``alias=None`` will be required in the signature of + ``django.db.models.Expression.get_group_by_cols()`` subclasses. + .. _deprecation-removed-in-3.1: 3.1 diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index bfd5ddea2c..4dc8d9476a 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -974,12 +974,17 @@ calling the appropriate methods on the wrapped expression. A hook allowing the expression to coerce ``value`` into a more appropriate type. - .. method:: get_group_by_cols() + .. method:: get_group_by_cols(alias=None) Responsible for returning the list of columns references by this expression. ``get_group_by_cols()`` should be called on any nested expressions. ``F()`` objects, in particular, hold a reference - to a column. + to a column. The ``alias`` parameter will be ``None`` unless the + expression has been annotated and is used for grouping. + + .. versionchanged:: 3.0 + + The ``alias`` parameter was added. .. method:: asc(nulls_first=False, nulls_last=False) diff --git a/docs/releases/3.0.txt b/docs/releases/3.0.txt index cf50eea1b5..72ecdd3c74 100644 --- a/docs/releases/3.0.txt +++ b/docs/releases/3.0.txt @@ -366,6 +366,9 @@ Miscellaneous in the session in Django 4.0. Since Django 2.1, the language is always stored in the :setting:`LANGUAGE_COOKIE_NAME` cookie. +* ``alias=None`` is added to the signature of + :meth:`.Expression.get_group_by_cols`. + .. _removed-features-3.0: Features removed in 3.0 diff --git a/tests/expressions/test_deprecation.py b/tests/expressions/test_deprecation.py new file mode 100644 index 0000000000..cdb1e43af6 --- /dev/null +++ b/tests/expressions/test_deprecation.py @@ -0,0 +1,24 @@ +from django.db.models import Count, Func +from django.test import SimpleTestCase +from django.utils.deprecation import RemovedInDjango40Warning + +from .models import Employee + + +class MissingAliasFunc(Func): + template = '1' + + def get_group_by_cols(self): + return [] + + +class GetGroupByColsTest(SimpleTestCase): + def test_missing_alias(self): + msg = ( + '`alias=None` must be added to the signature of ' + 'expressions.test_deprecation.MissingAliasFunc.get_group_by_cols().' + ) + with self.assertRaisesMessage(RemovedInDjango40Warning, msg): + Employee.objects.values( + one=MissingAliasFunc(), + ).annotate(cnt=Count('company_ceo_set'))