Initial implementation of custom lookups

This commit is contained in:
Anssi Kääriäinen 2013-11-27 22:07:30 +02:00
parent 01e8ac47b3
commit 4d219d4cde
16 changed files with 594 additions and 97 deletions

View File

@ -67,6 +67,9 @@ class BaseDatabaseWrapper(object):
self.allow_thread_sharing = allow_thread_sharing self.allow_thread_sharing = allow_thread_sharing
self._thread_ident = thread.get_ident() self._thread_ident = thread.get_ident()
# Compile implementations, used by compiler.compile(someelem)
self.compile_implementations = utils.get_implementations(self.vendor)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, BaseDatabaseWrapper): if isinstance(other, BaseDatabaseWrapper):
return self.alias == other.alias return self.alias == other.alias

View File

@ -194,3 +194,27 @@ def format_number(value, max_digits, decimal_places):
return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context)) return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context))
else: else:
return "%.*f" % (decimal_places, value) return "%.*f" % (decimal_places, value)
# Map of vendor name -> map of query element class -> implementation function
compile_implementations = {}
def get_implementations(vendor):
try:
implementation = compile_implementations[vendor]
except KeyError:
# TODO: do we need thread safety here? We could easily use an lock...
implementation = {}
compile_implementations[vendor] = implementation
return implementation
class add_implementation(object):
def __init__(self, klass, vendor):
self.klass = klass
self.vendor = vendor
def __call__(self, func):
implementations = get_implementations(self.vendor)
implementations[self.klass] = func
return func

View File

@ -17,8 +17,8 @@ def refs_aggregate(lookup_parts, aggregates):
""" """
for i in range(len(lookup_parts) + 1): for i in range(len(lookup_parts) + 1):
if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates: if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates:
return True return aggregates[LOOKUP_SEP.join(lookup_parts[0:i])], lookup_parts[i:]
return False return False, ()
class Aggregate(object): class Aggregate(object):

View File

@ -4,6 +4,7 @@ import collections
import copy import copy
import datetime import datetime
import decimal import decimal
import inspect
import math import math
import warnings import warnings
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
@ -11,6 +12,7 @@ from itertools import tee
from django.db import connection from django.db import connection
from django.db.models.loading import get_model from django.db.models.loading import get_model
from django.db.models.lookups import default_lookups
from django.db.models.query_utils import QueryWrapper from django.db.models.query_utils import QueryWrapper
from django.conf import settings from django.conf import settings
from django import forms from django import forms
@ -101,6 +103,7 @@ class Field(object):
'unique': _('%(model_name)s with this %(field_label)s ' 'unique': _('%(model_name)s with this %(field_label)s '
'already exists.'), 'already exists.'),
} }
class_lookups = default_lookups.copy()
# Generic field type description, usually overridden by subclasses # Generic field type description, usually overridden by subclasses
def _description(self): def _description(self):
@ -446,6 +449,30 @@ class Field(object):
def get_internal_type(self): def get_internal_type(self):
return self.__class__.__name__ return self.__class__.__name__
def get_lookup(self, lookup_name):
try:
return self.class_lookups[lookup_name]
except KeyError:
for parent in inspect.getmro(self.__class__):
if not 'class_lookups' in parent.__dict__:
continue
if lookup_name in parent.class_lookups:
return parent.class_lookups[lookup_name]
@classmethod
def register_lookup(cls, lookup):
if not 'class_lookups' in cls.__dict__:
cls.class_lookups = {}
cls.class_lookups[lookup.lookup_name] = lookup
@classmethod
def _unregister_lookup(cls, lookup):
"""
Removes given lookup from cls lookups. Meant to be used in
tests only.
"""
del cls.class_lookups[lookup.lookup_name]
def pre_save(self, model_instance, add): def pre_save(self, model_instance, add):
""" """
Returns field's value just before saving. Returns field's value just before saving.
@ -504,8 +531,7 @@ class Field(object):
except ValueError: except ValueError:
raise ValueError("The __year lookup type requires an integer " raise ValueError("The __year lookup type requires an integer "
"argument") "argument")
return self.get_prep_value(value)
raise TypeError("Field has invalid lookup: %s" % lookup_type)
def get_db_prep_lookup(self, lookup_type, value, connection, def get_db_prep_lookup(self, lookup_type, value, connection,
prepared=False): prepared=False):
@ -554,6 +580,8 @@ class Field(object):
return connection.ops.year_lookup_bounds_for_date_field(value) return connection.ops.year_lookup_bounds_for_date_field(value)
else: else:
return [value] # this isn't supposed to happen return [value] # this isn't supposed to happen
else:
return [value]
def has_default(self): def has_default(self):
""" """

View File

@ -934,6 +934,11 @@ class ForeignObjectRel(object):
# example custom multicolumn joins currently have no remote field). # example custom multicolumn joins currently have no remote field).
self.field_name = None self.field_name = None
def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
raw_value):
return self.field.get_lookup_constraint(constraint_class, alias, targets, sources,
lookup_type, raw_value)
class ManyToOneRel(ForeignObjectRel): class ManyToOneRel(ForeignObjectRel):
def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None, def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,

