Implemented nested lookups

But there is no support of using lookups outside filtering yet.
This commit is contained in:
Anssi Kääriäinen 2013-11-30 23:04:34 +02:00
parent 4d219d4cde
commit 7c8b3a32cc
8 changed files with 235 additions and 57 deletions

View File

@ -1136,11 +1136,14 @@ class ForeignObject(RelatedField):
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
return pathinfos return pathinfos
def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type, def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups,
raw_value): raw_value):
from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR
root_constraint = constraint_class() root_constraint = constraint_class()
assert len(targets) == len(sources) assert len(targets) == len(sources)
if len(lookups) > 1:
raise exceptions.FieldError('Relation fields do not support nested lookups')
lookup_type = lookups[0]
def get_normalized_value(value): def get_normalized_value(value):

View File

@ -1,27 +1,58 @@
from copy import copy from copy import copy
from django.core.exceptions import FieldError
from django.conf import settings from django.conf import settings
from django.utils import timezone from django.utils import timezone
from django.utils.functional import cached_property
class Extract(object):
def __init__(self, constraint_class, lhs):
self.constraint_class, self.lhs = constraint_class, lhs
def get_lookup(self, lookup):
return self.output_type.get_lookup(lookup)
def as_sql(self, qn, connection):
raise NotImplementedError
@cached_property
def output_type(self):
return self.lhs.output_type
def relabeled_clone(self, relabels):
return self.__class__(self.constraint_class, self.lhs.relabeled_clone(relabels))
class Lookup(object): class Lookup(object):
lookup_name = None
extract_class = None
def __init__(self, constraint_class, lhs, rhs): def __init__(self, constraint_class, lhs, rhs):
self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
if rhs is None:
if not self.extract_class:
raise FieldError("Lookup '%s' doesn't support nesting." % self.lookup_name)
else:
self.rhs = self.get_prep_lookup() self.rhs = self.get_prep_lookup()
def get_extract(self):
return self.extract_class(self.constraint_class, self.lhs)
def get_prep_lookup(self):
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
def get_db_prep_lookup(self, value, connection): def get_db_prep_lookup(self, value, connection):
return ( return (
'%s', self.lhs.output_type.get_db_prep_lookup( '%s', self.lhs.output_type.get_db_prep_lookup(
self.lookup_name, value, connection, prepared=True)) self.lookup_name, value, connection, prepared=True))
def get_prep_lookup(self): def process_lhs(self, qn, connection, lhs=None):
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs) lhs = lhs or self.lhs
return qn.compile(lhs)
def process_lhs(self, qn, connection): def process_rhs(self, qn, connection, rhs=None):
return qn.compile(self.lhs) value = rhs or self.rhs
def process_rhs(self, qn, connection):
value = self.rhs
# Due to historical reasons there are a couple of different # Due to historical reasons there are a couple of different
# ways to produce sql here. get_compiler is likely a Query # ways to produce sql here. get_compiler is likely a Query
# instance, _as_sql QuerySet and as_sql just something with # instance, _as_sql QuerySet and as_sql just something with
@ -118,7 +149,7 @@ class In(DjangoLookup):
lookup_name = 'in' lookup_name = 'in'
def get_db_prep_lookup(self, value, connection): def get_db_prep_lookup(self, value, connection):
params = self.lhs.field.get_db_prep_lookup( params = self.lhs.output_type.get_db_prep_lookup(
self.lookup_name, value, connection, prepared=True) self.lookup_name, value, connection, prepared=True)
if not params: if not params:
# TODO: check why this leads to circular import # TODO: check why this leads to circular import

View File

@ -100,6 +100,9 @@ class Aggregate(object):
def output_type(self): def output_type(self):
return self.field return self.field
def get_lookup(self, lookup):
return self.output_type.get_lookup(lookup)
class Avg(Aggregate): class Avg(Aggregate):
is_computed = True is_computed = True

View File

@ -25,6 +25,9 @@ class Col(object):
def get_cols(self): def get_cols(self):
return [(self.alias, self.target.column)] return [(self.alias, self.target.column)]
def get_lookup(self, name):
return self.output_type.get_lookup(name)
class EmptyResultSet(Exception): class EmptyResultSet(Exception):
pass pass

View File

