Initial implementation of custom lookups

This commit is contained in:
Anssi Kääriäinen 2013-11-27 22:07:30 +02:00
parent 01e8ac47b3
commit 4d219d4cde
16 changed files with 594 additions and 97 deletions

View File

@ -67,6 +67,9 @@ class BaseDatabaseWrapper(object):
self.allow_thread_sharing = allow_thread_sharing
self._thread_ident = thread.get_ident()
# Compile implementations, used by compiler.compile(someelem)
self.compile_implementations = utils.get_implementations(self.vendor)
def __eq__(self, other):
if isinstance(other, BaseDatabaseWrapper):
return self.alias == other.alias

View File

@ -194,3 +194,27 @@ def format_number(value, max_digits, decimal_places):
return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context))
else:
return "%.*f" % (decimal_places, value)
# Map of vendor name -> map of query element class -> implementation function
compile_implementations = {}
def get_implementations(vendor):
try:
implementation = compile_implementations[vendor]
except KeyError:
# TODO: do we need thread safety here? We could easily use an lock...
implementation = {}
compile_implementations[vendor] = implementation
return implementation
class add_implementation(object):
def __init__(self, klass, vendor):
self.klass = klass
self.vendor = vendor
def __call__(self, func):
implementations = get_implementations(self.vendor)
implementations[self.klass] = func
return func

View File