242
django/db/models/lookups.py Normal file
View File

@ -0,0 +1,242 @@
from copy import copy
from django.conf import settings
from django.utils import timezone
class Lookup(object):
def __init__(self, constraint_class, lhs, rhs):
self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
self.rhs = self.get_prep_lookup()
def get_db_prep_lookup(self, value, connection):
return (
'%s', self.lhs.output_type.get_db_prep_lookup(
self.lookup_name, value, connection, prepared=True))
def get_prep_lookup(self):
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
def process_lhs(self, qn, connection):
return qn.compile(self.lhs)
def process_rhs(self, qn, connection):
value = self.rhs
# Due to historical reasons there are a couple of different
# ways to produce sql here. get_compiler is likely a Query
# instance, _as_sql QuerySet and as_sql just something with
# as_sql. Finally the value can of course be just plain
# Python value.
if hasattr(value, 'get_compiler'):
value = value.get_compiler(connection=connection)
if hasattr(value, 'as_sql'):
sql, params = qn.compile(value)
return '(' + sql + ')', params
if hasattr(value, '_as_sql'):
sql, params = value._as_sql(connection=connection)
return '(' + sql + ')', params
else:
return self.get_db_prep_lookup(value, connection)
def relabeled_clone(self, relabels):
new = copy(self)
new.lhs = new.lhs.relabeled_clone(relabels)
if hasattr(new.rhs, 'relabeled_clone'):
new.rhs = new.rhs.relabeled_clone(relabels)
return new
def get_cols(self):
cols = self.lhs.get_cols()
if hasattr(self.rhs, 'get_cols'):
cols.extend(self.rhs.get_cols())
return cols
def as_sql(self, qn, connection):
raise NotImplementedError
class DjangoLookup(Lookup):
def as_sql(self, qn, connection):
lhs_sql, params = self.process_lhs(qn, connection)
field_internal_type = self.lhs.output_type.get_internal_type()
db_type = self.lhs.output_type
lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql
rhs_sql, rhs_params = self.process_rhs(qn, connection)
params.extend(rhs_params)
operator_plus_rhs = self.get_rhs_op(connection, rhs_sql)
return '%s %s' % (lhs_sql, operator_plus_rhs), params
def get_rhs_op(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
default_lookups = {}
class Exact(DjangoLookup):
lookup_name = 'exact'
default_lookups['exact'] = Exact
class IExact(DjangoLookup):
lookup_name = 'iexact'
default_lookups['iexact'] = IExact
class Contains(DjangoLookup):
lookup_name = 'contains'
default_lookups['contains'] = Contains
class IContains(DjangoLookup):
lookup_name = 'icontains'
default_lookups['icontains'] = IContains
class GreaterThan(DjangoLookup):
lookup_name = 'gt'
default_lookups['gt'] = GreaterThan
class GreaterThanOrEqual(DjangoLookup):
lookup_name = 'gte'
default_lookups['gte'] = GreaterThanOrEqual
class LessThan(DjangoLookup):
lookup_name = 'lt'
default_lookups['lt'] = LessThan
class LessThanOrEqual(DjangoLookup):
lookup_name = 'lte'
default_lookups['lte'] = LessThanOrEqual
class In(DjangoLookup):
lookup_name = 'in'
def get_db_prep_lookup(self, value, connection):
params = self.lhs.field.get_db_prep_lookup(
self.lookup_name, value, connection, prepared=True)
if not params:
# TODO: check why this leads to circular import
from django.db.models.sql.datastructures import EmptyResultSet
raise EmptyResultSet
placeholder = '(' + ', '.join('%s' for p in params) + ')'
return (placeholder, params)
def get_rhs_op(self, connection, rhs):
return 'IN %s' % rhs
default_lookups['in'] = In
class StartsWith(DjangoLookup):
lookup_name = 'startswith'
default_lookups['startswith'] = StartsWith
class IStartsWith(DjangoLookup):
lookup_name = 'istartswith'
default_lookups['istartswith'] = IStartsWith
class EndsWith(DjangoLookup):
lookup_name = 'endswith'
default_lookups['endswith'] = EndsWith
class IEndsWith(DjangoLookup):
lookup_name = 'iendswith'
default_lookups['iendswith'] = IEndsWith
class Between(DjangoLookup):
def get_rhs_op(self, connection, rhs):
return "BETWEEN %s AND %s" % (rhs, rhs)
class Year(Between):
lookup_name = 'year'
default_lookups['year'] = Year
class Range(Between):
lookup_name = 'range'
default_lookups['range'] = Range
class DateLookup(DjangoLookup):
def process_lhs(self, qn, connection):
lhs, params = super(DateLookup, self).process_lhs(qn, connection)
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname)
return connection.ops.lookup_cast(self.lookup_name) % sql, tz_params
def get_rhs_op(self, connection, rhs):
return '= %s' % rhs
class Month(DateLookup):
lookup_name = 'month'
extract_type = 'month'
default_lookups['month'] = Month
class Day(DateLookup):
lookup_name = 'day'
extract_type = 'day'
default_lookups['day'] = Day
class WeekDay(DateLookup):
lookup_name = 'week_day'
extract_type = 'week_day'
default_lookups['week_day'] = WeekDay
class Hour(DateLookup):
lookup_name = 'hour'
extract_type = 'hour'
default_lookups['hour'] = Hour
class Minute(DateLookup):
lookup_name = 'minute'
extract_type = 'minute'
default_lookups['minute'] = Minute
class Second(DateLookup):
lookup_name = 'second'
extract_type = 'second'
default_lookups['second'] = Second
class IsNull(DjangoLookup):
lookup_name = 'isnull'
def as_sql(self, qn, connection):
sql, params = qn.compile(self.lhs)
if self.rhs:
return "%s IS NULL" % sql, params
else:
return "%s IS NOT NULL" % sql, params
default_lookups['isnull'] = IsNull
class Search(DjangoLookup):
lookup_name = 'search'
default_lookups['search'] = Search
class Regex(DjangoLookup):
lookup_name = 'regex'
default_lookups['regex'] = Regex
class IRegex(DjangoLookup):
lookup_name = 'iregex'
default_lookups['iregex'] = IRegex

