From f3112fde981052801e0c2121027424496c03efdf Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Fri, 21 May 2021 21:48:46 -0400 Subject: [PATCH] Fixed #26430 -- Fixed coalesced aggregation of empty result sets. Disable the EmptyResultSet optimization when performing aggregation as it might interfere with coalescence. --- django/db/models/sql/compiler.py | 17 +++++++++++++---- django/db/models/sql/query.py | 8 +++----- tests/aggregation/tests.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 6082017bbd..136732bd0b 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -26,10 +26,13 @@ class SQLCompiler: re.MULTILINE | re.DOTALL, ) - def __init__(self, query, connection, using): + def __init__(self, query, connection, using, elide_empty=True): self.query = query self.connection = connection self.using = using + # Some queries, e.g. coalesced aggregation, need to be executed even if + # they would return an empty result set. + self.elide_empty = elide_empty self.quote_cache = {'*': '*'} # The select, klass_info, and annotations are needed by QuerySet.iterator() # these are set as a side-effect of executing the query. Note that we calculate @@ -458,7 +461,7 @@ class SQLCompiler: def get_combinator_sql(self, combinator, all): features = self.connection.features compilers = [ - query.get_compiler(self.using, self.connection) + query.get_compiler(self.using, self.connection, self.elide_empty) for query in self.query.combined_queries if not query.is_empty() ] if not features.supports_slicing_ordering_in_compound: @@ -535,7 +538,13 @@ class SQLCompiler: # This must come after 'select', 'ordering', and 'distinct' # (see docstring of get_from_clause() for details). from_, f_params = self.get_from_clause() - where, w_params = self.compile(self.where) if self.where is not None else ("", []) + try: + where, w_params = self.compile(self.where) if self.where is not None else ('', []) + except EmptyResultSet: + if self.elide_empty: + raise + # Use a predicate that's always False. + where, w_params = '0 = 1', [] having, h_params = self.compile(self.having) if self.having is not None else ("", []) result = ['SELECT'] params = [] @@ -1652,7 +1661,7 @@ class SQLAggregateCompiler(SQLCompiler): params = tuple(params) inner_query_sql, inner_query_params = self.query.inner_query.get_compiler( - self.using + self.using, elide_empty=self.elide_empty, ).as_sql(with_col_aliases=True) sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql) params = params + inner_query_params diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 814271a1f6..fabb346418 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -273,12 +273,12 @@ class Query(BaseExpression): memo[id(self)] = result return result - def get_compiler(self, using=None, connection=None): + def get_compiler(self, using=None, connection=None, elide_empty=True): if using is None and connection is None: raise ValueError("Need either using or connection") if using: connection = connections[using] - return connection.ops.compiler(self.compiler)(self, connection, using) + return connection.ops.compiler(self.compiler)(self, connection, using, elide_empty) def get_meta(self): """ @@ -494,10 +494,8 @@ class Query(BaseExpression): outer_query.clear_limits() outer_query.select_for_update = False outer_query.select_related = False - compiler = outer_query.get_compiler(using) + compiler = outer_query.get_compiler(using, elide_empty=False) result = compiler.execute_sql(SINGLE) - if result is None: - result = [None] * len(outer_query.annotation_select) converters = compiler.get_converters(outer_query.annotation_select.values()) result = next(compiler.apply_converters((result,), converters)) diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index 49123396dd..db24f36a79 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -8,6 +8,7 @@ from django.db.models import ( Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField, IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When, ) +from django.db.models.expressions import RawSQL from django.db.models.functions import Coalesce, Greatest from django.test import TestCase from django.test.testcases import skipUnlessDBFeature @@ -1340,3 +1341,34 @@ class AggregateTestCase(TestCase): ('Stuart Russell', 1), ('Peter Norvig', 2), ], lambda a: (a.name, a.contact_count), ordered=False) + + def test_coalesced_empty_result_set(self): + self.assertEqual( + Publisher.objects.none().aggregate( + sum_awards=Coalesce(Sum('num_awards'), 0), + )['sum_awards'], + 0, + ) + # Multiple expressions. + self.assertEqual( + Publisher.objects.none().aggregate( + sum_awards=Coalesce(Sum('num_awards'), None, 0), + )['sum_awards'], + 0, + ) + # Nested coalesce. + self.assertEqual( + Publisher.objects.none().aggregate( + sum_awards=Coalesce(Coalesce(Sum('num_awards'), None), 0), + )['sum_awards'], + 0, + ) + # Expression coalesce. + self.assertIsInstance( + Store.objects.none().aggregate( + latest_opening=Coalesce( + Max('original_opening'), RawSQL('CURRENT_TIMESTAMP', []), + ), + )['latest_opening'], + datetime.datetime, + )