@ -17,8 +17,8 @@ def refs_aggregate(lookup_parts, aggregates):
"""
for i in range(len(lookup_parts) + 1):
if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates:
return True
return False
return aggregates[LOOKUP_SEP.join(lookup_parts[0:i])], lookup_parts[i:]
return False, ()
class Aggregate(object):

View File

@ -4,6 +4,7 @@ import collections
import copy
import datetime
import decimal
import inspect
import math
import warnings
from base64 import b64decode, b64encode
@ -11,6 +12,7 @@ from itertools import tee
from django.db import connection
from django.db.models.loading import get_model
from django.db.models.lookups import default_lookups
from django.db.models.query_utils import QueryWrapper
from django.conf import settings
from django import forms
@ -101,6 +103,7 @@ class Field(object):
'unique': _('%(model_name)s with this %(field_label)s '
'already exists.'),
}
class_lookups = default_lookups.copy()
# Generic field type description, usually overridden by subclasses
def _description(self):
@ -446,6 +449,30 @@ class Field(object):
def get_internal_type(self):
return self.__class__.__name__
def get_lookup(self, lookup_name):
try:
return self.class_lookups[lookup_name]
except KeyError:
for parent in inspect.getmro(self.__class__):
if not 'class_lookups' in parent.__dict__:
continue
if lookup_name in parent.class_lookups:
return parent.class_lookups[lookup_name]
@classmethod
def register_lookup(cls, lookup):
if not 'class_lookups' in cls.__dict__:
cls.class_lookups = {}
cls.class_lookups[lookup.lookup_name] = lookup
@classmethod
def _unregister_lookup(cls, lookup):
"""
Removes given lookup from cls lookups. Meant to be used in
tests only.
"""
del cls.class_lookups[lookup.lookup_name]
def pre_save(self, model_instance, add):
"""
Returns field's value just before saving.
@ -504,8 +531,7 @@ class Field(object):
except ValueError:
raise ValueError("The __year lookup type requires an integer "
"argument")
raise TypeError("Field has invalid lookup: %s" % lookup_type)
return self.get_prep_value(value)
def get_db_prep_lookup(self, lookup_type, value, connection,
prepared=False):
@ -554,6 +580,8 @@ class Field(object):
return connection.ops.year_lookup_bounds_for_date_field(value)
else:
return [value] # this isn't supposed to happen
else:
return [value]
def has_default(self):
"""

View File

@ -934,6 +934,11 @@ class ForeignObjectRel(object):
# example custom multicolumn joins currently have no remote field).
self.field_name = None
def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
raw_value):
return self.field.get_lookup_constraint(constraint_class, alias, targets, sources,
lookup_type, raw_value)
class ManyToOneRel(ForeignObjectRel):
def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,

242
django/db/models/lookups.py Normal file
View File

@ -0,0 +1,242 @@
from copy import copy
from django.conf import settings
from django.utils import timezone
class Lookup(object):
def __init__(self, constraint_class, lhs, rhs):
self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
self.rhs = self.get_prep_lookup()
def get_db_prep_lookup(self, value, connection):
return (
'%s', self.lhs.output_type.get_db_prep_lookup(
self.lookup_name, value, connection, prepared=True))
def get_prep_lookup(self):
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
def process_lhs(self, qn, connection):
return qn.compile(self.lhs)
def process_rhs(self, qn, connection):
value = self.rhs
# Due to historical reasons there are a couple of different
# ways to produce sql here. get_compiler is likely a Query
# instance, _as_sql QuerySet and as_sql just something with
# as_sql. Finally the value can of course be just plain
# Python value.
if hasattr(value, 'get_compiler'):
value = value.get_compiler(connection=connection)
if hasattr(value, 'as_sql'):
sql, params = qn.compile(value)
return '(' + sql + ')', params
if hasattr(value, '_as_sql'):
sql, params = value._as_sql(connection=connection)
return '(' + sql + ')', params
else:
return self.get_db_prep_lookup(value, connection)
def relabeled_clone(self, relabels):
new = copy(self)
new.lhs = new.lhs.relabeled_clone(relabels)
if hasattr(new.rhs, 'relabeled_clone'):
new.rhs = new.rhs.relabeled_clone(relabels)
return new
def get_cols(self):
cols = self.lhs.get_cols()
if hasattr(self.rhs, 'get_cols'):
cols.extend(self.rhs.get_cols())
return cols
def as_sql(self, qn, connection):
raise NotImplementedError
class DjangoLookup(Lookup):
def as_sql(self, qn, connection):
lhs_sql, params = self.process_lhs(qn, connection)
field_internal_type = self.lhs.output_type.get_internal_type()
db_type = self.lhs.output_type
lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql
rhs_sql, rhs_params = self.process_rhs(qn, connection)
params.extend(rhs_params)
operator_plus_rhs = self.get_rhs_op(connection, rhs_sql)
return '%s %s' % (lhs_sql, operator_plus_rhs), params
def get_rhs_op(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
default_lookups = {}
class Exact(DjangoLookup):
lookup_name = 'exact'
default_lookups['exact'] = Exact
class IExact(DjangoLookup):
lookup_name = 'iexact'
default_lookups['iexact'] = IExact
class Contains(DjangoLookup):
lookup_name = 'contains'
default_lookups['contains'] = Contains
class IContains(DjangoLookup):
lookup_name = 'icontains'
default_lookups['icontains'] = IContains
class GreaterThan(DjangoLookup):
lookup_name = 'gt'
default_lookups['gt'] = GreaterThan
class GreaterThanOrEqual(DjangoLookup):
lookup_name = 'gte'
default_lookups['gte'] = GreaterThanOrEqual
class LessThan(DjangoLookup):
lookup_name = 'lt'
default_lookups['lt'] = LessThan
class LessThanOrEqual(DjangoLookup):
lookup_name = 'lte'
default_lookups['lte'] = LessThanOrEqual
class In(DjangoLookup):
lookup_name = 'in'
def get_db_prep_lookup(self, value, connection):
params = self.lhs.field.get_db_prep_lookup(
self.lookup_name, value, connection, prepared=True)
if not params:
# TODO: check why this leads to circular import
from django.db.models.sql.datastructures import EmptyResultSet
raise EmptyResultSet
placeholder = '(' + ', '.join('%s' for p in params) + ')'
return (placeholder, params)
def get_rhs_op(self, connection, rhs):
return 'IN %s' % rhs
default_lookups['in'] = In
class StartsWith(DjangoLookup):
lookup_name = 'startswith'
default_lookups['startswith'] = StartsWith
class IStartsWith(DjangoLookup):
lookup_name = 'istartswith'
default_lookups['istartswith'] = IStartsWith
class EndsWith(DjangoLookup):
lookup_name = 'endswith'
default_lookups['endswith'] = EndsWith
class IEndsWith(DjangoLookup):
lookup_name = 'iendswith'
default_lookups['iendswith'] = IEndsWith
class Between(DjangoLookup):
def get_rhs_op(self, connection, rhs):
return "BETWEEN %s AND %s" % (rhs, rhs)
class Year(Between):
lookup_name = 'year'
default_lookups['year'] = Year
class Range(Between):
lookup_name = 'range'
default_lookups['range'] = Range
class DateLookup(DjangoLookup):
def process_lhs(self, qn, connection):
lhs, params = super(DateLookup, self).process_lhs(qn, connection)
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname)
return connection.ops.lookup_cast(self.lookup_name) % sql, tz_params
def get_rhs_op(self, connection, rhs):
return '= %s' % rhs
class Month(DateLookup):
lookup_name = 'month'
extract_type = 'month'
default_lookups['month'] = Month
class Day(DateLookup):
lookup_name = 'day'
extract_type = 'day'
default_lookups['day'] = Day
class WeekDay(DateLookup):
lookup_name = 'week_day'
extract_type = 'week_day'
default_lookups['week_day'] = WeekDay
class Hour(DateLookup):
lookup_name = 'hour'
extract_type = 'hour'
default_lookups['hour'] = Hour
class Minute(DateLookup):
lookup_name = 'minute'
extract_type = 'minute'
default_lookups['minute'] = Minute
class Second(DateLookup):
lookup_name = 'second'
extract_type = 'second'
default_lookups['second'] = Second
class IsNull(DjangoLookup):
lookup_name = 'isnull'
def as_sql(self, qn, connection):
sql, params = qn.compile(self.lhs)
if self.rhs:
return "%s IS NULL" % sql, params
else:
return "%s IS NOT NULL" % sql, params
default_lookups['isnull'] = IsNull
class Search(DjangoLookup):
lookup_name = 'search'
default_lookups['search'] = Search
class Regex(DjangoLookup):
lookup_name = 'regex'
default_lookups['regex'] = Regex
class IRegex(DjangoLookup):
lookup_name = 'iregex'
default_lookups['iregex'] = IRegex

View File

@ -93,6 +93,13 @@ class Aggregate(object):
return self.sql_template % substitutions, params
def get_cols(self):
return []
@property
def output_type(self):
return self.field
class Avg(Aggregate):
is_computed = True

View File

@ -45,7 +45,7 @@ class SQLCompiler(object):
if self.query.select_related and not self.query.related_select_cols:
self.fill_related_selections()
def quote_name_unless_alias(self, name):
def __call__(self, name):
"""
A wrapper around connection.ops.quote_name that doesn't quote aliases
for table names. This avoids problems with some SQL dialects that treat
@ -61,6 +61,20 @@ class SQLCompiler(object):
self.quote_cache[name] = r
return r
def quote_name_unless_alias(self, name):
"""
A wrapper around connection.ops.quote_name that doesn't quote aliases
for table names. This avoids problems with some SQL dialects that treat
quoted strings specially (e.g. PostgreSQL).
"""
return self(name)
def compile(self, node):
if node.__class__ in self.connection.compile_implementations:
return self.connection.compile_implementations[node.__class__](node, self)
else:
return node.as_sql(self, self.connection)
def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Creates the SQL for this query. Returns the SQL string and list of
@ -88,10 +102,8 @@ class SQLCompiler(object):
# docstring of get_from_clause() for details.
from_, f_params = self.get_from_clause()
qn = self.quote_name_unless_alias
where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
where, w_params = self.compile(self.query.where)
having, h_params = self.compile(self.query.having)
having_group_by = self.query.having.get_cols()
params = []
for val in six.itervalues(self.query.extra_select):
@ -180,7 +192,7 @@ class SQLCompiler(object):
(without the table names) are given unique aliases. This is needed in
some cases to avoid ambiguity with nested queries.
"""
qn = self.quote_name_unless_alias
qn = self
qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
params = []
@ -213,7 +225,7 @@ class SQLCompiler(object):
aliases.add(r)
col_aliases.add(col[1])
else:
col_sql, col_params = col.as_sql(qn, self.connection)
col_sql, col_params = self.compile(col)
result.append(col_sql)
params.extend(col_params)
@ -229,7 +241,7 @@ class SQLCompiler(object):
max_name_length = self.connection.ops.max_name_length()
for alias, aggregate in self.query.aggregate_select.items():
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
agg_sql, agg_params = self.compile(aggregate)
if alias is None:
result.append(agg_sql)
else:
@ -267,7 +279,7 @@ class SQLCompiler(object):
result = []
if opts is None:
opts = self.query.get_meta()
qn = self.quote_name_unless_alias
qn = self
qn2 = self.connection.ops.quote_name
aliases = set()
only_load = self.deferred_to_columns()
@ -312,7 +324,7 @@ class SQLCompiler(object):
Note that this method can alter the tables in the query, and thus it
must be called before get_from_clause().
"""
qn = self.quote_name_unless_alias
qn = self
qn2 = self.connection.ops.quote_name
result = []
opts = self.query.get_meta()
@ -345,7 +357,7 @@ class SQLCompiler(object):
ordering = (self.query.order_by
or self.query.get_meta().ordering
or [])
qn = self.quote_name_unless_alias
qn = self
qn2 = self.connection.ops.quote_name
distinct = self.query.distinct
select_aliases = self._select_aliases
@ -483,7 +495,7 @@ class SQLCompiler(object):
ordering and distinct must be done first.
"""
result = []
qn = self.quote_name_unless_alias
qn = self
qn2 = self.connection.ops.quote_name
first = True
from_params = []
@ -501,8 +513,7 @@ class SQLCompiler(object):
extra_cond = join_field.get_extra_restriction(
self.query.where_class, alias, lhs)
if extra_cond:
extra_sql, extra_params = extra_cond.as_sql(
qn, self.connection)
extra_sql, extra_params = self.compile(extra_cond)
extra_sql = 'AND (%s)' % extra_sql
from_params.extend(extra_params)
else:
@ -534,7 +545,7 @@ class SQLCompiler(object):
"""
Returns a tuple representing the SQL elements in the "group by" clause.
"""
qn = self.quote_name_unless_alias
qn = self
result, params = [], []
if self.query.group_by is not None:
select_cols = self.query.select + self.query.related_select_cols
@ -553,7 +564,7 @@ class SQLCompiler(object):
if isinstance(col, (list, tuple)):
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
elif hasattr(col, 'as_sql'):
sql, col_params = col.as_sql(qn, self.connection)
self.compile(col)
else:
sql = '(%s)' % str(col)
if sql not in seen:
@ -776,7 +787,7 @@ class SQLCompiler(object):
return result
def as_subquery_condition(self, alias, columns, qn):
inner_qn = self.quote_name_unless_alias
inner_qn = self
qn2 = self.connection.ops.quote_name
if len(columns) == 1:
sql, params = self.as_sql()
@ -887,9 +898,9 @@ class SQLDeleteCompiler(SQLCompiler):
"""
assert len(self.query.tables) == 1, \
"Can only delete from one table at a time."
qn = self.quote_name_unless_alias
qn = self
result = ['DELETE FROM %s' % qn(self.query.tables[0])]
where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
where, params = self.compile(self.query.where)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(params)
@ -905,7 +916,7 @@ class SQLUpdateCompiler(SQLCompiler):
if not self.query.values:
return '', ()
table = self.query.tables[0]
qn = self.quote_name_unless_alias
qn = self
result = ['UPDATE %s' % qn(table)]
result.append('SET')
values, update_params = [], []
@ -925,7 +936,7 @@ class SQLUpdateCompiler(SQLCompiler):
val = SQLEvaluator(val, self.query, allow_joins=False)
name = field.column
if hasattr(val, 'as_sql'):
sql, params = val.as_sql(qn, self.connection)
sql, params = self.compile(val)
values.append('%s = %s' % (qn(name), sql))
update_params.extend(params)
elif val is not None:
@ -936,7 +947,7 @@ class SQLUpdateCompiler(SQLCompiler):
if not values:
return '', ()
result.append(', '.join(values))
where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
where, params = self.compile(self.query.where)
if where:
result.append('WHERE %s' % where)
return ' '.join(result), tuple(update_params + params)
@ -1016,11 +1027,11 @@ class SQLAggregateCompiler(SQLCompiler):
parameters.
"""
if qn is None:
qn = self.quote_name_unless_alias
qn = self
sql, params = [], []
for aggregate in self.query.aggregate_select.values():
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
agg_sql, agg_params = self.compile(aggregate)
sql.append(agg_sql)
params.extend(agg_params)
sql = ', '.join(sql)

View File

@ -5,18 +5,25 @@ the SQL domain.
class Col(object):
def __init__(self, alias, col):
self.alias = alias
self.col = col
def __init__(self, alias, target, source):
self.alias, self.target, self.source = alias, target, source
def as_sql(self, qn, connection):
return '%s.%s' % (qn(self.alias), self.col), []
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
def prepare(self):
return self
@property
def output_type(self):
return self.source
@property
def field(self):
return self.source
def relabeled_clone(self, relabels):
return self.__class__(relabels.get(self.alias, self.alias), self.col)
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source)
def get_cols(self):
return [(self.alias, self.target.column)]
class EmptyResultSet(Exception):

View File

@ -1030,6 +1030,12 @@ class Query(object):
def prepare_lookup_value(self, value, lookup_type, can_reuse):
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
# uses of None as a query value.
if len(lookup_type) > 1:
raise FieldError('Nested lookups not allowed')
elif len(lookup_type) == 0:
lookup_type = 'exact'
else:
lookup_type = lookup_type[0]
if value is None:
if lookup_type != 'exact':
raise ValueError("Cannot use None as a query value")
@ -1060,31 +1066,39 @@ class Query(object):
"""
Solve the lookup type from the lookup (eg: 'foobar__id__icontains')
"""
lookup_type = 'exact' # Default lookup type
lookup_parts = lookup.split(LOOKUP_SEP)
num_parts = len(lookup_parts)
if (len(lookup_parts) > 1 and lookup_parts[-1] in self.query_terms
and (not self._aggregates or lookup not in self._aggregates)):
# Traverse the lookup query to distinguish related fields from
# lookup types.
lookup_model = self.model
for counter, field_name in enumerate(lookup_parts):
try:
lookup_field = lookup_model._meta.get_field(field_name)
except FieldDoesNotExist:
# Not a field. Bail out.
lookup_type = lookup_parts.pop()
break
# Unless we're at the end of the list of lookups, let's attempt
# to continue traversing relations.
if (counter + 1) < num_parts:
try:
lookup_model = lookup_field.rel.to
except AttributeError:
# Not a related field. Bail out.
lookup_type = lookup_parts.pop()
break
return lookup_type, lookup_parts
lookup_splitted = lookup.split(LOOKUP_SEP)
aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
if aggregate:
if len(aggregate_lookups) > 1:
raise FieldError("Nested lookups not allowed.")
return aggregate_lookups, (), aggregate
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
if len(lookup_parts) == 0:
lookup_parts = ['exact']
elif len(lookup_parts) > 1:
if field_parts:
raise FieldError(
'Only one lookup part allowed (found path "%s" from "%s").' %
(LOOKUP_SEP.join(field_parts), lookup))
else:
raise FieldError(
'Invalid lookup "%s" for model %s".' %
(lookup, self.get_meta().model.__name__))
else:
if not hasattr(field, 'get_lookup_constraint'):
lookup_class = field.get_lookup(lookup_parts[0])
if lookup_class is None and lookup_parts[0] not in self.query_terms:
raise FieldError(
'Invalid lookup name %s' % lookup_parts[0])
return lookup_parts, field_parts, False
def build_lookup(self, lookup_type, lhs, rhs):
if hasattr(lhs.output_type, 'get_lookup'):
lookup = lhs.output_type.get_lookup(lookup_type)
if lookup:
return lookup(self.where_class, lhs, rhs)
return None
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
can_reuse=None, connector=AND):
@ -1114,9 +1128,9 @@ class Query(object):
is responsible for unreffing the joins used.
"""
arg, value = filter_expr
lookup_type, parts = self.solve_lookup_type(arg)
if not parts:
if not arg:
raise FieldError("Cannot parse keyword query %r" % arg)
lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg)
# Work out the lookup type and remove it from the end of 'parts',
# if necessary.
@ -1124,10 +1138,12 @@ class Query(object):
used_joins = getattr(value, '_used_joins', [])
clause = self.where_class()
if self._aggregates:
for alias, aggregate in self.aggregates.items():
if alias in (parts[0], LOOKUP_SEP.join(parts)):
clause.add((aggregate, lookup_type, value), AND)
if reffed_aggregate:
condition = self.build_lookup(lookup_type, reffed_aggregate, value)
if not condition:
# Backwards compat for custom lookups
condition = (reffed_aggregate, lookup_type, value)
clause.add(condition, AND)
return clause, []
opts = self.get_meta()
@ -1150,11 +1166,18 @@ class Query(object):
targets, alias, join_list = self.trim_joins(sources, join_list, path)
if hasattr(field, 'get_lookup_constraint'):
constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources,
# For now foreign keys get special treatment. This should be
# refactored when composite fields lands.
condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
lookup_type, value)
else:
constraint = (Constraint(alias, targets[0].column, field), lookup_type, value)
clause.add(constraint, AND)
assert(len(targets) == 1)
col = Col(alias, targets[0], field)
condition = self.build_lookup(lookup_type, col, value)
if not condition:
# Backwards compat for custom lookups
condition = (Constraint(alias, targets[0].column, field), lookup_type, value)
clause.add(condition, AND)
require_outer = lookup_type == 'isnull' and value is True and not current_negated
if current_negated and (lookup_type != 'isnull' or value is False):
@ -1185,7 +1208,7 @@ class Query(object):
if not self._aggregates:
return False
if not isinstance(obj, Node):
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0]
or (hasattr(obj[1], 'contains_aggregate')
and obj[1].contains_aggregate(self.aggregates)))
return any(self.need_having(c) for c in obj.children)
@ -1273,7 +1296,7 @@ class Query(object):
needed_inner = joinpromoter.update_join_types(self)
return target_clause, needed_inner
def names_to_path(self, names, opts, allow_many):
def names_to_path(self, names, opts, allow_many=True):
"""
Walks the names path and turns them PathInfo tuples. Note that a
single name in 'names' can generate multiple PathInfos (m2m for
@ -1293,9 +1316,10 @@ class Query(object):
try:
field, model, direct, m2m = opts.get_field_by_name(name)
except FieldDoesNotExist:
available = opts.get_all_field_names() + list(self.aggregate_select)
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(available)))
# We didn't found the current field, so move position back
# one step.
pos -= 1
break
# Check if we need any joins for concrete inheritance cases (the
# field lives in parent, but we are currently in one of its
# children)
@ -1330,15 +1354,9 @@ class Query(object):
final_field = field
targets = (field,)
break
if pos != len(names) - 1:
if pos == len(names) - 2:
raise FieldError(
"Join on field %r not permitted. Did you misspell %r for "
"the lookup type?" % (name, names[pos + 1]))
else:
raise FieldError("Join on field %r not permitted." % name)
return path, final_field, targets
if pos == -1:
raise FieldError('Whazaa')
return path, final_field, targets, names[pos + 1:]
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
"""
@ -1367,8 +1385,10 @@ class Query(object):
"""
joins = [alias]
# First, generate the path for the names
path, final_field, targets = self.names_to_path(
path, final_field, targets, rest = self.names_to_path(
names, opts, allow_many)
if rest:
raise FieldError('Invalid lookup')
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
@ -1383,8 +1403,6 @@ class Query(object):
alias = self.join(
connection, reuse=reuse, nullable=nullable, join_field=join.join_field)
joins.append(alias)
if hasattr(final_field, 'field'):
final_field = final_field.field
return final_field, targets, opts, joins, path
def trim_joins(self, targets, joins, path):
@ -1455,7 +1473,7 @@ class Query(object):
query.bump_prefix(self)
query.where.add(
(Constraint(query.select[0].col[0], pk.column, pk),
'exact', Col(alias, pk.column)),
'exact', Col(alias, pk, pk)),
AND
)

View File

@ -101,7 +101,7 @@ class WhereNode(tree.Node):
for child in self.children:
try:
if hasattr(child, 'as_sql'):
sql, params = child.as_sql(qn=qn, connection=connection)
sql, params = qn.compile(child)
else:
# A leaf node in the tree.
sql, params = self.make_atom(child, qn, connection)
@ -193,13 +193,13 @@ class WhereNode(tree.Node):
field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), []
else:
# A smart object with an as_sql() method.
field_sql, field_params = lvalue.as_sql(qn, connection)
field_sql, field_params = qn.compile(lvalue)
is_datetime_field = value_annotation is datetime.datetime
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
if hasattr(params, 'as_sql'):
extra, params = params.as_sql(qn, connection)
extra, params = qn.compile(params)
cast_sql = ''
else:
extra = ''
@ -282,6 +282,8 @@ class WhereNode(tree.Node):
if hasattr(child, 'relabel_aliases'):
# For example another WhereNode
child.relabel_aliases(change_map)
elif hasattr(child, 'relabeled_clone'):
self.children[pos] = child.relabeled_clone(change_map)
elif isinstance(child, (list, tuple)):
# tuple starting with Constraint
child = (child[0].relabeled_clone(change_map),) + child[1:]
@ -350,7 +352,7 @@ class Constraint(object):
self.alias, self.col, self.field = alias, col, field
def prepare(self, lookup_type, value):
if self.field:
if self.field and not hasattr(value, 'as_sql'):
return self.field.get_prep_lookup(lookup_type, value)
return value

View File

@ -443,7 +443,7 @@ class BaseAggregateTestCase(TestCase):
vals = Author.objects.filter(pk=1).aggregate(Count("friends__id"))
self.assertEqual(vals, {"friends__id__count": 2})
books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__ge=2).order_by("pk")
books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__exact=2).order_by("pk")
self.assertQuerysetEqual(
books, [
"The Definitive Guide to Django: Web Development Done Right",

View File

View File

@ -0,0 +1,7 @@
from django.db import models
class Author(models.Model):
name = models.CharField(max_length=20)
age = models.IntegerField(null=True)
birthdate = models.DateField(null=True)

View File

@ -0,0 +1,136 @@
from copy import copy
from datetime import date
import unittest
from django.test import TestCase
from .models import Author
from django.db import models
from django.db import connection
from django.db.backends.utils import add_implementation
class Div3Lookup(models.lookups.Lookup):
lookup_name = 'div3'
def as_sql(self, qn, connection):
lhs, params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params.extend(rhs_params)
return '%s %%%% 3 = %s' % (lhs, rhs), params
class InMonth(models.lookups.Lookup):
"""
InMonth matches if the column's month is contained in the value's month.
"""
lookup_name = 'inmonth'
def as_sql(self, qn, connection):
lhs, params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
# We need to be careful so that we get the params in right
# places.
full_params = params[:]
full_params.extend(rhs_params)
full_params.extend(params)
full_params.extend(rhs_params)
return ("%s >= date_trunc('month', %s) and "
"%s < date_trunc('month', %s) + interval '1 months'" %
(lhs, rhs, lhs, rhs), full_params)
class LookupTests(TestCase):
def test_basic_lookup(self):
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
models.IntegerField.register_lookup(Div3Lookup)
try:
self.assertQuerysetEqual(
Author.objects.filter(age__div3=0),
[a3], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(age__div3=1).order_by('age'),
[a1, a4], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(age__div3=2),
[a2], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(age__div3=3),
[], lambda x: x
)
finally:
models.IntegerField._unregister_lookup(Div3Lookup)
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
def test_birthdate_month(self):
a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
models.DateField.register_lookup(InMonth)
try:
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)),
[a3], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)),
[a2], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)),
[a1], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)),
[a4], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)),
[], lambda x: x
)
finally:
models.DateField._unregister_lookup(InMonth)
def test_custom_compiles(self):
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
class AnotherEqual(models.lookups.Exact):
lookup_name = 'anotherequal'
models.Field.register_lookup(AnotherEqual)
try:
@add_implementation(AnotherEqual, connection.vendor)
def custom_eq_sql(node, compiler):
return '1 = 1', []
self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
self.assertQuerysetEqual(
Author.objects.filter(name__anotherequal='asdf').order_by('name'),
[a1, a2, a3, a4], lambda x: x)
@add_implementation(AnotherEqual, connection.vendor)
def another_custom_eq_sql(node, compiler):
# If you need to override one method, it seems this is the best
# option.
node = copy(node)
class OverriddenAnotherEqual(AnotherEqual):
def get_rhs_op(self, connection, rhs):
return ' <> %s'
node.__class__ = OverriddenAnotherEqual
return node.as_sql(compiler, compiler.connection)
self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
self.assertQuerysetEqual(
Author.objects.filter(name__anotherequal='a1').order_by('name'),
[a2, a3, a4], lambda x: x
)
finally:
models.Field._unregister_lookup(AnotherEqual)

View File

@ -2620,8 +2620,15 @@ class WhereNodeTest(TestCase):
def as_sql(self, qn, connection):
return 'dummy', []
class MockCompiler(object):
def compile(self, node):
return node.as_sql(self, connection)
def __call__(self, name):
return connection.ops.quote_name(name)
def test_empty_full_handling_conjunction(self):
qn = connection.ops.quote_name
qn = WhereNodeTest.MockCompiler()
w = WhereNode(children=[EverythingNode()])
self.assertEqual(w.as_sql(qn, connection), ('', []))
w.negate()
@ -2646,7 +2653,7 @@ class WhereNodeTest(TestCase):
self.assertEqual(w.as_sql(qn, connection), ('', []))
def test_empty_full_handling_disjunction(self):
qn = connection.ops.quote_name
qn = WhereNodeTest.MockCompiler()
w = WhereNode(children=[EverythingNode()], connector='OR')
self.assertEqual(w.as_sql(qn, connection), ('', []))
w.negate()
@ -2673,7 +2680,7 @@ class WhereNodeTest(TestCase):
self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', []))
def test_empty_nodes(self):
qn = connection.ops.quote_name
qn = WhereNodeTest.MockCompiler()
empty_w = WhereNode()
w = WhereNode(children=[empty_w, empty_w])
self.assertEqual(w.as_sql(qn, connection), (None, []))