Refs #30158 -- Added alias argument to Expression.get_group_by_cols().

This commit is contained in:
Simon Charette 2019-03-19 01:05:47 -04:00 committed by Tim Graham
parent 2aaabe2004
commit 9dc367dc10
9 changed files with 67 additions and 18 deletions

View File

@ -64,7 +64,7 @@ class Aggregate(Func):
return '%s__%s' % (expressions[0].name, self.name.lower()) return '%s__%s' % (expressions[0].name, self.name.lower())
raise TypeError("Complex expressions require an alias") raise TypeError("Complex expressions require an alias")
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
return [] return []
def as_sql(self, compiler, connection, **extra_context): def as_sql(self, compiler, connection, **extra_context):

View File

@ -332,7 +332,7 @@ class BaseExpression:
def copy(self): def copy(self):
return copy.copy(self) return copy.copy(self)
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
if not self.contains_aggregate: if not self.contains_aggregate:
return [self] return [self]
cols = [] cols = []
@ -669,7 +669,7 @@ class Value(Expression):
c.for_save = for_save c.for_save = for_save
return c return c
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
return [] return []
@ -694,7 +694,7 @@ class RawSQL(Expression):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return '(%s)' % self.sql, self.params return '(%s)' % self.sql, self.params
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
return [self] return [self]
@ -737,7 +737,7 @@ class Col(Expression):
def relabeled_clone(self, relabels): def relabeled_clone(self, relabels):
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field) return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
return [self] return [self]
def get_db_converters(self, connection): def get_db_converters(self, connection):
@ -769,7 +769,7 @@ class SimpleCol(Expression):
qn = compiler.quote_name_unless_alias qn = compiler.quote_name_unless_alias
return qn(self.target.column), [] return qn(self.target.column), []
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
return [self] return [self]
def get_db_converters(self, connection): def get_db_converters(self, connection):
@ -810,7 +810,7 @@ class Ref(Expression):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return connection.ops.quote_name(self.refs), [] return connection.ops.quote_name(self.refs), []
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
return [self] return [self]
@ -905,7 +905,7 @@ class When(Expression):
template = template or self.template template = template or self.template
return template % template_params, sql_params return template % template_params, sql_params
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
# This is not a complete expression and cannot be used in GROUP BY. # This is not a complete expression and cannot be used in GROUP BY.
cols = [] cols = []
for source in self.get_source_expressions(): for source in self.get_source_expressions():
@ -1171,7 +1171,7 @@ class OrderBy(BaseExpression):
template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s ' template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s '
return self.as_sql(compiler, connection, template=template) return self.as_sql(compiler, connection, template=template)
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
cols = [] cols = []
for source in self.get_source_expressions(): for source in self.get_source_expressions():
cols.extend(source.get_group_by_cols()) cols.extend(source.get_group_by_cols())
@ -1281,7 +1281,7 @@ class Window(Expression):
def __repr__(self): def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self) return '<%s: %s>' % (self.__class__.__name__, self)
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
return [] return []
@ -1317,7 +1317,7 @@ class WindowFrame(Expression):
def __repr__(self): def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self) return '<%s: %s>' % (self.__class__.__name__, self)
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
return [] return []
def __str__(self): def __str__(self):

View File

@ -104,7 +104,7 @@ class Lookup:
new.rhs = new.rhs.relabeled_clone(relabels) new.rhs = new.rhs.relabeled_clone(relabels)
return new return new
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
cols = self.lhs.get_group_by_cols() cols = self.lhs.get_group_by_cols()
if hasattr(self.rhs, 'get_group_by_cols'): if hasattr(self.rhs, 'get_group_by_cols'):
cols.extend(self.rhs.get_group_by_cols()) cols.extend(self.rhs.get_group_by_cols())

View File