@ -1027,19 +1027,16 @@ class Query(object):
# Add the aggregate to the query # Add the aggregate to the query
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
def prepare_lookup_value(self, value, lookup_type, can_reuse): def prepare_lookup_value(self, value, lookups, can_reuse):
# Default lookup if none given is exact.
if len(lookups) == 0:
lookups = ['exact']
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
# uses of None as a query value. # uses of None as a query value.
if len(lookup_type) > 1:
raise FieldError('Nested lookups not allowed')
elif len(lookup_type) == 0:
lookup_type = 'exact'
else:
lookup_type = lookup_type[0]
if value is None: if value is None:
if lookup_type != 'exact': if lookups[-1] != 'exact':
raise ValueError("Cannot use None as a query value") raise ValueError("Cannot use None as a query value")
lookup_type = 'isnull' lookups[-1] = 'isnull'
value = True value = True
elif callable(value): elif callable(value):
value = value() value = value()
@ -1057,10 +1054,10 @@ class Query(object):
# stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we # stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we
# can do here. Similar thing is done in is_nullable(), too. # can do here. Similar thing is done in is_nullable(), too.
if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and
lookup_type == 'exact' and value == ''): lookups[-1] == 'exact' and value == ''):
value = True value = True
lookup_type = 'isnull' lookups[-1] = ['isnull']
return value, lookup_type return value, lookups
def solve_lookup_type(self, lookup): def solve_lookup_type(self, lookup):
""" """
@ -1069,35 +1066,36 @@ class Query(object):
lookup_splitted = lookup.split(LOOKUP_SEP) lookup_splitted = lookup.split(LOOKUP_SEP)
aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
if aggregate: if aggregate:
if len(aggregate_lookups) > 1:
raise FieldError("Nested lookups not allowed.")
return aggregate_lookups, (), aggregate return aggregate_lookups, (), aggregate
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
if len(lookup_parts) == 0: if len(lookup_parts) == 0:
lookup_parts = ['exact'] lookup_parts = ['exact']
elif len(lookup_parts) > 1: elif len(lookup_parts) > 1:
if field_parts: if not field_parts:
raise FieldError(
'Only one lookup part allowed (found path "%s" from "%s").' %
(LOOKUP_SEP.join(field_parts), lookup))
else:
raise FieldError( raise FieldError(
'Invalid lookup "%s" for model %s".' % 'Invalid lookup "%s" for model %s".' %
(lookup, self.get_meta().model.__name__)) (lookup, self.get_meta().model.__name__))
else:
if not hasattr(field, 'get_lookup_constraint'):
lookup_class = field.get_lookup(lookup_parts[0])
if lookup_class is None and lookup_parts[0] not in self.query_terms:
raise FieldError(
'Invalid lookup name %s' % lookup_parts[0])
return lookup_parts, field_parts, False return lookup_parts, field_parts, False
def build_lookup(self, lookup_type, lhs, rhs): def build_lookup(self, lookups, lhs, rhs):
if hasattr(lhs.output_type, 'get_lookup'): lookups = lookups[:]
lookup = lhs.output_type.get_lookup(lookup_type) lookups.reverse()
if lookup: while lookups:
return lookup(self.where_class, lhs, rhs) lookup = lookups.pop()
next = lhs.get_lookup(lookup)
if next:
if not lookups:
# This was the last lookup, so return value lookup.
return next(self.where_class, lhs, rhs)
else:
lhs = next(self.where_class, lhs, None).get_extract()
# A field's get_lookup() can return None to opt for backwards
# compatibility path.
elif len(lookups) > 1:
raise FieldError(
"Unsupported lookup for field '%s'" % lhs.output_type.name)
else:
return None return None
def build_filter(self, filter_expr, branch_negated=False, current_negated=False, def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
@ -1130,19 +1128,20 @@ class Query(object):
arg, value = filter_expr arg, value = filter_expr
if not arg: if not arg:
raise FieldError("Cannot parse keyword query %r" % arg) raise FieldError("Cannot parse keyword query %r" % arg)
lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg) lookups, parts, reffed_aggregate = self.solve_lookup_type(arg)
# Work out the lookup type and remove it from the end of 'parts', # Work out the lookup type and remove it from the end of 'parts',
# if necessary. # if necessary.
value, lookup_type = self.prepare_lookup_value(value, lookup_type, can_reuse) value, lookups = self.prepare_lookup_value(value, lookups, can_reuse)
used_joins = getattr(value, '_used_joins', []) used_joins = getattr(value, '_used_joins', [])
clause = self.where_class() clause = self.where_class()
if reffed_aggregate: if reffed_aggregate:
condition = self.build_lookup(lookup_type, reffed_aggregate, value) condition = self.build_lookup(lookups, reffed_aggregate, value)
if not condition: if not condition:
# Backwards compat for custom lookups # Backwards compat for custom lookups
condition = (reffed_aggregate, lookup_type, value) assert len(lookups) == 1
condition = (reffed_aggregate, lookups[0], value)
clause.add(condition, AND) clause.add(condition, AND)
return clause, [] return clause, []
@ -1169,14 +1168,27 @@ class Query(object):
# For now foreign keys get special treatment. This should be # For now foreign keys get special treatment. This should be
# refactored when composite fields lands. # refactored when composite fields lands.
condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
lookup_type, value) lookups, value)
lookup_type = lookups[-1]
else: else:
assert(len(targets) == 1) assert(len(targets) == 1)
col = Col(alias, targets[0], field) col = Col(alias, targets[0], field)
condition = self.build_lookup(lookup_type, col, value) condition = self.build_lookup(lookups, col, value)
if not condition: if not condition:
# Backwards compat for custom lookups # Backwards compat for custom lookups
condition = (Constraint(alias, targets[0].column, field), lookup_type, value) if lookups[0] not in self.query_terms:
raise FieldError(
"Join on field '%s' not permitted. Did you "
"misspell '%s' for the lookup type?" %
(col.output_type.name, lookups[0]))
if len(lookups) > 1:
raise FieldError("Nested lookup '%s' not supported." %
LOOKUP_SEP.join(lookups))
condition = (Constraint(alias, targets[0].column, field), lookups[0], value)
lookup_type = lookups[-1]
else:
lookup_type = condition.lookup_name
clause.add(condition, AND) clause.add(condition, AND)
require_outer = lookup_type == 'isnull' and value is True and not current_negated require_outer = lookup_type == 'isnull' and value is True and not current_negated
@ -1296,7 +1308,7 @@ class Query(object):
needed_inner = joinpromoter.update_join_types(self) needed_inner = joinpromoter.update_join_types(self)
return target_clause, needed_inner return target_clause, needed_inner
def names_to_path(self, names, opts, allow_many=True): def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False):
""" """
Walks the names path and turns them PathInfo tuples. Note that a Walks the names path and turns them PathInfo tuples. Note that a
single name in 'names' can generate multiple PathInfos (m2m for single name in 'names' can generate multiple PathInfos (m2m for
@ -1354,10 +1366,15 @@ class Query(object):
final_field = field final_field = field
targets = (field,) targets = (field,)
break break
if pos == -1: if pos == -1 or (fail_on_missing and pos + 1 != len(names)):
raise FieldError('Whazaa') self.raise_field_error(opts, name)
return path, final_field, targets, names[pos + 1:] return path, final_field, targets, names[pos + 1:]
def raise_field_error(self, opts, name):
available = opts.get_all_field_names() + list(self.aggregate_select)
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(available)))
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
""" """
Compute the necessary table joins for the passage through the fields Compute the necessary table joins for the passage through the fields
@ -1386,9 +1403,8 @@ class Query(object):
joins = [alias] joins = [alias]
# First, generate the path for the names # First, generate the path for the names
path, final_field, targets, rest = self.names_to_path( path, final_field, targets, rest = self.names_to_path(
names, opts, allow_many) names, opts, allow_many, fail_on_missing=True)
if rest:
raise FieldError('Invalid lookup')
# Then, add the path to the query's joins. Note that we can't trim # Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type # joins at this stage - we will need the information about join type
# of the trimmed joins. # of the trimmed joins.

