Refs #28477 -- Reduced complexity of aggregation over qualify queries.

This commit is contained in:
Simon Charette 2022-11-09 21:55:47 -05:00 committed by Mariusz Felisiak
parent 99b4f90ec6
commit a9d2d8d1c3
2 changed files with 34 additions and 22 deletions

View File

@ -447,13 +447,15 @@ class Query(BaseExpression):
if alias not in added_aggregate_names if alias not in added_aggregate_names
} }
# Existing usage of aggregation can be determined by the presence of # Existing usage of aggregation can be determined by the presence of
# selected aggregate and window annotations but also by filters against # selected aggregates but also by filters against aliased aggregates.
# aliased aggregate and windows via HAVING / QUALIFY. _, having, qualify = self.where.split_having_qualify()
has_existing_aggregation = any( has_existing_aggregation = (
getattr(annotation, "contains_aggregate", True) any(
or getattr(annotation, "contains_over_clause", True) getattr(annotation, "contains_aggregate", True)
for annotation in existing_annotations.values() for annotation in existing_annotations.values()
) or any(self.where.split_having_qualify()[1:]) )
or having
)
# Decide if we need to use a subquery. # Decide if we need to use a subquery.
# #
# Existing aggregations would cause incorrect results as # Existing aggregations would cause incorrect results as
@ -468,6 +470,7 @@ class Query(BaseExpression):
isinstance(self.group_by, tuple) isinstance(self.group_by, tuple)
or self.is_sliced or self.is_sliced
or has_existing_aggregation or has_existing_aggregation
or qualify
or self.distinct or self.distinct
or self.combinator or self.combinator
): ):
@ -494,13 +497,16 @@ class Query(BaseExpression):
self.model._meta.pk.get_col(inner_query.get_initial_alias()), self.model._meta.pk.get_col(inner_query.get_initial_alias()),
) )
inner_query.default_cols = False inner_query.default_cols = False
# Mask existing annotations that are not referenced by if not qualify:
# aggregates to be pushed to the outer query. # Mask existing annotations that are not referenced by
annotation_mask = set() # aggregates to be pushed to the outer query unless
for name in added_aggregate_names: # filtering against window functions is involved as it
annotation_mask.add(name) # requires complex realising.
annotation_mask |= inner_query.annotations[name].get_refs() annotation_mask = set()
inner_query.set_annotation_mask(annotation_mask) for name in added_aggregate_names:
annotation_mask.add(name)
annotation_mask |= inner_query.annotations[name].get_refs()
inner_query.set_annotation_mask(annotation_mask)
relabels = {t: "subquery" for t in inner_query.alias_map} relabels = {t: "subquery" for t in inner_query.alias_map}
relabels[None] = "subquery" relabels[None] = "subquery"

View File

@ -42,6 +42,7 @@ from django.db.models.functions import (
) )
from django.db.models.lookups import Exact from django.db.models.lookups import Exact
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext
from .models import Classification, Detail, Employee, PastEmployeeDepartment from .models import Classification, Detail, Employee, PastEmployeeDepartment
@ -1157,16 +1158,21 @@ class WindowFunctionTests(TestCase):
) )
def test_filter_count(self): def test_filter_count(self):
self.assertEqual( with CaptureQueriesContext(connection) as ctx:
Employee.objects.annotate( self.assertEqual(
department_salary_rank=Window( Employee.objects.annotate(
Rank(), partition_by="department", order_by="-salary" department_salary_rank=Window(
Rank(), partition_by="department", order_by="-salary"
)
) )
.filter(department_salary_rank=1)
.count(),
5,
) )
.filter(department_salary_rank=1) self.assertEqual(len(ctx.captured_queries), 1)
.count(), sql = ctx.captured_queries[0]["sql"].lower()
5, self.assertEqual(sql.count("select"), 3)
) self.assertNotIn("group by", sql)
@skipUnlessDBFeature("supports_frame_range_fixed_distance") @skipUnlessDBFeature("supports_frame_range_fixed_distance")
def test_range_n_preceding_and_following(self): def test_range_n_preceding_and_following(self):