diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index cb0402007e..29016301ed 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -1136,11 +1136,14 @@ class ForeignObject(RelatedField): pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] return pathinfos - def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type, + def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups, raw_value): from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR root_constraint = constraint_class() assert len(targets) == len(sources) + if len(lookups) > 1: + raise exceptions.FieldError('Relation fields do not support nested lookups') + lookup_type = lookups[0] def get_normalized_value(value): diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 85acb24a73..b216e9eb8d 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -1,27 +1,58 @@ from copy import copy +from django.core.exceptions import FieldError from django.conf import settings from django.utils import timezone +from django.utils.functional import cached_property + + +class Extract(object): + def __init__(self, constraint_class, lhs): + self.constraint_class, self.lhs = constraint_class, lhs + + def get_lookup(self, lookup): + return self.output_type.get_lookup(lookup) + + def as_sql(self, qn, connection): + raise NotImplementedError + + @cached_property + def output_type(self): + return self.lhs.output_type + + def relabeled_clone(self, relabels): + return self.__class__(self.constraint_class, self.lhs.relabeled_clone(relabels)) class Lookup(object): + lookup_name = None + extract_class = None + def __init__(self, constraint_class, lhs, rhs): self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs - self.rhs = self.get_prep_lookup() + if rhs is None: + if not self.extract_class: + raise FieldError("Lookup '%s' doesn't support nesting." % self.lookup_name) + else: + self.rhs = self.get_prep_lookup() + + def get_extract(self): + return self.extract_class(self.constraint_class, self.lhs) + + def get_prep_lookup(self): + return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs) def get_db_prep_lookup(self, value, connection): return ( '%s', self.lhs.output_type.get_db_prep_lookup( self.lookup_name, value, connection, prepared=True)) - def get_prep_lookup(self): - return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs) + def process_lhs(self, qn, connection, lhs=None): + lhs = lhs or self.lhs + return qn.compile(lhs) - def process_lhs(self, qn, connection): - return qn.compile(self.lhs) - - def process_rhs(self, qn, connection): - value = self.rhs + def process_rhs(self, qn, connection, rhs=None): + value = rhs or self.rhs # Due to historical reasons there are a couple of different # ways to produce sql here. get_compiler is likely a Query # instance, _as_sql QuerySet and as_sql just something with @@ -118,7 +149,7 @@ class In(DjangoLookup): lookup_name = 'in' def get_db_prep_lookup(self, value, connection): - params = self.lhs.field.get_db_prep_lookup( + params = self.lhs.output_type.get_db_prep_lookup( self.lookup_name, value, connection, prepared=True) if not params: # TODO: check why this leads to circular import diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index dcf04d4b78..7c4ec71be0 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -100,6 +100,9 @@ class Aggregate(object): def output_type(self): return self.field + def get_lookup(self, lookup): + return self.output_type.get_lookup(lookup) + class Avg(Aggregate): is_computed = True diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index dd273f51c7..1e1573f1e6 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -25,6 +25,9 @@ class Col(object): def get_cols(self): return [(self.alias, self.target.column)] + def get_lookup(self, name): + return self.output_type.get_lookup(name) + class EmptyResultSet(Exception): pass diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index a7fafd88b5..e6667bf26b 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1027,19 +1027,16 @@ class Query(object): # Add the aggregate to the query aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) - def prepare_lookup_value(self, value, lookup_type, can_reuse): + def prepare_lookup_value(self, value, lookups, can_reuse): + # Default lookup if none given is exact. + if len(lookups) == 0: + lookups = ['exact'] # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all # uses of None as a query value. - if len(lookup_type) > 1: - raise FieldError('Nested lookups not allowed') - elif len(lookup_type) == 0: - lookup_type = 'exact' - else: - lookup_type = lookup_type[0] if value is None: - if lookup_type != 'exact': + if lookups[-1] != 'exact': raise ValueError("Cannot use None as a query value") - lookup_type = 'isnull' + lookups[-1] = 'isnull' value = True elif callable(value): value = value() @@ -1057,10 +1054,10 @@ class Query(object): # stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we # can do here. Similar thing is done in is_nullable(), too. if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and - lookup_type == 'exact' and value == ''): + lookups[-1] == 'exact' and value == ''): value = True - lookup_type = 'isnull' - return value, lookup_type + lookups[-1] = ['isnull'] + return value, lookups def solve_lookup_type(self, lookup): """ @@ -1069,36 +1066,37 @@ class Query(object): lookup_splitted = lookup.split(LOOKUP_SEP) aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) if aggregate: - if len(aggregate_lookups) > 1: - raise FieldError("Nested lookups not allowed.") return aggregate_lookups, (), aggregate _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] if len(lookup_parts) == 0: lookup_parts = ['exact'] elif len(lookup_parts) > 1: - if field_parts: - raise FieldError( - 'Only one lookup part allowed (found path "%s" from "%s").' % - (LOOKUP_SEP.join(field_parts), lookup)) - else: + if not field_parts: raise FieldError( 'Invalid lookup "%s" for model %s".' % (lookup, self.get_meta().model.__name__)) - else: - if not hasattr(field, 'get_lookup_constraint'): - lookup_class = field.get_lookup(lookup_parts[0]) - if lookup_class is None and lookup_parts[0] not in self.query_terms: - raise FieldError( - 'Invalid lookup name %s' % lookup_parts[0]) return lookup_parts, field_parts, False - def build_lookup(self, lookup_type, lhs, rhs): - if hasattr(lhs.output_type, 'get_lookup'): - lookup = lhs.output_type.get_lookup(lookup_type) - if lookup: - return lookup(self.where_class, lhs, rhs) - return None + def build_lookup(self, lookups, lhs, rhs): + lookups = lookups[:] + lookups.reverse() + while lookups: + lookup = lookups.pop() + next = lhs.get_lookup(lookup) + if next: + if not lookups: + # This was the last lookup, so return value lookup. + return next(self.where_class, lhs, rhs) + else: + lhs = next(self.where_class, lhs, None).get_extract() + # A field's get_lookup() can return None to opt for backwards + # compatibility path. + elif len(lookups) > 1: + raise FieldError( + "Unsupported lookup for field '%s'" % lhs.output_type.name) + else: + return None def build_filter(self, filter_expr, branch_negated=False, current_negated=False, can_reuse=None, connector=AND): @@ -1130,19 +1128,20 @@ class Query(object): arg, value = filter_expr if not arg: raise FieldError("Cannot parse keyword query %r" % arg) - lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg) + lookups, parts, reffed_aggregate = self.solve_lookup_type(arg) # Work out the lookup type and remove it from the end of 'parts', # if necessary. - value, lookup_type = self.prepare_lookup_value(value, lookup_type, can_reuse) + value, lookups = self.prepare_lookup_value(value, lookups, can_reuse) used_joins = getattr(value, '_used_joins', []) clause = self.where_class() if reffed_aggregate: - condition = self.build_lookup(lookup_type, reffed_aggregate, value) + condition = self.build_lookup(lookups, reffed_aggregate, value) if not condition: # Backwards compat for custom lookups - condition = (reffed_aggregate, lookup_type, value) + assert len(lookups) == 1 + condition = (reffed_aggregate, lookups[0], value) clause.add(condition, AND) return clause, [] @@ -1169,14 +1168,27 @@ class Query(object): # For now foreign keys get special treatment. This should be # refactored when composite fields lands. condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, - lookup_type, value) + lookups, value) + lookup_type = lookups[-1] else: assert(len(targets) == 1) col = Col(alias, targets[0], field) - condition = self.build_lookup(lookup_type, col, value) + condition = self.build_lookup(lookups, col, value) if not condition: # Backwards compat for custom lookups - condition = (Constraint(alias, targets[0].column, field), lookup_type, value) + if lookups[0] not in self.query_terms: + raise FieldError( + "Join on field '%s' not permitted. Did you " + "misspell '%s' for the lookup type?" % + (col.output_type.name, lookups[0])) + if len(lookups) > 1: + raise FieldError("Nested lookup '%s' not supported." % + LOOKUP_SEP.join(lookups)) + condition = (Constraint(alias, targets[0].column, field), lookups[0], value) + lookup_type = lookups[-1] + else: + lookup_type = condition.lookup_name + clause.add(condition, AND) require_outer = lookup_type == 'isnull' and value is True and not current_negated @@ -1296,7 +1308,7 @@ class Query(object): needed_inner = joinpromoter.update_join_types(self) return target_clause, needed_inner - def names_to_path(self, names, opts, allow_many=True): + def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): """ Walks the names path and turns them PathInfo tuples. Note that a single name in 'names' can generate multiple PathInfos (m2m for @@ -1354,10 +1366,15 @@ class Query(object): final_field = field targets = (field,) break - if pos == -1: - raise FieldError('Whazaa') + if pos == -1 or (fail_on_missing and pos + 1 != len(names)): + self.raise_field_error(opts, name) return path, final_field, targets, names[pos + 1:] + def raise_field_error(self, opts, name): + available = opts.get_all_field_names() + list(self.aggregate_select) + raise FieldError("Cannot resolve keyword %r into field. " + "Choices are: %s" % (name, ", ".join(available))) + def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): """ Compute the necessary table joins for the passage through the fields @@ -1386,9 +1403,8 @@ class Query(object): joins = [alias] # First, generate the path for the names path, final_field, targets, rest = self.names_to_path( - names, opts, allow_many) - if rest: - raise FieldError('Invalid lookup') + names, opts, allow_many, fail_on_missing=True) + # Then, add the path to the query's joins. Note that we can't trim # joins at this stage - we will need the information about join type # of the trimmed joins. diff --git a/tests/custom_lookups/models.py b/tests/custom_lookups/models.py index 5152bc6502..9841b36ce5 100644 --- a/tests/custom_lookups/models.py +++ b/tests/custom_lookups/models.py @@ -1,7 +1,13 @@ from django.db import models +from django.utils.encoding import python_2_unicode_compatible +@python_2_unicode_compatible class Author(models.Model): name = models.CharField(max_length=20) age = models.IntegerField(null=True) birthdate = models.DateField(null=True) + average_rating = models.FloatField(null=True) + + def __str__(self): + return self.name diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index a608693137..5864bf3546 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -19,6 +19,56 @@ class Div3Lookup(models.lookups.Lookup): return '%s %%%% 3 = %s' % (lhs, rhs), params +class Div3Extract(models.lookups.Extract): + def as_sql(self, qn, connection): + lhs, lhs_params = qn.compile(self.lhs) + return '%s %%%% 3' % (lhs,), lhs_params + + +class Div3LookupWithExtract(Div3Lookup): + lookup_name = 'div3' + extract_class = Div3Extract + + +class YearLte(models.lookups.LessThanOrEqual): + """ + The purpose of this lookup is to efficiently compare the year of the field. + """ + + def as_sql(self, qn, connection): + # Skip the YearExtract above us (no possibility for efficient + # lookup otherwise). + real_lhs = self.lhs.lhs + lhs_sql, params = self.process_lhs(qn, connection, real_lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params.extend(rhs_params) + # Build SQL where the integer year is concatenated with last month + # and day, then convert that to date. (We try to have SQL like: + # WHERE somecol <= '2013-12-31') + # but also make it work if the rhs_sql is field reference. + return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params + + +class YearExtract(models.lookups.Extract): + def as_sql(self, qn, connection): + lhs_sql, params = qn.compile(self.lhs) + return connection.ops.date_extract_sql('year', lhs_sql), params + + @property + def output_type(self): + return models.IntegerField() + + def get_lookup(self, lookup): + if lookup == 'lte': + return YearLte + else: + return super(YearExtract, self).get_lookup(lookup) + + +class YearWithExtract(models.lookups.Year): + extract_class = YearExtract + + class InMonth(models.lookups.Lookup): """ InMonth matches if the column's month is contained in the value's month. @@ -134,3 +184,72 @@ class LookupTests(TestCase): ) finally: models.Field._unregister_lookup(AnotherEqual) + + def test_div3_extract(self): + models.IntegerField.register_lookup(Div3LookupWithExtract) + try: + a1 = Author.objects.create(name='a1', age=1) + a2 = Author.objects.create(name='a2', age=2) + a3 = Author.objects.create(name='a3', age=3) + a4 = Author.objects.create(name='a4', age=4) + baseqs = Author.objects.order_by('name') + self.assertQuerysetEqual( + baseqs.filter(age__div3__lte=3), + [a1, a2, a3, a4], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__in=[0, 2]), + [a2, a3], lambda x: x) + finally: + models.IntegerField._unregister_lookup(Div3LookupWithExtract) + + +class YearLteTests(TestCase): + def setUp(self): + models.DateField.register_lookup(YearWithExtract) + self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16)) + self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29)) + self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31)) + self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1)) + + def tearDown(self): + models.DateField._unregister_lookup(YearWithExtract) + + @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used") + def test_year_lte(self): + baseqs = Author.objects.order_by('name') + self.assertQuerysetEqual( + baseqs.filter(birthdate__year__lte=2012), + [self.a1, self.a2, self.a3, self.a4], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(birthdate__year__lte=2011), + [self.a1], lambda x: x) + # The non-optimized version works, too. + self.assertQuerysetEqual( + baseqs.filter(birthdate__year__lt=2012), + [self.a1], lambda x: x) + + @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used") + def test_year_lte_fexpr(self): + self.a2.age = 2011 + self.a2.save() + self.a3.age = 2012 + self.a3.save() + self.a4.age = 2013 + self.a4.save() + baseqs = Author.objects.order_by('name') + self.assertQuerysetEqual( + baseqs.filter(birthdate__year__lte=models.F('age')), + [self.a3, self.a4], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(birthdate__year__lt=models.F('age')), + [self.a4], lambda x: x) + + def test_year_lte_sql(self): + # This test will just check the generated SQL for __lte. This + # doesn't require running on PostgreSQL and spots the most likely + # error - not running YearLte SQL at all. + baseqs = Author.objects.order_by('name') + self.assertIn( + '<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query)) + self.assertIn( + '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query)) diff --git a/tests/null_queries/tests.py b/tests/null_queries/tests.py index e442479cd7..16b6a908d9 100644 --- a/tests/null_queries/tests.py +++ b/tests/null_queries/tests.py @@ -41,9 +41,6 @@ class NullQueriesTests(TestCase): # Can't use None on anything other than __exact self.assertRaises(ValueError, Choice.objects.filter, id__gt=None) - # Can't use None on anything other than __exact - self.assertRaises(ValueError, Choice.objects.filter, foo__gt=None) - # Related managers use __exact=None implicitly if the object hasn't been saved. p2 = Poll(question="How?") self.assertEqual(repr(p2.choice_set.all()), '[]')