diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index c596ccfe5d..97045d2f49 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1116,10 +1116,11 @@ class Subquery(Expression): def external_aliases(self): return self.query.external_aliases - def as_sql(self, compiler, connection, template=None, **extra_context): + def as_sql(self, compiler, connection, template=None, query=None, **extra_context): connection.ops.check_expression_support(self) template_params = {**self.extra, **extra_context} - subquery_sql, sql_params = self.query.as_sql(compiler, connection) + query = query or self.query + subquery_sql, sql_params = query.as_sql(compiler, connection) template_params['subquery'] = subquery_sql[1:-1] template = template or template_params.get('template', self.template) @@ -1142,7 +1143,6 @@ class Exists(Subquery): def __init__(self, queryset, negated=False, **kwargs): self.negated = negated super().__init__(queryset, **kwargs) - self.query = self.query.exists() def __invert__(self): clone = self.copy() @@ -1150,7 +1150,14 @@ class Exists(Subquery): return clone def as_sql(self, compiler, connection, template=None, **extra_context): - sql, params = super().as_sql(compiler, connection, template, **extra_context) + query = self.query.exists(using=connection.alias) + sql, params = super().as_sql( + compiler, + connection, + template=template, + query=query, + **extra_context, + ) if self.negated: sql = 'NOT {}'.format(sql) return sql, params diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 5a7dc25fd9..b7f8053cb1 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -525,7 +525,7 @@ class Query(BaseExpression): def has_filters(self): return self.where - def exists(self): + def exists(self, using, limit=True): q = self.clone() if not q.distinct: if q.group_by is True: @@ -534,14 +534,21 @@ class Query(BaseExpression): # SELECT clause which is about to be cleared. q.set_group_by(allow_aliases=False) q.clear_select_clause() + if q.combined_queries and q.combinator == 'union': + limit_combined = connections[using].features.supports_slicing_ordering_in_compound + q.combined_queries = tuple( + combined_query.exists(using, limit=limit_combined) + for combined_query in q.combined_queries + ) q.clear_ordering(True) - q.set_limits(high=1) + if limit: + q.set_limits(high=1) q.add_extra({'a': 1}, None, None, None, None, None) q.set_extra_mask(['a']) return q def has_results(self, using): - q = self.exists() + q = self.exists(using) compiler = q.get_compiler(using=using) return compiler.has_results() diff --git a/tests/queries/test_qs_combinators.py b/tests/queries/test_qs_combinators.py index b12277d34c..81c7b2e3a3 100644 --- a/tests/queries/test_qs_combinators.py +++ b/tests/queries/test_qs_combinators.py @@ -3,6 +3,7 @@ import operator from django.db import DatabaseError, NotSupportedError, connection from django.db.models import Exists, F, IntegerField, OuterRef, Value from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature +from django.test.utils import CaptureQueriesContext from .models import Number, ReservedName @@ -257,7 +258,19 @@ class QuerySetSetOperationTests(TestCase): def test_exists_union(self): qs1 = Number.objects.filter(num__gte=5) qs2 = Number.objects.filter(num__lte=5) - self.assertIs(qs1.union(qs2).exists(), True) + with CaptureQueriesContext(connection) as context: + self.assertIs(qs1.union(qs2).exists(), True) + captured_queries = context.captured_queries + self.assertEqual(len(captured_queries), 1) + captured_sql = captured_queries[0]['sql'] + self.assertNotIn( + connection.ops.quote_name(Number._meta.pk.column), + captured_sql, + ) + self.assertEqual( + captured_sql.count(connection.ops.limit_offset_sql(None, 1)), + 3 if connection.features.supports_slicing_ordering_in_compound else 1 + ) def test_exists_union_empty_result(self): qs = Number.objects.filter(pk__in=[])