@ -8,6 +8,8 @@ all about the internals of models in order to get the information it needs.
""" """
import difflib import difflib
import functools import functools
import inspect
import warnings
from collections import Counter, namedtuple from collections import Counter, namedtuple
from collections.abc import Iterator, Mapping from collections.abc import Iterator, Mapping
from itertools import chain, count, product from itertools import chain, count, product
@ -35,6 +37,7 @@ from django.db.models.sql.datastructures import (
from django.db.models.sql.where import ( from django.db.models.sql.where import (
AND, OR, ExtraWhere, NothingNode, WhereNode, AND, OR, ExtraWhere, NothingNode, WhereNode,
) )
from django.utils.deprecation import RemovedInDjango40Warning
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.tree import Node from django.utils.tree import Node
@ -1818,9 +1821,20 @@ class Query:
""" """
group_by = list(self.select) group_by = list(self.select)
if self.annotation_select: if self.annotation_select:
for annotation in self.annotation_select.values(): for alias, annotation in self.annotation_select.items():
for col in annotation.get_group_by_cols(): try:
group_by.append(col) inspect.getcallargs(annotation.get_group_by_cols, alias=alias)
except TypeError:
annotation_class = annotation.__class__
msg = (
'`alias=None` must be added to the signature of '
'%s.%s.get_group_by_cols().'
) % (annotation_class.__module__, annotation_class.__qualname__)
warnings.warn(msg, category=RemovedInDjango40Warning)
group_by_cols = annotation.get_group_by_cols()
else:
group_by_cols = annotation.get_group_by_cols(alias=alias)
group_by.extend(group_by_cols)
self.group_by = tuple(group_by) self.group_by = tuple(group_by)
def add_select_related(self, fields): def add_select_related(self, fields):

View File

@ -114,7 +114,7 @@ class WhereNode(tree.Node):
sql_string = '(%s)' % sql_string sql_string = '(%s)' % sql_string
return sql_string, result_params return sql_string, result_params
def get_group_by_cols(self): def get_group_by_cols(self, alias=None):
cols = [] cols = []
for child in self.children: for child in self.children:
cols.extend(child.get_group_by_cols()) cols.extend(child.get_group_by_cols())

View File

@ -27,6 +27,9 @@ details on these changes.
* ``django.views.i18n.set_language()`` will no longer set the user language in * ``django.views.i18n.set_language()`` will no longer set the user language in
``request.session`` (key ``django.utils.translation.LANGUAGE_SESSION_KEY``). ``request.session`` (key ``django.utils.translation.LANGUAGE_SESSION_KEY``).
* ``alias=None`` will be required in the signature of
``django.db.models.Expression.get_group_by_cols()`` subclasses.
.. _deprecation-removed-in-3.1: .. _deprecation-removed-in-3.1:
3.1 3.1

View File

@ -974,12 +974,17 @@ calling the appropriate methods on the wrapped expression.
A hook allowing the expression to coerce ``value`` into a more A hook allowing the expression to coerce ``value`` into a more
appropriate type. appropriate type.
.. method:: get_group_by_cols() .. method:: get_group_by_cols(alias=None)
Responsible for returning the list of columns references by Responsible for returning the list of columns references by
this expression. ``get_group_by_cols()`` should be called on any this expression. ``get_group_by_cols()`` should be called on any
nested expressions. ``F()`` objects, in particular, hold a reference nested expressions. ``F()`` objects, in particular, hold a reference
to a column. to a column. The ``alias`` parameter will be ``None`` unless the
expression has been annotated and is used for grouping.
.. versionchanged:: 3.0
The ``alias`` parameter was added.
.. method:: asc(nulls_first=False, nulls_last=False) .. method:: asc(nulls_first=False, nulls_last=False)

View File

@ -366,6 +366,9 @@ Miscellaneous
in the session in Django 4.0. Since Django 2.1, the language is always stored in the session in Django 4.0. Since Django 2.1, the language is always stored
in the :setting:`LANGUAGE_COOKIE_NAME` cookie. in the :setting:`LANGUAGE_COOKIE_NAME` cookie.
* ``alias=None`` is added to the signature of
:meth:`.Expression.get_group_by_cols`.
.. _removed-features-3.0: .. _removed-features-3.0:
Features removed in 3.0 Features removed in 3.0

View File

@ -0,0 +1,24 @@
from django.db.models import Count, Func
from django.test import SimpleTestCase
from django.utils.deprecation import RemovedInDjango40Warning
from .models import Employee
class MissingAliasFunc(Func):
template = '1'
def get_group_by_cols(self):
return []
class GetGroupByColsTest(SimpleTestCase):
def test_missing_alias(self):
msg = (
'`alias=None` must be added to the signature of '
'expressions.test_deprecation.MissingAliasFunc.get_group_by_cols().'
)
with self.assertRaisesMessage(RemovedInDjango40Warning, msg):
Employee.objects.values(
one=MissingAliasFunc(),
).annotate(cnt=Count('company_ceo_set'))