View File

@ -93,6 +93,13 @@ class Aggregate(object):
return self.sql_template % substitutions, params return self.sql_template % substitutions, params
def get_cols(self):
return []
@property
def output_type(self):
return self.field
class Avg(Aggregate): class Avg(Aggregate):
is_computed = True is_computed = True

View File

@ -45,7 +45,7 @@ class SQLCompiler(object):
if self.query.select_related and not self.query.related_select_cols: if self.query.select_related and not self.query.related_select_cols:
self.fill_related_selections() self.fill_related_selections()
def quote_name_unless_alias(self, name): def __call__(self, name):
""" """
A wrapper around connection.ops.quote_name that doesn't quote aliases A wrapper around connection.ops.quote_name that doesn't quote aliases
for table names. This avoids problems with some SQL dialects that treat for table names. This avoids problems with some SQL dialects that treat
@ -61,6 +61,20 @@ class SQLCompiler(object):
self.quote_cache[name] = r self.quote_cache[name] = r
return r return r
def quote_name_unless_alias(self, name):
"""
A wrapper around connection.ops.quote_name that doesn't quote aliases
for table names. This avoids problems with some SQL dialects that treat
quoted strings specially (e.g. PostgreSQL).
"""
return self(name)
def compile(self, node):
if node.__class__ in self.connection.compile_implementations:
return self.connection.compile_implementations[node.__class__](node, self)
else:
return node.as_sql(self, self.connection)
def as_sql(self, with_limits=True, with_col_aliases=False): def as_sql(self, with_limits=True, with_col_aliases=False):
""" """
Creates the SQL for this query. Returns the SQL string and list of Creates the SQL for this query. Returns the SQL string and list of
@ -88,10 +102,8 @@ class SQLCompiler(object):
# docstring of get_from_clause() for details. # docstring of get_from_clause() for details.
from_, f_params = self.get_from_clause() from_, f_params = self.get_from_clause()
qn = self.quote_name_unless_alias where, w_params = self.compile(self.query.where)
having, h_params = self.compile(self.query.having)
where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
having_group_by = self.query.having.get_cols() having_group_by = self.query.having.get_cols()
params = [] params = []
for val in six.itervalues(self.query.extra_select): for val in six.itervalues(self.query.extra_select):
@ -180,7 +192,7 @@ class SQLCompiler(object):
(without the table names) are given unique aliases. This is needed in (without the table names) are given unique aliases. This is needed in
some cases to avoid ambiguity with nested queries. some cases to avoid ambiguity with nested queries.
""" """
qn = self.quote_name_unless_alias qn = self
qn2 = self.connection.ops.quote_name qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
params = [] params = []
@ -213,7 +225,7 @@ class SQLCompiler(object):
aliases.add(r) aliases.add(r)
col_aliases.add(col[1]) col_aliases.add(col[1])
else: else:
col_sql, col_params = col.as_sql(qn, self.connection) col_sql, col_params = self.compile(col)
result.append(col_sql) result.append(col_sql)
params.extend(col_params) params.extend(col_params)
@ -229,7 +241,7 @@ class SQLCompiler(object):
max_name_length = self.connection.ops.max_name_length() max_name_length = self.connection.ops.max_name_length()
for alias, aggregate in self.query.aggregate_select.items(): for alias, aggregate in self.query.aggregate_select.items():
agg_sql, agg_params = aggregate.as_sql(qn, self.connection) agg_sql, agg_params = self.compile(aggregate)
if alias is None: if alias is None:
result.append(agg_sql) result.append(agg_sql)
else: else:
@ -267,7 +279,7 @@ class SQLCompiler(object):
result = [] result = []
if opts is None: if opts is None:
opts = self.query.get_meta() opts = self.query.get_meta()
qn = self.quote_name_unless_alias qn = self
qn2 = self.connection.ops.quote_name qn2 = self.connection.ops.quote_name
aliases = set() aliases = set()
only_load = self.deferred_to_columns() only_load = self.deferred_to_columns()
@ -312,7 +324,7 @@ class SQLCompiler(object):
Note that this method can alter the tables in the query, and thus it Note that this method can alter the tables in the query, and thus it
must be called before get_from_clause(). must be called before get_from_clause().
""" """
qn = self.quote_name_unless_alias qn = self
qn2 = self.connection.ops.quote_name qn2 = self.connection.ops.quote_name
result = [] result = []
opts = self.query.get_meta() opts = self.query.get_meta()
@ -345,7 +357,7 @@ class SQLCompiler(object):
ordering = (self.query.order_by ordering = (self.query.order_by
or self.query.get_meta().ordering or self.query.get_meta().ordering
or []) or [])
qn = self.quote_name_unless_alias qn = self
qn2 = self.connection.ops.quote_name qn2 = self.connection.ops.quote_name
distinct = self.query.distinct distinct = self.query.distinct
select_aliases = self._select_aliases select_aliases = self._select_aliases
@ -483,7 +495,7 @@ class SQLCompiler(object):
ordering and distinct must be done first. ordering and distinct must be done first.
""" """
result = [] result = []
qn = self.quote_name_unless_alias qn = self
qn2 = self.connection.ops.quote_name qn2 = self.connection.ops.quote_name
first = True first = True
from_params = [] from_params = []
@ -501,8 +513,7 @@ class SQLCompiler(object):
extra_cond = join_field.get_extra_restriction( extra_cond = join_field.get_extra_restriction(
self.query.where_class, alias, lhs) self.query.where_class, alias, lhs)
if extra_cond: if extra_cond:
extra_sql, extra_params = extra_cond.as_sql( extra_sql, extra_params = self.compile(extra_cond)
qn, self.connection)
extra_sql = 'AND (%s)' % extra_sql extra_sql = 'AND (%s)' % extra_sql
from_params.extend(extra_params) from_params.extend(extra_params)
else: else:
@ -534,7 +545,7 @@ class SQLCompiler(object):
""" """
Returns a tuple representing the SQL elements in the "group by" clause. Returns a tuple representing the SQL elements in the "group by" clause.
""" """
qn = self.quote_name_unless_alias qn = self
result, params = [], [] result, params = [], []
if self.query.group_by is not None: if self.query.group_by is not None:
select_cols = self.query.select + self.query.related_select_cols select_cols = self.query.select + self.query.related_select_cols
@ -553,7 +564,7 @@ class SQLCompiler(object):
if isinstance(col, (list, tuple)): if isinstance(col, (list, tuple)):
sql = '%s.%s' % (qn(col[0]), qn(col[1])) sql = '%s.%s' % (qn(col[0]), qn(col[1]))
elif hasattr(col, 'as_sql'): elif hasattr(col, 'as_sql'):
sql, col_params = col.as_sql(qn, self.connection) self.compile(col)
else: else:
sql = '(%s)' % str(col) sql = '(%s)' % str(col)
if sql not in seen: if sql not in seen:
@ -776,7 +787,7 @@ class SQLCompiler(object):
return result return result
def as_subquery_condition(self, alias, columns, qn): def as_subquery_condition(self, alias, columns, qn):
inner_qn = self.quote_name_unless_alias inner_qn = self
qn2 = self.connection.ops.quote_name qn2 = self.connection.ops.quote_name
if len(columns) == 1: if len(columns) == 1:
sql, params = self.as_sql() sql, params = self.as_sql()
@ -887,9 +898,9 @@ class SQLDeleteCompiler(SQLCompiler):
""" """
assert len(self.query.tables) == 1, \ assert len(self.query.tables) == 1, \
"Can only delete from one table at a time." "Can only delete from one table at a time."
qn = self.quote_name_unless_alias qn = self
result = ['DELETE FROM %s' % qn(self.query.tables[0])] result = ['DELETE FROM %s' % qn(self.query.tables[0])]
where, params = self.query.where.as_sql(qn=qn, connection=self.connection) where, params = self.compile(self.query.where)
if where: if where:
result.append('WHERE %s' % where) result.append('WHERE %s' % where)
return ' '.join(result), tuple(params) return ' '.join(result), tuple(params)
@ -905,7 +916,7 @@ class SQLUpdateCompiler(SQLCompiler):
if not self.query.values: if not self.query.values:
return '', () return '', ()
table = self.query.tables[0] table = self.query.tables[0]
qn = self.quote_name_unless_alias qn = self
result = ['UPDATE %s' % qn(table)] result = ['UPDATE %s' % qn(table)]
result.append('SET') result.append('SET')
values, update_params = [], [] values, update_params = [], []
@ -925,7 +936,7 @@ class SQLUpdateCompiler(SQLCompiler):
val = SQLEvaluator(val, self.query, allow_joins=False) val = SQLEvaluator(val, self.query, allow_joins=False)
name = field.column name = field.column
if hasattr(val, 'as_sql'): if hasattr(val, 'as_sql'):
sql, params = val.as_sql(qn, self.connection) sql, params = self.compile(val)
values.append('%s = %s' % (qn(name), sql)) values.append('%s = %s' % (qn(name), sql))
update_params.extend(params) update_params.extend(params)
elif val is not None: elif val is not None:
@ -936,7 +947,7 @@ class SQLUpdateCompiler(SQLCompiler):
if not values: if not values:
return '', () return '', ()
result.append(', '.join(values)) result.append(', '.join(values))
where, params = self.query.where.as_sql(qn=qn, connection=self.connection) where, params = self.compile(self.query.where)
if where: if where:
result.append('WHERE %s' % where) result.append('WHERE %s' % where)
return ' '.join(result), tuple(update_params + params) return ' '.join(result), tuple(update_params + params)
@ -1016,11 +1027,11 @@ class SQLAggregateCompiler(SQLCompiler):
parameters. parameters.
""" """
if qn is None: if qn is None:
qn = self.quote_name_unless_alias qn = self
sql, params = [], [] sql, params = [], []
for aggregate in self.query.aggregate_select.values(): for aggregate in self.query.aggregate_select.values():
agg_sql, agg_params = aggregate.as_sql(qn, self.connection) agg_sql, agg_params = self.compile(aggregate)
sql.append(agg_sql) sql.append(agg_sql)
params.extend(agg_params) params.extend(agg_params)
sql = ', '.join(sql) sql = ', '.join(sql)

