Fixed #31679 -- Delayed annotating aggregations.

By avoiding to annotate aggregations meant to be possibly pushed to an
outer query until their references are resolved it is possible to
aggregate over a query with the same alias.

Even if #34176 is a convoluted case to support, this refactor seems
worth it given the reduction in complexity it brings with regards to
annotation removal when performing a subquery pushdown.
This commit is contained in:
Simon Charette 2022-11-22 21:49:12 -05:00 committed by Mariusz Felisiak
parent d526d1569c
commit 1297c0d0d7
4 changed files with 51 additions and 65 deletions

View File

@ -23,7 +23,7 @@ from django.db import (
from django.db.models import AutoField, DateField, DateTimeField, Field, sql
from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector
from django.db.models.expressions import Case, F, Ref, Value, When
from django.db.models.expressions import Case, F, Value, When
from django.db.models.functions import Cast, Trunc
from django.db.models.query_utils import FilteredRelation, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
@ -589,24 +589,7 @@ class QuerySet(AltersData):
raise TypeError("Complex aggregates require an alias")
kwargs[arg.default_alias] = arg
query = self.query.chain()
for (alias, aggregate_expr) in kwargs.items():
query.add_annotation(aggregate_expr, alias, is_summary=True)
annotation = query.annotations[alias]
if not annotation.contains_aggregate:
raise TypeError("%s is not an aggregate expression" % alias)
for expr in annotation.get_source_expressions():
if (
expr.contains_aggregate
and isinstance(expr, Ref)
and expr.refs in kwargs
):
name = expr.refs
raise exceptions.FieldError(
"Cannot compute %s('%s'): '%s' is an aggregate"
% (annotation.name, name, name)
)
return query.get_aggregation(self.db, kwargs)
return self.query.chain().get_aggregation(self.db, kwargs)
async def aaggregate(self, *args, **kwargs):
return await sync_to_async(self.aggregate)(*args, **kwargs)
@ -1655,7 +1638,6 @@ class QuerySet(AltersData):
clone.query.add_annotation(
annotation,
alias,
is_summary=False,
select=select,
)
for alias, annotation in clone.query.annotations.items():

View File

@ -381,24 +381,28 @@ class Query(BaseExpression):
alias = None
return target.get_col(alias, field)
def get_aggregation(self, using, added_aggregate_names):
def get_aggregation(self, using, aggregate_exprs):
"""
Return the dictionary with the values of the existing aggregations.
"""
if not self.annotation_select:
if not aggregate_exprs:
return {}
existing_annotations = {
alias: annotation
for alias, annotation in self.annotations.items()
if alias not in added_aggregate_names
}
aggregates = {}
for alias, aggregate_expr in aggregate_exprs.items():
self.check_alias(alias)
aggregate = aggregate_expr.resolve_expression(
self, allow_joins=True, reuse=None, summarize=True
)
if not aggregate.contains_aggregate:
raise TypeError("%s is not an aggregate expression" % alias)
aggregates[alias] = aggregate
# Existing usage of aggregation can be determined by the presence of
# selected aggregates but also by filters against aliased aggregates.
_, having, qualify = self.where.split_having_qualify()
has_existing_aggregation = (
any(
getattr(annotation, "contains_aggregate", True)
for annotation in existing_annotations.values()
for annotation in self.annotations.values()
)
or having
)
@ -449,25 +453,19 @@ class Query(BaseExpression):
# filtering against window functions is involved as it
# requires complex realising.
annotation_mask = set()
for name in added_aggregate_names:
annotation_mask.add(name)
annotation_mask |= inner_query.annotations[name].get_refs()
for aggregate in aggregates.values():
annotation_mask |= aggregate.get_refs()
inner_query.set_annotation_mask(annotation_mask)
# Remove any aggregates marked for reduction from the subquery and
# move them to the outer AggregateQuery. This requires making sure
# all columns referenced by the aggregates are selected in the
# subquery. It is achieved by retrieving all column references from
# the aggregates, explicitly selecting them if they are not
# already, and making sure the aggregates are repointed to
# referenced to them.
# Add aggregates to the outer AggregateQuery. This requires making
# sure all columns referenced by the aggregates are selected in the
# inner query. It is achieved by retrieving all column references
# by the aggregates, explicitly selecting them in the inner query,
# and making sure the aggregates are repointed to them.
col_refs = {}
for alias, expression in list(inner_query.annotation_select.items()):
if not expression.is_summary:
continue
annotation_select_mask = inner_query.annotation_select_mask
for alias, aggregate in aggregates.items():
replacements = {}
for col in self._gen_cols([expression], resolve_refs=False):
for col in self._gen_cols([aggregate], resolve_refs=False):
if not (col_ref := col_refs.get(col)):
index = len(col_refs) + 1
col_alias = f"__col{index}"
@ -476,13 +474,9 @@ class Query(BaseExpression):
inner_query.annotations[col_alias] = col
inner_query.append_annotation_mask([col_alias])
replacements[col] = col_ref
outer_query.annotations[alias] = expression.replace_expressions(
outer_query.annotations[alias] = aggregate.replace_expressions(
replacements
)
del inner_query.annotations[alias]
annotation_select_mask.remove(alias)
# Make sure the annotation_select wont use cached results.
inner_query.set_annotation_mask(inner_query.annotation_select_mask)
if (
inner_query.select == ()
and not inner_query.default_cols
@ -499,19 +493,21 @@ class Query(BaseExpression):
self.select = ()
self.default_cols = False
self.extra = {}
if existing_annotations:
if self.annotations:
# Inline reference to existing annotations and mask them as
# they are unnecessary given only the summarized aggregations
# are requested.
replacements = {
Ref(alias, annotation): annotation
for alias, annotation in existing_annotations.items()
for alias, annotation in self.annotations.items()
}
for name in added_aggregate_names:
self.annotations[name] = self.annotations[name].replace_expressions(
replacements
)
self.set_annotation_mask(added_aggregate_names)
self.annotations = {
alias: aggregate.replace_expressions(replacements)
for alias, aggregate in aggregates.items()
}
else:
self.annotations = aggregates
self.set_annotation_mask(aggregates)
empty_set_result = [
expression.empty_result_set_value
@ -537,8 +533,7 @@ class Query(BaseExpression):
Perform a COUNT() query using the current filter constraints.
"""
obj = self.clone()
obj.add_annotation(Count("*"), alias="__count", is_summary=True)
return obj.get_aggregation(using, ["__count"])["__count"]
return obj.get_aggregation(using, {"__count": Count("*")})["__count"]
def has_filters(self):
return self.where
@ -1085,12 +1080,10 @@ class Query(BaseExpression):
"semicolons, or SQL comments."
)
def add_annotation(self, annotation, alias, is_summary=False, select=True):
def add_annotation(self, annotation, alias, select=True):
"""Add a single annotation expression to the Query."""
self.check_alias(alias)
annotation = annotation.resolve_expression(
self, allow_joins=True, reuse=None, summarize=is_summary
)
annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None)
if select:
self.append_annotation_mask([alias])
else:

View File

@ -395,6 +395,9 @@ Miscellaneous
* The undocumented ``negated`` parameter of the
:class:`~django.db.models.Exists` expression is removed.
* The ``is_summary`` argument of the undocumented ``Query.add_annotation()``
method is removed.
.. _deprecated-features-4.2:
Features deprecated in 4.2

View File

@ -1258,11 +1258,11 @@ class AggregateTestCase(TestCase):
self.assertEqual(author.sum_age, other_author.sum_age)
def test_aggregate_over_aggregate(self):
msg = "Cannot compute Avg('age'): 'age' is an aggregate"
msg = "Cannot resolve keyword 'age_agg' into field."
with self.assertRaisesMessage(FieldError, msg):
Author.objects.annotate(age_alias=F("age"),).aggregate(
age=Sum(F("age")),
avg_age=Avg(F("age")),
Author.objects.aggregate(
age_agg=Sum(F("age")),
avg_age=Avg(F("age_agg")),
)
def test_annotated_aggregate_over_annotated_aggregate(self):
@ -2086,6 +2086,14 @@ class AggregateTestCase(TestCase):
)
self.assertEqual(len(qs), 6)
def test_aggregation_over_annotation_shared_alias(self):
self.assertEqual(
Publisher.objects.annotate(agg=Count("book__authors"),).aggregate(
agg=Count("agg"),
),
{"agg": 5},
)
class AggregateAnnotationPruningTests(TestCase):
@classmethod