Fixed #27985 -- Fixed query for __exact=value when get_prep_value() converts value to None.

Also fixed crash of .filter(field__transform=None).
This commit is contained in:
Sergey Fedoseev 2017-07-30 22:00:00 +05:00 committed by Tim Graham
parent 6155bc4a51
commit 58da81a5a3
4 changed files with 66 additions and 58 deletions

View File

@ -19,6 +19,9 @@ class MultiColSource:
return self.__class__(relabels.get(self.alias, self.alias), return self.__class__(relabels.get(self.alias, self.alias),
self.targets, self.sources, self.field) self.targets, self.sources, self.field)
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
def get_normalized_value(value, lhs): def get_normalized_value(value, lhs):
from django.db.models import Model from django.db.models import Model

View File

@ -964,19 +964,9 @@ class Query:
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return self.get_compiler(connection=connection).as_sql() return self.get_compiler(connection=connection).as_sql()
def prepare_lookup_value(self, value, lookups, can_reuse, allow_joins=True): def resolve_lookup_value(self, value, can_reuse, allow_joins):
# Default lookup if none given is exact.
used_joins = set() used_joins = set()
if len(lookups) == 0: if hasattr(value, 'resolve_expression'):
lookups = ['exact']
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
# uses of None as a query value.
if value is None:
if lookups[-1] not in ('exact', 'iexact'):
raise ValueError("Cannot use None as a query value")
lookups[-1] = 'isnull'
return True, lookups, used_joins
elif hasattr(value, 'resolve_expression'):
pre_joins = self.alias_refcount.copy() pre_joins = self.alias_refcount.copy()
value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)} used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)}
@ -993,15 +983,7 @@ class Query:
# The used_joins for a tuple of expressions is the union of # The used_joins for a tuple of expressions is the union of
# the used_joins for the individual expressions. # the used_joins for the individual expressions.
used_joins.update(k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)) used_joins.update(k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0))
# For Oracle '' is equivalent to null. The check needs to be done return value, used_joins
# at this stage because join promotion can't be done at compiler
# stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we
# can do here. Similar thing is done in is_nullable(), too.
if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and
lookups[-1] == 'exact' and value == ''):
value = True
lookups[-1] = 'isnull'
return value, lookups, used_joins
def solve_lookup_type(self, lookup): def solve_lookup_type(self, lookup):
""" """
@ -1014,13 +996,11 @@ class Query:
return expression_lookups, (), expression return expression_lookups, (), expression
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
if len(lookup_parts) == 0: if len(lookup_parts) > 1 and not field_parts:
lookup_parts = ['exact'] raise FieldError(
elif len(lookup_parts) > 1: 'Invalid lookup "%s" for model %s".' %
if not field_parts: (lookup, self.get_meta().model.__name__)
raise FieldError( )
'Invalid lookup "%s" for model %s".' %
(lookup, self.get_meta().model.__name__))
return lookup_parts, field_parts, False return lookup_parts, field_parts, False
def check_query_object_type(self, value, opts, field): def check_query_object_type(self, value, opts, field):
@ -1063,23 +1043,43 @@ class Query:
The lookups is a list of names to extract using get_lookup() The lookups is a list of names to extract using get_lookup()
and get_transform(). and get_transform().
""" """
lookups = lookups[:] # __exact is the default lookup if one isn't given.
while lookups: if len(lookups) == 0:
name = lookups[0] lookups = ['exact']
# If there is just one part left, try first get_lookup() so
# that if the lhs supports both transform and lookup for the for name in lookups[:-1]:
# name, then lookup will be picked.
if len(lookups) == 1:
final_lookup = lhs.get_lookup(name)
if not final_lookup:
# We didn't find a lookup. We are going to interpret
# the name as transform, and do an Exact lookup against
# it.
lhs = self.try_transform(lhs, name)
final_lookup = lhs.get_lookup('exact')
return final_lookup(lhs, rhs)
lhs = self.try_transform(lhs, name) lhs = self.try_transform(lhs, name)
lookups = lookups[1:] # First try get_lookup() so that the lookup takes precedence if the lhs
# supports both transform and lookup for the name.
lookup_class = lhs.get_lookup(lookups[-1])
if not lookup_class:
if lhs.field.is_relation:
raise FieldError('Related Field got invalid lookup: {}'.format(lookups[-1]))
# A lookup wasn't found. Try to interpret the name as a transform
# and do an Exact lookup against it.
lhs = self.try_transform(lhs, lookups[-1])
lookup_class = lhs.get_lookup('exact')
if not lookup_class:
return
lookup = lookup_class(lhs, rhs)
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
# uses of None as a query value.
if lookup.rhs is None:
if lookup.lookup_name not in ('exact', 'iexact'):
raise ValueError("Cannot use None as a query value")
return lhs.get_lookup('isnull')(lhs, True)
# For Oracle '' is equivalent to null. The check must be done at this
# stage because join promotion can't be done in the compiler. Using
# DEFAULT_DB_ALIAS isn't nice but it's the best that can be done here.
# A similar thing is done in is_nullable(), too.
if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and
lookup.lookup_name == 'exact' and lookup.rhs == ''):
return lhs.get_lookup('isnull')(lhs, True)
return lookup
def try_transform(self, lhs, name): def try_transform(self, lhs, name):
""" """
@ -1133,7 +1133,7 @@ class Query:
# Work out the lookup type and remove it from the end of 'parts', # Work out the lookup type and remove it from the end of 'parts',
# if necessary. # if necessary.
value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse, allow_joins) value, used_joins = self.resolve_lookup_value(value, can_reuse, allow_joins)
clause = self.where_class() clause = self.where_class()
if reffed_expression: if reffed_expression:
@ -1173,25 +1173,19 @@ class Query:
num_lookups = len(lookups) num_lookups = len(lookups)
if num_lookups > 1: if num_lookups > 1:
raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0])) raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0]))
assert num_lookups > 0 # Likely a bug in Django if this fails.
lookup_class = field.get_lookup(lookups[0])
if lookup_class is None:
raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0]))
if len(targets) == 1: if len(targets) == 1:
lhs = targets[0].get_col(alias, field) col = targets[0].get_col(alias, field)
else: else:
lhs = MultiColSource(alias, targets, sources, field) col = MultiColSource(alias, targets, sources, field)
condition = lookup_class(lhs, value)
lookup_type = lookup_class.lookup_name
else: else:
col = targets[0].get_col(alias, field) col = targets[0].get_col(alias, field)
condition = self.build_lookup(lookups, col, value)
lookup_type = condition.lookup_name
condition = self.build_lookup(lookups, col, value)
lookup_type = condition.lookup_name
clause.add(condition, AND) clause.add(condition, AND)
require_outer = lookup_type == 'isnull' and value is True and not current_negated require_outer = lookup_type == 'isnull' and condition.rhs is True and not current_negated
if current_negated and (lookup_type != 'isnull' or value is False): if current_negated and (lookup_type != 'isnull' or condition.rhs is False):
require_outer = True require_outer = True
if (lookup_type != 'isnull' and ( if (lookup_type != 'isnull' and (
self.is_nullable(targets[0]) or self.is_nullable(targets[0]) or

View File

@ -44,7 +44,8 @@ class Tag(models.Model):
class NulledTextField(models.TextField): class NulledTextField(models.TextField):
pass def get_prep_value(self, value):
return None if value == '' else value
@NulledTextField.register_lookup @NulledTextField.register_lookup

View File

@ -846,3 +846,13 @@ class LookupTests(TestCase):
self.assertFalse(Season.objects.filter(nulled_text_field__isnull=True)) self.assertFalse(Season.objects.filter(nulled_text_field__isnull=True))
self.assertTrue(Season.objects.filter(nulled_text_field__nulled__isnull=True)) self.assertTrue(Season.objects.filter(nulled_text_field__nulled__isnull=True))
self.assertTrue(Season.objects.filter(nulled_text_field__nulled__exact=None)) self.assertTrue(Season.objects.filter(nulled_text_field__nulled__exact=None))
self.assertTrue(Season.objects.filter(nulled_text_field__nulled=None))
def test_custom_field_none_rhs(self):
"""
__exact=value is transformed to __isnull=True if Field.get_prep_value()
converts value to None.
"""
season = Season.objects.create(year=2012, nulled_text_field=None)
self.assertTrue(Season.objects.filter(pk=season.pk, nulled_text_field__isnull=True))
self.assertTrue(Season.objects.filter(pk=season.pk, nulled_text_field=''))