Fixed #31910 -- Fixed crash of GIS aggregations over subqueries.

Regression was introduced by fff5186 but was due a long standing issue.

AggregateQuery was abusing Query.subquery: bool by stashing its
compiled inner query's SQL for later use in its compiler which made
select_format checks for Query.subquery wrongly assume the provide
query was a subquery.

This patch prevents that from happening by using a dedicated
inner_query attribute which is compiled at a later time by
SQLAggregateCompiler.

Moving the inner query's compilation to SQLAggregateCompiler.compile
had the side effect of addressing a long standing issue with
aggregation subquery pushdown which prevented converters from being
run. This is now fixed as the aggregation_regress adjustments
demonstrate.

Refs #25367.

Thanks Eran Keydar for the report.
This commit is contained in:
Simon Charette 2020-11-03 16:50:10 -05:00 committed by Mariusz Felisiak
parent 789c47e6de
commit c2d4926702
5 changed files with 26 additions and 17 deletions

View File

@ -1596,8 +1596,11 @@ class SQLAggregateCompiler(SQLCompiler):
sql = ', '.join(sql) sql = ', '.join(sql)
params = tuple(params) params = tuple(params)
sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery) inner_query_sql, inner_query_params = self.query.inner_query.get_compiler(
params = params + self.query.sub_params self.using
).as_sql(with_col_aliases=True)
sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql)
params = params + inner_query_params
return sql, params return sql, params

View File

@ -17,9 +17,7 @@ from collections.abc import Iterator, Mapping
from itertools import chain, count, product from itertools import chain, count, product
from string import ascii_uppercase from string import ascii_uppercase
from django.core.exceptions import ( from django.core.exceptions import FieldDoesNotExist, FieldError
EmptyResultSet, FieldDoesNotExist, FieldError,
)
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
@ -449,8 +447,9 @@ class Query(BaseExpression):
if (isinstance(self.group_by, tuple) or self.is_sliced or existing_annotations or if (isinstance(self.group_by, tuple) or self.is_sliced or existing_annotations or
self.distinct or self.combinator): self.distinct or self.combinator):
from django.db.models.sql.subqueries import AggregateQuery from django.db.models.sql.subqueries import AggregateQuery
outer_query = AggregateQuery(self.model)
inner_query = self.clone() inner_query = self.clone()
inner_query.subquery = True
outer_query = AggregateQuery(self.model, inner_query)
inner_query.select_for_update = False inner_query.select_for_update = False
inner_query.select_related = False inner_query.select_related = False
inner_query.set_annotation_mask(self.annotation_select) inner_query.set_annotation_mask(self.annotation_select)
@ -492,13 +491,6 @@ class Query(BaseExpression):
# field selected in the inner query, yet we must use a subquery. # field selected in the inner query, yet we must use a subquery.
# So, make sure at least one field is selected. # So, make sure at least one field is selected.
inner_query.select = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),) inner_query.select = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),)
try:
outer_query.add_subquery(inner_query, using)
except EmptyResultSet:
return {
alias: None
for alias in outer_query.annotation_select
}
else: else:
outer_query = self outer_query = self
self.select = () self.select = ()

View File

@ -157,6 +157,6 @@ class AggregateQuery(Query):
compiler = 'SQLAggregateCompiler' compiler = 'SQLAggregateCompiler'
def add_subquery(self, query, using): def __init__(self, model, inner_query):
query.subquery = True self.inner_query = inner_query
self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True) super().__init__(model)

View File

@ -974,7 +974,7 @@ class AggregationTests(TestCase):
def test_empty_filter_aggregate(self): def test_empty_filter_aggregate(self):
self.assertEqual( self.assertEqual(
Author.objects.filter(id__in=[]).annotate(Count("friends")).aggregate(Count("pk")), Author.objects.filter(id__in=[]).annotate(Count("friends")).aggregate(Count("pk")),
{"pk__count": None} {"pk__count": 0}
) )
def test_none_call_before_aggregate(self): def test_none_call_before_aggregate(self):

View File

@ -12,6 +12,7 @@ from django.core.management import call_command
from django.db import DatabaseError, NotSupportedError, connection from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import F, OuterRef, Subquery from django.db.models import F, OuterRef, Subquery
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext
from ..utils import ( from ..utils import (
mariadb, mysql, oracle, postgis, skipUnlessGISLookup, spatialite, mariadb, mysql, oracle, postgis, skipUnlessGISLookup, spatialite,
@ -593,6 +594,19 @@ class GeoQuerySetTest(TestCase):
qs = City.objects.filter(name='NotACity') qs = City.objects.filter(name='NotACity')
self.assertIsNone(qs.aggregate(Union('point'))['point__union']) self.assertIsNone(qs.aggregate(Union('point'))['point__union'])
@skipUnlessDBFeature('supports_union_aggr')
def test_geoagg_subquery(self):
ks = State.objects.get(name='Kansas')
union = GEOSGeometry('MULTIPOINT(-95.235060 38.971823)')
# Use distinct() to force the usage of a subquery for aggregation.
with CaptureQueriesContext(connection) as ctx:
self.assertIs(union.equals(
City.objects.filter(point__within=ks.poly).distinct().aggregate(
Union('point'),
)['point__union'],
), True)
self.assertIn('subquery', ctx.captured_queries[0]['sql'])
@unittest.skipUnless( @unittest.skipUnless(
connection.vendor == 'oracle', connection.vendor == 'oracle',
'Oracle supports tolerance parameter.', 'Oracle supports tolerance parameter.',