564 lines
19 KiB
Python
564 lines
19 KiB
Python
from copy import copy
|
|
|
|
from django.conf import settings
|
|
from django.db.models.expressions import Func, Value
|
|
from django.db.models.fields import (
|
|
DateField, DateTimeField, Field, IntegerField, TimeField,
|
|
)
|
|
from django.db.models.query_utils import RegisterLookupMixin
|
|
from django.utils import timezone
|
|
from django.utils.functional import cached_property
|
|
from django.utils.six.moves import range
|
|
|
|
|
|
class Lookup(object):
|
|
lookup_name = None
|
|
|
|
def __init__(self, lhs, rhs):
|
|
self.lhs, self.rhs = lhs, rhs
|
|
self.rhs = self.get_prep_lookup()
|
|
if hasattr(self.lhs, 'get_bilateral_transforms'):
|
|
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
|
else:
|
|
bilateral_transforms = []
|
|
if bilateral_transforms:
|
|
# Warn the user as soon as possible if they are trying to apply
|
|
# a bilateral transformation on a nested QuerySet: that won't work.
|
|
# We need to import QuerySet here so as to avoid circular
|
|
from django.db.models.query import QuerySet
|
|
if isinstance(rhs, QuerySet):
|
|
raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
|
|
self.bilateral_transforms = bilateral_transforms
|
|
|
|
def apply_bilateral_transforms(self, value):
|
|
for transform in self.bilateral_transforms:
|
|
value = transform(value)
|
|
return value
|
|
|
|
def batch_process_rhs(self, compiler, connection, rhs=None):
|
|
if rhs is None:
|
|
rhs = self.rhs
|
|
if self.bilateral_transforms:
|
|
sqls, sqls_params = [], []
|
|
for p in rhs:
|
|
value = Value(p, output_field=self.lhs.output_field)
|
|
value = self.apply_bilateral_transforms(value)
|
|
value = value.resolve_expression(compiler.query)
|
|
sql, sql_params = compiler.compile(value)
|
|
sqls.append(sql)
|
|
sqls_params.extend(sql_params)
|
|
else:
|
|
params = self.lhs.output_field.get_db_prep_lookup(
|
|
self.lookup_name, rhs, connection, prepared=True)
|
|
sqls, sqls_params = ['%s'] * len(params), params
|
|
return sqls, sqls_params
|
|
|
|
def get_prep_lookup(self):
|
|
return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)
|
|
|
|
def get_db_prep_lookup(self, value, connection):
|
|
return (
|
|
'%s', self.lhs.output_field.get_db_prep_lookup(
|
|
self.lookup_name, value, connection, prepared=True))
|
|
|
|
def process_lhs(self, compiler, connection, lhs=None):
|
|
lhs = lhs or self.lhs
|
|
return compiler.compile(lhs)
|
|
|
|
def process_rhs(self, compiler, connection):
|
|
value = self.rhs
|
|
if self.bilateral_transforms:
|
|
if self.rhs_is_direct_value():
|
|
# Do not call get_db_prep_lookup here as the value will be
|
|
# transformed before being used for lookup
|
|
value = Value(value, output_field=self.lhs.output_field)
|
|
value = self.apply_bilateral_transforms(value)
|
|
value = value.resolve_expression(compiler.query)
|
|
# Due to historical reasons there are a couple of different
|
|
# ways to produce sql here. get_compiler is likely a Query
|
|
# instance, _as_sql QuerySet and as_sql just something with
|
|
# as_sql. Finally the value can of course be just plain
|
|
# Python value.
|
|
if hasattr(value, 'get_compiler'):
|
|
value = value.get_compiler(connection=connection)
|
|
if hasattr(value, 'as_sql'):
|
|
sql, params = compiler.compile(value)
|
|
return '(' + sql + ')', params
|
|
if hasattr(value, '_as_sql'):
|
|
sql, params = value._as_sql(connection=connection)
|
|
return '(' + sql + ')', params
|
|
else:
|
|
return self.get_db_prep_lookup(value, connection)
|
|
|
|
def rhs_is_direct_value(self):
|
|
return not(
|
|
hasattr(self.rhs, 'as_sql') or
|
|
hasattr(self.rhs, '_as_sql') or
|
|
hasattr(self.rhs, 'get_compiler'))
|
|
|
|
def relabeled_clone(self, relabels):
|
|
new = copy(self)
|
|
new.lhs = new.lhs.relabeled_clone(relabels)
|
|
if hasattr(new.rhs, 'relabeled_clone'):
|
|
new.rhs = new.rhs.relabeled_clone(relabels)
|
|
return new
|
|
|
|
def get_group_by_cols(self):
|
|
cols = self.lhs.get_group_by_cols()
|
|
if hasattr(self.rhs, 'get_group_by_cols'):
|
|
cols.extend(self.rhs.get_group_by_cols())
|
|
return cols
|
|
|
|
def as_sql(self, compiler, connection):
|
|
raise NotImplementedError
|
|
|
|
@cached_property
|
|
def contains_aggregate(self):
|
|
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
|
|
|
|
|
|
class Transform(RegisterLookupMixin, Func):
|
|
"""
|
|
RegisterLookupMixin() is first so that get_lookup() and get_transform()
|
|
first examine self and then check output_field.
|
|
"""
|
|
bilateral = False
|
|
arity = 1
|
|
|
|
@property
|
|
def lhs(self):
|
|
return self.get_source_expressions()[0]
|
|
|
|
def get_bilateral_transforms(self):
|
|
if hasattr(self.lhs, 'get_bilateral_transforms'):
|
|
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
|
else:
|
|
bilateral_transforms = []
|
|
if self.bilateral:
|
|
bilateral_transforms.append(self.__class__)
|
|
return bilateral_transforms
|
|
|
|
|
|
class BuiltinLookup(Lookup):
|
|
def process_lhs(self, compiler, connection, lhs=None):
|
|
lhs_sql, params = super(BuiltinLookup, self).process_lhs(
|
|
compiler, connection, lhs)
|
|
field_internal_type = self.lhs.output_field.get_internal_type()
|
|
db_type = self.lhs.output_field.db_type(connection=connection)
|
|
lhs_sql = connection.ops.field_cast_sql(
|
|
db_type, field_internal_type) % lhs_sql
|
|
lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
|
|
return lhs_sql, list(params)
|
|
|
|
def as_sql(self, compiler, connection):
|
|
lhs_sql, params = self.process_lhs(compiler, connection)
|
|
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
|
params.extend(rhs_params)
|
|
rhs_sql = self.get_rhs_op(connection, rhs_sql)
|
|
return '%s %s' % (lhs_sql, rhs_sql), params
|
|
|
|
def get_rhs_op(self, connection, rhs):
|
|
return connection.operators[self.lookup_name] % rhs
|
|
|
|
|
|
class Exact(BuiltinLookup):
|
|
lookup_name = 'exact'
|
|
Field.register_lookup(Exact)
|
|
|
|
|
|
class IExact(BuiltinLookup):
|
|
lookup_name = 'iexact'
|
|
|
|
def process_rhs(self, qn, connection):
|
|
rhs, params = super(IExact, self).process_rhs(qn, connection)
|
|
if params:
|
|
params[0] = connection.ops.prep_for_iexact_query(params[0])
|
|
return rhs, params
|
|
|
|
|
|
Field.register_lookup(IExact)
|
|
|
|
|
|
class GreaterThan(BuiltinLookup):
|
|
lookup_name = 'gt'
|
|
Field.register_lookup(GreaterThan)
|
|
|
|
|
|
class GreaterThanOrEqual(BuiltinLookup):
|
|
lookup_name = 'gte'
|
|
Field.register_lookup(GreaterThanOrEqual)
|
|
|
|
|
|
class LessThan(BuiltinLookup):
|
|
lookup_name = 'lt'
|
|
Field.register_lookup(LessThan)
|
|
|
|
|
|
class LessThanOrEqual(BuiltinLookup):
|
|
lookup_name = 'lte'
|
|
Field.register_lookup(LessThanOrEqual)
|
|
|
|
|
|
class In(BuiltinLookup):
|
|
lookup_name = 'in'
|
|
|
|
def process_rhs(self, compiler, connection):
|
|
db_rhs = getattr(self.rhs, '_db', None)
|
|
if db_rhs is not None and db_rhs != connection.alias:
|
|
raise ValueError(
|
|
"Subqueries aren't allowed across different databases. Force "
|
|
"the inner query to be evaluated using `list(inner_query)`."
|
|
)
|
|
|
|
if self.rhs_is_direct_value():
|
|
try:
|
|
rhs = set(self.rhs)
|
|
except TypeError: # Unhashable items in self.rhs
|
|
rhs = self.rhs
|
|
|
|
if not rhs:
|
|
from django.db.models.sql.datastructures import EmptyResultSet
|
|
raise EmptyResultSet
|
|
|
|
# rhs should be an iterable; use batch_process_rhs() to
|
|
# prepare/transform those values.
|
|
sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
|
|
placeholder = '(' + ', '.join(sqls) + ')'
|
|
return (placeholder, sqls_params)
|
|
else:
|
|
return super(In, self).process_rhs(compiler, connection)
|
|
|
|
def get_rhs_op(self, connection, rhs):
|
|
return 'IN %s' % rhs
|
|
|
|
def as_sql(self, compiler, connection):
|
|
max_in_list_size = connection.ops.max_in_list_size()
|
|
if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
|
|
return self.split_parameter_list_as_sql(compiler, connection)
|
|
return super(In, self).as_sql(compiler, connection)
|
|
|
|
def split_parameter_list_as_sql(self, compiler, connection):
|
|
# This is a special case for databases which limit the number of
|
|
# elements which can appear in an 'IN' clause.
|
|
max_in_list_size = connection.ops.max_in_list_size()
|
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
|
rhs, rhs_params = self.batch_process_rhs(compiler, connection)
|
|
in_clause_elements = ['(']
|
|
params = []
|
|
for offset in range(0, len(rhs_params), max_in_list_size):
|
|
if offset > 0:
|
|
in_clause_elements.append(' OR ')
|
|
in_clause_elements.append('%s IN (' % lhs)
|
|
params.extend(lhs_params)
|
|
sqls = rhs[offset: offset + max_in_list_size]
|
|
sqls_params = rhs_params[offset: offset + max_in_list_size]
|
|
param_group = ', '.join(sqls)
|
|
in_clause_elements.append(param_group)
|
|
in_clause_elements.append(')')
|
|
params.extend(sqls_params)
|
|
in_clause_elements.append(')')
|
|
return ''.join(in_clause_elements), params
|
|
Field.register_lookup(In)
|
|
|
|
|
|
class PatternLookup(BuiltinLookup):
|
|
|
|
def get_rhs_op(self, connection, rhs):
|
|
# Assume we are in startswith. We need to produce SQL like:
|
|
# col LIKE %s, ['thevalue%']
|
|
# For python values we can (and should) do that directly in Python,
|
|
# but if the value is for example reference to other column, then
|
|
# we need to add the % pattern match to the lookup by something like
|
|
# col LIKE othercol || '%%'
|
|
# So, for Python values we don't need any special pattern, but for
|
|
# SQL reference values or SQL transformations we need the correct
|
|
# pattern added.
|
|
if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql')
|
|
or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
|
|
pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
|
|
return pattern.format(rhs)
|
|
else:
|
|
return super(PatternLookup, self).get_rhs_op(connection, rhs)
|
|
|
|
|
|
class Contains(PatternLookup):
|
|
lookup_name = 'contains'
|
|
|
|
def process_rhs(self, qn, connection):
|
|
rhs, params = super(Contains, self).process_rhs(qn, connection)
|
|
if params and not self.bilateral_transforms:
|
|
params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0])
|
|
return rhs, params
|
|
Field.register_lookup(Contains)
|
|
|
|
|
|
class IContains(Contains):
|
|
lookup_name = 'icontains'
|
|
Field.register_lookup(IContains)
|
|
|
|
|
|
class StartsWith(PatternLookup):
|
|
lookup_name = 'startswith'
|
|
|
|
def process_rhs(self, qn, connection):
|
|
rhs, params = super(StartsWith, self).process_rhs(qn, connection)
|
|
if params and not self.bilateral_transforms:
|
|
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
|
|
return rhs, params
|
|
Field.register_lookup(StartsWith)
|
|
|
|
|
|
class IStartsWith(PatternLookup):
|
|
lookup_name = 'istartswith'
|
|
|
|
def process_rhs(self, qn, connection):
|
|
rhs, params = super(IStartsWith, self).process_rhs(qn, connection)
|
|
if params and not self.bilateral_transforms:
|
|
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
|
|
return rhs, params
|
|
Field.register_lookup(IStartsWith)
|
|
|
|
|
|
class EndsWith(PatternLookup):
|
|
lookup_name = 'endswith'
|
|
|
|
def process_rhs(self, qn, connection):
|
|
rhs, params = super(EndsWith, self).process_rhs(qn, connection)
|
|
if params and not self.bilateral_transforms:
|
|
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
|
|
return rhs, params
|
|
Field.register_lookup(EndsWith)
|
|
|
|
|
|
class IEndsWith(PatternLookup):
|
|
lookup_name = 'iendswith'
|
|
|
|
def process_rhs(self, qn, connection):
|
|
rhs, params = super(IEndsWith, self).process_rhs(qn, connection)
|
|
if params and not self.bilateral_transforms:
|
|
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
|
|
return rhs, params
|
|
Field.register_lookup(IEndsWith)
|
|
|
|
|
|
class Between(BuiltinLookup):
|
|
def get_rhs_op(self, connection, rhs):
|
|
return "BETWEEN %s AND %s" % (rhs, rhs)
|
|
|
|
|
|
class Range(BuiltinLookup):
|
|
lookup_name = 'range'
|
|
|
|
def get_rhs_op(self, connection, rhs):
|
|
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
|
|
|
|
def process_rhs(self, compiler, connection):
|
|
if self.rhs_is_direct_value():
|
|
# rhs should be an iterable of 2 values, we use batch_process_rhs
|
|
# to prepare/transform those values
|
|
return self.batch_process_rhs(compiler, connection)
|
|
else:
|
|
return super(Range, self).process_rhs(compiler, connection)
|
|
Field.register_lookup(Range)
|
|
|
|
|
|
class IsNull(BuiltinLookup):
|
|
lookup_name = 'isnull'
|
|
|
|
def as_sql(self, compiler, connection):
|
|
sql, params = compiler.compile(self.lhs)
|
|
if self.rhs:
|
|
return "%s IS NULL" % sql, params
|
|
else:
|
|
return "%s IS NOT NULL" % sql, params
|
|
Field.register_lookup(IsNull)
|
|
|
|
|
|
class Search(BuiltinLookup):
|
|
lookup_name = 'search'
|
|
|
|
def as_sql(self, compiler, connection):
|
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
|
sql_template = connection.ops.fulltext_search_sql(field_name=lhs)
|
|
return sql_template, lhs_params + rhs_params
|
|
Field.register_lookup(Search)
|
|
|
|
|
|
class Regex(BuiltinLookup):
|
|
lookup_name = 'regex'
|
|
|
|
def as_sql(self, compiler, connection):
|
|
if self.lookup_name in connection.operators:
|
|
return super(Regex, self).as_sql(compiler, connection)
|
|
else:
|
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
|
sql_template = connection.ops.regex_lookup(self.lookup_name)
|
|
return sql_template % (lhs, rhs), lhs_params + rhs_params
|
|
Field.register_lookup(Regex)
|
|
|
|
|
|
class IRegex(Regex):
|
|
lookup_name = 'iregex'
|
|
Field.register_lookup(IRegex)
|
|
|
|
|
|
class DateTimeDateTransform(Transform):
|
|
lookup_name = 'date'
|
|
|
|
@cached_property
|
|
def output_field(self):
|
|
return DateField()
|
|
|
|
def as_sql(self, compiler, connection):
|
|
lhs, lhs_params = compiler.compile(self.lhs)
|
|
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
|
|
sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
|
|
lhs_params.extend(tz_params)
|
|
return sql, lhs_params
|
|
|
|
|
|
class DateTransform(Transform):
|
|
def as_sql(self, compiler, connection):
|
|
sql, params = compiler.compile(self.lhs)
|
|
lhs_output_field = self.lhs.output_field
|
|
if isinstance(lhs_output_field, DateTimeField):
|
|
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
|
|
sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
|
|
params.extend(tz_params)
|
|
elif isinstance(lhs_output_field, DateField):
|
|
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
|
|
elif isinstance(lhs_output_field, TimeField):
|
|
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
|
|
else:
|
|
raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
|
|
return sql, params
|
|
|
|
@cached_property
|
|
def output_field(self):
|
|
return IntegerField()
|
|
|
|
|
|
class YearTransform(DateTransform):
|
|
lookup_name = 'year'
|
|
|
|
|
|
class YearLookup(Lookup):
|
|
def year_lookup_bounds(self, connection, year):
|
|
output_field = self.lhs.lhs.output_field
|
|
if isinstance(output_field, DateTimeField):
|
|
bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
|
|
else:
|
|
bounds = connection.ops.year_lookup_bounds_for_date_field(year)
|
|
return bounds
|
|
|
|
|
|
@YearTransform.register_lookup
|
|
class YearExact(YearLookup):
|
|
lookup_name = 'exact'
|
|
|
|
def as_sql(self, compiler, connection):
|
|
# We will need to skip the extract part and instead go
|
|
# directly with the originating field, that is self.lhs.lhs.
|
|
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
|
|
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
|
bounds = self.year_lookup_bounds(connection, rhs_params[0])
|
|
params.extend(bounds)
|
|
return '%s BETWEEN %%s AND %%s' % lhs_sql, params
|
|
|
|
|
|
class YearComparisonLookup(YearLookup):
|
|
def as_sql(self, compiler, connection):
|
|
# We will need to skip the extract part and instead go
|
|
# directly with the originating field, that is self.lhs.lhs.
|
|
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
|
|
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
|
rhs_sql = self.get_rhs_op(connection, rhs_sql)
|
|
start, finish = self.year_lookup_bounds(connection, rhs_params[0])
|
|
params.append(self.get_bound(start, finish))
|
|
return '%s %s' % (lhs_sql, rhs_sql), params
|
|
|
|
def get_rhs_op(self, connection, rhs):
|
|
return connection.operators[self.lookup_name] % rhs
|
|
|
|
def get_bound(self):
|
|
raise NotImplementedError(
|
|
'subclasses of YearComparisonLookup must provide a get_bound() method'
|
|
)
|
|
|
|
|
|
@YearTransform.register_lookup
|
|
class YearGt(YearComparisonLookup):
|
|
lookup_name = 'gt'
|
|
|
|
def get_bound(self, start, finish):
|
|
return finish
|
|
|
|
|
|
@YearTransform.register_lookup
|
|
class YearGte(YearComparisonLookup):
|
|
lookup_name = 'gte'
|
|
|
|
def get_bound(self, start, finish):
|
|
return start
|
|
|
|
|
|
@YearTransform.register_lookup
|
|
class YearLt(YearComparisonLookup):
|
|
lookup_name = 'lt'
|
|
|
|
def get_bound(self, start, finish):
|
|
return start
|
|
|
|
|
|
@YearTransform.register_lookup
|
|
class YearLte(YearComparisonLookup):
|
|
lookup_name = 'lte'
|
|
|
|
def get_bound(self, start, finish):
|
|
return finish
|
|
|
|
|
|
class MonthTransform(DateTransform):
|
|
lookup_name = 'month'
|
|
|
|
|
|
class DayTransform(DateTransform):
|
|
lookup_name = 'day'
|
|
|
|
|
|
class WeekDayTransform(DateTransform):
|
|
lookup_name = 'week_day'
|
|
|
|
|
|
class HourTransform(DateTransform):
|
|
lookup_name = 'hour'
|
|
|
|
|
|
class MinuteTransform(DateTransform):
|
|
lookup_name = 'minute'
|
|
|
|
|
|
class SecondTransform(DateTransform):
|
|
lookup_name = 'second'
|
|
|
|
|
|
DateField.register_lookup(YearTransform)
|
|
DateField.register_lookup(MonthTransform)
|
|
DateField.register_lookup(DayTransform)
|
|
DateField.register_lookup(WeekDayTransform)
|
|
|
|
TimeField.register_lookup(HourTransform)
|
|
TimeField.register_lookup(MinuteTransform)
|
|
TimeField.register_lookup(SecondTransform)
|
|
|
|
DateTimeField.register_lookup(DateTimeDateTransform)
|
|
DateTimeField.register_lookup(YearTransform)
|
|
DateTimeField.register_lookup(MonthTransform)
|
|
DateTimeField.register_lookup(DayTransform)
|
|
DateTimeField.register_lookup(WeekDayTransform)
|
|
DateTimeField.register_lookup(HourTransform)
|
|
DateTimeField.register_lookup(MinuteTransform)
|
|
DateTimeField.register_lookup(SecondTransform)
|