Fixed #28293 -- Fixed union(), intersection(), and difference() when combining with an EmptyQuerySet.
Thanks Jon Dufresne for the report and Tim Graham for the review.
This commit is contained in:
parent
9dc83c356d
commit
82175ead72
|
@ -818,12 +818,25 @@ class QuerySet:
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
def union(self, *other_qs, all=False):
|
def union(self, *other_qs, all=False):
|
||||||
|
# If the query is an EmptyQuerySet, combine all nonempty querysets.
|
||||||
|
if isinstance(self, EmptyQuerySet):
|
||||||
|
qs = [q for q in other_qs if not isinstance(q, EmptyQuerySet)]
|
||||||
|
return qs[0]._combinator_query('union', *qs[1:], all=all) if qs else self
|
||||||
return self._combinator_query('union', *other_qs, all=all)
|
return self._combinator_query('union', *other_qs, all=all)
|
||||||
|
|
||||||
def intersection(self, *other_qs):
|
def intersection(self, *other_qs):
|
||||||
|
# If any query is an EmptyQuerySet, return it.
|
||||||
|
if isinstance(self, EmptyQuerySet):
|
||||||
|
return self
|
||||||
|
for other in other_qs:
|
||||||
|
if isinstance(other, EmptyQuerySet):
|
||||||
|
return other
|
||||||
return self._combinator_query('intersection', *other_qs)
|
return self._combinator_query('intersection', *other_qs)
|
||||||
|
|
||||||
def difference(self, *other_qs):
|
def difference(self, *other_qs):
|
||||||
|
# If the query is an EmptyQuerySet, return it.
|
||||||
|
if isinstance(self, EmptyQuerySet):
|
||||||
|
return self
|
||||||
return self._combinator_query('difference', *other_qs)
|
return self._combinator_query('difference', *other_qs)
|
||||||
|
|
||||||
def select_for_update(self, nowait=False, skip_locked=False):
|
def select_for_update(self, nowait=False, skip_locked=False):
|
||||||
|
|
|
@ -390,7 +390,7 @@ class SQLCompiler:
|
||||||
features = self.connection.features
|
features = self.connection.features
|
||||||
compilers = [
|
compilers = [
|
||||||
query.get_compiler(self.using, self.connection)
|
query.get_compiler(self.using, self.connection)
|
||||||
for query in self.query.combined_queries
|
for query in self.query.combined_queries if not query.is_empty()
|
||||||
]
|
]
|
||||||
if not features.supports_slicing_ordering_in_compound:
|
if not features.supports_slicing_ordering_in_compound:
|
||||||
for query, compiler in zip(self.query.combined_queries, compilers):
|
for query, compiler in zip(self.query.combined_queries, compilers):
|
||||||
|
|
|
@ -29,3 +29,6 @@ Bugfixes
|
||||||
|
|
||||||
* Fixed crash in admin's inlines when a model has an inherited non-editable
|
* Fixed crash in admin's inlines when a model has an inherited non-editable
|
||||||
primary key (:ticket:`27967`).
|
primary key (:ticket:`27967`).
|
||||||
|
|
||||||
|
* Fixed ``QuerySet.union()``, ``intersection()``, and ``difference()`` when
|
||||||
|
combining with an ``EmptyQuerySet`` (:ticket:`28293`).
|
||||||
|
|
|
@ -42,6 +42,31 @@ class QuerySetSetOperationTests(TestCase):
|
||||||
self.assertEqual(len(list(qs1.union(qs2, all=True))), 20)
|
self.assertEqual(len(list(qs1.union(qs2, all=True))), 20)
|
||||||
self.assertEqual(len(list(qs1.union(qs2))), 10)
|
self.assertEqual(len(list(qs1.union(qs2))), 10)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_select_intersection')
|
||||||
|
def test_intersection_with_empty_qs(self):
|
||||||
|
qs1 = Number.objects.all()
|
||||||
|
qs2 = Number.objects.none()
|
||||||
|
self.assertEqual(len(qs1.intersection(qs2)), 0)
|
||||||
|
self.assertEqual(len(qs2.intersection(qs1)), 0)
|
||||||
|
self.assertEqual(len(qs2.intersection(qs2)), 0)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_select_difference')
|
||||||
|
def test_difference_with_empty_qs(self):
|
||||||
|
qs1 = Number.objects.all()
|
||||||
|
qs2 = Number.objects.none()
|
||||||
|
self.assertEqual(len(qs1.difference(qs2)), 10)
|
||||||
|
self.assertEqual(len(qs2.difference(qs1)), 0)
|
||||||
|
self.assertEqual(len(qs2.difference(qs2)), 0)
|
||||||
|
|
||||||
|
def test_union_with_empty_qs(self):
|
||||||
|
qs1 = Number.objects.all()
|
||||||
|
qs2 = Number.objects.none()
|
||||||
|
self.assertEqual(len(qs1.union(qs2)), 10)
|
||||||
|
self.assertEqual(len(qs2.union(qs1)), 10)
|
||||||
|
self.assertEqual(len(qs2.union(qs1, qs1, qs1)), 10)
|
||||||
|
self.assertEqual(len(qs2.union(qs1, qs1, all=True)), 20)
|
||||||
|
self.assertEqual(len(qs2.union(qs2)), 0)
|
||||||
|
|
||||||
def test_limits(self):
|
def test_limits(self):
|
||||||
qs1 = Number.objects.all()
|
qs1 = Number.objects.all()
|
||||||
qs2 = Number.objects.all()
|
qs2 = Number.objects.all()
|
||||||
|
|
Loading…
Reference in New Issue