Refs #30158 -- Added alias argument to Expression.get_group_by_cols().
This commit is contained in:
parent
2aaabe2004
commit
9dc367dc10
|
@ -64,7 +64,7 @@ class Aggregate(Func):
|
||||||
return '%s__%s' % (expressions[0].name, self.name.lower())
|
return '%s__%s' % (expressions[0].name, self.name.lower())
|
||||||
raise TypeError("Complex expressions require an alias")
|
raise TypeError("Complex expressions require an alias")
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, **extra_context):
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
|
|
|
@ -332,7 +332,7 @@ class BaseExpression:
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return copy.copy(self)
|
return copy.copy(self)
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
if not self.contains_aggregate:
|
if not self.contains_aggregate:
|
||||||
return [self]
|
return [self]
|
||||||
cols = []
|
cols = []
|
||||||
|
@ -669,7 +669,7 @@ class Value(Expression):
|
||||||
c.for_save = for_save
|
c.for_save = for_save
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@ -694,7 +694,7 @@ class RawSQL(Expression):
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
return '(%s)' % self.sql, self.params
|
return '(%s)' % self.sql, self.params
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
return [self]
|
return [self]
|
||||||
|
|
||||||
|
|
||||||
|
@ -737,7 +737,7 @@ class Col(Expression):
|
||||||
def relabeled_clone(self, relabels):
|
def relabeled_clone(self, relabels):
|
||||||
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
|
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
return [self]
|
return [self]
|
||||||
|
|
||||||
def get_db_converters(self, connection):
|
def get_db_converters(self, connection):
|
||||||
|
@ -769,7 +769,7 @@ class SimpleCol(Expression):
|
||||||
qn = compiler.quote_name_unless_alias
|
qn = compiler.quote_name_unless_alias
|
||||||
return qn(self.target.column), []
|
return qn(self.target.column), []
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
return [self]
|
return [self]
|
||||||
|
|
||||||
def get_db_converters(self, connection):
|
def get_db_converters(self, connection):
|
||||||
|
@ -810,7 +810,7 @@ class Ref(Expression):
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
return connection.ops.quote_name(self.refs), []
|
return connection.ops.quote_name(self.refs), []
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
return [self]
|
return [self]
|
||||||
|
|
||||||
|
|
||||||
|
@ -905,7 +905,7 @@ class When(Expression):
|
||||||
template = template or self.template
|
template = template or self.template
|
||||||
return template % template_params, sql_params
|
return template % template_params, sql_params
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
# This is not a complete expression and cannot be used in GROUP BY.
|
# This is not a complete expression and cannot be used in GROUP BY.
|
||||||
cols = []
|
cols = []
|
||||||
for source in self.get_source_expressions():
|
for source in self.get_source_expressions():
|
||||||
|
@ -1171,7 +1171,7 @@ class OrderBy(BaseExpression):
|
||||||
template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s '
|
template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s '
|
||||||
return self.as_sql(compiler, connection, template=template)
|
return self.as_sql(compiler, connection, template=template)
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
cols = []
|
cols = []
|
||||||
for source in self.get_source_expressions():
|
for source in self.get_source_expressions():
|
||||||
cols.extend(source.get_group_by_cols())
|
cols.extend(source.get_group_by_cols())
|
||||||
|
@ -1281,7 +1281,7 @@ class Window(Expression):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<%s: %s>' % (self.__class__.__name__, self)
|
return '<%s: %s>' % (self.__class__.__name__, self)
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@ -1317,7 +1317,7 @@ class WindowFrame(Expression):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<%s: %s>' % (self.__class__.__name__, self)
|
return '<%s: %s>' % (self.__class__.__name__, self)
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|
|
@ -104,7 +104,7 @@ class Lookup:
|
||||||
new.rhs = new.rhs.relabeled_clone(relabels)
|
new.rhs = new.rhs.relabeled_clone(relabels)
|
||||||
return new
|
return new
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
cols = self.lhs.get_group_by_cols()
|
cols = self.lhs.get_group_by_cols()
|
||||||
if hasattr(self.rhs, 'get_group_by_cols'):
|
if hasattr(self.rhs, 'get_group_by_cols'):
|
||||||
cols.extend(self.rhs.get_group_by_cols())
|
cols.extend(self.rhs.get_group_by_cols())
|
||||||
|
|
|
@ -8,6 +8,8 @@ all about the internals of models in order to get the information it needs.
|
||||||
"""
|
"""
|
||||||
import difflib
|
import difflib
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
|
import warnings
|
||||||
from collections import Counter, namedtuple
|
from collections import Counter, namedtuple
|
||||||
from collections.abc import Iterator, Mapping
|
from collections.abc import Iterator, Mapping
|
||||||
from itertools import chain, count, product
|
from itertools import chain, count, product
|
||||||
|
@ -35,6 +37,7 @@ from django.db.models.sql.datastructures import (
|
||||||
from django.db.models.sql.where import (
|
from django.db.models.sql.where import (
|
||||||
AND, OR, ExtraWhere, NothingNode, WhereNode,
|
AND, OR, ExtraWhere, NothingNode, WhereNode,
|
||||||
)
|
)
|
||||||
|
from django.utils.deprecation import RemovedInDjango40Warning
|
||||||
from django.utils.functional import cached_property
|
from django.utils.functional import cached_property
|
||||||
from django.utils.tree import Node
|
from django.utils.tree import Node
|
||||||
|
|
||||||
|
@ -1818,9 +1821,20 @@ class Query:
|
||||||
"""
|
"""
|
||||||
group_by = list(self.select)
|
group_by = list(self.select)
|
||||||
if self.annotation_select:
|
if self.annotation_select:
|
||||||
for annotation in self.annotation_select.values():
|
for alias, annotation in self.annotation_select.items():
|
||||||
for col in annotation.get_group_by_cols():
|
try:
|
||||||
group_by.append(col)
|
inspect.getcallargs(annotation.get_group_by_cols, alias=alias)
|
||||||
|
except TypeError:
|
||||||
|
annotation_class = annotation.__class__
|
||||||
|
msg = (
|
||||||
|
'`alias=None` must be added to the signature of '
|
||||||
|
'%s.%s.get_group_by_cols().'
|
||||||
|
) % (annotation_class.__module__, annotation_class.__qualname__)
|
||||||
|
warnings.warn(msg, category=RemovedInDjango40Warning)
|
||||||
|
group_by_cols = annotation.get_group_by_cols()
|
||||||
|
else:
|
||||||
|
group_by_cols = annotation.get_group_by_cols(alias=alias)
|
||||||
|
group_by.extend(group_by_cols)
|
||||||
self.group_by = tuple(group_by)
|
self.group_by = tuple(group_by)
|
||||||
|
|
||||||
def add_select_related(self, fields):
|
def add_select_related(self, fields):
|
||||||
|
|
|
@ -114,7 +114,7 @@ class WhereNode(tree.Node):
|
||||||
sql_string = '(%s)' % sql_string
|
sql_string = '(%s)' % sql_string
|
||||||
return sql_string, result_params
|
return sql_string, result_params
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self, alias=None):
|
||||||
cols = []
|
cols = []
|
||||||
for child in self.children:
|
for child in self.children:
|
||||||
cols.extend(child.get_group_by_cols())
|
cols.extend(child.get_group_by_cols())
|
||||||
|
|
|
@ -27,6 +27,9 @@ details on these changes.
|
||||||
* ``django.views.i18n.set_language()`` will no longer set the user language in
|
* ``django.views.i18n.set_language()`` will no longer set the user language in
|
||||||
``request.session`` (key ``django.utils.translation.LANGUAGE_SESSION_KEY``).
|
``request.session`` (key ``django.utils.translation.LANGUAGE_SESSION_KEY``).
|
||||||
|
|
||||||
|
* ``alias=None`` will be required in the signature of
|
||||||
|
``django.db.models.Expression.get_group_by_cols()`` subclasses.
|
||||||
|
|
||||||
.. _deprecation-removed-in-3.1:
|
.. _deprecation-removed-in-3.1:
|
||||||
|
|
||||||
3.1
|
3.1
|
||||||
|
|
|
@ -974,12 +974,17 @@ calling the appropriate methods on the wrapped expression.
|
||||||
A hook allowing the expression to coerce ``value`` into a more
|
A hook allowing the expression to coerce ``value`` into a more
|
||||||
appropriate type.
|
appropriate type.
|
||||||
|
|
||||||
.. method:: get_group_by_cols()
|
.. method:: get_group_by_cols(alias=None)
|
||||||
|
|
||||||
Responsible for returning the list of columns references by
|
Responsible for returning the list of columns references by
|
||||||
this expression. ``get_group_by_cols()`` should be called on any
|
this expression. ``get_group_by_cols()`` should be called on any
|
||||||
nested expressions. ``F()`` objects, in particular, hold a reference
|
nested expressions. ``F()`` objects, in particular, hold a reference
|
||||||
to a column.
|
to a column. The ``alias`` parameter will be ``None`` unless the
|
||||||
|
expression has been annotated and is used for grouping.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.0
|
||||||
|
|
||||||
|
The ``alias`` parameter was added.
|
||||||
|
|
||||||
.. method:: asc(nulls_first=False, nulls_last=False)
|
.. method:: asc(nulls_first=False, nulls_last=False)
|
||||||
|
|
||||||
|
|
|
@ -366,6 +366,9 @@ Miscellaneous
|
||||||
in the session in Django 4.0. Since Django 2.1, the language is always stored
|
in the session in Django 4.0. Since Django 2.1, the language is always stored
|
||||||
in the :setting:`LANGUAGE_COOKIE_NAME` cookie.
|
in the :setting:`LANGUAGE_COOKIE_NAME` cookie.
|
||||||
|
|
||||||
|
* ``alias=None`` is added to the signature of
|
||||||
|
:meth:`.Expression.get_group_by_cols`.
|
||||||
|
|
||||||
.. _removed-features-3.0:
|
.. _removed-features-3.0:
|
||||||
|
|
||||||
Features removed in 3.0
|
Features removed in 3.0
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
from django.db.models import Count, Func
|
||||||
|
from django.test import SimpleTestCase
|
||||||
|
from django.utils.deprecation import RemovedInDjango40Warning
|
||||||
|
|
||||||
|
from .models import Employee
|
||||||
|
|
||||||
|
|
||||||
|
class MissingAliasFunc(Func):
|
||||||
|
template = '1'
|
||||||
|
|
||||||
|
def get_group_by_cols(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class GetGroupByColsTest(SimpleTestCase):
|
||||||
|
def test_missing_alias(self):
|
||||||
|
msg = (
|
||||||
|
'`alias=None` must be added to the signature of '
|
||||||
|
'expressions.test_deprecation.MissingAliasFunc.get_group_by_cols().'
|
||||||
|
)
|
||||||
|
with self.assertRaisesMessage(RemovedInDjango40Warning, msg):
|
||||||
|
Employee.objects.values(
|
||||||
|
one=MissingAliasFunc(),
|
||||||
|
).annotate(cnt=Count('company_ceo_set'))
|
Loading…
Reference in New Issue