mirror of https://github.com/django/django.git
Initial implementation of custom lookups
This commit is contained in:
parent
01e8ac47b3
commit
4d219d4cde
|
@ -67,6 +67,9 @@ class BaseDatabaseWrapper(object):
|
|||
self.allow_thread_sharing = allow_thread_sharing
|
||||
self._thread_ident = thread.get_ident()
|
||||
|
||||
# Compile implementations, used by compiler.compile(someelem)
|
||||
self.compile_implementations = utils.get_implementations(self.vendor)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, BaseDatabaseWrapper):
|
||||
return self.alias == other.alias
|
||||
|
|
|
@ -194,3 +194,27 @@ def format_number(value, max_digits, decimal_places):
|
|||
return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context))
|
||||
else:
|
||||
return "%.*f" % (decimal_places, value)
|
||||
|
||||
# Map of vendor name -> map of query element class -> implementation function
|
||||
compile_implementations = {}
|
||||
|
||||
|
||||
def get_implementations(vendor):
|
||||
try:
|
||||
implementation = compile_implementations[vendor]
|
||||
except KeyError:
|
||||
# TODO: do we need thread safety here? We could easily use an lock...
|
||||
implementation = {}
|
||||
compile_implementations[vendor] = implementation
|
||||
return implementation
|
||||
|
||||
|
||||
class add_implementation(object):
|
||||
def __init__(self, klass, vendor):
|
||||
self.klass = klass
|
||||
self.vendor = vendor
|
||||
|
||||
def __call__(self, func):
|
||||
implementations = get_implementations(self.vendor)
|
||||
implementations[self.klass] = func
|
||||
return func
|
||||
|
|
|
@ -17,8 +17,8 @@ def refs_aggregate(lookup_parts, aggregates):
|
|||
"""
|
||||
for i in range(len(lookup_parts) + 1):
|
||||
if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates:
|
||||
return True
|
||||
return False
|
||||
return aggregates[LOOKUP_SEP.join(lookup_parts[0:i])], lookup_parts[i:]
|
||||
return False, ()
|
||||
|
||||
|
||||
class Aggregate(object):
|
||||
|
|
|
@ -4,6 +4,7 @@ import collections
|
|||
import copy
|
||||
import datetime
|
||||
import decimal
|
||||
import inspect
|
||||
import math
|
||||
import warnings
|
||||
from base64 import b64decode, b64encode
|
||||
|
@ -11,6 +12,7 @@ from itertools import tee
|
|||
|
||||
from django.db import connection
|
||||
from django.db.models.loading import get_model
|
||||
from django.db.models.lookups import default_lookups
|
||||
from django.db.models.query_utils import QueryWrapper
|
||||
from django.conf import settings
|
||||
from django import forms
|
||||
|
@ -101,6 +103,7 @@ class Field(object):
|
|||
'unique': _('%(model_name)s with this %(field_label)s '
|
||||
'already exists.'),
|
||||
}
|
||||
class_lookups = default_lookups.copy()
|
||||
|
||||
# Generic field type description, usually overridden by subclasses
|
||||
def _description(self):
|
||||
|
@ -446,6 +449,30 @@ class Field(object):
|
|||
def get_internal_type(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
try:
|
||||
return self.class_lookups[lookup_name]
|
||||
except KeyError:
|
||||
for parent in inspect.getmro(self.__class__):
|
||||
if not 'class_lookups' in parent.__dict__:
|
||||
continue
|
||||
if lookup_name in parent.class_lookups:
|
||||
return parent.class_lookups[lookup_name]
|
||||
|
||||
@classmethod
|
||||
def register_lookup(cls, lookup):
|
||||
if not 'class_lookups' in cls.__dict__:
|
||||
cls.class_lookups = {}
|
||||
cls.class_lookups[lookup.lookup_name] = lookup
|
||||
|
||||
@classmethod
|
||||
def _unregister_lookup(cls, lookup):
|
||||
"""
|
||||
Removes given lookup from cls lookups. Meant to be used in
|
||||
tests only.
|
||||
"""
|
||||
del cls.class_lookups[lookup.lookup_name]
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
"""
|
||||
Returns field's value just before saving.
|
||||
|
@ -504,8 +531,7 @@ class Field(object):
|
|||
except ValueError:
|
||||
raise ValueError("The __year lookup type requires an integer "
|
||||
"argument")
|
||||
|
||||
raise TypeError("Field has invalid lookup: %s" % lookup_type)
|
||||
return self.get_prep_value(value)
|
||||
|
||||
def get_db_prep_lookup(self, lookup_type, value, connection,
|
||||
prepared=False):
|
||||
|
@ -554,6 +580,8 @@ class Field(object):
|
|||
return connection.ops.year_lookup_bounds_for_date_field(value)
|
||||
else:
|
||||
return [value] # this isn't supposed to happen
|
||||
else:
|
||||
return [value]
|
||||
|
||||
def has_default(self):
|
||||
"""
|
||||
|
|
|
@ -934,6 +934,11 @@ class ForeignObjectRel(object):
|
|||
# example custom multicolumn joins currently have no remote field).
|
||||
self.field_name = None
|
||||
|
||||
def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
|
||||
raw_value):
|
||||
return self.field.get_lookup_constraint(constraint_class, alias, targets, sources,
|
||||
lookup_type, raw_value)
|
||||
|
||||
|
||||
class ManyToOneRel(ForeignObjectRel):
|
||||
def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
|
||||
|
|
|
@ -0,0 +1,242 @@
|
|||
from copy import copy
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils import timezone
|
||||
|
||||
|
||||
class Lookup(object):
|
||||
def __init__(self, constraint_class, lhs, rhs):
|
||||
self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
|
||||
self.rhs = self.get_prep_lookup()
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
return (
|
||||
'%s', self.lhs.output_type.get_db_prep_lookup(
|
||||
self.lookup_name, value, connection, prepared=True))
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
|
||||
|
||||
def process_lhs(self, qn, connection):
|
||||
return qn.compile(self.lhs)
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
value = self.rhs
|
||||
# Due to historical reasons there are a couple of different
|
||||
# ways to produce sql here. get_compiler is likely a Query
|
||||
# instance, _as_sql QuerySet and as_sql just something with
|
||||
# as_sql. Finally the value can of course be just plain
|
||||
# Python value.
|
||||
if hasattr(value, 'get_compiler'):
|
||||
value = value.get_compiler(connection=connection)
|
||||
if hasattr(value, 'as_sql'):
|
||||
sql, params = qn.compile(value)
|
||||
return '(' + sql + ')', params
|
||||
if hasattr(value, '_as_sql'):
|
||||
sql, params = value._as_sql(connection=connection)
|
||||
return '(' + sql + ')', params
|
||||
else:
|
||||
return self.get_db_prep_lookup(value, connection)
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
new = copy(self)
|
||||
new.lhs = new.lhs.relabeled_clone(relabels)
|
||||
if hasattr(new.rhs, 'relabeled_clone'):
|
||||
new.rhs = new.rhs.relabeled_clone(relabels)
|
||||
return new
|
||||
|
||||
def get_cols(self):
|
||||
cols = self.lhs.get_cols()
|
||||
if hasattr(self.rhs, 'get_cols'):
|
||||
cols.extend(self.rhs.get_cols())
|
||||
return cols
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DjangoLookup(Lookup):
|
||||
def as_sql(self, qn, connection):
|
||||
lhs_sql, params = self.process_lhs(qn, connection)
|
||||
field_internal_type = self.lhs.output_type.get_internal_type()
|
||||
db_type = self.lhs.output_type
|
||||
lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
|
||||
lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql
|
||||
rhs_sql, rhs_params = self.process_rhs(qn, connection)
|
||||
params.extend(rhs_params)
|
||||
operator_plus_rhs = self.get_rhs_op(connection, rhs_sql)
|
||||
return '%s %s' % (lhs_sql, operator_plus_rhs), params
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return connection.operators[self.lookup_name] % rhs
|
||||
|
||||
|
||||
default_lookups = {}
|
||||
|
||||
|
||||
class Exact(DjangoLookup):
|
||||
lookup_name = 'exact'
|
||||
default_lookups['exact'] = Exact
|
||||
|
||||
|
||||
class IExact(DjangoLookup):
|
||||
lookup_name = 'iexact'
|
||||
default_lookups['iexact'] = IExact
|
||||
|
||||
|
||||
class Contains(DjangoLookup):
|
||||
lookup_name = 'contains'
|
||||
default_lookups['contains'] = Contains
|
||||
|
||||
|
||||
class IContains(DjangoLookup):
|
||||
lookup_name = 'icontains'
|
||||
default_lookups['icontains'] = IContains
|
||||
|
||||
|
||||
class GreaterThan(DjangoLookup):
|
||||
lookup_name = 'gt'
|
||||
default_lookups['gt'] = GreaterThan
|
||||
|
||||
|
||||
class GreaterThanOrEqual(DjangoLookup):
|
||||
lookup_name = 'gte'
|
||||
default_lookups['gte'] = GreaterThanOrEqual
|
||||
|
||||
|
||||
class LessThan(DjangoLookup):
|
||||
lookup_name = 'lt'
|
||||
default_lookups['lt'] = LessThan
|
||||
|
||||
|
||||
class LessThanOrEqual(DjangoLookup):
|
||||
lookup_name = 'lte'
|
||||
default_lookups['lte'] = LessThanOrEqual
|
||||
|
||||
|
||||
class In(DjangoLookup):
|
||||
lookup_name = 'in'
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
params = self.lhs.field.get_db_prep_lookup(
|
||||
self.lookup_name, value, connection, prepared=True)
|
||||
if not params:
|
||||
# TODO: check why this leads to circular import
|
||||
from django.db.models.sql.datastructures import EmptyResultSet
|
||||
raise EmptyResultSet
|
||||
placeholder = '(' + ', '.join('%s' for p in params) + ')'
|
||||
return (placeholder, params)
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return 'IN %s' % rhs
|
||||
default_lookups['in'] = In
|
||||
|
||||
|
||||
class StartsWith(DjangoLookup):
|
||||
lookup_name = 'startswith'
|
||||
default_lookups['startswith'] = StartsWith
|
||||
|
||||
|
||||
class IStartsWith(DjangoLookup):
|
||||
lookup_name = 'istartswith'
|
||||
default_lookups['istartswith'] = IStartsWith
|
||||
|
||||
|
||||
class EndsWith(DjangoLookup):
|
||||
lookup_name = 'endswith'
|
||||
default_lookups['endswith'] = EndsWith
|
||||
|
||||
|
||||
class IEndsWith(DjangoLookup):
|
||||
lookup_name = 'iendswith'
|
||||
default_lookups['iendswith'] = IEndsWith
|
||||
|
||||
|
||||
class Between(DjangoLookup):
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return "BETWEEN %s AND %s" % (rhs, rhs)
|
||||
|
||||
|
||||
class Year(Between):
|
||||
lookup_name = 'year'
|
||||
default_lookups['year'] = Year
|
||||
|
||||
|
||||
class Range(Between):
|
||||
lookup_name = 'range'
|
||||
default_lookups['range'] = Range
|
||||
|
||||
|
||||
class DateLookup(DjangoLookup):
|
||||
|
||||
def process_lhs(self, qn, connection):
|
||||
lhs, params = super(DateLookup, self).process_lhs(qn, connection)
|
||||
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
|
||||
sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname)
|
||||
return connection.ops.lookup_cast(self.lookup_name) % sql, tz_params
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return '= %s' % rhs
|
||||
|
||||
|
||||
class Month(DateLookup):
|
||||
lookup_name = 'month'
|
||||
extract_type = 'month'
|
||||
default_lookups['month'] = Month
|
||||
|
||||
|
||||
class Day(DateLookup):
|
||||
lookup_name = 'day'
|
||||
extract_type = 'day'
|
||||
default_lookups['day'] = Day
|
||||
|
||||
|
||||
class WeekDay(DateLookup):
|
||||
lookup_name = 'week_day'
|
||||
extract_type = 'week_day'
|
||||
default_lookups['week_day'] = WeekDay
|
||||
|
||||
|
||||
class Hour(DateLookup):
|
||||
lookup_name = 'hour'
|
||||
extract_type = 'hour'
|
||||
default_lookups['hour'] = Hour
|
||||
|
||||
|
||||
class Minute(DateLookup):
|
||||
lookup_name = 'minute'
|
||||
extract_type = 'minute'
|
||||
default_lookups['minute'] = Minute
|
||||
|
||||
|
||||
class Second(DateLookup):
|
||||
lookup_name = 'second'
|
||||
extract_type = 'second'
|
||||
default_lookups['second'] = Second
|
||||
|
||||
|
||||
class IsNull(DjangoLookup):
|
||||
lookup_name = 'isnull'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
sql, params = qn.compile(self.lhs)
|
||||
if self.rhs:
|
||||
return "%s IS NULL" % sql, params
|
||||
else:
|
||||
return "%s IS NOT NULL" % sql, params
|
||||
default_lookups['isnull'] = IsNull
|
||||
|
||||
|
||||
class Search(DjangoLookup):
|
||||
lookup_name = 'search'
|
||||
default_lookups['search'] = Search
|
||||
|
||||
|
||||
class Regex(DjangoLookup):
|
||||
lookup_name = 'regex'
|
||||
default_lookups['regex'] = Regex
|
||||
|
||||
|
||||
class IRegex(DjangoLookup):
|
||||
lookup_name = 'iregex'
|
||||
default_lookups['iregex'] = IRegex
|
|
@ -93,6 +93,13 @@ class Aggregate(object):
|
|||
|
||||
return self.sql_template % substitutions, params
|
||||
|
||||
def get_cols(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def output_type(self):
|
||||
return self.field
|
||||
|
||||
|
||||
class Avg(Aggregate):
|
||||
is_computed = True
|
||||
|
|
|
@ -45,7 +45,7 @@ class SQLCompiler(object):
|
|||
if self.query.select_related and not self.query.related_select_cols:
|
||||
self.fill_related_selections()
|
||||
|
||||
def quote_name_unless_alias(self, name):
|
||||
def __call__(self, name):
|
||||
"""
|
||||
A wrapper around connection.ops.quote_name that doesn't quote aliases
|
||||
for table names. This avoids problems with some SQL dialects that treat
|
||||
|
@ -61,6 +61,20 @@ class SQLCompiler(object):
|
|||
self.quote_cache[name] = r
|
||||
return r
|
||||
|
||||
def quote_name_unless_alias(self, name):
|
||||
"""
|
||||
A wrapper around connection.ops.quote_name that doesn't quote aliases
|
||||
for table names. This avoids problems with some SQL dialects that treat
|
||||
quoted strings specially (e.g. PostgreSQL).
|
||||
"""
|
||||
return self(name)
|
||||
|
||||
def compile(self, node):
|
||||
if node.__class__ in self.connection.compile_implementations:
|
||||
return self.connection.compile_implementations[node.__class__](node, self)
|
||||
else:
|
||||
return node.as_sql(self, self.connection)
|
||||
|
||||
def as_sql(self, with_limits=True, with_col_aliases=False):
|
||||
"""
|
||||
Creates the SQL for this query. Returns the SQL string and list of
|
||||
|
@ -88,10 +102,8 @@ class SQLCompiler(object):
|
|||
# docstring of get_from_clause() for details.
|
||||
from_, f_params = self.get_from_clause()
|
||||
|
||||
qn = self.quote_name_unless_alias
|
||||
|
||||
where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
|
||||
having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
|
||||
where, w_params = self.compile(self.query.where)
|
||||
having, h_params = self.compile(self.query.having)
|
||||
having_group_by = self.query.having.get_cols()
|
||||
params = []
|
||||
for val in six.itervalues(self.query.extra_select):
|
||||
|
@ -180,7 +192,7 @@ class SQLCompiler(object):
|
|||
(without the table names) are given unique aliases. This is needed in
|
||||
some cases to avoid ambiguity with nested queries.
|
||||
"""
|
||||
qn = self.quote_name_unless_alias
|
||||
qn = self
|
||||
qn2 = self.connection.ops.quote_name
|
||||
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
|
||||
params = []
|
||||
|
@ -213,7 +225,7 @@ class SQLCompiler(object):
|
|||
aliases.add(r)
|
||||
col_aliases.add(col[1])
|
||||
else:
|
||||
col_sql, col_params = col.as_sql(qn, self.connection)
|
||||
col_sql, col_params = self.compile(col)
|
||||
result.append(col_sql)
|
||||
params.extend(col_params)
|
||||
|
||||
|
@ -229,7 +241,7 @@ class SQLCompiler(object):
|
|||
|
||||
max_name_length = self.connection.ops.max_name_length()
|
||||
for alias, aggregate in self.query.aggregate_select.items():
|
||||
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
|
||||
agg_sql, agg_params = self.compile(aggregate)
|
||||
if alias is None:
|
||||
result.append(agg_sql)
|
||||
else:
|
||||
|
@ -267,7 +279,7 @@ class SQLCompiler(object):
|
|||
result = []
|
||||
if opts is None:
|
||||
opts = self.query.get_meta()
|
||||
qn = self.quote_name_unless_alias
|
||||
qn = self
|
||||
qn2 = self.connection.ops.quote_name
|
||||
aliases = set()
|
||||
only_load = self.deferred_to_columns()
|
||||
|
@ -312,7 +324,7 @@ class SQLCompiler(object):
|
|||
Note that this method can alter the tables in the query, and thus it
|
||||
must be called before get_from_clause().
|
||||
"""
|
||||
qn = self.quote_name_unless_alias
|
||||
qn = self
|
||||
qn2 = self.connection.ops.quote_name
|
||||
result = []
|
||||
opts = self.query.get_meta()
|
||||
|
@ -345,7 +357,7 @@ class SQLCompiler(object):
|
|||
ordering = (self.query.order_by
|
||||
or self.query.get_meta().ordering
|
||||
or [])
|
||||
qn = self.quote_name_unless_alias
|
||||
qn = self
|
||||
qn2 = self.connection.ops.quote_name
|
||||
distinct = self.query.distinct
|
||||
select_aliases = self._select_aliases
|
||||
|
@ -483,7 +495,7 @@ class SQLCompiler(object):
|
|||
ordering and distinct must be done first.
|
||||
"""
|
||||
result = []
|
||||
qn = self.quote_name_unless_alias
|
||||
qn = self
|
||||
qn2 = self.connection.ops.quote_name
|
||||
first = True
|
||||
from_params = []
|
||||
|
@ -501,8 +513,7 @@ class SQLCompiler(object):
|
|||
extra_cond = join_field.get_extra_restriction(
|
||||
self.query.where_class, alias, lhs)
|
||||
if extra_cond:
|
||||
extra_sql, extra_params = extra_cond.as_sql(
|
||||
qn, self.connection)
|
||||
extra_sql, extra_params = self.compile(extra_cond)
|
||||
extra_sql = 'AND (%s)' % extra_sql
|
||||
from_params.extend(extra_params)
|
||||
else:
|
||||
|
@ -534,7 +545,7 @@ class SQLCompiler(object):
|
|||
"""
|
||||
Returns a tuple representing the SQL elements in the "group by" clause.
|
||||
"""
|
||||
qn = self.quote_name_unless_alias
|
||||
qn = self
|
||||
result, params = [], []
|
||||
if self.query.group_by is not None:
|
||||
select_cols = self.query.select + self.query.related_select_cols
|
||||
|
@ -553,7 +564,7 @@ class SQLCompiler(object):
|
|||
if isinstance(col, (list, tuple)):
|
||||
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
|
||||
elif hasattr(col, 'as_sql'):
|
||||
sql, col_params = col.as_sql(qn, self.connection)
|
||||
self.compile(col)
|
||||
else:
|
||||
sql = '(%s)' % str(col)
|
||||
if sql not in seen:
|
||||
|
@ -776,7 +787,7 @@ class SQLCompiler(object):
|
|||
return result
|
||||
|
||||
def as_subquery_condition(self, alias, columns, qn):
|
||||
inner_qn = self.quote_name_unless_alias
|
||||
inner_qn = self
|
||||
qn2 = self.connection.ops.quote_name
|
||||
if len(columns) == 1:
|
||||
sql, params = self.as_sql()
|
||||
|
@ -887,9 +898,9 @@ class SQLDeleteCompiler(SQLCompiler):
|
|||
"""
|
||||
assert len(self.query.tables) == 1, \
|
||||
"Can only delete from one table at a time."
|
||||
qn = self.quote_name_unless_alias
|
||||
qn = self
|
||||
result = ['DELETE FROM %s' % qn(self.query.tables[0])]
|
||||
where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
|
||||
where, params = self.compile(self.query.where)
|
||||
if where:
|
||||
result.append('WHERE %s' % where)
|
||||
return ' '.join(result), tuple(params)
|
||||
|
@ -905,7 +916,7 @@ class SQLUpdateCompiler(SQLCompiler):
|
|||
if not self.query.values:
|
||||
return '', ()
|
||||
table = self.query.tables[0]
|
||||
qn = self.quote_name_unless_alias
|
||||
qn = self
|
||||
result = ['UPDATE %s' % qn(table)]
|
||||
result.append('SET')
|
||||
values, update_params = [], []
|
||||
|
@ -925,7 +936,7 @@ class SQLUpdateCompiler(SQLCompiler):
|
|||
val = SQLEvaluator(val, self.query, allow_joins=False)
|
||||
name = field.column
|
||||
if hasattr(val, 'as_sql'):
|
||||
sql, params = val.as_sql(qn, self.connection)
|
||||
sql, params = self.compile(val)
|
||||
values.append('%s = %s' % (qn(name), sql))
|
||||
update_params.extend(params)
|
||||
elif val is not None:
|
||||
|
@ -936,7 +947,7 @@ class SQLUpdateCompiler(SQLCompiler):
|
|||
if not values:
|
||||
return '', ()
|
||||
result.append(', '.join(values))
|
||||
where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
|
||||
where, params = self.compile(self.query.where)
|
||||
if where:
|
||||
result.append('WHERE %s' % where)
|
||||
return ' '.join(result), tuple(update_params + params)
|
||||
|
@ -1016,11 +1027,11 @@ class SQLAggregateCompiler(SQLCompiler):
|
|||
parameters.
|
||||
"""
|
||||
if qn is None:
|
||||
qn = self.quote_name_unless_alias
|
||||
qn = self
|
||||
|
||||
sql, params = [], []
|
||||
for aggregate in self.query.aggregate_select.values():
|
||||
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
|
||||
agg_sql, agg_params = self.compile(aggregate)
|
||||
sql.append(agg_sql)
|
||||
params.extend(agg_params)
|
||||
sql = ', '.join(sql)
|
||||
|
|
|
@ -5,18 +5,25 @@ the SQL domain.
|
|||
|
||||
|
||||
class Col(object):
|
||||
def __init__(self, alias, col):
|
||||
self.alias = alias
|
||||
self.col = col
|
||||
def __init__(self, alias, target, source):
|
||||
self.alias, self.target, self.source = alias, target, source
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
return '%s.%s' % (qn(self.alias), self.col), []
|
||||
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
|
||||
|
||||
def prepare(self):
|
||||
return self
|
||||
@property
|
||||
def output_type(self):
|
||||
return self.source
|
||||
|
||||
@property
|
||||
def field(self):
|
||||
return self.source
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self.__class__(relabels.get(self.alias, self.alias), self.col)
|
||||
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source)
|
||||
|
||||
def get_cols(self):
|
||||
return [(self.alias, self.target.column)]
|
||||
|
||||
|
||||
class EmptyResultSet(Exception):
|
||||
|
|
|
@ -1030,6 +1030,12 @@ class Query(object):
|
|||
def prepare_lookup_value(self, value, lookup_type, can_reuse):
|
||||
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
|
||||
# uses of None as a query value.
|
||||
if len(lookup_type) > 1:
|
||||
raise FieldError('Nested lookups not allowed')
|
||||
elif len(lookup_type) == 0:
|
||||
lookup_type = 'exact'
|
||||
else:
|
||||
lookup_type = lookup_type[0]
|
||||
if value is None:
|
||||
if lookup_type != 'exact':
|
||||
raise ValueError("Cannot use None as a query value")
|
||||
|
@ -1060,31 +1066,39 @@ class Query(object):
|
|||
"""
|
||||
Solve the lookup type from the lookup (eg: 'foobar__id__icontains')
|
||||
"""
|
||||
lookup_type = 'exact' # Default lookup type
|
||||
lookup_parts = lookup.split(LOOKUP_SEP)
|
||||
num_parts = len(lookup_parts)
|
||||
if (len(lookup_parts) > 1 and lookup_parts[-1] in self.query_terms
|
||||
and (not self._aggregates or lookup not in self._aggregates)):
|
||||
# Traverse the lookup query to distinguish related fields from
|
||||
# lookup types.
|
||||
lookup_model = self.model
|
||||
for counter, field_name in enumerate(lookup_parts):
|
||||
try:
|
||||
lookup_field = lookup_model._meta.get_field(field_name)
|
||||
except FieldDoesNotExist:
|
||||
# Not a field. Bail out.
|
||||
lookup_type = lookup_parts.pop()
|
||||
break
|
||||
# Unless we're at the end of the list of lookups, let's attempt
|
||||
# to continue traversing relations.
|
||||
if (counter + 1) < num_parts:
|
||||
try:
|
||||
lookup_model = lookup_field.rel.to
|
||||
except AttributeError:
|
||||
# Not a related field. Bail out.
|
||||
lookup_type = lookup_parts.pop()
|
||||
break
|
||||
return lookup_type, lookup_parts
|
||||
lookup_splitted = lookup.split(LOOKUP_SEP)
|
||||
aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
|
||||
if aggregate:
|
||||
if len(aggregate_lookups) > 1:
|
||||
raise FieldError("Nested lookups not allowed.")
|
||||
return aggregate_lookups, (), aggregate
|
||||
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
|
||||
field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
|
||||
if len(lookup_parts) == 0:
|
||||
lookup_parts = ['exact']
|
||||
elif len(lookup_parts) > 1:
|
||||
if field_parts:
|
||||
raise FieldError(
|
||||
'Only one lookup part allowed (found path "%s" from "%s").' %
|
||||
(LOOKUP_SEP.join(field_parts), lookup))
|
||||
else:
|
||||
raise FieldError(
|
||||
'Invalid lookup "%s" for model %s".' %
|
||||
(lookup, self.get_meta().model.__name__))
|
||||
else:
|
||||
if not hasattr(field, 'get_lookup_constraint'):
|
||||
lookup_class = field.get_lookup(lookup_parts[0])
|
||||
if lookup_class is None and lookup_parts[0] not in self.query_terms:
|
||||
raise FieldError(
|
||||
'Invalid lookup name %s' % lookup_parts[0])
|
||||
return lookup_parts, field_parts, False
|
||||
|
||||
def build_lookup(self, lookup_type, lhs, rhs):
|
||||
if hasattr(lhs.output_type, 'get_lookup'):
|
||||
lookup = lhs.output_type.get_lookup(lookup_type)
|
||||
if lookup:
|
||||
return lookup(self.where_class, lhs, rhs)
|
||||
return None
|
||||
|
||||
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
|
||||
can_reuse=None, connector=AND):
|
||||
|
@ -1114,9 +1128,9 @@ class Query(object):
|
|||
is responsible for unreffing the joins used.
|
||||
"""
|
||||
arg, value = filter_expr
|
||||
lookup_type, parts = self.solve_lookup_type(arg)
|
||||
if not parts:
|
||||
if not arg:
|
||||
raise FieldError("Cannot parse keyword query %r" % arg)
|
||||
lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg)
|
||||
|
||||
# Work out the lookup type and remove it from the end of 'parts',
|
||||
# if necessary.
|
||||
|
@ -1124,11 +1138,13 @@ class Query(object):
|
|||
used_joins = getattr(value, '_used_joins', [])
|
||||
|
||||
clause = self.where_class()
|
||||
if self._aggregates:
|
||||
for alias, aggregate in self.aggregates.items():
|
||||
if alias in (parts[0], LOOKUP_SEP.join(parts)):
|
||||
clause.add((aggregate, lookup_type, value), AND)
|
||||
return clause, []
|
||||
if reffed_aggregate:
|
||||
condition = self.build_lookup(lookup_type, reffed_aggregate, value)
|
||||
if not condition:
|
||||
# Backwards compat for custom lookups
|
||||
condition = (reffed_aggregate, lookup_type, value)
|
||||
clause.add(condition, AND)
|
||||
return clause, []
|
||||
|
||||
opts = self.get_meta()
|
||||
alias = self.get_initial_alias()
|
||||
|
@ -1150,11 +1166,18 @@ class Query(object):
|
|||
targets, alias, join_list = self.trim_joins(sources, join_list, path)
|
||||
|
||||
if hasattr(field, 'get_lookup_constraint'):
|
||||
constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources,
|
||||
lookup_type, value)
|
||||
# For now foreign keys get special treatment. This should be
|
||||
# refactored when composite fields lands.
|
||||
condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
|
||||
lookup_type, value)
|
||||
else:
|
||||
constraint = (Constraint(alias, targets[0].column, field), lookup_type, value)
|
||||
clause.add(constraint, AND)
|
||||
assert(len(targets) == 1)
|
||||
col = Col(alias, targets[0], field)
|
||||
condition = self.build_lookup(lookup_type, col, value)
|
||||
if not condition:
|
||||
# Backwards compat for custom lookups
|
||||
condition = (Constraint(alias, targets[0].column, field), lookup_type, value)
|
||||
clause.add(condition, AND)
|
||||
|
||||
require_outer = lookup_type == 'isnull' and value is True and not current_negated
|
||||
if current_negated and (lookup_type != 'isnull' or value is False):
|
||||
|
@ -1185,7 +1208,7 @@ class Query(object):
|
|||
if not self._aggregates:
|
||||
return False
|
||||
if not isinstance(obj, Node):
|
||||
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)
|
||||
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0]
|
||||
or (hasattr(obj[1], 'contains_aggregate')
|
||||
and obj[1].contains_aggregate(self.aggregates)))
|
||||
return any(self.need_having(c) for c in obj.children)
|
||||
|
@ -1273,7 +1296,7 @@ class Query(object):
|
|||
needed_inner = joinpromoter.update_join_types(self)
|
||||
return target_clause, needed_inner
|
||||
|
||||
def names_to_path(self, names, opts, allow_many):
|
||||
def names_to_path(self, names, opts, allow_many=True):
|
||||
"""
|
||||
Walks the names path and turns them PathInfo tuples. Note that a
|
||||
single name in 'names' can generate multiple PathInfos (m2m for
|
||||
|
@ -1293,9 +1316,10 @@ class Query(object):
|
|||
try:
|
||||
field, model, direct, m2m = opts.get_field_by_name(name)
|
||||
except FieldDoesNotExist:
|
||||
available = opts.get_all_field_names() + list(self.aggregate_select)
|
||||
raise FieldError("Cannot resolve keyword %r into field. "
|
||||
"Choices are: %s" % (name, ", ".join(available)))
|
||||
# We didn't found the current field, so move position back
|
||||
# one step.
|
||||
pos -= 1
|
||||
break
|
||||
# Check if we need any joins for concrete inheritance cases (the
|
||||
# field lives in parent, but we are currently in one of its
|
||||
# children)
|
||||
|
@ -1330,15 +1354,9 @@ class Query(object):
|
|||
final_field = field
|
||||
targets = (field,)
|
||||
break
|
||||
|
||||
if pos != len(names) - 1:
|
||||
if pos == len(names) - 2:
|
||||
raise FieldError(
|
||||
"Join on field %r not permitted. Did you misspell %r for "
|
||||
"the lookup type?" % (name, names[pos + 1]))
|
||||
else:
|
||||
raise FieldError("Join on field %r not permitted." % name)
|
||||
return path, final_field, targets
|
||||
if pos == -1:
|
||||
raise FieldError('Whazaa')
|
||||
return path, final_field, targets, names[pos + 1:]
|
||||
|
||||
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
|
||||
"""
|
||||
|
@ -1367,8 +1385,10 @@ class Query(object):
|
|||
"""
|
||||
joins = [alias]
|
||||
# First, generate the path for the names
|
||||
path, final_field, targets = self.names_to_path(
|
||||
path, final_field, targets, rest = self.names_to_path(
|
||||
names, opts, allow_many)
|
||||
if rest:
|
||||
raise FieldError('Invalid lookup')
|
||||
# Then, add the path to the query's joins. Note that we can't trim
|
||||
# joins at this stage - we will need the information about join type
|
||||
# of the trimmed joins.
|
||||
|
@ -1383,8 +1403,6 @@ class Query(object):
|
|||
alias = self.join(
|
||||
connection, reuse=reuse, nullable=nullable, join_field=join.join_field)
|
||||
joins.append(alias)
|
||||
if hasattr(final_field, 'field'):
|
||||
final_field = final_field.field
|
||||
return final_field, targets, opts, joins, path
|
||||
|
||||
def trim_joins(self, targets, joins, path):
|
||||
|
@ -1455,7 +1473,7 @@ class Query(object):
|
|||
query.bump_prefix(self)
|
||||
query.where.add(
|
||||
(Constraint(query.select[0].col[0], pk.column, pk),
|
||||
'exact', Col(alias, pk.column)),
|
||||
'exact', Col(alias, pk, pk)),
|
||||
AND
|
||||
)
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ class WhereNode(tree.Node):
|
|||
for child in self.children:
|
||||
try:
|
||||
if hasattr(child, 'as_sql'):
|
||||
sql, params = child.as_sql(qn=qn, connection=connection)
|
||||
sql, params = qn.compile(child)
|
||||
else:
|
||||
# A leaf node in the tree.
|
||||
sql, params = self.make_atom(child, qn, connection)
|
||||
|
@ -193,13 +193,13 @@ class WhereNode(tree.Node):
|
|||
field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), []
|
||||
else:
|
||||
# A smart object with an as_sql() method.
|
||||
field_sql, field_params = lvalue.as_sql(qn, connection)
|
||||
field_sql, field_params = qn.compile(lvalue)
|
||||
|
||||
is_datetime_field = value_annotation is datetime.datetime
|
||||
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
|
||||
|
||||
if hasattr(params, 'as_sql'):
|
||||
extra, params = params.as_sql(qn, connection)
|
||||
extra, params = qn.compile(params)
|
||||
cast_sql = ''
|
||||
else:
|
||||
extra = ''
|
||||
|
@ -282,6 +282,8 @@ class WhereNode(tree.Node):
|
|||
if hasattr(child, 'relabel_aliases'):
|
||||
# For example another WhereNode
|
||||
child.relabel_aliases(change_map)
|
||||
elif hasattr(child, 'relabeled_clone'):
|
||||
self.children[pos] = child.relabeled_clone(change_map)
|
||||
elif isinstance(child, (list, tuple)):
|
||||
# tuple starting with Constraint
|
||||
child = (child[0].relabeled_clone(change_map),) + child[1:]
|
||||
|
@ -350,7 +352,7 @@ class Constraint(object):
|
|||
self.alias, self.col, self.field = alias, col, field
|
||||
|
||||
def prepare(self, lookup_type, value):
|
||||
if self.field:
|
||||
if self.field and not hasattr(value, 'as_sql'):
|
||||
return self.field.get_prep_lookup(lookup_type, value)
|
||||
return value
|
||||
|
||||
|
|
|
@ -443,7 +443,7 @@ class BaseAggregateTestCase(TestCase):
|
|||
vals = Author.objects.filter(pk=1).aggregate(Count("friends__id"))
|
||||
self.assertEqual(vals, {"friends__id__count": 2})
|
||||
|
||||
books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__ge=2).order_by("pk")
|
||||
books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__exact=2).order_by("pk")
|
||||
self.assertQuerysetEqual(
|
||||
books, [
|
||||
"The Definitive Guide to Django: Web Development Done Right",
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from django.db import models
|
||||
|
||||
|
||||
class Author(models.Model):
|
||||
name = models.CharField(max_length=20)
|
||||
age = models.IntegerField(null=True)
|
||||
birthdate = models.DateField(null=True)
|
|
@ -0,0 +1,136 @@
|
|||
from copy import copy
|
||||
from datetime import date
|
||||
import unittest
|
||||
|
||||
from django.test import TestCase
|
||||
from .models import Author
|
||||
from django.db import models
|
||||
from django.db import connection
|
||||
from django.db.backends.utils import add_implementation
|
||||
|
||||
|
||||
class Div3Lookup(models.lookups.Lookup):
|
||||
lookup_name = 'div3'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, params = self.process_lhs(qn, connection)
|
||||
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||
params.extend(rhs_params)
|
||||
return '%s %%%% 3 = %s' % (lhs, rhs), params
|
||||
|
||||
|
||||
class InMonth(models.lookups.Lookup):
|
||||
"""
|
||||
InMonth matches if the column's month is contained in the value's month.
|
||||
"""
|
||||
lookup_name = 'inmonth'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, params = self.process_lhs(qn, connection)
|
||||
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||
# We need to be careful so that we get the params in right
|
||||
# places.
|
||||
full_params = params[:]
|
||||
full_params.extend(rhs_params)
|
||||
full_params.extend(params)
|
||||
full_params.extend(rhs_params)
|
||||
return ("%s >= date_trunc('month', %s) and "
|
||||
"%s < date_trunc('month', %s) + interval '1 months'" %
|
||||
(lhs, rhs, lhs, rhs), full_params)
|
||||
|
||||
|
||||
class LookupTests(TestCase):
|
||||
def test_basic_lookup(self):
|
||||
a1 = Author.objects.create(name='a1', age=1)
|
||||
a2 = Author.objects.create(name='a2', age=2)
|
||||
a3 = Author.objects.create(name='a3', age=3)
|
||||
a4 = Author.objects.create(name='a4', age=4)
|
||||
models.IntegerField.register_lookup(Div3Lookup)
|
||||
try:
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(age__div3=0),
|
||||
[a3], lambda x: x
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(age__div3=1).order_by('age'),
|
||||
[a1, a4], lambda x: x
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(age__div3=2),
|
||||
[a2], lambda x: x
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(age__div3=3),
|
||||
[], lambda x: x
|
||||
)
|
||||
finally:
|
||||
models.IntegerField._unregister_lookup(Div3Lookup)
|
||||
|
||||
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
|
||||
def test_birthdate_month(self):
|
||||
a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
|
||||
a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
|
||||
a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
|
||||
a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
|
||||
models.DateField.register_lookup(InMonth)
|
||||
try:
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)),
|
||||
[a3], lambda x: x
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)),
|
||||
[a2], lambda x: x
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)),
|
||||
[a1], lambda x: x
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)),
|
||||
[a4], lambda x: x
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)),
|
||||
[], lambda x: x
|
||||
)
|
||||
finally:
|
||||
models.DateField._unregister_lookup(InMonth)
|
||||
|
||||
def test_custom_compiles(self):
|
||||
a1 = Author.objects.create(name='a1', age=1)
|
||||
a2 = Author.objects.create(name='a2', age=2)
|
||||
a3 = Author.objects.create(name='a3', age=3)
|
||||
a4 = Author.objects.create(name='a4', age=4)
|
||||
|
||||
class AnotherEqual(models.lookups.Exact):
|
||||
lookup_name = 'anotherequal'
|
||||
models.Field.register_lookup(AnotherEqual)
|
||||
try:
|
||||
@add_implementation(AnotherEqual, connection.vendor)
|
||||
def custom_eq_sql(node, compiler):
|
||||
return '1 = 1', []
|
||||
|
||||
self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(name__anotherequal='asdf').order_by('name'),
|
||||
[a1, a2, a3, a4], lambda x: x)
|
||||
|
||||
@add_implementation(AnotherEqual, connection.vendor)
|
||||
def another_custom_eq_sql(node, compiler):
|
||||
# If you need to override one method, it seems this is the best
|
||||
# option.
|
||||
node = copy(node)
|
||||
|
||||
class OverriddenAnotherEqual(AnotherEqual):
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return ' <> %s'
|
||||
node.__class__ = OverriddenAnotherEqual
|
||||
return node.as_sql(compiler, compiler.connection)
|
||||
self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(name__anotherequal='a1').order_by('name'),
|
||||
[a2, a3, a4], lambda x: x
|
||||
)
|
||||
finally:
|
||||
models.Field._unregister_lookup(AnotherEqual)
|
|
@ -2620,8 +2620,15 @@ class WhereNodeTest(TestCase):
|
|||
def as_sql(self, qn, connection):
|
||||
return 'dummy', []
|
||||
|
||||
class MockCompiler(object):
|
||||
def compile(self, node):
|
||||
return node.as_sql(self, connection)
|
||||
|
||||
def __call__(self, name):
|
||||
return connection.ops.quote_name(name)
|
||||
|
||||
def test_empty_full_handling_conjunction(self):
|
||||
qn = connection.ops.quote_name
|
||||
qn = WhereNodeTest.MockCompiler()
|
||||
w = WhereNode(children=[EverythingNode()])
|
||||
self.assertEqual(w.as_sql(qn, connection), ('', []))
|
||||
w.negate()
|
||||
|
@ -2646,7 +2653,7 @@ class WhereNodeTest(TestCase):
|
|||
self.assertEqual(w.as_sql(qn, connection), ('', []))
|
||||
|
||||
def test_empty_full_handling_disjunction(self):
|
||||
qn = connection.ops.quote_name
|
||||
qn = WhereNodeTest.MockCompiler()
|
||||
w = WhereNode(children=[EverythingNode()], connector='OR')
|
||||
self.assertEqual(w.as_sql(qn, connection), ('', []))
|
||||
w.negate()
|
||||
|
@ -2673,7 +2680,7 @@ class WhereNodeTest(TestCase):
|
|||
self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', []))
|
||||
|
||||
def test_empty_nodes(self):
|
||||
qn = connection.ops.quote_name
|
||||
qn = WhereNodeTest.MockCompiler()
|
||||
empty_w = WhereNode()
|
||||
w = WhereNode(children=[empty_w, empty_w])
|
||||
self.assertEqual(w.as_sql(qn, connection), (None, []))
|
||||
|
|
Loading…
Reference in New Issue