diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 04e430a42e..3360f9c806 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1596,8 +1596,11 @@ class SQLAggregateCompiler(SQLCompiler): sql = ', '.join(sql) params = tuple(params) - sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery) - params = params + self.query.sub_params + inner_query_sql, inner_query_params = self.query.inner_query.get_compiler( + self.using + ).as_sql(with_col_aliases=True) + sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql) + params = params + inner_query_params return sql, params diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index d34b9da601..a7dadf5a40 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -17,9 +17,7 @@ from collections.abc import Iterator, Mapping from itertools import chain, count, product from string import ascii_uppercase -from django.core.exceptions import ( - EmptyResultSet, FieldDoesNotExist, FieldError, -) +from django.core.exceptions import FieldDoesNotExist, FieldError from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP @@ -449,8 +447,9 @@ class Query(BaseExpression): if (isinstance(self.group_by, tuple) or self.is_sliced or existing_annotations or self.distinct or self.combinator): from django.db.models.sql.subqueries import AggregateQuery - outer_query = AggregateQuery(self.model) inner_query = self.clone() + inner_query.subquery = True + outer_query = AggregateQuery(self.model, inner_query) inner_query.select_for_update = False inner_query.select_related = False inner_query.set_annotation_mask(self.annotation_select) @@ -492,13 +491,6 @@ class Query(BaseExpression): # field selected in the inner query, yet we must use a subquery. # So, make sure at least one field is selected. inner_query.select = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),) - try: - outer_query.add_subquery(inner_query, using) - except EmptyResultSet: - return { - alias: None - for alias in outer_query.annotation_select - } else: outer_query = self self.select = () diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 72b6712864..e83112b046 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -157,6 +157,6 @@ class AggregateQuery(Query): compiler = 'SQLAggregateCompiler' - def add_subquery(self, query, using): - query.subquery = True - self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True) + def __init__(self, model, inner_query): + self.inner_query = inner_query + super().__init__(model) diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py index 7604335257..48187ee00b 100644 --- a/tests/aggregation_regress/tests.py +++ b/tests/aggregation_regress/tests.py @@ -974,7 +974,7 @@ class AggregationTests(TestCase): def test_empty_filter_aggregate(self): self.assertEqual( Author.objects.filter(id__in=[]).annotate(Count("friends")).aggregate(Count("pk")), - {"pk__count": None} + {"pk__count": 0} ) def test_none_call_before_aggregate(self): diff --git a/tests/gis_tests/geoapp/tests.py b/tests/gis_tests/geoapp/tests.py index 98523c2add..8dfe3d02a1 100644 --- a/tests/gis_tests/geoapp/tests.py +++ b/tests/gis_tests/geoapp/tests.py @@ -12,6 +12,7 @@ from django.core.management import call_command from django.db import DatabaseError, NotSupportedError, connection from django.db.models import F, OuterRef, Subquery from django.test import TestCase, skipUnlessDBFeature +from django.test.utils import CaptureQueriesContext from ..utils import ( mariadb, mysql, oracle, postgis, skipUnlessGISLookup, spatialite, @@ -593,6 +594,19 @@ class GeoQuerySetTest(TestCase): qs = City.objects.filter(name='NotACity') self.assertIsNone(qs.aggregate(Union('point'))['point__union']) + @skipUnlessDBFeature('supports_union_aggr') + def test_geoagg_subquery(self): + ks = State.objects.get(name='Kansas') + union = GEOSGeometry('MULTIPOINT(-95.235060 38.971823)') + # Use distinct() to force the usage of a subquery for aggregation. + with CaptureQueriesContext(connection) as ctx: + self.assertIs(union.equals( + City.objects.filter(point__within=ks.poly).distinct().aggregate( + Union('point'), + )['point__union'], + ), True) + self.assertIn('subquery', ctx.captured_queries[0]['sql']) + @unittest.skipUnless( connection.vendor == 'oracle', 'Oracle supports tolerance parameter.',