from copy import copy from django.conf import settings from django.db.models.expressions import Func, Value from django.db.models.fields import ( DateField, DateTimeField, Field, IntegerField, TimeField, ) from django.db.models.query_utils import RegisterLookupMixin from django.utils import timezone from django.utils.functional import cached_property from django.utils.six.moves import range class Lookup(object): lookup_name = None def __init__(self, lhs, rhs): self.lhs, self.rhs = lhs, rhs self.rhs = self.get_prep_lookup() if hasattr(self.lhs, 'get_bilateral_transforms'): bilateral_transforms = self.lhs.get_bilateral_transforms() else: bilateral_transforms = [] if bilateral_transforms: # We should warn the user as soon as possible if he is trying to apply # a bilateral transformation on a nested QuerySet: that won't work. # We need to import QuerySet here so as to avoid circular from django.db.models.query import QuerySet if isinstance(rhs, QuerySet): raise NotImplementedError("Bilateral transformations on nested querysets are not supported.") self.bilateral_transforms = bilateral_transforms def apply_bilateral_transforms(self, value): for transform in self.bilateral_transforms: value = transform(value) return value def batch_process_rhs(self, compiler, connection, rhs=None): if rhs is None: rhs = self.rhs if self.bilateral_transforms: sqls, sqls_params = [], [] for p in rhs: value = Value(p, output_field=self.lhs.output_field) value = self.apply_bilateral_transforms(value) value = value.resolve_expression(compiler.query) sql, sql_params = compiler.compile(value) sqls.append(sql) sqls_params.extend(sql_params) else: params = self.lhs.output_field.get_db_prep_lookup( self.lookup_name, rhs, connection, prepared=True) sqls, sqls_params = ['%s'] * len(params), params return sqls, sqls_params def get_prep_lookup(self): return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs) def get_db_prep_lookup(self, value, connection): return ( '%s', self.lhs.output_field.get_db_prep_lookup( self.lookup_name, value, connection, prepared=True)) def process_lhs(self, compiler, connection, lhs=None): lhs = lhs or self.lhs return compiler.compile(lhs) def process_rhs(self, compiler, connection): value = self.rhs if self.bilateral_transforms: if self.rhs_is_direct_value(): # Do not call get_db_prep_lookup here as the value will be # transformed before being used for lookup value = Value(value, output_field=self.lhs.output_field) value = self.apply_bilateral_transforms(value) value = value.resolve_expression(compiler.query) # Due to historical reasons there are a couple of different # ways to produce sql here. get_compiler is likely a Query # instance, _as_sql QuerySet and as_sql just something with # as_sql. Finally the value can of course be just plain # Python value. if hasattr(value, 'get_compiler'): value = value.get_compiler(connection=connection) if hasattr(value, 'as_sql'): sql, params = compiler.compile(value) return '(' + sql + ')', params if hasattr(value, '_as_sql'): sql, params = value._as_sql(connection=connection) return '(' + sql + ')', params else: return self.get_db_prep_lookup(value, connection) def rhs_is_direct_value(self): return not( hasattr(self.rhs, 'as_sql') or hasattr(self.rhs, '_as_sql') or hasattr(self.rhs, 'get_compiler')) def relabeled_clone(self, relabels): new = copy(self) new.lhs = new.lhs.relabeled_clone(relabels) if hasattr(new.rhs, 'relabeled_clone'): new.rhs = new.rhs.relabeled_clone(relabels) return new def get_group_by_cols(self): cols = self.lhs.get_group_by_cols() if hasattr(self.rhs, 'get_group_by_cols'): cols.extend(self.rhs.get_group_by_cols()) return cols def as_sql(self, compiler, connection): raise NotImplementedError @cached_property def contains_aggregate(self): return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) class Transform(RegisterLookupMixin, Func): """ RegisterLookupMixin() is first so that get_lookup() and get_transform() first examine self and then check output_field. """ bilateral = False def __init__(self, expression, **extra): # Restrict Transform to allow only a single expression. super(Transform, self).__init__(expression, **extra) @property def lhs(self): return self.get_source_expressions()[0] def get_bilateral_transforms(self): if hasattr(self.lhs, 'get_bilateral_transforms'): bilateral_transforms = self.lhs.get_bilateral_transforms() else: bilateral_transforms = [] if self.bilateral: bilateral_transforms.append(self.__class__) return bilateral_transforms class BuiltinLookup(Lookup): def process_lhs(self, compiler, connection, lhs=None): lhs_sql, params = super(BuiltinLookup, self).process_lhs( compiler, connection, lhs) field_internal_type = self.lhs.output_field.get_internal_type() db_type = self.lhs.output_field.db_type(connection=connection) lhs_sql = connection.ops.field_cast_sql( db_type, field_internal_type) % lhs_sql lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql return lhs_sql, list(params) def as_sql(self, compiler, connection): lhs_sql, params = self.process_lhs(compiler, connection) rhs_sql, rhs_params = self.process_rhs(compiler, connection) params.extend(rhs_params) rhs_sql = self.get_rhs_op(connection, rhs_sql) return '%s %s' % (lhs_sql, rhs_sql), params def get_rhs_op(self, connection, rhs): return connection.operators[self.lookup_name] % rhs class Exact(BuiltinLookup): lookup_name = 'exact' Field.register_lookup(Exact) class IExact(BuiltinLookup): lookup_name = 'iexact' def process_rhs(self, qn, connection): rhs, params = super(IExact, self).process_rhs(qn, connection) if params: params[0] = connection.ops.prep_for_iexact_query(params[0]) return rhs, params Field.register_lookup(IExact) class GreaterThan(BuiltinLookup): lookup_name = 'gt' Field.register_lookup(GreaterThan) class GreaterThanOrEqual(BuiltinLookup): lookup_name = 'gte' Field.register_lookup(GreaterThanOrEqual) class LessThan(BuiltinLookup): lookup_name = 'lt' Field.register_lookup(LessThan) class LessThanOrEqual(BuiltinLookup): lookup_name = 'lte' Field.register_lookup(LessThanOrEqual) class In(BuiltinLookup): lookup_name = 'in' def process_rhs(self, compiler, connection): if self.rhs_is_direct_value(): # rhs should be an iterable, we use batch_process_rhs # to prepare/transform those values rhs = list(self.rhs) if not rhs: from django.db.models.sql.datastructures import EmptyResultSet raise EmptyResultSet sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs) placeholder = '(' + ', '.join(sqls) + ')' return (placeholder, sqls_params) else: return super(In, self).process_rhs(compiler, connection) def get_rhs_op(self, connection, rhs): return 'IN %s' % rhs def as_sql(self, compiler, connection): max_in_list_size = connection.ops.max_in_list_size() if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size: return self.split_parameter_list_as_sql(compiler, connection) return super(In, self).as_sql(compiler, connection) def split_parameter_list_as_sql(self, compiler, connection): # This is a special case for databases which limit the number of # elements which can appear in an 'IN' clause. max_in_list_size = connection.ops.max_in_list_size() lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.batch_process_rhs(compiler, connection) in_clause_elements = ['('] params = [] for offset in range(0, len(rhs_params), max_in_list_size): if offset > 0: in_clause_elements.append(' OR ') in_clause_elements.append('%s IN (' % lhs) params.extend(lhs_params) sqls = rhs[offset: offset + max_in_list_size] sqls_params = rhs_params[offset: offset + max_in_list_size] param_group = ', '.join(sqls) in_clause_elements.append(param_group) in_clause_elements.append(')') params.extend(sqls_params) in_clause_elements.append(')') return ''.join(in_clause_elements), params Field.register_lookup(In) class PatternLookup(BuiltinLookup): def get_rhs_op(self, connection, rhs): # Assume we are in startswith. We need to produce SQL like: # col LIKE %s, ['thevalue%'] # For python values we can (and should) do that directly in Python, # but if the value is for example reference to other column, then # we need to add the % pattern match to the lookup by something like # col LIKE othercol || '%%' # So, for Python values we don't need any special pattern, but for # SQL reference values or SQL transformations we need the correct # pattern added. if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql') or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms): pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc) return pattern.format(rhs) else: return super(PatternLookup, self).get_rhs_op(connection, rhs) class Contains(PatternLookup): lookup_name = 'contains' def process_rhs(self, qn, connection): rhs, params = super(Contains, self).process_rhs(qn, connection) if params and not self.bilateral_transforms: params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0]) return rhs, params Field.register_lookup(Contains) class IContains(Contains): lookup_name = 'icontains' Field.register_lookup(IContains) class StartsWith(PatternLookup): lookup_name = 'startswith' def process_rhs(self, qn, connection): rhs, params = super(StartsWith, self).process_rhs(qn, connection) if params and not self.bilateral_transforms: params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0]) return rhs, params Field.register_lookup(StartsWith) class IStartsWith(PatternLookup): lookup_name = 'istartswith' def process_rhs(self, qn, connection): rhs, params = super(IStartsWith, self).process_rhs(qn, connection) if params and not self.bilateral_transforms: params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0]) return rhs, params Field.register_lookup(IStartsWith) class EndsWith(PatternLookup): lookup_name = 'endswith' def process_rhs(self, qn, connection): rhs, params = super(EndsWith, self).process_rhs(qn, connection) if params and not self.bilateral_transforms: params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0]) return rhs, params Field.register_lookup(EndsWith) class IEndsWith(PatternLookup): lookup_name = 'iendswith' def process_rhs(self, qn, connection): rhs, params = super(IEndsWith, self).process_rhs(qn, connection) if params and not self.bilateral_transforms: params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0]) return rhs, params Field.register_lookup(IEndsWith) class Between(BuiltinLookup): def get_rhs_op(self, connection, rhs): return "BETWEEN %s AND %s" % (rhs, rhs) class Range(BuiltinLookup): lookup_name = 'range' def get_rhs_op(self, connection, rhs): return "BETWEEN %s AND %s" % (rhs[0], rhs[1]) def process_rhs(self, compiler, connection): if self.rhs_is_direct_value(): # rhs should be an iterable of 2 values, we use batch_process_rhs # to prepare/transform those values return self.batch_process_rhs(compiler, connection) else: return super(Range, self).process_rhs(compiler, connection) Field.register_lookup(Range) class IsNull(BuiltinLookup): lookup_name = 'isnull' def as_sql(self, compiler, connection): sql, params = compiler.compile(self.lhs) if self.rhs: return "%s IS NULL" % sql, params else: return "%s IS NOT NULL" % sql, params Field.register_lookup(IsNull) class Search(BuiltinLookup): lookup_name = 'search' def as_sql(self, compiler, connection): lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) sql_template = connection.ops.fulltext_search_sql(field_name=lhs) return sql_template, lhs_params + rhs_params Field.register_lookup(Search) class Regex(BuiltinLookup): lookup_name = 'regex' def as_sql(self, compiler, connection): if self.lookup_name in connection.operators: return super(Regex, self).as_sql(compiler, connection) else: lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) sql_template = connection.ops.regex_lookup(self.lookup_name) return sql_template % (lhs, rhs), lhs_params + rhs_params Field.register_lookup(Regex) class IRegex(Regex): lookup_name = 'iregex' Field.register_lookup(IRegex) class DateTimeDateTransform(Transform): lookup_name = 'date' @cached_property def output_field(self): return DateField() def as_sql(self, compiler, connection): lhs, lhs_params = compiler.compile(self.lhs) tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname) lhs_params.extend(tz_params) return sql, lhs_params class DateTransform(Transform): def as_sql(self, compiler, connection): sql, params = compiler.compile(self.lhs) lhs_output_field = self.lhs.output_field if isinstance(lhs_output_field, DateTimeField): tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname) params.extend(tz_params) elif isinstance(lhs_output_field, DateField): sql = connection.ops.date_extract_sql(self.lookup_name, sql) elif isinstance(lhs_output_field, TimeField): sql = connection.ops.time_extract_sql(self.lookup_name, sql) else: raise ValueError('DateTransform only valid on Date/Time/DateTimeFields') return sql, params @cached_property def output_field(self): return IntegerField() class YearTransform(DateTransform): lookup_name = 'year' class YearLookup(Lookup): def year_lookup_bounds(self, connection, year): output_field = self.lhs.lhs.output_field if isinstance(output_field, DateTimeField): bounds = connection.ops.year_lookup_bounds_for_datetime_field(year) else: bounds = connection.ops.year_lookup_bounds_for_date_field(year) return bounds @YearTransform.register_lookup class YearExact(YearLookup): lookup_name = 'exact' def as_sql(self, compiler, connection): # We will need to skip the extract part and instead go # directly with the originating field, that is self.lhs.lhs. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) rhs_sql, rhs_params = self.process_rhs(compiler, connection) bounds = self.year_lookup_bounds(connection, rhs_params[0]) params.extend(bounds) return '%s BETWEEN %%s AND %%s' % lhs_sql, params class YearComparisonLookup(YearLookup): def as_sql(self, compiler, connection): # We will need to skip the extract part and instead go # directly with the originating field, that is self.lhs.lhs. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) rhs_sql, rhs_params = self.process_rhs(compiler, connection) rhs_sql = self.get_rhs_op(connection, rhs_sql) start, finish = self.year_lookup_bounds(connection, rhs_params[0]) params.append(self.get_bound(start, finish)) return '%s %s' % (lhs_sql, rhs_sql), params def get_rhs_op(self, connection, rhs): return connection.operators[self.lookup_name] % rhs def get_bound(self): raise NotImplementedError( 'subclasses of YearComparisonLookup must provide a get_bound() method' ) @YearTransform.register_lookup class YearGt(YearComparisonLookup): lookup_name = 'gt' def get_bound(self, start, finish): return finish @YearTransform.register_lookup class YearGte(YearComparisonLookup): lookup_name = 'gte' def get_bound(self, start, finish): return start @YearTransform.register_lookup class YearLt(YearComparisonLookup): lookup_name = 'lt' def get_bound(self, start, finish): return start @YearTransform.register_lookup class YearLte(YearComparisonLookup): lookup_name = 'lte' def get_bound(self, start, finish): return finish class MonthTransform(DateTransform): lookup_name = 'month' class DayTransform(DateTransform): lookup_name = 'day' class WeekDayTransform(DateTransform): lookup_name = 'week_day' class HourTransform(DateTransform): lookup_name = 'hour' class MinuteTransform(DateTransform): lookup_name = 'minute' class SecondTransform(DateTransform): lookup_name = 'second' DateField.register_lookup(YearTransform) DateField.register_lookup(MonthTransform) DateField.register_lookup(DayTransform) DateField.register_lookup(WeekDayTransform) TimeField.register_lookup(HourTransform) TimeField.register_lookup(MinuteTransform) TimeField.register_lookup(SecondTransform) DateTimeField.register_lookup(DateTimeDateTransform) DateTimeField.register_lookup(YearTransform) DateTimeField.register_lookup(MonthTransform) DateTimeField.register_lookup(DayTransform) DateTimeField.register_lookup(WeekDayTransform) DateTimeField.register_lookup(HourTransform) DateTimeField.register_lookup(MinuteTransform) DateTimeField.register_lookup(SecondTransform)