Refs #30158 -- Added alias argument to Expression.get_group_by_cols().

This commit is contained in:
Simon Charette 2019-03-19 01:05:47 -04:00 committed by Tim Graham
parent 2aaabe2004
commit 9dc367dc10
9 changed files with 67 additions and 18 deletions

View File

@ -64,7 +64,7 @@ class Aggregate(Func):
return '%s__%s' % (expressions[0].name, self.name.lower())
raise TypeError("Complex expressions require an alias")
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
return []
def as_sql(self, compiler, connection, **extra_context):

View File

@ -332,7 +332,7 @@ class BaseExpression:
def copy(self):
return copy.copy(self)
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
if not self.contains_aggregate:
return [self]
cols = []
@ -669,7 +669,7 @@ class Value(Expression):
c.for_save = for_save
return c
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
return []
@ -694,7 +694,7 @@ class RawSQL(Expression):
def as_sql(self, compiler, connection):
return '(%s)' % self.sql, self.params
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
return [self]
@ -737,7 +737,7 @@ class Col(Expression):
def relabeled_clone(self, relabels):
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
return [self]
def get_db_converters(self, connection):
@ -769,7 +769,7 @@ class SimpleCol(Expression):
qn = compiler.quote_name_unless_alias
return qn(self.target.column), []
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
return [self]
def get_db_converters(self, connection):
@ -810,7 +810,7 @@ class Ref(Expression):
def as_sql(self, compiler, connection):
return connection.ops.quote_name(self.refs), []
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
return [self]
@ -905,7 +905,7 @@ class When(Expression):
template = template or self.template
return template % template_params, sql_params
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
# This is not a complete expression and cannot be used in GROUP BY.
cols = []
for source in self.get_source_expressions():
@ -1171,7 +1171,7 @@ class OrderBy(BaseExpression):
template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s '
return self.as_sql(compiler, connection, template=template)
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
cols = []
for source in self.get_source_expressions():
cols.extend(source.get_group_by_cols())
@ -1281,7 +1281,7 @@ class Window(Expression):
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self)
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
return []
@ -1317,7 +1317,7 @@ class WindowFrame(Expression):
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self)
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
return []
def __str__(self):

View File

@ -104,7 +104,7 @@ class Lookup:
new.rhs = new.rhs.relabeled_clone(relabels)
return new
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
cols = self.lhs.get_group_by_cols()
if hasattr(self.rhs, 'get_group_by_cols'):
cols.extend(self.rhs.get_group_by_cols())

View File

@ -8,6 +8,8 @@ all about the internals of models in order to get the information it needs.
"""
import difflib
import functools
import inspect
import warnings
from collections import Counter, namedtuple
from collections.abc import Iterator, Mapping
from itertools import chain, count, product
@ -35,6 +37,7 @@ from django.db.models.sql.datastructures import (
from django.db.models.sql.where import (
AND, OR, ExtraWhere, NothingNode, WhereNode,
)
from django.utils.deprecation import RemovedInDjango40Warning
from django.utils.functional import cached_property
from django.utils.tree import Node
@ -1818,9 +1821,20 @@ class Query:
"""
group_by = list(self.select)
if self.annotation_select:
for annotation in self.annotation_select.values():
for col in annotation.get_group_by_cols():
group_by.append(col)
for alias, annotation in self.annotation_select.items():
try:
inspect.getcallargs(annotation.get_group_by_cols, alias=alias)
except TypeError:
annotation_class = annotation.__class__
msg = (
'`alias=None` must be added to the signature of '
'%s.%s.get_group_by_cols().'
) % (annotation_class.__module__, annotation_class.__qualname__)
warnings.warn(msg, category=RemovedInDjango40Warning)
group_by_cols = annotation.get_group_by_cols()
else:
group_by_cols = annotation.get_group_by_cols(alias=alias)
group_by.extend(group_by_cols)
self.group_by = tuple(group_by)
def add_select_related(self, fields):

View File

@ -114,7 +114,7 @@ class WhereNode(tree.Node):
sql_string = '(%s)' % sql_string
return sql_string, result_params
def get_group_by_cols(self):
def get_group_by_cols(self, alias=None):
cols = []
for child in self.children:
cols.extend(child.get_group_by_cols())

View File

@ -27,6 +27,9 @@ details on these changes.
* ``django.views.i18n.set_language()`` will no longer set the user language in
``request.session`` (key ``django.utils.translation.LANGUAGE_SESSION_KEY``).
* ``alias=None`` will be required in the signature of
``django.db.models.Expression.get_group_by_cols()`` subclasses.
.. _deprecation-removed-in-3.1:
3.1

View File

@ -974,12 +974,17 @@ calling the appropriate methods on the wrapped expression.
A hook allowing the expression to coerce ``value`` into a more
appropriate type.
.. method:: get_group_by_cols()
.. method:: get_group_by_cols(alias=None)
Responsible for returning the list of columns references by
this expression. ``get_group_by_cols()`` should be called on any
nested expressions. ``F()`` objects, in particular, hold a reference
to a column.
to a column. The ``alias`` parameter will be ``None`` unless the
expression has been annotated and is used for grouping.
.. versionchanged:: 3.0
The ``alias`` parameter was added.
.. method:: asc(nulls_first=False, nulls_last=False)

View File

@ -366,6 +366,9 @@ Miscellaneous
in the session in Django 4.0. Since Django 2.1, the language is always stored
in the :setting:`LANGUAGE_COOKIE_NAME` cookie.
* ``alias=None`` is added to the signature of
:meth:`.Expression.get_group_by_cols`.
.. _removed-features-3.0:
Features removed in 3.0

View File

@ -0,0 +1,24 @@
from django.db.models import Count, Func
from django.test import SimpleTestCase
from django.utils.deprecation import RemovedInDjango40Warning
from .models import Employee
class MissingAliasFunc(Func):
template = '1'
def get_group_by_cols(self):
return []
class GetGroupByColsTest(SimpleTestCase):
def test_missing_alias(self):
msg = (
'`alias=None` must be added to the signature of '
'expressions.test_deprecation.MissingAliasFunc.get_group_by_cols().'
)
with self.assertRaisesMessage(RemovedInDjango40Warning, msg):
Employee.objects.values(
one=MissingAliasFunc(),
).annotate(cnt=Count('company_ceo_set'))