Fixed #31507 -- Added QuerySet.exists() optimizations to compound queries.

This commit is contained in:
David-Wobrock 2020-11-11 23:16:32 +01:00 committed by Mariusz Felisiak
parent 7b42d34646
commit ba42569d5c
3 changed files with 35 additions and 8 deletions

View File

@ -1116,10 +1116,11 @@ class Subquery(Expression):
def external_aliases(self): def external_aliases(self):
return self.query.external_aliases return self.query.external_aliases
def as_sql(self, compiler, connection, template=None, **extra_context): def as_sql(self, compiler, connection, template=None, query=None, **extra_context):
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
template_params = {**self.extra, **extra_context} template_params = {**self.extra, **extra_context}
subquery_sql, sql_params = self.query.as_sql(compiler, connection) query = query or self.query
subquery_sql, sql_params = query.as_sql(compiler, connection)
template_params['subquery'] = subquery_sql[1:-1] template_params['subquery'] = subquery_sql[1:-1]
template = template or template_params.get('template', self.template) template = template or template_params.get('template', self.template)
@ -1142,7 +1143,6 @@ class Exists(Subquery):
def __init__(self, queryset, negated=False, **kwargs): def __init__(self, queryset, negated=False, **kwargs):
self.negated = negated self.negated = negated
super().__init__(queryset, **kwargs) super().__init__(queryset, **kwargs)
self.query = self.query.exists()
def __invert__(self): def __invert__(self):
clone = self.copy() clone = self.copy()
@ -1150,7 +1150,14 @@ class Exists(Subquery):
return clone return clone
def as_sql(self, compiler, connection, template=None, **extra_context): def as_sql(self, compiler, connection, template=None, **extra_context):
sql, params = super().as_sql(compiler, connection, template, **extra_context) query = self.query.exists(using=connection.alias)
sql, params = super().as_sql(
compiler,
connection,
template=template,
query=query,
**extra_context,
)
if self.negated: if self.negated:
sql = 'NOT {}'.format(sql) sql = 'NOT {}'.format(sql)
return sql, params return sql, params

View File

@ -525,7 +525,7 @@ class Query(BaseExpression):
def has_filters(self): def has_filters(self):
return self.where return self.where
def exists(self): def exists(self, using, limit=True):
q = self.clone() q = self.clone()
if not q.distinct: if not q.distinct:
if q.group_by is True: if q.group_by is True:
@ -534,14 +534,21 @@ class Query(BaseExpression):
# SELECT clause which is about to be cleared. # SELECT clause which is about to be cleared.
q.set_group_by(allow_aliases=False) q.set_group_by(allow_aliases=False)
q.clear_select_clause() q.clear_select_clause()
if q.combined_queries and q.combinator == 'union':
limit_combined = connections[using].features.supports_slicing_ordering_in_compound
q.combined_queries = tuple(
combined_query.exists(using, limit=limit_combined)
for combined_query in q.combined_queries
)
q.clear_ordering(True) q.clear_ordering(True)
q.set_limits(high=1) if limit:
q.set_limits(high=1)
q.add_extra({'a': 1}, None, None, None, None, None) q.add_extra({'a': 1}, None, None, None, None, None)
q.set_extra_mask(['a']) q.set_extra_mask(['a'])
return q return q
def has_results(self, using): def has_results(self, using):
q = self.exists() q = self.exists(using)
compiler = q.get_compiler(using=using) compiler = q.get_compiler(using=using)
return compiler.has_results() return compiler.has_results()

View File

@ -3,6 +3,7 @@ import operator
from django.db import DatabaseError, NotSupportedError, connection from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import Exists, F, IntegerField, OuterRef, Value from django.db.models import Exists, F, IntegerField, OuterRef, Value
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext
from .models import Number, ReservedName from .models import Number, ReservedName
@ -257,7 +258,19 @@ class QuerySetSetOperationTests(TestCase):
def test_exists_union(self): def test_exists_union(self):
qs1 = Number.objects.filter(num__gte=5) qs1 = Number.objects.filter(num__gte=5)
qs2 = Number.objects.filter(num__lte=5) qs2 = Number.objects.filter(num__lte=5)
self.assertIs(qs1.union(qs2).exists(), True) with CaptureQueriesContext(connection) as context:
self.assertIs(qs1.union(qs2).exists(), True)
captured_queries = context.captured_queries
self.assertEqual(len(captured_queries), 1)
captured_sql = captured_queries[0]['sql']
self.assertNotIn(
connection.ops.quote_name(Number._meta.pk.column),
captured_sql,
)
self.assertEqual(
captured_sql.count(connection.ops.limit_offset_sql(None, 1)),
3 if connection.features.supports_slicing_ordering_in_compound else 1
)
def test_exists_union_empty_result(self): def test_exists_union_empty_result(self):
qs = Number.objects.filter(pk__in=[]) qs = Number.objects.filter(pk__in=[])