View File

@ -5,18 +5,25 @@ the SQL domain.
class Col(object): class Col(object):
def __init__(self, alias, col): def __init__(self, alias, target, source):
self.alias = alias self.alias, self.target, self.source = alias, target, source
self.col = col
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
return '%s.%s' % (qn(self.alias), self.col), [] return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
def prepare(self): @property
return self def output_type(self):
return self.source
@property
def field(self):
return self.source
def relabeled_clone(self, relabels): def relabeled_clone(self, relabels):
return self.__class__(relabels.get(self.alias, self.alias), self.col) return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source)
def get_cols(self):
return [(self.alias, self.target.column)]
class EmptyResultSet(Exception): class EmptyResultSet(Exception):

View File

@ -1030,6 +1030,12 @@ class Query(object):
def prepare_lookup_value(self, value, lookup_type, can_reuse): def prepare_lookup_value(self, value, lookup_type, can_reuse):
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
# uses of None as a query value. # uses of None as a query value.
if len(lookup_type) > 1:
raise FieldError('Nested lookups not allowed')
elif len(lookup_type) == 0:
lookup_type = 'exact'
else:
lookup_type = lookup_type[0]
if value is None: if value is None:
if lookup_type != 'exact': if lookup_type != 'exact':
raise ValueError("Cannot use None as a query value") raise ValueError("Cannot use None as a query value")
@ -1060,31 +1066,39 @@ class Query(object):
""" """
Solve the lookup type from the lookup (eg: 'foobar__id__icontains') Solve the lookup type from the lookup (eg: 'foobar__id__icontains')
""" """
lookup_type = 'exact' # Default lookup type lookup_splitted = lookup.split(LOOKUP_SEP)
lookup_parts = lookup.split(LOOKUP_SEP) aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
num_parts = len(lookup_parts) if aggregate:
if (len(lookup_parts) > 1 and lookup_parts[-1] in self.query_terms if len(aggregate_lookups) > 1:
and (not self._aggregates or lookup not in self._aggregates)): raise FieldError("Nested lookups not allowed.")
# Traverse the lookup query to distinguish related fields from return aggregate_lookups, (), aggregate
# lookup types. _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
lookup_model = self.model field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
for counter, field_name in enumerate(lookup_parts): if len(lookup_parts) == 0:
try: lookup_parts = ['exact']
lookup_field = lookup_model._meta.get_field(field_name) elif len(lookup_parts) > 1:
except FieldDoesNotExist: if field_parts:
# Not a field. Bail out. raise FieldError(
lookup_type = lookup_parts.pop() 'Only one lookup part allowed (found path "%s" from "%s").' %
break (LOOKUP_SEP.join(field_parts), lookup))
# Unless we're at the end of the list of lookups, let's attempt else:
# to continue traversing relations. raise FieldError(
if (counter + 1) < num_parts: 'Invalid lookup "%s" for model %s".' %
try: (lookup, self.get_meta().model.__name__))
lookup_model = lookup_field.rel.to else:
except AttributeError: if not hasattr(field, 'get_lookup_constraint'):
# Not a related field. Bail out. lookup_class = field.get_lookup(lookup_parts[0])
lookup_type = lookup_parts.pop() if lookup_class is None and lookup_parts[0] not in self.query_terms:
break raise FieldError(
return lookup_type, lookup_parts 'Invalid lookup name %s' % lookup_parts[0])
return lookup_parts, field_parts, False
def build_lookup(self, lookup_type, lhs, rhs):
if hasattr(lhs.output_type, 'get_lookup'):
lookup = lhs.output_type.get_lookup(lookup_type)
if lookup:
return lookup(self.where_class, lhs, rhs)
return None
def build_filter(self, filter_expr, branch_negated=False, current_negated=False, def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
can_reuse=None, connector=AND): can_reuse=None, connector=AND):
@ -1114,9 +1128,9 @@ class Query(object):
is responsible for unreffing the joins used. is responsible for unreffing the joins used.
""" """
arg, value = filter_expr arg, value = filter_expr
lookup_type, parts = self.solve_lookup_type(arg) if not arg:
if not parts:
raise FieldError("Cannot parse keyword query %r" % arg) raise FieldError("Cannot parse keyword query %r" % arg)
lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg)
# Work out the lookup type and remove it from the end of 'parts', # Work out the lookup type and remove it from the end of 'parts',
# if necessary. # if necessary.
@ -1124,11 +1138,13 @@ class Query(object):
used_joins = getattr(value, '_used_joins', []) used_joins = getattr(value, '_used_joins', [])
clause = self.where_class() clause = self.where_class()
if self._aggregates: if reffed_aggregate:
for alias, aggregate in self.aggregates.items(): condition = self.build_lookup(lookup_type, reffed_aggregate, value)
if alias in (parts[0], LOOKUP_SEP.join(parts)): if not condition:
clause.add((aggregate, lookup_type, value), AND) # Backwards compat for custom lookups
return clause, [] condition = (reffed_aggregate, lookup_type, value)
clause.add(condition, AND)
return clause, []
opts = self.get_meta() opts = self.get_meta()
alias = self.get_initial_alias() alias = self.get_initial_alias()
@ -1150,11 +1166,18 @@ class Query(object):
targets, alias, join_list = self.trim_joins(sources, join_list, path) targets, alias, join_list = self.trim_joins(sources, join_list, path)
if hasattr(field, 'get_lookup_constraint'): if hasattr(field, 'get_lookup_constraint'):
constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources, # For now foreign keys get special treatment. This should be
lookup_type, value) # refactored when composite fields lands.
condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
lookup_type, value)
else: else:
constraint = (Constraint(alias, targets[0].column, field), lookup_type, value) assert(len(targets) == 1)
clause.add(constraint, AND) col = Col(alias, targets[0], field)
condition = self.build_lookup(lookup_type, col, value)
if not condition:
# Backwards compat for custom lookups
condition = (Constraint(alias, targets[0].column, field), lookup_type, value)
clause.add(condition, AND)
require_outer = lookup_type == 'isnull' and value is True and not current_negated require_outer = lookup_type == 'isnull' and value is True and not current_negated
if current_negated and (lookup_type != 'isnull' or value is False): if current_negated and (lookup_type != 'isnull' or value is False):
@ -1185,7 +1208,7 @@ class Query(object):
if not self._aggregates: if not self._aggregates:
return False return False
if not isinstance(obj, Node): if not isinstance(obj, Node):
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates) return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0]
or (hasattr(obj[1], 'contains_aggregate') or (hasattr(obj[1], 'contains_aggregate')
and obj[1].contains_aggregate(self.aggregates))) and obj[1].contains_aggregate(self.aggregates)))
return any(self.need_having(c) for c in obj.children) return any(self.need_having(c) for c in obj.children)
@ -1273,7 +1296,7 @@ class Query(object):
needed_inner = joinpromoter.update_join_types(self) needed_inner = joinpromoter.update_join_types(self)
return target_clause, needed_inner return target_clause, needed_inner
def names_to_path(self, names, opts, allow_many): def names_to_path(self, names, opts, allow_many=True):
""" """
Walks the names path and turns them PathInfo tuples. Note that a Walks the names path and turns them PathInfo tuples. Note that a
single name in 'names' can generate multiple PathInfos (m2m for single name in 'names' can generate multiple PathInfos (m2m for
@ -1293,9 +1316,10 @@ class Query(object):
try: try:
field, model, direct, m2m = opts.get_field_by_name(name) field, model, direct, m2m = opts.get_field_by_name(name)
except FieldDoesNotExist: except FieldDoesNotExist:
available = opts.get_all_field_names() + list(self.aggregate_select) # We didn't found the current field, so move position back
raise FieldError("Cannot resolve keyword %r into field. " # one step.
"Choices are: %s" % (name, ", ".join(available))) pos -= 1
break
# Check if we need any joins for concrete inheritance cases (the # Check if we need any joins for concrete inheritance cases (the
# field lives in parent, but we are currently in one of its # field lives in parent, but we are currently in one of its
# children) # children)
@ -1330,15 +1354,9 @@ class Query(object):
final_field = field final_field = field
targets = (field,) targets = (field,)
break break
if pos == -1:
if pos != len(names) - 1: raise FieldError('Whazaa')
if pos == len(names) - 2: return path, final_field, targets, names[pos + 1:]
raise FieldError(
"Join on field %r not permitted. Did you misspell %r for "
"the lookup type?" % (name, names[pos + 1]))
else:
raise FieldError("Join on field %r not permitted." % name)
return path, final_field, targets
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
""" """
@ -1367,8 +1385,10 @@ class Query(object):
""" """
joins = [alias] joins = [alias]
# First, generate the path for the names # First, generate the path for the names
path, final_field, targets = self.names_to_path( path, final_field, targets, rest = self.names_to_path(
names, opts, allow_many) names, opts, allow_many)
if rest:
raise FieldError('Invalid lookup')
# Then, add the path to the query's joins. Note that we can't trim # Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type # joins at this stage - we will need the information about join type
# of the trimmed joins. # of the trimmed joins.
@ -1383,8 +1403,6 @@ class Query(object):
alias = self.join( alias = self.join(
connection, reuse=reuse, nullable=nullable, join_field=join.join_field) connection, reuse=reuse, nullable=nullable, join_field=join.join_field)
joins.append(alias) joins.append(alias)
if hasattr(final_field, 'field'):
final_field = final_field.field
return final_field, targets, opts, joins, path return final_field, targets, opts, joins, path
def trim_joins(self, targets, joins, path): def trim_joins(self, targets, joins, path):
@ -1455,7 +1473,7 @@ class Query(object):
query.bump_prefix(self) query.bump_prefix(self)
query.where.add( query.where.add(
(Constraint(query.select[0].col[0], pk.column, pk), (Constraint(query.select[0].col[0], pk.column, pk),
'exact', Col(alias, pk.column)), 'exact', Col(alias, pk, pk)),
AND AND
) )

View File

@ -101,7 +101,7 @@ class WhereNode(tree.Node):
for child in self.children: for child in self.children:
try: try:
if hasattr(child, 'as_sql'): if hasattr(child, 'as_sql'):
sql, params = child.as_sql(qn=qn, connection=connection) sql, params = qn.compile(child)
else: else:
# A leaf node in the tree. # A leaf node in the tree.
sql, params = self.make_atom(child, qn, connection) sql, params = self.make_atom(child, qn, connection)
@ -193,13 +193,13 @@ class WhereNode(tree.Node):
field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), [] field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), []
else: else:
# A smart object with an as_sql() method. # A smart object with an as_sql() method.
field_sql, field_params = lvalue.as_sql(qn, connection) field_sql, field_params = qn.compile(lvalue)
is_datetime_field = value_annotation is datetime.datetime is_datetime_field = value_annotation is datetime.datetime
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
if hasattr(params, 'as_sql'): if hasattr(params, 'as_sql'):
extra, params = params.as_sql(qn, connection) extra, params = qn.compile(params)
cast_sql = '' cast_sql = ''
else: else:
extra = '' extra = ''
@ -282,6 +282,8 @@ class WhereNode(tree.Node):
if hasattr(child, 'relabel_aliases'): if hasattr(child, 'relabel_aliases'):
# For example another WhereNode # For example another WhereNode
child.relabel_aliases(change_map) child.relabel_aliases(change_map)
elif hasattr(child, 'relabeled_clone'):
self.children[pos] = child.relabeled_clone(change_map)
elif isinstance(child, (list, tuple)): elif isinstance(child, (list, tuple)):
# tuple starting with Constraint # tuple starting with Constraint
child = (child[0].relabeled_clone(change_map),) + child[1:] child = (child[0].relabeled_clone(change_map),) + child[1:]
@ -350,7 +352,7 @@ class Constraint(object):
self.alias, self.col, self.field = alias, col, field self.alias, self.col, self.field = alias, col, field
def prepare(self, lookup_type, value): def prepare(self, lookup_type, value):
if self.field: if self.field and not hasattr(value, 'as_sql'):
return self.field.get_prep_lookup(lookup_type, value) return self.field.get_prep_lookup(lookup_type, value)
return value return value

