mirror of https://github.com/django/django.git
Implemented nested lookups
But there is no support of using lookups outside filtering yet.
This commit is contained in:
parent
4d219d4cde
commit
7c8b3a32cc
|
@ -1136,11 +1136,14 @@ class ForeignObject(RelatedField):
|
||||||
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
|
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
|
||||||
return pathinfos
|
return pathinfos
|
||||||
|
|
||||||
def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
|
def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups,
|
||||||
raw_value):
|
raw_value):
|
||||||
from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR
|
from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR
|
||||||
root_constraint = constraint_class()
|
root_constraint = constraint_class()
|
||||||
assert len(targets) == len(sources)
|
assert len(targets) == len(sources)
|
||||||
|
if len(lookups) > 1:
|
||||||
|
raise exceptions.FieldError('Relation fields do not support nested lookups')
|
||||||
|
lookup_type = lookups[0]
|
||||||
|
|
||||||
def get_normalized_value(value):
|
def get_normalized_value(value):
|
||||||
|
|
||||||
|
|
|
@ -1,27 +1,58 @@
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
|
||||||
|
from django.core.exceptions import FieldError
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
from django.utils.functional import cached_property
|
||||||
|
|
||||||
|
|
||||||
|
class Extract(object):
|
||||||
|
def __init__(self, constraint_class, lhs):
|
||||||
|
self.constraint_class, self.lhs = constraint_class, lhs
|
||||||
|
|
||||||
|
def get_lookup(self, lookup):
|
||||||
|
return self.output_type.get_lookup(lookup)
|
||||||
|
|
||||||
|
def as_sql(self, qn, connection):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def output_type(self):
|
||||||
|
return self.lhs.output_type
|
||||||
|
|
||||||
|
def relabeled_clone(self, relabels):
|
||||||
|
return self.__class__(self.constraint_class, self.lhs.relabeled_clone(relabels))
|
||||||
|
|
||||||
|
|
||||||
class Lookup(object):
|
class Lookup(object):
|
||||||
|
lookup_name = None
|
||||||
|
extract_class = None
|
||||||
|
|
||||||
def __init__(self, constraint_class, lhs, rhs):
|
def __init__(self, constraint_class, lhs, rhs):
|
||||||
self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
|
self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
|
||||||
self.rhs = self.get_prep_lookup()
|
if rhs is None:
|
||||||
|
if not self.extract_class:
|
||||||
|
raise FieldError("Lookup '%s' doesn't support nesting." % self.lookup_name)
|
||||||
|
else:
|
||||||
|
self.rhs = self.get_prep_lookup()
|
||||||
|
|
||||||
|
def get_extract(self):
|
||||||
|
return self.extract_class(self.constraint_class, self.lhs)
|
||||||
|
|
||||||
|
def get_prep_lookup(self):
|
||||||
|
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
|
||||||
|
|
||||||
def get_db_prep_lookup(self, value, connection):
|
def get_db_prep_lookup(self, value, connection):
|
||||||
return (
|
return (
|
||||||
'%s', self.lhs.output_type.get_db_prep_lookup(
|
'%s', self.lhs.output_type.get_db_prep_lookup(
|
||||||
self.lookup_name, value, connection, prepared=True))
|
self.lookup_name, value, connection, prepared=True))
|
||||||
|
|
||||||
def get_prep_lookup(self):
|
def process_lhs(self, qn, connection, lhs=None):
|
||||||
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
|
lhs = lhs or self.lhs
|
||||||
|
return qn.compile(lhs)
|
||||||
|
|
||||||
def process_lhs(self, qn, connection):
|
def process_rhs(self, qn, connection, rhs=None):
|
||||||
return qn.compile(self.lhs)
|
value = rhs or self.rhs
|
||||||
|
|
||||||
def process_rhs(self, qn, connection):
|
|
||||||
value = self.rhs
|
|
||||||
# Due to historical reasons there are a couple of different
|
# Due to historical reasons there are a couple of different
|
||||||
# ways to produce sql here. get_compiler is likely a Query
|
# ways to produce sql here. get_compiler is likely a Query
|
||||||
# instance, _as_sql QuerySet and as_sql just something with
|
# instance, _as_sql QuerySet and as_sql just something with
|
||||||
|
@ -118,7 +149,7 @@ class In(DjangoLookup):
|
||||||
lookup_name = 'in'
|
lookup_name = 'in'
|
||||||
|
|
||||||
def get_db_prep_lookup(self, value, connection):
|
def get_db_prep_lookup(self, value, connection):
|
||||||
params = self.lhs.field.get_db_prep_lookup(
|
params = self.lhs.output_type.get_db_prep_lookup(
|
||||||
self.lookup_name, value, connection, prepared=True)
|
self.lookup_name, value, connection, prepared=True)
|
||||||
if not params:
|
if not params:
|
||||||
# TODO: check why this leads to circular import
|
# TODO: check why this leads to circular import
|
||||||
|
|
|
@ -100,6 +100,9 @@ class Aggregate(object):
|
||||||
def output_type(self):
|
def output_type(self):
|
||||||
return self.field
|
return self.field
|
||||||
|
|
||||||
|
def get_lookup(self, lookup):
|
||||||
|
return self.output_type.get_lookup(lookup)
|
||||||
|
|
||||||
|
|
||||||
class Avg(Aggregate):
|
class Avg(Aggregate):
|
||||||
is_computed = True
|
is_computed = True
|
||||||
|
|
|
@ -25,6 +25,9 @@ class Col(object):
|
||||||
def get_cols(self):
|
def get_cols(self):
|
||||||
return [(self.alias, self.target.column)]
|
return [(self.alias, self.target.column)]
|
||||||
|
|
||||||
|
def get_lookup(self, name):
|
||||||
|
return self.output_type.get_lookup(name)
|
||||||
|
|
||||||
|
|
||||||
class EmptyResultSet(Exception):
|
class EmptyResultSet(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1027,19 +1027,16 @@ class Query(object):
|
||||||
# Add the aggregate to the query
|
# Add the aggregate to the query
|
||||||
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
|
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
|
||||||
|
|
||||||
def prepare_lookup_value(self, value, lookup_type, can_reuse):
|
def prepare_lookup_value(self, value, lookups, can_reuse):
|
||||||
|
# Default lookup if none given is exact.
|
||||||
|
if len(lookups) == 0:
|
||||||
|
lookups = ['exact']
|
||||||
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
|
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
|
||||||
# uses of None as a query value.
|
# uses of None as a query value.
|
||||||
if len(lookup_type) > 1:
|
|
||||||
raise FieldError('Nested lookups not allowed')
|
|
||||||
elif len(lookup_type) == 0:
|
|
||||||
lookup_type = 'exact'
|
|
||||||
else:
|
|
||||||
lookup_type = lookup_type[0]
|
|
||||||
if value is None:
|
if value is None:
|
||||||
if lookup_type != 'exact':
|
if lookups[-1] != 'exact':
|
||||||
raise ValueError("Cannot use None as a query value")
|
raise ValueError("Cannot use None as a query value")
|
||||||
lookup_type = 'isnull'
|
lookups[-1] = 'isnull'
|
||||||
value = True
|
value = True
|
||||||
elif callable(value):
|
elif callable(value):
|
||||||
value = value()
|
value = value()
|
||||||
|
@ -1057,10 +1054,10 @@ class Query(object):
|
||||||
# stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we
|
# stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we
|
||||||
# can do here. Similar thing is done in is_nullable(), too.
|
# can do here. Similar thing is done in is_nullable(), too.
|
||||||
if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and
|
if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and
|
||||||
lookup_type == 'exact' and value == ''):
|
lookups[-1] == 'exact' and value == ''):
|
||||||
value = True
|
value = True
|
||||||
lookup_type = 'isnull'
|
lookups[-1] = ['isnull']
|
||||||
return value, lookup_type
|
return value, lookups
|
||||||
|
|
||||||
def solve_lookup_type(self, lookup):
|
def solve_lookup_type(self, lookup):
|
||||||
"""
|
"""
|
||||||
|
@ -1069,36 +1066,37 @@ class Query(object):
|
||||||
lookup_splitted = lookup.split(LOOKUP_SEP)
|
lookup_splitted = lookup.split(LOOKUP_SEP)
|
||||||
aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
|
aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
|
||||||
if aggregate:
|
if aggregate:
|
||||||
if len(aggregate_lookups) > 1:
|
|
||||||
raise FieldError("Nested lookups not allowed.")
|
|
||||||
return aggregate_lookups, (), aggregate
|
return aggregate_lookups, (), aggregate
|
||||||
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
|
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
|
||||||
field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
|
field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
|
||||||
if len(lookup_parts) == 0:
|
if len(lookup_parts) == 0:
|
||||||
lookup_parts = ['exact']
|
lookup_parts = ['exact']
|
||||||
elif len(lookup_parts) > 1:
|
elif len(lookup_parts) > 1:
|
||||||
if field_parts:
|
if not field_parts:
|
||||||
raise FieldError(
|
|
||||||
'Only one lookup part allowed (found path "%s" from "%s").' %
|
|
||||||
(LOOKUP_SEP.join(field_parts), lookup))
|
|
||||||
else:
|
|
||||||
raise FieldError(
|
raise FieldError(
|
||||||
'Invalid lookup "%s" for model %s".' %
|
'Invalid lookup "%s" for model %s".' %
|
||||||
(lookup, self.get_meta().model.__name__))
|
(lookup, self.get_meta().model.__name__))
|
||||||
else:
|
|
||||||
if not hasattr(field, 'get_lookup_constraint'):
|
|
||||||
lookup_class = field.get_lookup(lookup_parts[0])
|
|
||||||
if lookup_class is None and lookup_parts[0] not in self.query_terms:
|
|
||||||
raise FieldError(
|
|
||||||
'Invalid lookup name %s' % lookup_parts[0])
|
|
||||||
return lookup_parts, field_parts, False
|
return lookup_parts, field_parts, False
|
||||||
|
|
||||||
def build_lookup(self, lookup_type, lhs, rhs):
|
def build_lookup(self, lookups, lhs, rhs):
|
||||||
if hasattr(lhs.output_type, 'get_lookup'):
|
lookups = lookups[:]
|
||||||
lookup = lhs.output_type.get_lookup(lookup_type)
|
lookups.reverse()
|
||||||
if lookup:
|
while lookups:
|
||||||
return lookup(self.where_class, lhs, rhs)
|
lookup = lookups.pop()
|
||||||
return None
|
next = lhs.get_lookup(lookup)
|
||||||
|
if next:
|
||||||
|
if not lookups:
|
||||||
|
# This was the last lookup, so return value lookup.
|
||||||
|
return next(self.where_class, lhs, rhs)
|
||||||
|
else:
|
||||||
|
lhs = next(self.where_class, lhs, None).get_extract()
|
||||||
|
# A field's get_lookup() can return None to opt for backwards
|
||||||
|
# compatibility path.
|
||||||
|
elif len(lookups) > 1:
|
||||||
|
raise FieldError(
|
||||||
|
"Unsupported lookup for field '%s'" % lhs.output_type.name)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
|
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
|
||||||
can_reuse=None, connector=AND):
|
can_reuse=None, connector=AND):
|
||||||
|
@ -1130,19 +1128,20 @@ class Query(object):
|
||||||
arg, value = filter_expr
|
arg, value = filter_expr
|
||||||
if not arg:
|
if not arg:
|
||||||
raise FieldError("Cannot parse keyword query %r" % arg)
|
raise FieldError("Cannot parse keyword query %r" % arg)
|
||||||
lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg)
|
lookups, parts, reffed_aggregate = self.solve_lookup_type(arg)
|
||||||
|
|
||||||
# Work out the lookup type and remove it from the end of 'parts',
|
# Work out the lookup type and remove it from the end of 'parts',
|
||||||
# if necessary.
|
# if necessary.
|
||||||
value, lookup_type = self.prepare_lookup_value(value, lookup_type, can_reuse)
|
value, lookups = self.prepare_lookup_value(value, lookups, can_reuse)
|
||||||
used_joins = getattr(value, '_used_joins', [])
|
used_joins = getattr(value, '_used_joins', [])
|
||||||
|
|
||||||
clause = self.where_class()
|
clause = self.where_class()
|
||||||
if reffed_aggregate:
|
if reffed_aggregate:
|
||||||
condition = self.build_lookup(lookup_type, reffed_aggregate, value)
|
condition = self.build_lookup(lookups, reffed_aggregate, value)
|
||||||
if not condition:
|
if not condition:
|
||||||
# Backwards compat for custom lookups
|
# Backwards compat for custom lookups
|
||||||
condition = (reffed_aggregate, lookup_type, value)
|
assert len(lookups) == 1
|
||||||
|
condition = (reffed_aggregate, lookups[0], value)
|
||||||
clause.add(condition, AND)
|
clause.add(condition, AND)
|
||||||
return clause, []
|
return clause, []
|
||||||
|
|
||||||
|
@ -1169,14 +1168,27 @@ class Query(object):
|
||||||
# For now foreign keys get special treatment. This should be
|
# For now foreign keys get special treatment. This should be
|
||||||
# refactored when composite fields lands.
|
# refactored when composite fields lands.
|
||||||
condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
|
condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
|
||||||
lookup_type, value)
|
lookups, value)
|
||||||
|
lookup_type = lookups[-1]
|
||||||
else:
|
else:
|
||||||
assert(len(targets) == 1)
|
assert(len(targets) == 1)
|
||||||
col = Col(alias, targets[0], field)
|
col = Col(alias, targets[0], field)
|
||||||
condition = self.build_lookup(lookup_type, col, value)
|
condition = self.build_lookup(lookups, col, value)
|
||||||
if not condition:
|
if not condition:
|
||||||
# Backwards compat for custom lookups
|
# Backwards compat for custom lookups
|
||||||
condition = (Constraint(alias, targets[0].column, field), lookup_type, value)
|
if lookups[0] not in self.query_terms:
|
||||||
|
raise FieldError(
|
||||||
|
"Join on field '%s' not permitted. Did you "
|
||||||
|
"misspell '%s' for the lookup type?" %
|
||||||
|
(col.output_type.name, lookups[0]))
|
||||||
|
if len(lookups) > 1:
|
||||||
|
raise FieldError("Nested lookup '%s' not supported." %
|
||||||
|
LOOKUP_SEP.join(lookups))
|
||||||
|
condition = (Constraint(alias, targets[0].column, field), lookups[0], value)
|
||||||
|
lookup_type = lookups[-1]
|
||||||
|
else:
|
||||||
|
lookup_type = condition.lookup_name
|
||||||
|
|
||||||
clause.add(condition, AND)
|
clause.add(condition, AND)
|
||||||
|
|
||||||
require_outer = lookup_type == 'isnull' and value is True and not current_negated
|
require_outer = lookup_type == 'isnull' and value is True and not current_negated
|
||||||
|
@ -1296,7 +1308,7 @@ class Query(object):
|
||||||
needed_inner = joinpromoter.update_join_types(self)
|
needed_inner = joinpromoter.update_join_types(self)
|
||||||
return target_clause, needed_inner
|
return target_clause, needed_inner
|
||||||
|
|
||||||
def names_to_path(self, names, opts, allow_many=True):
|
def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False):
|
||||||
"""
|
"""
|
||||||
Walks the names path and turns them PathInfo tuples. Note that a
|
Walks the names path and turns them PathInfo tuples. Note that a
|
||||||
single name in 'names' can generate multiple PathInfos (m2m for
|
single name in 'names' can generate multiple PathInfos (m2m for
|
||||||
|
@ -1354,10 +1366,15 @@ class Query(object):
|
||||||
final_field = field
|
final_field = field
|
||||||
targets = (field,)
|
targets = (field,)
|
||||||
break
|
break
|
||||||
if pos == -1:
|
if pos == -1 or (fail_on_missing and pos + 1 != len(names)):
|
||||||
raise FieldError('Whazaa')
|
self.raise_field_error(opts, name)
|
||||||
return path, final_field, targets, names[pos + 1:]
|
return path, final_field, targets, names[pos + 1:]
|
||||||
|
|
||||||
|
def raise_field_error(self, opts, name):
|
||||||
|
available = opts.get_all_field_names() + list(self.aggregate_select)
|
||||||
|
raise FieldError("Cannot resolve keyword %r into field. "
|
||||||
|
"Choices are: %s" % (name, ", ".join(available)))
|
||||||
|
|
||||||
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
|
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
|
||||||
"""
|
"""
|
||||||
Compute the necessary table joins for the passage through the fields
|
Compute the necessary table joins for the passage through the fields
|
||||||
|
@ -1386,9 +1403,8 @@ class Query(object):
|
||||||
joins = [alias]
|
joins = [alias]
|
||||||
# First, generate the path for the names
|
# First, generate the path for the names
|
||||||
path, final_field, targets, rest = self.names_to_path(
|
path, final_field, targets, rest = self.names_to_path(
|
||||||
names, opts, allow_many)
|
names, opts, allow_many, fail_on_missing=True)
|
||||||
if rest:
|
|
||||||
raise FieldError('Invalid lookup')
|
|
||||||
# Then, add the path to the query's joins. Note that we can't trim
|
# Then, add the path to the query's joins. Note that we can't trim
|
||||||
# joins at this stage - we will need the information about join type
|
# joins at this stage - we will need the information about join type
|
||||||
# of the trimmed joins.
|
# of the trimmed joins.
|
||||||
|
|
|
@ -1,7 +1,13 @@
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.utils.encoding import python_2_unicode_compatible
|
||||||
|
|
||||||
|
|
||||||
|
@python_2_unicode_compatible
|
||||||
class Author(models.Model):
|
class Author(models.Model):
|
||||||
name = models.CharField(max_length=20)
|
name = models.CharField(max_length=20)
|
||||||
age = models.IntegerField(null=True)
|
age = models.IntegerField(null=True)
|
||||||
birthdate = models.DateField(null=True)
|
birthdate = models.DateField(null=True)
|
||||||
|
average_rating = models.FloatField(null=True)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
|
@ -19,6 +19,56 @@ class Div3Lookup(models.lookups.Lookup):
|
||||||
return '%s %%%% 3 = %s' % (lhs, rhs), params
|
return '%s %%%% 3 = %s' % (lhs, rhs), params
|
||||||
|
|
||||||
|
|
||||||
|
class Div3Extract(models.lookups.Extract):
|
||||||
|
def as_sql(self, qn, connection):
|
||||||
|
lhs, lhs_params = qn.compile(self.lhs)
|
||||||
|
return '%s %%%% 3' % (lhs,), lhs_params
|
||||||
|
|
||||||
|
|
||||||
|
class Div3LookupWithExtract(Div3Lookup):
|
||||||
|
lookup_name = 'div3'
|
||||||
|
extract_class = Div3Extract
|
||||||
|
|
||||||
|
|
||||||
|
class YearLte(models.lookups.LessThanOrEqual):
|
||||||
|
"""
|
||||||
|
The purpose of this lookup is to efficiently compare the year of the field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def as_sql(self, qn, connection):
|
||||||
|
# Skip the YearExtract above us (no possibility for efficient
|
||||||
|
# lookup otherwise).
|
||||||
|
real_lhs = self.lhs.lhs
|
||||||
|
lhs_sql, params = self.process_lhs(qn, connection, real_lhs)
|
||||||
|
rhs_sql, rhs_params = self.process_rhs(qn, connection)
|
||||||
|
params.extend(rhs_params)
|
||||||
|
# Build SQL where the integer year is concatenated with last month
|
||||||
|
# and day, then convert that to date. (We try to have SQL like:
|
||||||
|
# WHERE somecol <= '2013-12-31')
|
||||||
|
# but also make it work if the rhs_sql is field reference.
|
||||||
|
return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
|
||||||
|
|
||||||
|
|
||||||
|
class YearExtract(models.lookups.Extract):
|
||||||
|
def as_sql(self, qn, connection):
|
||||||
|
lhs_sql, params = qn.compile(self.lhs)
|
||||||
|
return connection.ops.date_extract_sql('year', lhs_sql), params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_type(self):
|
||||||
|
return models.IntegerField()
|
||||||
|
|
||||||
|
def get_lookup(self, lookup):
|
||||||
|
if lookup == 'lte':
|
||||||
|
return YearLte
|
||||||
|
else:
|
||||||
|
return super(YearExtract, self).get_lookup(lookup)
|
||||||
|
|
||||||
|
|
||||||
|
class YearWithExtract(models.lookups.Year):
|
||||||
|
extract_class = YearExtract
|
||||||
|
|
||||||
|
|
||||||
class InMonth(models.lookups.Lookup):
|
class InMonth(models.lookups.Lookup):
|
||||||
"""
|
"""
|
||||||
InMonth matches if the column's month is contained in the value's month.
|
InMonth matches if the column's month is contained in the value's month.
|
||||||
|
@ -134,3 +184,72 @@ class LookupTests(TestCase):
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
models.Field._unregister_lookup(AnotherEqual)
|
models.Field._unregister_lookup(AnotherEqual)
|
||||||
|
|
||||||
|
def test_div3_extract(self):
|
||||||
|
models.IntegerField.register_lookup(Div3LookupWithExtract)
|
||||||
|
try:
|
||||||
|
a1 = Author.objects.create(name='a1', age=1)
|
||||||
|
a2 = Author.objects.create(name='a2', age=2)
|
||||||
|
a3 = Author.objects.create(name='a3', age=3)
|
||||||
|
a4 = Author.objects.create(name='a4', age=4)
|
||||||
|
baseqs = Author.objects.order_by('name')
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
baseqs.filter(age__div3__lte=3),
|
||||||
|
[a1, a2, a3, a4], lambda x: x)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
baseqs.filter(age__div3__in=[0, 2]),
|
||||||
|
[a2, a3], lambda x: x)
|
||||||
|
finally:
|
||||||
|
models.IntegerField._unregister_lookup(Div3LookupWithExtract)
|
||||||
|
|
||||||
|
|
||||||
|
class YearLteTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
models.DateField.register_lookup(YearWithExtract)
|
||||||
|
self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
|
||||||
|
self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
|
||||||
|
self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
|
||||||
|
self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
models.DateField._unregister_lookup(YearWithExtract)
|
||||||
|
|
||||||
|
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
|
||||||
|
def test_year_lte(self):
|
||||||
|
baseqs = Author.objects.order_by('name')
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
baseqs.filter(birthdate__year__lte=2012),
|
||||||
|
[self.a1, self.a2, self.a3, self.a4], lambda x: x)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
baseqs.filter(birthdate__year__lte=2011),
|
||||||
|
[self.a1], lambda x: x)
|
||||||
|
# The non-optimized version works, too.
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
baseqs.filter(birthdate__year__lt=2012),
|
||||||
|
[self.a1], lambda x: x)
|
||||||
|
|
||||||
|
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
|
||||||
|
def test_year_lte_fexpr(self):
|
||||||
|
self.a2.age = 2011
|
||||||
|
self.a2.save()
|
||||||
|
self.a3.age = 2012
|
||||||
|
self.a3.save()
|
||||||
|
self.a4.age = 2013
|
||||||
|
self.a4.save()
|
||||||
|
baseqs = Author.objects.order_by('name')
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
baseqs.filter(birthdate__year__lte=models.F('age')),
|
||||||
|
[self.a3, self.a4], lambda x: x)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
baseqs.filter(birthdate__year__lt=models.F('age')),
|
||||||
|
[self.a4], lambda x: x)
|
||||||
|
|
||||||
|
def test_year_lte_sql(self):
|
||||||
|
# This test will just check the generated SQL for __lte. This
|
||||||
|
# doesn't require running on PostgreSQL and spots the most likely
|
||||||
|
# error - not running YearLte SQL at all.
|
||||||
|
baseqs = Author.objects.order_by('name')
|
||||||
|
self.assertIn(
|
||||||
|
'<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
|
||||||
|
self.assertIn(
|
||||||
|
'-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))
|
||||||
|
|
|
@ -41,9 +41,6 @@ class NullQueriesTests(TestCase):
|
||||||
# Can't use None on anything other than __exact
|
# Can't use None on anything other than __exact
|
||||||
self.assertRaises(ValueError, Choice.objects.filter, id__gt=None)
|
self.assertRaises(ValueError, Choice.objects.filter, id__gt=None)
|
||||||
|
|
||||||
# Can't use None on anything other than __exact
|
|
||||||
self.assertRaises(ValueError, Choice.objects.filter, foo__gt=None)
|
|
||||||
|
|
||||||
# Related managers use __exact=None implicitly if the object hasn't been saved.
|
# Related managers use __exact=None implicitly if the object hasn't been saved.
|
||||||
p2 = Poll(question="How?")
|
p2 = Poll(question="How?")
|
||||||
self.assertEqual(repr(p2.choice_set.all()), '[]')
|
self.assertEqual(repr(p2.choice_set.all()), '[]')
|
||||||
|
|
Loading…
Reference in New Issue