View File

@ -1,7 +1,13 @@
from django.db import models from django.db import models
from django.utils.encoding import python_2_unicode_compatible
@python_2_unicode_compatible
class Author(models.Model): class Author(models.Model):
name = models.CharField(max_length=20) name = models.CharField(max_length=20)
age = models.IntegerField(null=True) age = models.IntegerField(null=True)
birthdate = models.DateField(null=True) birthdate = models.DateField(null=True)
average_rating = models.FloatField(null=True)
def __str__(self):
return self.name

View File

@ -19,6 +19,56 @@ class Div3Lookup(models.lookups.Lookup):
return '%s %%%% 3 = %s' % (lhs, rhs), params return '%s %%%% 3 = %s' % (lhs, rhs), params
class Div3Extract(models.lookups.Extract):
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return '%s %%%% 3' % (lhs,), lhs_params
class Div3LookupWithExtract(Div3Lookup):
lookup_name = 'div3'
extract_class = Div3Extract
class YearLte(models.lookups.LessThanOrEqual):
"""
The purpose of this lookup is to efficiently compare the year of the field.
"""
def as_sql(self, qn, connection):
# Skip the YearExtract above us (no possibility for efficient
# lookup otherwise).
real_lhs = self.lhs.lhs
lhs_sql, params = self.process_lhs(qn, connection, real_lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection)
params.extend(rhs_params)
# Build SQL where the integer year is concatenated with last month
# and day, then convert that to date. (We try to have SQL like:
# WHERE somecol <= '2013-12-31')
# but also make it work if the rhs_sql is field reference.
return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
class YearExtract(models.lookups.Extract):
def as_sql(self, qn, connection):
lhs_sql, params = qn.compile(self.lhs)
return connection.ops.date_extract_sql('year', lhs_sql), params
@property
def output_type(self):
return models.IntegerField()
def get_lookup(self, lookup):
if lookup == 'lte':
return YearLte
else:
return super(YearExtract, self).get_lookup(lookup)
class YearWithExtract(models.lookups.Year):
extract_class = YearExtract
class InMonth(models.lookups.Lookup): class InMonth(models.lookups.Lookup):
""" """
InMonth matches if the column's month is contained in the value's month. InMonth matches if the column's month is contained in the value's month.
@ -134,3 +184,72 @@ class LookupTests(TestCase):
) )
finally: finally:
models.Field._unregister_lookup(AnotherEqual) models.Field._unregister_lookup(AnotherEqual)
def test_div3_extract(self):
models.IntegerField.register_lookup(Div3LookupWithExtract)
try:
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(age__div3__lte=3),
[a1, a2, a3, a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[0, 2]),
[a2, a3], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Div3LookupWithExtract)
class YearLteTests(TestCase):
def setUp(self):
models.DateField.register_lookup(YearWithExtract)
self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
def tearDown(self):
models.DateField._unregister_lookup(YearWithExtract)
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
def test_year_lte(self):
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lte=2012),
[self.a1, self.a2, self.a3, self.a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lte=2011),
[self.a1], lambda x: x)
# The non-optimized version works, too.
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lt=2012),
[self.a1], lambda x: x)
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
def test_year_lte_fexpr(self):
self.a2.age = 2011
self.a2.save()
self.a3.age = 2012
self.a3.save()
self.a4.age = 2013
self.a4.save()
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lte=models.F('age')),
[self.a3, self.a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lt=models.F('age')),
[self.a4], lambda x: x)
def test_year_lte_sql(self):
# This test will just check the generated SQL for __lte. This
# doesn't require running on PostgreSQL and spots the most likely
# error - not running YearLte SQL at all.
baseqs = Author.objects.order_by('name')
self.assertIn(
'<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
self.assertIn(
'-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))

View File

@ -41,9 +41,6 @@ class NullQueriesTests(TestCase):
# Can't use None on anything other than __exact # Can't use None on anything other than __exact
self.assertRaises(ValueError, Choice.objects.filter, id__gt=None) self.assertRaises(ValueError, Choice.objects.filter, id__gt=None)
# Can't use None on anything other than __exact
self.assertRaises(ValueError, Choice.objects.filter, foo__gt=None)
# Related managers use __exact=None implicitly if the object hasn't been saved. # Related managers use __exact=None implicitly if the object hasn't been saved.
p2 = Poll(question="How?") p2 = Poll(question="How?")
self.assertEqual(repr(p2.choice_set.all()), '[]') self.assertEqual(repr(p2.choice_set.all()), '[]')