Initial implementation of custom lookups
This commit is contained in:
parent
01e8ac47b3
commit
4d219d4cde
|
@ -67,6 +67,9 @@ class BaseDatabaseWrapper(object):
|
||||||
self.allow_thread_sharing = allow_thread_sharing
|
self.allow_thread_sharing = allow_thread_sharing
|
||||||
self._thread_ident = thread.get_ident()
|
self._thread_ident = thread.get_ident()
|
||||||
|
|
||||||
|
# Compile implementations, used by compiler.compile(someelem)
|
||||||
|
self.compile_implementations = utils.get_implementations(self.vendor)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, BaseDatabaseWrapper):
|
if isinstance(other, BaseDatabaseWrapper):
|
||||||
return self.alias == other.alias
|
return self.alias == other.alias
|
||||||
|
|
|
@ -194,3 +194,27 @@ def format_number(value, max_digits, decimal_places):
|
||||||
return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context))
|
return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context))
|
||||||
else:
|
else:
|
||||||
return "%.*f" % (decimal_places, value)
|
return "%.*f" % (decimal_places, value)
|
||||||
|
|
||||||
|
# Map of vendor name -> map of query element class -> implementation function
|
||||||
|
compile_implementations = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_implementations(vendor):
|
||||||
|
try:
|
||||||
|
implementation = compile_implementations[vendor]
|
||||||
|
except KeyError:
|
||||||
|
# TODO: do we need thread safety here? We could easily use an lock...
|
||||||
|
implementation = {}
|
||||||
|
compile_implementations[vendor] = implementation
|
||||||
|
return implementation
|
||||||
|
|
||||||
|
|
||||||
|
class add_implementation(object):
|
||||||
|
def __init__(self, klass, vendor):
|
||||||
|
self.klass = klass
|
||||||
|
self.vendor = vendor
|
||||||
|
|
||||||
|
def __call__(self, func):
|
||||||
|
implementations = get_implementations(self.vendor)
|
||||||
|
implementations[self.klass] = func
|
||||||
|
return func
|
||||||
|
|
|
@ -17,8 +17,8 @@ def refs_aggregate(lookup_parts, aggregates):
|
||||||
"""
|
"""
|
||||||
for i in range(len(lookup_parts) + 1):
|
for i in range(len(lookup_parts) + 1):
|
||||||
if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates:
|
if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates:
|
||||||
return True
|
return aggregates[LOOKUP_SEP.join(lookup_parts[0:i])], lookup_parts[i:]
|
||||||
return False
|
return False, ()
|
||||||
|
|
||||||
|
|
||||||
class Aggregate(object):
|
class Aggregate(object):
|
||||||
|
|
|
@ -4,6 +4,7 @@ import collections
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
import decimal
|
import decimal
|
||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from base64 import b64decode, b64encode
|
from base64 import b64decode, b64encode
|
||||||
|
@ -11,6 +12,7 @@ from itertools import tee
|
||||||
|
|
||||||
from django.db import connection
|
from django.db import connection
|
||||||
from django.db.models.loading import get_model
|
from django.db.models.loading import get_model
|
||||||
|
from django.db.models.lookups import default_lookups
|
||||||
from django.db.models.query_utils import QueryWrapper
|
from django.db.models.query_utils import QueryWrapper
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django import forms
|
from django import forms
|
||||||
|
@ -101,6 +103,7 @@ class Field(object):
|
||||||
'unique': _('%(model_name)s with this %(field_label)s '
|
'unique': _('%(model_name)s with this %(field_label)s '
|
||||||
'already exists.'),
|
'already exists.'),
|
||||||
}
|
}
|
||||||
|
class_lookups = default_lookups.copy()
|
||||||
|
|
||||||
# Generic field type description, usually overridden by subclasses
|
# Generic field type description, usually overridden by subclasses
|
||||||
def _description(self):
|
def _description(self):
|
||||||
|
@ -446,6 +449,30 @@ class Field(object):
|
||||||
def get_internal_type(self):
|
def get_internal_type(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
def get_lookup(self, lookup_name):
|
||||||
|
try:
|
||||||
|
return self.class_lookups[lookup_name]
|
||||||
|
except KeyError:
|
||||||
|
for parent in inspect.getmro(self.__class__):
|
||||||
|
if not 'class_lookups' in parent.__dict__:
|
||||||
|
continue
|
||||||
|
if lookup_name in parent.class_lookups:
|
||||||
|
return parent.class_lookups[lookup_name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_lookup(cls, lookup):
|
||||||
|
if not 'class_lookups' in cls.__dict__:
|
||||||
|
cls.class_lookups = {}
|
||||||
|
cls.class_lookups[lookup.lookup_name] = lookup
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _unregister_lookup(cls, lookup):
|
||||||
|
"""
|
||||||
|
Removes given lookup from cls lookups. Meant to be used in
|
||||||
|
tests only.
|
||||||
|
"""
|
||||||
|
del cls.class_lookups[lookup.lookup_name]
|
||||||
|
|
||||||
def pre_save(self, model_instance, add):
|
def pre_save(self, model_instance, add):
|
||||||
"""
|
"""
|
||||||
Returns field's value just before saving.
|
Returns field's value just before saving.
|
||||||
|
@ -504,8 +531,7 @@ class Field(object):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError("The __year lookup type requires an integer "
|
raise ValueError("The __year lookup type requires an integer "
|
||||||
"argument")
|
"argument")
|
||||||
|
return self.get_prep_value(value)
|
||||||
raise TypeError("Field has invalid lookup: %s" % lookup_type)
|
|
||||||
|
|
||||||
def get_db_prep_lookup(self, lookup_type, value, connection,
|
def get_db_prep_lookup(self, lookup_type, value, connection,
|
||||||
prepared=False):
|
prepared=False):
|
||||||
|
@ -554,6 +580,8 @@ class Field(object):
|
||||||
return connection.ops.year_lookup_bounds_for_date_field(value)
|
return connection.ops.year_lookup_bounds_for_date_field(value)
|
||||||
else:
|
else:
|
||||||
return [value] # this isn't supposed to happen
|
return [value] # this isn't supposed to happen
|
||||||
|
else:
|
||||||
|
return [value]
|
||||||
|
|
||||||
def has_default(self):
|
def has_default(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -934,6 +934,11 @@ class ForeignObjectRel(object):
|
||||||
# example custom multicolumn joins currently have no remote field).
|
# example custom multicolumn joins currently have no remote field).
|
||||||
self.field_name = None
|
self.field_name = None
|
||||||
|
|
||||||
|
def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
|
||||||
|
raw_value):
|
||||||
|
return self.field.get_lookup_constraint(constraint_class, alias, targets, sources,
|
||||||
|
lookup_type, raw_value)
|
||||||
|
|
||||||
|
|
||||||
class ManyToOneRel(ForeignObjectRel):
|
class ManyToOneRel(ForeignObjectRel):
|
||||||
def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
|
def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
|
||||||
|
|
|
@ -0,0 +1,242 @@
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
|
|
||||||
|
class Lookup(object):
|
||||||
|
def __init__(self, constraint_class, lhs, rhs):
|
||||||
|
self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
|
||||||
|
self.rhs = self.get_prep_lookup()
|
||||||
|
|
||||||
|
def get_db_prep_lookup(self, value, connection):
|
||||||
|
return (
|
||||||
|
'%s', self.lhs.output_type.get_db_prep_lookup(
|
||||||
|
self.lookup_name, value, connection, prepared=True))
|
||||||
|
|
||||||
|
def get_prep_lookup(self):
|
||||||
|
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
|
||||||
|
|
||||||
|
def process_lhs(self, qn, connection):
|
||||||
|
return qn.compile(self.lhs)
|
||||||
|
|
||||||
|
def process_rhs(self, qn, connection):
|
||||||
|
value = self.rhs
|
||||||
|
# Due to historical reasons there are a couple of different
|
||||||
|
# ways to produce sql here. get_compiler is likely a Query
|
||||||
|
# instance, _as_sql QuerySet and as_sql just something with
|
||||||
|
# as_sql. Finally the value can of course be just plain
|
||||||
|
# Python value.
|
||||||
|
if hasattr(value, 'get_compiler'):
|
||||||
|
value = value.get_compiler(connection=connection)
|
||||||
|
if hasattr(value, 'as_sql'):
|
||||||
|
sql, params = qn.compile(value)
|
||||||
|
return '(' + sql + ')', params
|
||||||
|
if hasattr(value, '_as_sql'):
|
||||||
|
sql, params = value._as_sql(connection=connection)
|
||||||
|
return '(' + sql + ')', params
|
||||||
|
else:
|
||||||
|
return self.get_db_prep_lookup(value, connection)
|
||||||
|
|
||||||
|
def relabeled_clone(self, relabels):
|
||||||
|
new = copy(self)
|
||||||
|
new.lhs = new.lhs.relabeled_clone(relabels)
|
||||||
|
if hasattr(new.rhs, 'relabeled_clone'):
|
||||||
|
new.rhs = new.rhs.relabeled_clone(relabels)
|
||||||
|
return new
|
||||||
|
|
||||||
|
def get_cols(self):
|
||||||
|
cols = self.lhs.get_cols()
|
||||||
|
if hasattr(self.rhs, 'get_cols'):
|
||||||
|
cols.extend(self.rhs.get_cols())
|
||||||
|
return cols
|
||||||
|
|
||||||
|
def as_sql(self, qn, connection):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoLookup(Lookup):
|
||||||
|
def as_sql(self, qn, connection):
|
||||||
|
lhs_sql, params = self.process_lhs(qn, connection)
|
||||||
|
field_internal_type = self.lhs.output_type.get_internal_type()
|
||||||
|
db_type = self.lhs.output_type
|
||||||
|
lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
|
||||||
|
lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql
|
||||||
|
rhs_sql, rhs_params = self.process_rhs(qn, connection)
|
||||||
|
params.extend(rhs_params)
|
||||||
|
operator_plus_rhs = self.get_rhs_op(connection, rhs_sql)
|
||||||
|
return '%s %s' % (lhs_sql, operator_plus_rhs), params
|
||||||
|
|
||||||
|
def get_rhs_op(self, connection, rhs):
|
||||||
|
return connection.operators[self.lookup_name] % rhs
|
||||||
|
|
||||||
|
|
||||||
|
default_lookups = {}
|
||||||
|
|
||||||
|
|
||||||
|
class Exact(DjangoLookup):
|
||||||
|
lookup_name = 'exact'
|
||||||
|
default_lookups['exact'] = Exact
|
||||||
|
|
||||||
|
|
||||||
|
class IExact(DjangoLookup):
|
||||||
|
lookup_name = 'iexact'
|
||||||
|
default_lookups['iexact'] = IExact
|
||||||
|
|
||||||
|
|
||||||
|
class Contains(DjangoLookup):
|
||||||
|
lookup_name = 'contains'
|
||||||
|
default_lookups['contains'] = Contains
|
||||||
|
|
||||||
|
|
||||||
|
class IContains(DjangoLookup):
|
||||||
|
lookup_name = 'icontains'
|
||||||
|
default_lookups['icontains'] = IContains
|
||||||
|
|
||||||
|
|
||||||
|
class GreaterThan(DjangoLookup):
|
||||||
|
lookup_name = 'gt'
|
||||||
|
default_lookups['gt'] = GreaterThan
|
||||||
|
|
||||||
|
|
||||||
|
class GreaterThanOrEqual(DjangoLookup):
|
||||||
|
lookup_name = 'gte'
|
||||||
|
default_lookups['gte'] = GreaterThanOrEqual
|
||||||
|
|
||||||
|
|
||||||
|
class LessThan(DjangoLookup):
|
||||||
|
lookup_name = 'lt'
|
||||||
|
default_lookups['lt'] = LessThan
|
||||||
|
|
||||||
|
|
||||||
|
class LessThanOrEqual(DjangoLookup):
|
||||||
|
lookup_name = 'lte'
|
||||||
|
default_lookups['lte'] = LessThanOrEqual
|
||||||
|
|
||||||
|
|
||||||
|
class In(DjangoLookup):
|
||||||
|
lookup_name = 'in'
|
||||||
|
|
||||||
|
def get_db_prep_lookup(self, value, connection):
|
||||||
|
params = self.lhs.field.get_db_prep_lookup(
|
||||||
|
self.lookup_name, value, connection, prepared=True)
|
||||||
|
if not params:
|
||||||
|
# TODO: check why this leads to circular import
|
||||||
|
from django.db.models.sql.datastructures import EmptyResultSet
|
||||||
|
raise EmptyResultSet
|
||||||
|
placeholder = '(' + ', '.join('%s' for p in params) + ')'
|
||||||
|
return (placeholder, params)
|
||||||
|
|
||||||
|
def get_rhs_op(self, connection, rhs):
|
||||||
|
return 'IN %s' % rhs
|
||||||
|
default_lookups['in'] = In
|
||||||
|
|
||||||
|
|
||||||
|
class StartsWith(DjangoLookup):
|
||||||
|
lookup_name = 'startswith'
|
||||||
|
default_lookups['startswith'] = StartsWith
|
||||||
|
|
||||||
|
|
||||||
|
class IStartsWith(DjangoLookup):
|
||||||
|
lookup_name = 'istartswith'
|
||||||
|
default_lookups['istartswith'] = IStartsWith
|
||||||
|
|
||||||
|
|
||||||
|
class EndsWith(DjangoLookup):
|
||||||
|
lookup_name = 'endswith'
|
||||||
|
default_lookups['endswith'] = EndsWith
|
||||||
|
|
||||||
|
|
||||||
|
class IEndsWith(DjangoLookup):
|
||||||
|
lookup_name = 'iendswith'
|
||||||
|
default_lookups['iendswith'] = IEndsWith
|
||||||
|
|
||||||
|
|
||||||
|
class Between(DjangoLookup):
|
||||||
|
def get_rhs_op(self, connection, rhs):
|
||||||
|
return "BETWEEN %s AND %s" % (rhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
class Year(Between):
|
||||||
|
lookup_name = 'year'
|
||||||
|
default_lookups['year'] = Year
|
||||||
|
|
||||||
|
|
||||||
|
class Range(Between):
|
||||||
|
lookup_name = 'range'
|
||||||
|
default_lookups['range'] = Range
|
||||||
|
|
||||||
|
|
||||||
|
class DateLookup(DjangoLookup):
|
||||||
|
|
||||||
|
def process_lhs(self, qn, connection):
|
||||||
|
lhs, params = super(DateLookup, self).process_lhs(qn, connection)
|
||||||
|
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
|
||||||
|
sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname)
|
||||||
|
return connection.ops.lookup_cast(self.lookup_name) % sql, tz_params
|
||||||
|
|
||||||
|
def get_rhs_op(self, connection, rhs):
|
||||||
|
return '= %s' % rhs
|
||||||
|
|
||||||
|
|
||||||
|
class Month(DateLookup):
|
||||||
|
lookup_name = 'month'
|
||||||
|
extract_type = 'month'
|
||||||
|
default_lookups['month'] = Month
|
||||||
|
|
||||||
|
|
||||||
|
class Day(DateLookup):
|
||||||
|
lookup_name = 'day'
|
||||||
|
extract_type = 'day'
|
||||||
|
default_lookups['day'] = Day
|
||||||
|
|
||||||
|
|
||||||
|
class WeekDay(DateLookup):
|
||||||
|
lookup_name = 'week_day'
|
||||||
|
extract_type = 'week_day'
|
||||||
|
default_lookups['week_day'] = WeekDay
|
||||||
|
|
||||||
|
|
||||||
|
class Hour(DateLookup):
|
||||||
|
lookup_name = 'hour'
|
||||||
|
extract_type = 'hour'
|
||||||
|
default_lookups['hour'] = Hour
|
||||||
|
|
||||||
|
|
||||||
|
class Minute(DateLookup):
|
||||||
|
lookup_name = 'minute'
|
||||||
|
extract_type = 'minute'
|
||||||
|
default_lookups['minute'] = Minute
|
||||||
|
|
||||||
|
|
||||||
|
class Second(DateLookup):
|
||||||
|
lookup_name = 'second'
|
||||||
|
extract_type = 'second'
|
||||||
|
default_lookups['second'] = Second
|
||||||
|
|
||||||
|
|
||||||
|
class IsNull(DjangoLookup):
|
||||||
|
lookup_name = 'isnull'
|
||||||
|
|
||||||
|
def as_sql(self, qn, connection):
|
||||||
|
sql, params = qn.compile(self.lhs)
|
||||||
|
if self.rhs:
|
||||||
|
return "%s IS NULL" % sql, params
|
||||||
|
else:
|
||||||
|
return "%s IS NOT NULL" % sql, params
|
||||||
|
default_lookups['isnull'] = IsNull
|
||||||
|
|
||||||
|
|
||||||
|
class Search(DjangoLookup):
|
||||||
|
lookup_name = 'search'
|
||||||
|
default_lookups['search'] = Search
|
||||||
|
|
||||||
|
|
||||||
|
class Regex(DjangoLookup):
|
||||||
|
lookup_name = 'regex'
|
||||||
|
default_lookups['regex'] = Regex
|
||||||
|
|
||||||
|
|
||||||
|
class IRegex(DjangoLookup):
|
||||||
|
lookup_name = 'iregex'
|
||||||
|
default_lookups['iregex'] = IRegex
|
|
@ -93,6 +93,13 @@ class Aggregate(object):
|
||||||
|
|
||||||
return self.sql_template % substitutions, params
|
return self.sql_template % substitutions, params
|
||||||
|
|
||||||
|
def get_cols(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_type(self):
|
||||||
|
return self.field
|
||||||
|
|
||||||
|
|
||||||
class Avg(Aggregate):
|
class Avg(Aggregate):
|
||||||
is_computed = True
|
is_computed = True
|
||||||
|
|
|
@ -45,7 +45,7 @@ class SQLCompiler(object):
|
||||||
if self.query.select_related and not self.query.related_select_cols:
|
if self.query.select_related and not self.query.related_select_cols:
|
||||||
self.fill_related_selections()
|
self.fill_related_selections()
|
||||||
|
|
||||||
def quote_name_unless_alias(self, name):
|
def __call__(self, name):
|
||||||
"""
|
"""
|
||||||
A wrapper around connection.ops.quote_name that doesn't quote aliases
|
A wrapper around connection.ops.quote_name that doesn't quote aliases
|
||||||
for table names. This avoids problems with some SQL dialects that treat
|
for table names. This avoids problems with some SQL dialects that treat
|
||||||
|
@ -61,6 +61,20 @@ class SQLCompiler(object):
|
||||||
self.quote_cache[name] = r
|
self.quote_cache[name] = r
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
def quote_name_unless_alias(self, name):
|
||||||
|
"""
|
||||||
|
A wrapper around connection.ops.quote_name that doesn't quote aliases
|
||||||
|
for table names. This avoids problems with some SQL dialects that treat
|
||||||
|
quoted strings specially (e.g. PostgreSQL).
|
||||||
|
"""
|
||||||
|
return self(name)
|
||||||
|
|
||||||
|
def compile(self, node):
|
||||||
|
if node.__class__ in self.connection.compile_implementations:
|
||||||
|
return self.connection.compile_implementations[node.__class__](node, self)
|
||||||
|
else:
|
||||||
|
return node.as_sql(self, self.connection)
|
||||||
|
|
||||||
def as_sql(self, with_limits=True, with_col_aliases=False):
|
def as_sql(self, with_limits=True, with_col_aliases=False):
|
||||||
"""
|
"""
|
||||||
Creates the SQL for this query. Returns the SQL string and list of
|
Creates the SQL for this query. Returns the SQL string and list of
|
||||||
|
@ -88,10 +102,8 @@ class SQLCompiler(object):
|
||||||
# docstring of get_from_clause() for details.
|
# docstring of get_from_clause() for details.
|
||||||
from_, f_params = self.get_from_clause()
|
from_, f_params = self.get_from_clause()
|
||||||
|
|
||||||
qn = self.quote_name_unless_alias
|
where, w_params = self.compile(self.query.where)
|
||||||
|
having, h_params = self.compile(self.query.having)
|
||||||
where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
|
|
||||||
having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
|
|
||||||
having_group_by = self.query.having.get_cols()
|
having_group_by = self.query.having.get_cols()
|
||||||
params = []
|
params = []
|
||||||
for val in six.itervalues(self.query.extra_select):
|
for val in six.itervalues(self.query.extra_select):
|
||||||
|
@ -180,7 +192,7 @@ class SQLCompiler(object):
|
||||||
(without the table names) are given unique aliases. This is needed in
|
(without the table names) are given unique aliases. This is needed in
|
||||||
some cases to avoid ambiguity with nested queries.
|
some cases to avoid ambiguity with nested queries.
|
||||||
"""
|
"""
|
||||||
qn = self.quote_name_unless_alias
|
qn = self
|
||||||
qn2 = self.connection.ops.quote_name
|
qn2 = self.connection.ops.quote_name
|
||||||
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
|
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
|
||||||
params = []
|
params = []
|
||||||
|
@ -213,7 +225,7 @@ class SQLCompiler(object):
|
||||||
aliases.add(r)
|
aliases.add(r)
|
||||||
col_aliases.add(col[1])
|
col_aliases.add(col[1])
|
||||||
else:
|
else:
|
||||||
col_sql, col_params = col.as_sql(qn, self.connection)
|
col_sql, col_params = self.compile(col)
|
||||||
result.append(col_sql)
|
result.append(col_sql)
|
||||||
params.extend(col_params)
|
params.extend(col_params)
|
||||||
|
|
||||||
|
@ -229,7 +241,7 @@ class SQLCompiler(object):
|
||||||
|
|
||||||
max_name_length = self.connection.ops.max_name_length()
|
max_name_length = self.connection.ops.max_name_length()
|
||||||
for alias, aggregate in self.query.aggregate_select.items():
|
for alias, aggregate in self.query.aggregate_select.items():
|
||||||
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
|
agg_sql, agg_params = self.compile(aggregate)
|
||||||
if alias is None:
|
if alias is None:
|
||||||
result.append(agg_sql)
|
result.append(agg_sql)
|
||||||
else:
|
else:
|
||||||
|
@ -267,7 +279,7 @@ class SQLCompiler(object):
|
||||||
result = []
|
result = []
|
||||||
if opts is None:
|
if opts is None:
|
||||||
opts = self.query.get_meta()
|
opts = self.query.get_meta()
|
||||||
qn = self.quote_name_unless_alias
|
qn = self
|
||||||
qn2 = self.connection.ops.quote_name
|
qn2 = self.connection.ops.quote_name
|
||||||
aliases = set()
|
aliases = set()
|
||||||
only_load = self.deferred_to_columns()
|
only_load = self.deferred_to_columns()
|
||||||
|
@ -312,7 +324,7 @@ class SQLCompiler(object):
|
||||||
Note that this method can alter the tables in the query, and thus it
|
Note that this method can alter the tables in the query, and thus it
|
||||||
must be called before get_from_clause().
|
must be called before get_from_clause().
|
||||||
"""
|
"""
|
||||||
qn = self.quote_name_unless_alias
|
qn = self
|
||||||
qn2 = self.connection.ops.quote_name
|
qn2 = self.connection.ops.quote_name
|
||||||
result = []
|
result = []
|
||||||
opts = self.query.get_meta()
|
opts = self.query.get_meta()
|
||||||
|
@ -345,7 +357,7 @@ class SQLCompiler(object):
|
||||||
ordering = (self.query.order_by
|
ordering = (self.query.order_by
|
||||||
or self.query.get_meta().ordering
|
or self.query.get_meta().ordering
|
||||||
or [])
|
or [])
|
||||||
qn = self.quote_name_unless_alias
|
qn = self
|
||||||
qn2 = self.connection.ops.quote_name
|
qn2 = self.connection.ops.quote_name
|
||||||
distinct = self.query.distinct
|
distinct = self.query.distinct
|
||||||
select_aliases = self._select_aliases
|
select_aliases = self._select_aliases
|
||||||
|
@ -483,7 +495,7 @@ class SQLCompiler(object):
|
||||||
ordering and distinct must be done first.
|
ordering and distinct must be done first.
|
||||||
"""
|
"""
|
||||||
result = []
|
result = []
|
||||||
qn = self.quote_name_unless_alias
|
qn = self
|
||||||
qn2 = self.connection.ops.quote_name
|
qn2 = self.connection.ops.quote_name
|
||||||
first = True
|
first = True
|
||||||
from_params = []
|
from_params = []
|
||||||
|
@ -501,8 +513,7 @@ class SQLCompiler(object):
|
||||||
extra_cond = join_field.get_extra_restriction(
|
extra_cond = join_field.get_extra_restriction(
|
||||||
self.query.where_class, alias, lhs)
|
self.query.where_class, alias, lhs)
|
||||||
if extra_cond:
|
if extra_cond:
|
||||||
extra_sql, extra_params = extra_cond.as_sql(
|
extra_sql, extra_params = self.compile(extra_cond)
|
||||||
qn, self.connection)
|
|
||||||
extra_sql = 'AND (%s)' % extra_sql
|
extra_sql = 'AND (%s)' % extra_sql
|
||||||
from_params.extend(extra_params)
|
from_params.extend(extra_params)
|
||||||
else:
|
else:
|
||||||
|
@ -534,7 +545,7 @@ class SQLCompiler(object):
|
||||||
"""
|
"""
|
||||||
Returns a tuple representing the SQL elements in the "group by" clause.
|
Returns a tuple representing the SQL elements in the "group by" clause.
|
||||||
"""
|
"""
|
||||||
qn = self.quote_name_unless_alias
|
qn = self
|
||||||
result, params = [], []
|
result, params = [], []
|
||||||
if self.query.group_by is not None:
|
if self.query.group_by is not None:
|
||||||
select_cols = self.query.select + self.query.related_select_cols
|
select_cols = self.query.select + self.query.related_select_cols
|
||||||
|
@ -553,7 +564,7 @@ class SQLCompiler(object):
|
||||||
if isinstance(col, (list, tuple)):
|
if isinstance(col, (list, tuple)):
|
||||||
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
|
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
|
||||||
elif hasattr(col, 'as_sql'):
|
elif hasattr(col, 'as_sql'):
|
||||||
sql, col_params = col.as_sql(qn, self.connection)
|
self.compile(col)
|
||||||
else:
|
else:
|
||||||
sql = '(%s)' % str(col)
|
sql = '(%s)' % str(col)
|
||||||
if sql not in seen:
|
if sql not in seen:
|
||||||
|
@ -776,7 +787,7 @@ class SQLCompiler(object):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def as_subquery_condition(self, alias, columns, qn):
|
def as_subquery_condition(self, alias, columns, qn):
|
||||||
inner_qn = self.quote_name_unless_alias
|
inner_qn = self
|
||||||
qn2 = self.connection.ops.quote_name
|
qn2 = self.connection.ops.quote_name
|
||||||
if len(columns) == 1:
|
if len(columns) == 1:
|
||||||
sql, params = self.as_sql()
|
sql, params = self.as_sql()
|
||||||
|
@ -887,9 +898,9 @@ class SQLDeleteCompiler(SQLCompiler):
|
||||||
"""
|
"""
|
||||||
assert len(self.query.tables) == 1, \
|
assert len(self.query.tables) == 1, \
|
||||||
"Can only delete from one table at a time."
|
"Can only delete from one table at a time."
|
||||||
qn = self.quote_name_unless_alias
|
qn = self
|
||||||
result = ['DELETE FROM %s' % qn(self.query.tables[0])]
|
result = ['DELETE FROM %s' % qn(self.query.tables[0])]
|
||||||
where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
|
where, params = self.compile(self.query.where)
|
||||||
if where:
|
if where:
|
||||||
result.append('WHERE %s' % where)
|
result.append('WHERE %s' % where)
|
||||||
return ' '.join(result), tuple(params)
|
return ' '.join(result), tuple(params)
|
||||||
|
@ -905,7 +916,7 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||||
if not self.query.values:
|
if not self.query.values:
|
||||||
return '', ()
|
return '', ()
|
||||||
table = self.query.tables[0]
|
table = self.query.tables[0]
|
||||||
qn = self.quote_name_unless_alias
|
qn = self
|
||||||
result = ['UPDATE %s' % qn(table)]
|
result = ['UPDATE %s' % qn(table)]
|
||||||
result.append('SET')
|
result.append('SET')
|
||||||
values, update_params = [], []
|
values, update_params = [], []
|
||||||
|
@ -925,7 +936,7 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||||
val = SQLEvaluator(val, self.query, allow_joins=False)
|
val = SQLEvaluator(val, self.query, allow_joins=False)
|
||||||
name = field.column
|
name = field.column
|
||||||
if hasattr(val, 'as_sql'):
|
if hasattr(val, 'as_sql'):
|
||||||
sql, params = val.as_sql(qn, self.connection)
|
sql, params = self.compile(val)
|
||||||
values.append('%s = %s' % (qn(name), sql))
|
values.append('%s = %s' % (qn(name), sql))
|
||||||
update_params.extend(params)
|
update_params.extend(params)
|
||||||
elif val is not None:
|
elif val is not None:
|
||||||
|
@ -936,7 +947,7 @@ class SQLUpdateCompiler(SQLCompiler):
|
||||||
if not values:
|
if not values:
|
||||||
return '', ()
|
return '', ()
|
||||||
result.append(', '.join(values))
|
result.append(', '.join(values))
|
||||||
where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
|
where, params = self.compile(self.query.where)
|
||||||
if where:
|
if where:
|
||||||
result.append('WHERE %s' % where)
|
result.append('WHERE %s' % where)
|
||||||
return ' '.join(result), tuple(update_params + params)
|
return ' '.join(result), tuple(update_params + params)
|
||||||
|
@ -1016,11 +1027,11 @@ class SQLAggregateCompiler(SQLCompiler):
|
||||||
parameters.
|
parameters.
|
||||||
"""
|
"""
|
||||||
if qn is None:
|
if qn is None:
|
||||||
qn = self.quote_name_unless_alias
|
qn = self
|
||||||
|
|
||||||
sql, params = [], []
|
sql, params = [], []
|
||||||
for aggregate in self.query.aggregate_select.values():
|
for aggregate in self.query.aggregate_select.values():
|
||||||
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
|
agg_sql, agg_params = self.compile(aggregate)
|
||||||
sql.append(agg_sql)
|
sql.append(agg_sql)
|
||||||
params.extend(agg_params)
|
params.extend(agg_params)
|
||||||
sql = ', '.join(sql)
|
sql = ', '.join(sql)
|
||||||
|
|
|
@ -5,18 +5,25 @@ the SQL domain.
|
||||||
|
|
||||||
|
|
||||||
class Col(object):
|
class Col(object):
|
||||||
def __init__(self, alias, col):
|
def __init__(self, alias, target, source):
|
||||||
self.alias = alias
|
self.alias, self.target, self.source = alias, target, source
|
||||||
self.col = col
|
|
||||||
|
|
||||||
def as_sql(self, qn, connection):
|
def as_sql(self, qn, connection):
|
||||||
return '%s.%s' % (qn(self.alias), self.col), []
|
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
|
||||||
|
|
||||||
def prepare(self):
|
@property
|
||||||
return self
|
def output_type(self):
|
||||||
|
return self.source
|
||||||
|
|
||||||
|
@property
|
||||||
|
def field(self):
|
||||||
|
return self.source
|
||||||
|
|
||||||
def relabeled_clone(self, relabels):
|
def relabeled_clone(self, relabels):
|
||||||
return self.__class__(relabels.get(self.alias, self.alias), self.col)
|
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source)
|
||||||
|
|
||||||
|
def get_cols(self):
|
||||||
|
return [(self.alias, self.target.column)]
|
||||||
|
|
||||||
|
|
||||||
class EmptyResultSet(Exception):
|
class EmptyResultSet(Exception):
|
||||||
|
|
|
@ -1030,6 +1030,12 @@ class Query(object):
|
||||||
def prepare_lookup_value(self, value, lookup_type, can_reuse):
|
def prepare_lookup_value(self, value, lookup_type, can_reuse):
|
||||||
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
|
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
|
||||||
# uses of None as a query value.
|
# uses of None as a query value.
|
||||||
|
if len(lookup_type) > 1:
|
||||||
|
raise FieldError('Nested lookups not allowed')
|
||||||
|
elif len(lookup_type) == 0:
|
||||||
|
lookup_type = 'exact'
|
||||||
|
else:
|
||||||
|
lookup_type = lookup_type[0]
|
||||||
if value is None:
|
if value is None:
|
||||||
if lookup_type != 'exact':
|
if lookup_type != 'exact':
|
||||||
raise ValueError("Cannot use None as a query value")
|
raise ValueError("Cannot use None as a query value")
|
||||||
|
@ -1060,31 +1066,39 @@ class Query(object):
|
||||||
"""
|
"""
|
||||||
Solve the lookup type from the lookup (eg: 'foobar__id__icontains')
|
Solve the lookup type from the lookup (eg: 'foobar__id__icontains')
|
||||||
"""
|
"""
|
||||||
lookup_type = 'exact' # Default lookup type
|
lookup_splitted = lookup.split(LOOKUP_SEP)
|
||||||
lookup_parts = lookup.split(LOOKUP_SEP)
|
aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates)
|
||||||
num_parts = len(lookup_parts)
|
if aggregate:
|
||||||
if (len(lookup_parts) > 1 and lookup_parts[-1] in self.query_terms
|
if len(aggregate_lookups) > 1:
|
||||||
and (not self._aggregates or lookup not in self._aggregates)):
|
raise FieldError("Nested lookups not allowed.")
|
||||||
# Traverse the lookup query to distinguish related fields from
|
return aggregate_lookups, (), aggregate
|
||||||
# lookup types.
|
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
|
||||||
lookup_model = self.model
|
field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)]
|
||||||
for counter, field_name in enumerate(lookup_parts):
|
if len(lookup_parts) == 0:
|
||||||
try:
|
lookup_parts = ['exact']
|
||||||
lookup_field = lookup_model._meta.get_field(field_name)
|
elif len(lookup_parts) > 1:
|
||||||
except FieldDoesNotExist:
|
if field_parts:
|
||||||
# Not a field. Bail out.
|
raise FieldError(
|
||||||
lookup_type = lookup_parts.pop()
|
'Only one lookup part allowed (found path "%s" from "%s").' %
|
||||||
break
|
(LOOKUP_SEP.join(field_parts), lookup))
|
||||||
# Unless we're at the end of the list of lookups, let's attempt
|
else:
|
||||||
# to continue traversing relations.
|
raise FieldError(
|
||||||
if (counter + 1) < num_parts:
|
'Invalid lookup "%s" for model %s".' %
|
||||||
try:
|
(lookup, self.get_meta().model.__name__))
|
||||||
lookup_model = lookup_field.rel.to
|
else:
|
||||||
except AttributeError:
|
if not hasattr(field, 'get_lookup_constraint'):
|
||||||
# Not a related field. Bail out.
|
lookup_class = field.get_lookup(lookup_parts[0])
|
||||||
lookup_type = lookup_parts.pop()
|
if lookup_class is None and lookup_parts[0] not in self.query_terms:
|
||||||
break
|
raise FieldError(
|
||||||
return lookup_type, lookup_parts
|
'Invalid lookup name %s' % lookup_parts[0])
|
||||||
|
return lookup_parts, field_parts, False
|
||||||
|
|
||||||
|
def build_lookup(self, lookup_type, lhs, rhs):
|
||||||
|
if hasattr(lhs.output_type, 'get_lookup'):
|
||||||
|
lookup = lhs.output_type.get_lookup(lookup_type)
|
||||||
|
if lookup:
|
||||||
|
return lookup(self.where_class, lhs, rhs)
|
||||||
|
return None
|
||||||
|
|
||||||
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
|
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
|
||||||
can_reuse=None, connector=AND):
|
can_reuse=None, connector=AND):
|
||||||
|
@ -1114,9 +1128,9 @@ class Query(object):
|
||||||
is responsible for unreffing the joins used.
|
is responsible for unreffing the joins used.
|
||||||
"""
|
"""
|
||||||
arg, value = filter_expr
|
arg, value = filter_expr
|
||||||
lookup_type, parts = self.solve_lookup_type(arg)
|
if not arg:
|
||||||
if not parts:
|
|
||||||
raise FieldError("Cannot parse keyword query %r" % arg)
|
raise FieldError("Cannot parse keyword query %r" % arg)
|
||||||
|
lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg)
|
||||||
|
|
||||||
# Work out the lookup type and remove it from the end of 'parts',
|
# Work out the lookup type and remove it from the end of 'parts',
|
||||||
# if necessary.
|
# if necessary.
|
||||||
|
@ -1124,11 +1138,13 @@ class Query(object):
|
||||||
used_joins = getattr(value, '_used_joins', [])
|
used_joins = getattr(value, '_used_joins', [])
|
||||||
|
|
||||||
clause = self.where_class()
|
clause = self.where_class()
|
||||||
if self._aggregates:
|
if reffed_aggregate:
|
||||||
for alias, aggregate in self.aggregates.items():
|
condition = self.build_lookup(lookup_type, reffed_aggregate, value)
|
||||||
if alias in (parts[0], LOOKUP_SEP.join(parts)):
|
if not condition:
|
||||||
clause.add((aggregate, lookup_type, value), AND)
|
# Backwards compat for custom lookups
|
||||||
return clause, []
|
condition = (reffed_aggregate, lookup_type, value)
|
||||||
|
clause.add(condition, AND)
|
||||||
|
return clause, []
|
||||||
|
|
||||||
opts = self.get_meta()
|
opts = self.get_meta()
|
||||||
alias = self.get_initial_alias()
|
alias = self.get_initial_alias()
|
||||||
|
@ -1150,11 +1166,18 @@ class Query(object):
|
||||||
targets, alias, join_list = self.trim_joins(sources, join_list, path)
|
targets, alias, join_list = self.trim_joins(sources, join_list, path)
|
||||||
|
|
||||||
if hasattr(field, 'get_lookup_constraint'):
|
if hasattr(field, 'get_lookup_constraint'):
|
||||||
constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources,
|
# For now foreign keys get special treatment. This should be
|
||||||
lookup_type, value)
|
# refactored when composite fields lands.
|
||||||
|
condition = field.get_lookup_constraint(self.where_class, alias, targets, sources,
|
||||||
|
lookup_type, value)
|
||||||
else:
|
else:
|
||||||
constraint = (Constraint(alias, targets[0].column, field), lookup_type, value)
|
assert(len(targets) == 1)
|
||||||
clause.add(constraint, AND)
|
col = Col(alias, targets[0], field)
|
||||||
|
condition = self.build_lookup(lookup_type, col, value)
|
||||||
|
if not condition:
|
||||||
|
# Backwards compat for custom lookups
|
||||||
|
condition = (Constraint(alias, targets[0].column, field), lookup_type, value)
|
||||||
|
clause.add(condition, AND)
|
||||||
|
|
||||||
require_outer = lookup_type == 'isnull' and value is True and not current_negated
|
require_outer = lookup_type == 'isnull' and value is True and not current_negated
|
||||||
if current_negated and (lookup_type != 'isnull' or value is False):
|
if current_negated and (lookup_type != 'isnull' or value is False):
|
||||||
|
@ -1185,7 +1208,7 @@ class Query(object):
|
||||||
if not self._aggregates:
|
if not self._aggregates:
|
||||||
return False
|
return False
|
||||||
if not isinstance(obj, Node):
|
if not isinstance(obj, Node):
|
||||||
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)
|
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0]
|
||||||
or (hasattr(obj[1], 'contains_aggregate')
|
or (hasattr(obj[1], 'contains_aggregate')
|
||||||
and obj[1].contains_aggregate(self.aggregates)))
|
and obj[1].contains_aggregate(self.aggregates)))
|
||||||
return any(self.need_having(c) for c in obj.children)
|
return any(self.need_having(c) for c in obj.children)
|
||||||
|
@ -1273,7 +1296,7 @@ class Query(object):
|
||||||
needed_inner = joinpromoter.update_join_types(self)
|
needed_inner = joinpromoter.update_join_types(self)
|
||||||
return target_clause, needed_inner
|
return target_clause, needed_inner
|
||||||
|
|
||||||
def names_to_path(self, names, opts, allow_many):
|
def names_to_path(self, names, opts, allow_many=True):
|
||||||
"""
|
"""
|
||||||
Walks the names path and turns them PathInfo tuples. Note that a
|
Walks the names path and turns them PathInfo tuples. Note that a
|
||||||
single name in 'names' can generate multiple PathInfos (m2m for
|
single name in 'names' can generate multiple PathInfos (m2m for
|
||||||
|
@ -1293,9 +1316,10 @@ class Query(object):
|
||||||
try:
|
try:
|
||||||
field, model, direct, m2m = opts.get_field_by_name(name)
|
field, model, direct, m2m = opts.get_field_by_name(name)
|
||||||
except FieldDoesNotExist:
|
except FieldDoesNotExist:
|
||||||
available = opts.get_all_field_names() + list(self.aggregate_select)
|
# We didn't found the current field, so move position back
|
||||||
raise FieldError("Cannot resolve keyword %r into field. "
|
# one step.
|
||||||
"Choices are: %s" % (name, ", ".join(available)))
|
pos -= 1
|
||||||
|
break
|
||||||
# Check if we need any joins for concrete inheritance cases (the
|
# Check if we need any joins for concrete inheritance cases (the
|
||||||
# field lives in parent, but we are currently in one of its
|
# field lives in parent, but we are currently in one of its
|
||||||
# children)
|
# children)
|
||||||
|
@ -1330,15 +1354,9 @@ class Query(object):
|
||||||
final_field = field
|
final_field = field
|
||||||
targets = (field,)
|
targets = (field,)
|
||||||
break
|
break
|
||||||
|
if pos == -1:
|
||||||
if pos != len(names) - 1:
|
raise FieldError('Whazaa')
|
||||||
if pos == len(names) - 2:
|
return path, final_field, targets, names[pos + 1:]
|
||||||
raise FieldError(
|
|
||||||
"Join on field %r not permitted. Did you misspell %r for "
|
|
||||||
"the lookup type?" % (name, names[pos + 1]))
|
|
||||||
else:
|
|
||||||
raise FieldError("Join on field %r not permitted." % name)
|
|
||||||
return path, final_field, targets
|
|
||||||
|
|
||||||
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
|
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
|
||||||
"""
|
"""
|
||||||
|
@ -1367,8 +1385,10 @@ class Query(object):
|
||||||
"""
|
"""
|
||||||
joins = [alias]
|
joins = [alias]
|
||||||
# First, generate the path for the names
|
# First, generate the path for the names
|
||||||
path, final_field, targets = self.names_to_path(
|
path, final_field, targets, rest = self.names_to_path(
|
||||||
names, opts, allow_many)
|
names, opts, allow_many)
|
||||||
|
if rest:
|
||||||
|
raise FieldError('Invalid lookup')
|
||||||
# Then, add the path to the query's joins. Note that we can't trim
|
# Then, add the path to the query's joins. Note that we can't trim
|
||||||
# joins at this stage - we will need the information about join type
|
# joins at this stage - we will need the information about join type
|
||||||
# of the trimmed joins.
|
# of the trimmed joins.
|
||||||
|
@ -1383,8 +1403,6 @@ class Query(object):
|
||||||
alias = self.join(
|
alias = self.join(
|
||||||
connection, reuse=reuse, nullable=nullable, join_field=join.join_field)
|
connection, reuse=reuse, nullable=nullable, join_field=join.join_field)
|
||||||
joins.append(alias)
|
joins.append(alias)
|
||||||
if hasattr(final_field, 'field'):
|
|
||||||
final_field = final_field.field
|
|
||||||
return final_field, targets, opts, joins, path
|
return final_field, targets, opts, joins, path
|
||||||
|
|
||||||
def trim_joins(self, targets, joins, path):
|
def trim_joins(self, targets, joins, path):
|
||||||
|
@ -1455,7 +1473,7 @@ class Query(object):
|
||||||
query.bump_prefix(self)
|
query.bump_prefix(self)
|
||||||
query.where.add(
|
query.where.add(
|
||||||
(Constraint(query.select[0].col[0], pk.column, pk),
|
(Constraint(query.select[0].col[0], pk.column, pk),
|
||||||
'exact', Col(alias, pk.column)),
|
'exact', Col(alias, pk, pk)),
|
||||||
AND
|
AND
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,7 @@ class WhereNode(tree.Node):
|
||||||
for child in self.children:
|
for child in self.children:
|
||||||
try:
|
try:
|
||||||
if hasattr(child, 'as_sql'):
|
if hasattr(child, 'as_sql'):
|
||||||
sql, params = child.as_sql(qn=qn, connection=connection)
|
sql, params = qn.compile(child)
|
||||||
else:
|
else:
|
||||||
# A leaf node in the tree.
|
# A leaf node in the tree.
|
||||||
sql, params = self.make_atom(child, qn, connection)
|
sql, params = self.make_atom(child, qn, connection)
|
||||||
|
@ -193,13 +193,13 @@ class WhereNode(tree.Node):
|
||||||
field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), []
|
field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), []
|
||||||
else:
|
else:
|
||||||
# A smart object with an as_sql() method.
|
# A smart object with an as_sql() method.
|
||||||
field_sql, field_params = lvalue.as_sql(qn, connection)
|
field_sql, field_params = qn.compile(lvalue)
|
||||||
|
|
||||||
is_datetime_field = value_annotation is datetime.datetime
|
is_datetime_field = value_annotation is datetime.datetime
|
||||||
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
|
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
|
||||||
|
|
||||||
if hasattr(params, 'as_sql'):
|
if hasattr(params, 'as_sql'):
|
||||||
extra, params = params.as_sql(qn, connection)
|
extra, params = qn.compile(params)
|
||||||
cast_sql = ''
|
cast_sql = ''
|
||||||
else:
|
else:
|
||||||
extra = ''
|
extra = ''
|
||||||
|
@ -282,6 +282,8 @@ class WhereNode(tree.Node):
|
||||||
if hasattr(child, 'relabel_aliases'):
|
if hasattr(child, 'relabel_aliases'):
|
||||||
# For example another WhereNode
|
# For example another WhereNode
|
||||||
child.relabel_aliases(change_map)
|
child.relabel_aliases(change_map)
|
||||||
|
elif hasattr(child, 'relabeled_clone'):
|
||||||
|
self.children[pos] = child.relabeled_clone(change_map)
|
||||||
elif isinstance(child, (list, tuple)):
|
elif isinstance(child, (list, tuple)):
|
||||||
# tuple starting with Constraint
|
# tuple starting with Constraint
|
||||||
child = (child[0].relabeled_clone(change_map),) + child[1:]
|
child = (child[0].relabeled_clone(change_map),) + child[1:]
|
||||||
|
@ -350,7 +352,7 @@ class Constraint(object):
|
||||||
self.alias, self.col, self.field = alias, col, field
|
self.alias, self.col, self.field = alias, col, field
|
||||||
|
|
||||||
def prepare(self, lookup_type, value):
|
def prepare(self, lookup_type, value):
|
||||||
if self.field:
|
if self.field and not hasattr(value, 'as_sql'):
|
||||||
return self.field.get_prep_lookup(lookup_type, value)
|
return self.field.get_prep_lookup(lookup_type, value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
|
@ -443,7 +443,7 @@ class BaseAggregateTestCase(TestCase):
|
||||||
vals = Author.objects.filter(pk=1).aggregate(Count("friends__id"))
|
vals = Author.objects.filter(pk=1).aggregate(Count("friends__id"))
|
||||||
self.assertEqual(vals, {"friends__id__count": 2})
|
self.assertEqual(vals, {"friends__id__count": 2})
|
||||||
|
|
||||||
books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__ge=2).order_by("pk")
|
books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__exact=2).order_by("pk")
|
||||||
self.assertQuerysetEqual(
|
self.assertQuerysetEqual(
|
||||||
books, [
|
books, [
|
||||||
"The Definitive Guide to Django: Web Development Done Right",
|
"The Definitive Guide to Django: Web Development Done Right",
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
|
||||||
|
class Author(models.Model):
|
||||||
|
name = models.CharField(max_length=20)
|
||||||
|
age = models.IntegerField(null=True)
|
||||||
|
birthdate = models.DateField(null=True)
|
|
@ -0,0 +1,136 @@
|
||||||
|
from copy import copy
|
||||||
|
from datetime import date
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from django.test import TestCase
|
||||||
|
from .models import Author
|
||||||
|
from django.db import models
|
||||||
|
from django.db import connection
|
||||||
|
from django.db.backends.utils import add_implementation
|
||||||
|
|
||||||
|
|
||||||
|
class Div3Lookup(models.lookups.Lookup):
|
||||||
|
lookup_name = 'div3'
|
||||||
|
|
||||||
|
def as_sql(self, qn, connection):
|
||||||
|
lhs, params = self.process_lhs(qn, connection)
|
||||||
|
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||||
|
params.extend(rhs_params)
|
||||||
|
return '%s %%%% 3 = %s' % (lhs, rhs), params
|
||||||
|
|
||||||
|
|
||||||
|
class InMonth(models.lookups.Lookup):
|
||||||
|
"""
|
||||||
|
InMonth matches if the column's month is contained in the value's month.
|
||||||
|
"""
|
||||||
|
lookup_name = 'inmonth'
|
||||||
|
|
||||||
|
def as_sql(self, qn, connection):
|
||||||
|
lhs, params = self.process_lhs(qn, connection)
|
||||||
|
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||||
|
# We need to be careful so that we get the params in right
|
||||||
|
# places.
|
||||||
|
full_params = params[:]
|
||||||
|
full_params.extend(rhs_params)
|
||||||
|
full_params.extend(params)
|
||||||
|
full_params.extend(rhs_params)
|
||||||
|
return ("%s >= date_trunc('month', %s) and "
|
||||||
|
"%s < date_trunc('month', %s) + interval '1 months'" %
|
||||||
|
(lhs, rhs, lhs, rhs), full_params)
|
||||||
|
|
||||||
|
|
||||||
|
class LookupTests(TestCase):
|
||||||
|
def test_basic_lookup(self):
|
||||||
|
a1 = Author.objects.create(name='a1', age=1)
|
||||||
|
a2 = Author.objects.create(name='a2', age=2)
|
||||||
|
a3 = Author.objects.create(name='a3', age=3)
|
||||||
|
a4 = Author.objects.create(name='a4', age=4)
|
||||||
|
models.IntegerField.register_lookup(Div3Lookup)
|
||||||
|
try:
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Author.objects.filter(age__div3=0),
|
||||||
|
[a3], lambda x: x
|
||||||
|
)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Author.objects.filter(age__div3=1).order_by('age'),
|
||||||
|
[a1, a4], lambda x: x
|
||||||
|
)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Author.objects.filter(age__div3=2),
|
||||||
|
[a2], lambda x: x
|
||||||
|
)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Author.objects.filter(age__div3=3),
|
||||||
|
[], lambda x: x
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
models.IntegerField._unregister_lookup(Div3Lookup)
|
||||||
|
|
||||||
|
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
|
||||||
|
def test_birthdate_month(self):
|
||||||
|
a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
|
||||||
|
a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
|
||||||
|
a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
|
||||||
|
a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
|
||||||
|
models.DateField.register_lookup(InMonth)
|
||||||
|
try:
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)),
|
||||||
|
[a3], lambda x: x
|
||||||
|
)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)),
|
||||||
|
[a2], lambda x: x
|
||||||
|
)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)),
|
||||||
|
[a1], lambda x: x
|
||||||
|
)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)),
|
||||||
|
[a4], lambda x: x
|
||||||
|
)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)),
|
||||||
|
[], lambda x: x
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
models.DateField._unregister_lookup(InMonth)
|
||||||
|
|
||||||
|
def test_custom_compiles(self):
|
||||||
|
a1 = Author.objects.create(name='a1', age=1)
|
||||||
|
a2 = Author.objects.create(name='a2', age=2)
|
||||||
|
a3 = Author.objects.create(name='a3', age=3)
|
||||||
|
a4 = Author.objects.create(name='a4', age=4)
|
||||||
|
|
||||||
|
class AnotherEqual(models.lookups.Exact):
|
||||||
|
lookup_name = 'anotherequal'
|
||||||
|
models.Field.register_lookup(AnotherEqual)
|
||||||
|
try:
|
||||||
|
@add_implementation(AnotherEqual, connection.vendor)
|
||||||
|
def custom_eq_sql(node, compiler):
|
||||||
|
return '1 = 1', []
|
||||||
|
|
||||||
|
self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Author.objects.filter(name__anotherequal='asdf').order_by('name'),
|
||||||
|
[a1, a2, a3, a4], lambda x: x)
|
||||||
|
|
||||||
|
@add_implementation(AnotherEqual, connection.vendor)
|
||||||
|
def another_custom_eq_sql(node, compiler):
|
||||||
|
# If you need to override one method, it seems this is the best
|
||||||
|
# option.
|
||||||
|
node = copy(node)
|
||||||
|
|
||||||
|
class OverriddenAnotherEqual(AnotherEqual):
|
||||||
|
def get_rhs_op(self, connection, rhs):
|
||||||
|
return ' <> %s'
|
||||||
|
node.__class__ = OverriddenAnotherEqual
|
||||||
|
return node.as_sql(compiler, compiler.connection)
|
||||||
|
self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
Author.objects.filter(name__anotherequal='a1').order_by('name'),
|
||||||
|
[a2, a3, a4], lambda x: x
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
models.Field._unregister_lookup(AnotherEqual)
|
|
@ -2620,8 +2620,15 @@ class WhereNodeTest(TestCase):
|
||||||
def as_sql(self, qn, connection):
|
def as_sql(self, qn, connection):
|
||||||
return 'dummy', []
|
return 'dummy', []
|
||||||
|
|
||||||
|
class MockCompiler(object):
|
||||||
|
def compile(self, node):
|
||||||
|
return node.as_sql(self, connection)
|
||||||
|
|
||||||
|
def __call__(self, name):
|
||||||
|
return connection.ops.quote_name(name)
|
||||||
|
|
||||||
def test_empty_full_handling_conjunction(self):
|
def test_empty_full_handling_conjunction(self):
|
||||||
qn = connection.ops.quote_name
|
qn = WhereNodeTest.MockCompiler()
|
||||||
w = WhereNode(children=[EverythingNode()])
|
w = WhereNode(children=[EverythingNode()])
|
||||||
self.assertEqual(w.as_sql(qn, connection), ('', []))
|
self.assertEqual(w.as_sql(qn, connection), ('', []))
|
||||||
w.negate()
|
w.negate()
|
||||||
|
@ -2646,7 +2653,7 @@ class WhereNodeTest(TestCase):
|
||||||
self.assertEqual(w.as_sql(qn, connection), ('', []))
|
self.assertEqual(w.as_sql(qn, connection), ('', []))
|
||||||
|
|
||||||
def test_empty_full_handling_disjunction(self):
|
def test_empty_full_handling_disjunction(self):
|
||||||
qn = connection.ops.quote_name
|
qn = WhereNodeTest.MockCompiler()
|
||||||
w = WhereNode(children=[EverythingNode()], connector='OR')
|
w = WhereNode(children=[EverythingNode()], connector='OR')
|
||||||
self.assertEqual(w.as_sql(qn, connection), ('', []))
|
self.assertEqual(w.as_sql(qn, connection), ('', []))
|
||||||
w.negate()
|
w.negate()
|
||||||
|
@ -2673,7 +2680,7 @@ class WhereNodeTest(TestCase):
|
||||||
self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', []))
|
self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', []))
|
||||||
|
|
||||||
def test_empty_nodes(self):
|
def test_empty_nodes(self):
|
||||||
qn = connection.ops.quote_name
|
qn = WhereNodeTest.MockCompiler()
|
||||||
empty_w = WhereNode()
|
empty_w = WhereNode()
|
||||||
w = WhereNode(children=[empty_w, empty_w])
|
w = WhereNode(children=[empty_w, empty_w])
|
||||||
self.assertEqual(w.as_sql(qn, connection), (None, []))
|
self.assertEqual(w.as_sql(qn, connection), (None, []))
|
||||||
|
|
Loading…
Reference in New Issue