View File

@ -443,7 +443,7 @@ class BaseAggregateTestCase(TestCase):
vals = Author.objects.filter(pk=1).aggregate(Count("friends__id")) vals = Author.objects.filter(pk=1).aggregate(Count("friends__id"))
self.assertEqual(vals, {"friends__id__count": 2}) self.assertEqual(vals, {"friends__id__count": 2})
books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__ge=2).order_by("pk") books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__exact=2).order_by("pk")
self.assertQuerysetEqual( self.assertQuerysetEqual(
books, [ books, [
"The Definitive Guide to Django: Web Development Done Right", "The Definitive Guide to Django: Web Development Done Right",

View File

View File

@ -0,0 +1,7 @@
from django.db import models
class Author(models.Model):
name = models.CharField(max_length=20)
age = models.IntegerField(null=True)
birthdate = models.DateField(null=True)

View File

@ -0,0 +1,136 @@
from copy import copy
from datetime import date
import unittest
from django.test import TestCase
from .models import Author
from django.db import models
from django.db import connection
from django.db.backends.utils import add_implementation
class Div3Lookup(models.lookups.Lookup):
lookup_name = 'div3'
def as_sql(self, qn, connection):
lhs, params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params.extend(rhs_params)
return '%s %%%% 3 = %s' % (lhs, rhs), params
class InMonth(models.lookups.Lookup):
"""
InMonth matches if the column's month is contained in the value's month.
"""
lookup_name = 'inmonth'
def as_sql(self, qn, connection):
lhs, params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
# We need to be careful so that we get the params in right
# places.
full_params = params[:]
full_params.extend(rhs_params)
full_params.extend(params)
full_params.extend(rhs_params)
return ("%s >= date_trunc('month', %s) and "
"%s < date_trunc('month', %s) + interval '1 months'" %
(lhs, rhs, lhs, rhs), full_params)
class LookupTests(TestCase):
def test_basic_lookup(self):
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
models.IntegerField.register_lookup(Div3Lookup)
try:
self.assertQuerysetEqual(
Author.objects.filter(age__div3=0),
[a3], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(age__div3=1).order_by('age'),
[a1, a4], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(age__div3=2),
[a2], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(age__div3=3),
[], lambda x: x
)
finally:
models.IntegerField._unregister_lookup(Div3Lookup)
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
def test_birthdate_month(self):
a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
models.DateField.register_lookup(InMonth)
try:
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)),
[a3], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)),
[a2], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)),
[a1], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)),
[a4], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)),
[], lambda x: x
)
finally:
models.DateField._unregister_lookup(InMonth)
def test_custom_compiles(self):
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
class AnotherEqual(models.lookups.Exact):
lookup_name = 'anotherequal'
models.Field.register_lookup(AnotherEqual)
try:
@add_implementation(AnotherEqual, connection.vendor)
def custom_eq_sql(node, compiler):
return '1 = 1', []
self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
self.assertQuerysetEqual(
Author.objects.filter(name__anotherequal='asdf').order_by('name'),
[a1, a2, a3, a4], lambda x: x)
@add_implementation(AnotherEqual, connection.vendor)
def another_custom_eq_sql(node, compiler):
# If you need to override one method, it seems this is the best
# option.
node = copy(node)
class OverriddenAnotherEqual(AnotherEqual):
def get_rhs_op(self, connection, rhs):
return ' <> %s'
node.__class__ = OverriddenAnotherEqual
return node.as_sql(compiler, compiler.connection)
self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
self.assertQuerysetEqual(
Author.objects.filter(name__anotherequal='a1').order_by('name'),
[a2, a3, a4], lambda x: x
)
finally:
models.Field._unregister_lookup(AnotherEqual)

