Fixed #31667 -- Made __in lookup ignore None values.

This commit is contained in:
Adam Johnson 2020-06-05 23:49:08 +01:00 committed by Mariusz Felisiak
parent 36db4dd937
commit 5776a1660e
2 changed files with 31 additions and 3 deletions

View File

@ -366,10 +366,12 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
)
if self.rhs_is_direct_value():
# Remove None from the list as NULL is never equal to anything.
try:
rhs = OrderedSet(self.rhs)
rhs.discard(None)
except TypeError: # Unhashable items in self.rhs
rhs = self.rhs
rhs = [r for r in self.rhs if r is not None]
if not rhs:
raise EmptyResultSet

View File

@ -576,8 +576,6 @@ class LookupTests(TestCase):
self.assertQuerysetEqual(Article.objects.none().iterator(), [])
def test_in(self):
# using __in with an empty list should return an empty query set
self.assertQuerysetEqual(Article.objects.filter(id__in=[]), [])
self.assertQuerysetEqual(
Article.objects.exclude(id__in=[]),
[
@ -591,6 +589,9 @@ class LookupTests(TestCase):
]
)
def test_in_empty_list(self):
self.assertSequenceEqual(Article.objects.filter(id__in=[]), [])
def test_in_different_database(self):
with self.assertRaisesMessage(
ValueError,
@ -603,6 +604,31 @@ class LookupTests(TestCase):
query = Article.objects.filter(slug__in=['a%d' % i for i in range(1, 8)]).values('pk').query
self.assertIn(' IN (a1, a2, a3, a4, a5, a6, a7) ', str(query))
def test_in_ignore_none(self):
with self.assertNumQueries(1) as ctx:
self.assertSequenceEqual(
Article.objects.filter(id__in=[None, self.a1.id]),
[self.a1],
)
sql = ctx.captured_queries[0]['sql']
self.assertIn('IN (%s)' % self.a1.pk, sql)
def test_in_ignore_solo_none(self):
with self.assertNumQueries(0):
self.assertSequenceEqual(Article.objects.filter(id__in=[None]), [])
def test_in_ignore_none_with_unhashable_items(self):
class UnhashableInt(int):
__hash__ = None
with self.assertNumQueries(1) as ctx:
self.assertSequenceEqual(
Article.objects.filter(id__in=[None, UnhashableInt(self.a1.id)]),
[self.a1],
)
sql = ctx.captured_queries[0]['sql']
self.assertIn('IN (%s)' % self.a1.pk, sql)
def test_error_messages(self):
# Programming errors are pointed out with nice error messages
with self.assertRaisesMessage(