Fixed #31679 -- Delayed annotating aggregations.

By avoiding to annotate aggregations meant to be possibly pushed to an
outer query until their references are resolved it is possible to
aggregate over a query with the same alias.

Even if #34176 is a convoluted case to support, this refactor seems
worth it given the reduction in complexity it brings with regards to
annotation removal when performing a subquery pushdown.
This commit is contained in:
Simon Charette 2022-11-22 21:49:12 -05:00 committed by Mariusz Felisiak
parent d526d1569c
commit 1297c0d0d7
4 changed files with 51 additions and 65 deletions

View File

@ -23,7 +23,7 @@ from django.db import (
from django.db.models import AutoField, DateField, DateTimeField, Field, sql from django.db.models import AutoField, DateField, DateTimeField, Field, sql
from django.db.models.constants import LOOKUP_SEP, OnConflict from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector from django.db.models.deletion import Collector
from django.db.models.expressions import Case, F, Ref, Value, When from django.db.models.expressions import Case, F, Value, When
from django.db.models.functions import Cast, Trunc from django.db.models.functions import Cast, Trunc
from django.db.models.query_utils import FilteredRelation, Q from django.db.models.query_utils import FilteredRelation, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
@ -589,24 +589,7 @@ class QuerySet(AltersData):
raise TypeError("Complex aggregates require an alias") raise TypeError("Complex aggregates require an alias")
kwargs[arg.default_alias] = arg kwargs[arg.default_alias] = arg
query = self.query.chain() return self.query.chain().get_aggregation(self.db, kwargs)
for (alias, aggregate_expr) in kwargs.items():
query.add_annotation(aggregate_expr, alias, is_summary=True)
annotation = query.annotations[alias]
if not annotation.contains_aggregate:
raise TypeError("%s is not an aggregate expression" % alias)
for expr in annotation.get_source_expressions():
if (
expr.contains_aggregate
and isinstance(expr, Ref)
and expr.refs in kwargs
):
name = expr.refs
raise exceptions.FieldError(
"Cannot compute %s('%s'): '%s' is an aggregate"
% (annotation.name, name, name)
)
return query.get_aggregation(self.db, kwargs)
async def aaggregate(self, *args, **kwargs): async def aaggregate(self, *args, **kwargs):
return await sync_to_async(self.aggregate)(*args, **kwargs) return await sync_to_async(self.aggregate)(*args, **kwargs)
@ -1655,7 +1638,6 @@ class QuerySet(AltersData):
clone.query.add_annotation( clone.query.add_annotation(
annotation, annotation,
alias, alias,
is_summary=False,
select=select, select=select,
) )
for alias, annotation in clone.query.annotations.items(): for alias, annotation in clone.query.annotations.items():

View File

@ -381,24 +381,28 @@ class Query(BaseExpression):
alias = None alias = None
return target.get_col(alias, field) return target.get_col(alias, field)
def get_aggregation(self, using, added_aggregate_names): def get_aggregation(self, using, aggregate_exprs):
""" """
Return the dictionary with the values of the existing aggregations. Return the dictionary with the values of the existing aggregations.
""" """
if not self.annotation_select: if not aggregate_exprs:
return {} return {}
existing_annotations = { aggregates = {}
alias: annotation for alias, aggregate_expr in aggregate_exprs.items():
for alias, annotation in self.annotations.items() self.check_alias(alias)
if alias not in added_aggregate_names aggregate = aggregate_expr.resolve_expression(
} self, allow_joins=True, reuse=None, summarize=True
)
if not aggregate.contains_aggregate:
raise TypeError("%s is not an aggregate expression" % alias)
aggregates[alias] = aggregate
# Existing usage of aggregation can be determined by the presence of # Existing usage of aggregation can be determined by the presence of
# selected aggregates but also by filters against aliased aggregates. # selected aggregates but also by filters against aliased aggregates.
_, having, qualify = self.where.split_having_qualify() _, having, qualify = self.where.split_having_qualify()
has_existing_aggregation = ( has_existing_aggregation = (
any( any(
getattr(annotation, "contains_aggregate", True) getattr(annotation, "contains_aggregate", True)
for annotation in existing_annotations.values() for annotation in self.annotations.values()
) )
or having or having
) )
@ -449,25 +453,19 @@ class Query(BaseExpression):
# filtering against window functions is involved as it # filtering against window functions is involved as it
# requires complex realising. # requires complex realising.
annotation_mask = set() annotation_mask = set()
for name in added_aggregate_names: for aggregate in aggregates.values():
annotation_mask.add(name) annotation_mask |= aggregate.get_refs()
annotation_mask |= inner_query.annotations[name].get_refs()
inner_query.set_annotation_mask(annotation_mask) inner_query.set_annotation_mask(annotation_mask)
# Remove any aggregates marked for reduction from the subquery and # Add aggregates to the outer AggregateQuery. This requires making
# move them to the outer AggregateQuery. This requires making sure # sure all columns referenced by the aggregates are selected in the
# all columns referenced by the aggregates are selected in the # inner query. It is achieved by retrieving all column references
# subquery. It is achieved by retrieving all column references from # by the aggregates, explicitly selecting them in the inner query,
# the aggregates, explicitly selecting them if they are not # and making sure the aggregates are repointed to them.
# already, and making sure the aggregates are repointed to
# referenced to them.
col_refs = {} col_refs = {}
for alias, expression in list(inner_query.annotation_select.items()): for alias, aggregate in aggregates.items():
if not expression.is_summary:
continue
annotation_select_mask = inner_query.annotation_select_mask
replacements = {} replacements = {}
for col in self._gen_cols([expression], resolve_refs=False): for col in self._gen_cols([aggregate], resolve_refs=False):
if not (col_ref := col_refs.get(col)): if not (col_ref := col_refs.get(col)):
index = len(col_refs) + 1 index = len(col_refs) + 1
col_alias = f"__col{index}" col_alias = f"__col{index}"
@ -476,13 +474,9 @@ class Query(BaseExpression):
inner_query.annotations[col_alias] = col inner_query.annotations[col_alias] = col
inner_query.append_annotation_mask([col_alias]) inner_query.append_annotation_mask([col_alias])
replacements[col] = col_ref replacements[col] = col_ref
outer_query.annotations[alias] = expression.replace_expressions( outer_query.annotations[alias] = aggregate.replace_expressions(
replacements replacements
) )
del inner_query.annotations[alias]
annotation_select_mask.remove(alias)
# Make sure the annotation_select wont use cached results.
inner_query.set_annotation_mask(inner_query.annotation_select_mask)
if ( if (
inner_query.select == () inner_query.select == ()
and not inner_query.default_cols and not inner_query.default_cols
@ -499,19 +493,21 @@ class Query(BaseExpression):
self.select = () self.select = ()
self.default_cols = False self.default_cols = False
self.extra = {} self.extra = {}
if existing_annotations: if self.annotations:
# Inline reference to existing annotations and mask them as # Inline reference to existing annotations and mask them as
# they are unnecessary given only the summarized aggregations # they are unnecessary given only the summarized aggregations
# are requested. # are requested.
replacements = { replacements = {
Ref(alias, annotation): annotation Ref(alias, annotation): annotation
for alias, annotation in existing_annotations.items() for alias, annotation in self.annotations.items()
} }
for name in added_aggregate_names: self.annotations = {
self.annotations[name] = self.annotations[name].replace_expressions( alias: aggregate.replace_expressions(replacements)
replacements for alias, aggregate in aggregates.items()
) }
self.set_annotation_mask(added_aggregate_names) else:
self.annotations = aggregates
self.set_annotation_mask(aggregates)
empty_set_result = [ empty_set_result = [
expression.empty_result_set_value expression.empty_result_set_value
@ -537,8 +533,7 @@ class Query(BaseExpression):
Perform a COUNT() query using the current filter constraints. Perform a COUNT() query using the current filter constraints.
""" """
obj = self.clone() obj = self.clone()
obj.add_annotation(Count("*"), alias="__count", is_summary=True) return obj.get_aggregation(using, {"__count": Count("*")})["__count"]
return obj.get_aggregation(using, ["__count"])["__count"]
def has_filters(self): def has_filters(self):
return self.where return self.where
@ -1085,12 +1080,10 @@ class Query(BaseExpression):
"semicolons, or SQL comments." "semicolons, or SQL comments."
) )
def add_annotation(self, annotation, alias, is_summary=False, select=True): def add_annotation(self, annotation, alias, select=True):
"""Add a single annotation expression to the Query.""" """Add a single annotation expression to the Query."""
self.check_alias(alias) self.check_alias(alias)
annotation = annotation.resolve_expression( annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None)
self, allow_joins=True, reuse=None, summarize=is_summary
)
if select: if select:
self.append_annotation_mask([alias]) self.append_annotation_mask([alias])
else: else:

View File

@ -395,6 +395,9 @@ Miscellaneous
* The undocumented ``negated`` parameter of the * The undocumented ``negated`` parameter of the
:class:`~django.db.models.Exists` expression is removed. :class:`~django.db.models.Exists` expression is removed.
* The ``is_summary`` argument of the undocumented ``Query.add_annotation()``
method is removed.
.. _deprecated-features-4.2: .. _deprecated-features-4.2:
Features deprecated in 4.2 Features deprecated in 4.2

View File

@ -1258,11 +1258,11 @@ class AggregateTestCase(TestCase):
self.assertEqual(author.sum_age, other_author.sum_age) self.assertEqual(author.sum_age, other_author.sum_age)
def test_aggregate_over_aggregate(self): def test_aggregate_over_aggregate(self):
msg = "Cannot compute Avg('age'): 'age' is an aggregate" msg = "Cannot resolve keyword 'age_agg' into field."
with self.assertRaisesMessage(FieldError, msg): with self.assertRaisesMessage(FieldError, msg):
Author.objects.annotate(age_alias=F("age"),).aggregate( Author.objects.aggregate(
age=Sum(F("age")), age_agg=Sum(F("age")),
avg_age=Avg(F("age")), avg_age=Avg(F("age_agg")),
) )
def test_annotated_aggregate_over_annotated_aggregate(self): def test_annotated_aggregate_over_annotated_aggregate(self):
@ -2086,6 +2086,14 @@ class AggregateTestCase(TestCase):
) )
self.assertEqual(len(qs), 6) self.assertEqual(len(qs), 6)
def test_aggregation_over_annotation_shared_alias(self):
self.assertEqual(
Publisher.objects.annotate(agg=Count("book__authors"),).aggregate(
agg=Count("agg"),
),
{"agg": 5},
)
class AggregateAnnotationPruningTests(TestCase): class AggregateAnnotationPruningTests(TestCase):
@classmethod @classmethod