View File

@ -2620,8 +2620,15 @@ class WhereNodeTest(TestCase):
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
return 'dummy', [] return 'dummy', []
class MockCompiler(object):
def compile(self, node):
return node.as_sql(self, connection)
def __call__(self, name):
return connection.ops.quote_name(name)
def test_empty_full_handling_conjunction(self): def test_empty_full_handling_conjunction(self):
qn = connection.ops.quote_name qn = WhereNodeTest.MockCompiler()
w = WhereNode(children=[EverythingNode()]) w = WhereNode(children=[EverythingNode()])
self.assertEqual(w.as_sql(qn, connection), ('', [])) self.assertEqual(w.as_sql(qn, connection), ('', []))
w.negate() w.negate()
@ -2646,7 +2653,7 @@ class WhereNodeTest(TestCase):
self.assertEqual(w.as_sql(qn, connection), ('', [])) self.assertEqual(w.as_sql(qn, connection), ('', []))
def test_empty_full_handling_disjunction(self): def test_empty_full_handling_disjunction(self):
qn = connection.ops.quote_name qn = WhereNodeTest.MockCompiler()
w = WhereNode(children=[EverythingNode()], connector='OR') w = WhereNode(children=[EverythingNode()], connector='OR')
self.assertEqual(w.as_sql(qn, connection), ('', [])) self.assertEqual(w.as_sql(qn, connection), ('', []))
w.negate() w.negate()
@ -2673,7 +2680,7 @@ class WhereNodeTest(TestCase):
self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', [])) self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', []))
def test_empty_nodes(self): def test_empty_nodes(self):
qn = connection.ops.quote_name qn = WhereNodeTest.MockCompiler()
empty_w = WhereNode() empty_w = WhereNode()
w = WhereNode(children=[empty_w, empty_w]) w = WhereNode(children=[empty_w, empty_w])
self.assertEqual(w.as_sql(qn, connection), (None, [])) self.assertEqual(w.as_sql(qn, connection), (None, []))