Fixed #28378 -- Fixed union() and difference() when combining with a queryset raising EmptyResultSet.

Thanks Jon Dufresne for the report. Thanks Tim Graham and Simon Charette
for the reviews.
This commit is contained in:
Mariusz Felisiak 2017-07-10 19:45:09 +02:00
parent 9bca0d0b38
commit ca74e56350
3 changed files with 29 additions and 10 deletions

View File

@ -399,7 +399,18 @@ class SQLCompiler:
raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.') raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.')
if compiler.get_order_by(): if compiler.get_order_by():
raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.') raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.')
parts = (compiler.as_sql() for compiler in compilers) parts = ()
for compiler in compilers:
try:
parts += (compiler.as_sql(),)
except EmptyResultSet:
# Omit the empty queryset with UNION and with DIFFERENCE if the
# first queryset is nonempty.
if combinator == 'union' or (combinator == 'difference' and parts):
continue
raise
if not parts:
return [], []
combinator_sql = self.connection.ops.set_operators[combinator] combinator_sql = self.connection.ops.set_operators[combinator]
if all and combinator == 'union': if all and combinator == 'union':
combinator_sql += ' ALL' combinator_sql += ' ALL'
@ -422,16 +433,7 @@ class SQLCompiler:
refcounts_before = self.query.alias_refcount.copy() refcounts_before = self.query.alias_refcount.copy()
try: try:
extra_select, order_by, group_by = self.pre_sql_setup() extra_select, order_by, group_by = self.pre_sql_setup()
distinct_fields = self.get_distinct()
# This must come after 'select', 'ordering', and 'distinct' -- see
# docstring of get_from_clause() for details.
from_, f_params = self.get_from_clause()
for_update_part = None for_update_part = None
where, w_params = self.compile(self.where) if self.where is not None else ("", [])
having, h_params = self.compile(self.having) if self.having is not None else ("", [])
combinator = self.query.combinator combinator = self.query.combinator
features = self.connection.features features = self.connection.features
if combinator: if combinator:
@ -439,6 +441,12 @@ class SQLCompiler:
raise NotSupportedError('{} is not supported on this database backend.'.format(combinator)) raise NotSupportedError('{} is not supported on this database backend.'.format(combinator))
result, params = self.get_combinator_sql(combinator, self.query.combinator_all) result, params = self.get_combinator_sql(combinator, self.query.combinator_all)
else: else:
distinct_fields = self.get_distinct()
# This must come after 'select', 'ordering', and 'distinct'
# (see docstring of get_from_clause() for details).
from_, f_params = self.get_from_clause()
where, w_params = self.compile(self.where) if self.where is not None else ("", [])
having, h_params = self.compile(self.having) if self.having is not None else ("", [])
result = ['SELECT'] result = ['SELECT']
params = [] params = []

View File

@ -12,3 +12,6 @@ Bugfixes
* Fixed a regression in 1.11.3 on Python 2 where non-ASCII ``format`` values * Fixed a regression in 1.11.3 on Python 2 where non-ASCII ``format`` values
for date/time widgets results in an empty ``value`` in the widget's HTML for date/time widgets results in an empty ``value`` in the widget's HTML
(:ticket:`28355`). (:ticket:`28355`).
* Fixed ``QuerySet.union()`` and ``difference()`` when combining with
a queryset raising ``EmptyResultSet`` (:ticket:`28378`).

View File

@ -58,18 +58,26 @@ class QuerySetSetOperationTests(TestCase):
def test_difference_with_empty_qs(self): def test_difference_with_empty_qs(self):
qs1 = Number.objects.all() qs1 = Number.objects.all()
qs2 = Number.objects.none() qs2 = Number.objects.none()
qs3 = Number.objects.filter(pk__in=[])
self.assertEqual(len(qs1.difference(qs2)), 10) self.assertEqual(len(qs1.difference(qs2)), 10)
self.assertEqual(len(qs1.difference(qs3)), 10)
self.assertEqual(len(qs2.difference(qs1)), 0) self.assertEqual(len(qs2.difference(qs1)), 0)
self.assertEqual(len(qs3.difference(qs1)), 0)
self.assertEqual(len(qs2.difference(qs2)), 0) self.assertEqual(len(qs2.difference(qs2)), 0)
self.assertEqual(len(qs3.difference(qs3)), 0)
def test_union_with_empty_qs(self): def test_union_with_empty_qs(self):
qs1 = Number.objects.all() qs1 = Number.objects.all()
qs2 = Number.objects.none() qs2 = Number.objects.none()
qs3 = Number.objects.filter(pk__in=[])
self.assertEqual(len(qs1.union(qs2)), 10) self.assertEqual(len(qs1.union(qs2)), 10)
self.assertEqual(len(qs2.union(qs1)), 10) self.assertEqual(len(qs2.union(qs1)), 10)
self.assertEqual(len(qs1.union(qs3)), 10)
self.assertEqual(len(qs3.union(qs1)), 10)
self.assertEqual(len(qs2.union(qs1, qs1, qs1)), 10) self.assertEqual(len(qs2.union(qs1, qs1, qs1)), 10)
self.assertEqual(len(qs2.union(qs1, qs1, all=True)), 20) self.assertEqual(len(qs2.union(qs1, qs1, all=True)), 20)
self.assertEqual(len(qs2.union(qs2)), 0) self.assertEqual(len(qs2.union(qs2)), 0)
self.assertEqual(len(qs3.union(qs3)), 0)
def test_limits(self): def test_limits(self):
qs1 = Number.objects.all() qs1 = Number.objects.all()