import itertools import math from copy import copy from django.core.exceptions import EmptyResultSet from django.db.models.expressions import Func, Value from django.db.models.fields import DateTimeField, Field, IntegerField from django.db.models.query_utils import RegisterLookupMixin from django.utils.functional import cached_property class Lookup: lookup_name = None prepare_rhs = True can_use_none_as_rhs = False def __init__(self, lhs, rhs): self.lhs, self.rhs = lhs, rhs self.rhs = self.get_prep_lookup() if hasattr(self.lhs, 'get_bilateral_transforms'): bilateral_transforms = self.lhs.get_bilateral_transforms() else: bilateral_transforms = [] if bilateral_transforms: # Warn the user as soon as possible if they are trying to apply # a bilateral transformation on a nested QuerySet: that won't work. from django.db.models.sql.query import Query # avoid circular import if isinstance(rhs, Query): raise NotImplementedError("Bilateral transformations on nested querysets are not implemented.") self.bilateral_transforms = bilateral_transforms def apply_bilateral_transforms(self, value): for transform in self.bilateral_transforms: value = transform(value) return value def batch_process_rhs(self, compiler, connection, rhs=None): if rhs is None: rhs = self.rhs if self.bilateral_transforms: sqls, sqls_params = [], [] for p in rhs: value = Value(p, output_field=self.lhs.output_field) value = self.apply_bilateral_transforms(value) value = value.resolve_expression(compiler.query) sql, sql_params = compiler.compile(value) sqls.append(sql) sqls_params.extend(sql_params) else: _, params = self.get_db_prep_lookup(rhs, connection) sqls, sqls_params = ['%s'] * len(params), params return sqls, sqls_params def get_source_expressions(self): if self.rhs_is_direct_value(): return [self.lhs] return [self.lhs, self.rhs] def set_source_expressions(self, new_exprs): if len(new_exprs) == 1: self.lhs = new_exprs[0] else: self.lhs, self.rhs = new_exprs def get_prep_lookup(self): if hasattr(self.rhs, '_prepare'): return self.rhs._prepare(self.lhs.output_field) if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'): return self.lhs.output_field.get_prep_value(self.rhs) return self.rhs def get_db_prep_lookup(self, value, connection): return ('%s', [value]) def process_lhs(self, compiler, connection, lhs=None): lhs = lhs or self.lhs if hasattr(lhs, 'resolve_expression'): lhs = lhs.resolve_expression(compiler.query) return compiler.compile(lhs) def process_rhs(self, compiler, connection): value = self.rhs if self.bilateral_transforms: if self.rhs_is_direct_value(): # Do not call get_db_prep_lookup here as the value will be # transformed before being used for lookup value = Value(value, output_field=self.lhs.output_field) value = self.apply_bilateral_transforms(value) value = value.resolve_expression(compiler.query) if hasattr(value, 'as_sql'): sql, params = compiler.compile(value) return '(' + sql + ')', params else: return self.get_db_prep_lookup(value, connection) def rhs_is_direct_value(self): return not hasattr(self.rhs, 'as_sql') def relabeled_clone(self, relabels): new = copy(self) new.lhs = new.lhs.relabeled_clone(relabels) if hasattr(new.rhs, 'relabeled_clone'): new.rhs = new.rhs.relabeled_clone(relabels) return new def get_group_by_cols(self): cols = self.lhs.get_group_by_cols() if hasattr(self.rhs, 'get_group_by_cols'): cols.extend(self.rhs.get_group_by_cols()) return cols def as_sql(self, compiler, connection): raise NotImplementedError @cached_property def contains_aggregate(self): return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) @cached_property def contains_over_clause(self): return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False) @property def is_summary(self): return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False) class Transform(RegisterLookupMixin, Func): """ RegisterLookupMixin() is first so that get_lookup() and get_transform() first examine self and then check output_field. """ bilateral = False arity = 1 @property def lhs(self): return self.get_source_expressions()[0] def get_bilateral_transforms(self): if hasattr(self.lhs, 'get_bilateral_transforms'): bilateral_transforms = self.lhs.get_bilateral_transforms() else: bilateral_transforms = [] if self.bilateral: bilateral_transforms.append(self.__class__) return bilateral_transforms class BuiltinLookup(Lookup): def process_lhs(self, compiler, connection, lhs=None): lhs_sql, params = super().process_lhs(compiler, connection, lhs) field_internal_type = self.lhs.output_field.get_internal_type() db_type = self.lhs.output_field.db_type(connection=connection) lhs_sql = connection.ops.field_cast_sql( db_type, field_internal_type) % lhs_sql lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql return lhs_sql, list(params) def as_sql(self, compiler, connection): lhs_sql, params = self.process_lhs(compiler, connection) rhs_sql, rhs_params = self.process_rhs(compiler, connection) params.extend(rhs_params) rhs_sql = self.get_rhs_op(connection, rhs_sql) return '%s %s' % (lhs_sql, rhs_sql), params def get_rhs_op(self, connection, rhs): return connection.operators[self.lookup_name] % rhs class FieldGetDbPrepValueMixin: """ Some lookups require Field.get_db_prep_value() to be called on their inputs. """ get_db_prep_lookup_value_is_iterable = False def get_db_prep_lookup(self, value, connection): # For relational fields, use the output_field of the 'field' attribute. field = getattr(self.lhs.output_field, 'field', None) get_db_prep_value = getattr(field, 'get_db_prep_value', None) or self.lhs.output_field.get_db_prep_value return ( '%s', [get_db_prep_value(v, connection, prepared=True) for v in value] if self.get_db_prep_lookup_value_is_iterable else [get_db_prep_value(value, connection, prepared=True)] ) class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin): """ Some lookups require Field.get_db_prep_value() to be called on each value in an iterable. """ get_db_prep_lookup_value_is_iterable = True def get_prep_lookup(self): prepared_values = [] if hasattr(self.rhs, '_prepare'): # A subquery is like an iterable but its items shouldn't be # prepared independently. return self.rhs._prepare(self.lhs.output_field) for rhs_value in self.rhs: if hasattr(rhs_value, 'resolve_expression'): # An expression will be handled by the database but can coexist # alongside real values. pass elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'): rhs_value = self.lhs.output_field.get_prep_value(rhs_value) prepared_values.append(rhs_value) return prepared_values def process_rhs(self, compiler, connection): if self.rhs_is_direct_value(): # rhs should be an iterable of values. Use batch_process_rhs() # to prepare/transform those values. return self.batch_process_rhs(compiler, connection) else: return super().process_rhs(compiler, connection) def resolve_expression_parameter(self, compiler, connection, sql, param): params = [param] if hasattr(param, 'resolve_expression'): param = param.resolve_expression(compiler.query) if hasattr(param, 'as_sql'): sql, params = param.as_sql(compiler, connection) return sql, params def batch_process_rhs(self, compiler, connection, rhs=None): pre_processed = super().batch_process_rhs(compiler, connection, rhs) # The params list may contain expressions which compile to a # sql/param pair. Zip them to get sql and param pairs that refer to the # same argument and attempt to replace them with the result of # compiling the param step. sql, params = zip(*( self.resolve_expression_parameter(compiler, connection, sql, param) for sql, param in zip(*pre_processed) )) params = itertools.chain.from_iterable(params) return sql, tuple(params) @Field.register_lookup class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'exact' def process_rhs(self, compiler, connection): from django.db.models.sql.query import Query if isinstance(self.rhs, Query): if self.rhs.has_limit_one(): # The subquery must select only the pk. self.rhs.clear_select_clause() self.rhs.add_fields(['pk']) else: raise ValueError( 'The QuerySet value for an exact lookup must be limited to ' 'one result using slicing.' ) return super().process_rhs(compiler, connection) @Field.register_lookup class IExact(BuiltinLookup): lookup_name = 'iexact' prepare_rhs = False def process_rhs(self, qn, connection): rhs, params = super().process_rhs(qn, connection) if params: params[0] = connection.ops.prep_for_iexact_query(params[0]) return rhs, params @Field.register_lookup class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'gt' @Field.register_lookup class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'gte' @Field.register_lookup class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'lt' @Field.register_lookup class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup): lookup_name = 'lte' class IntegerFieldFloatRounding: """ Allow floats to work as query values for IntegerField. Without this, the decimal portion of the float would always be discarded. """ def get_prep_lookup(self): if isinstance(self.rhs, float): self.rhs = math.ceil(self.rhs) return super().get_prep_lookup() @IntegerField.register_lookup class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual): pass @IntegerField.register_lookup class IntegerLessThan(IntegerFieldFloatRounding, LessThan): pass @Field.register_lookup class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): lookup_name = 'in' def process_rhs(self, compiler, connection): db_rhs = getattr(self.rhs, '_db', None) if db_rhs is not None and db_rhs != connection.alias: raise ValueError( "Subqueries aren't allowed across different databases. Force " "the inner query to be evaluated using `list(inner_query)`." ) if self.rhs_is_direct_value(): try: rhs = set(self.rhs) except TypeError: # Unhashable items in self.rhs rhs = self.rhs if not rhs: raise EmptyResultSet # rhs should be an iterable; use batch_process_rhs() to # prepare/transform those values. sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs) placeholder = '(' + ', '.join(sqls) + ')' return (placeholder, sqls_params) else: if not getattr(self.rhs, 'has_select_fields', True): self.rhs.clear_select_clause() self.rhs.add_fields(['pk']) return super().process_rhs(compiler, connection) def get_rhs_op(self, connection, rhs): return 'IN %s' % rhs def as_sql(self, compiler, connection): max_in_list_size = connection.ops.max_in_list_size() if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size: return self.split_parameter_list_as_sql(compiler, connection) return super().as_sql(compiler, connection) def split_parameter_list_as_sql(self, compiler, connection): # This is a special case for databases which limit the number of # elements which can appear in an 'IN' clause. max_in_list_size = connection.ops.max_in_list_size() lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.batch_process_rhs(compiler, connection) in_clause_elements = ['('] params = [] for offset in range(0, len(rhs_params), max_in_list_size): if offset > 0: in_clause_elements.append(' OR ') in_clause_elements.append('%s IN (' % lhs) params.extend(lhs_params) sqls = rhs[offset: offset + max_in_list_size] sqls_params = rhs_params[offset: offset + max_in_list_size] param_group = ', '.join(sqls) in_clause_elements.append(param_group) in_clause_elements.append(')') params.extend(sqls_params) in_clause_elements.append(')') return ''.join(in_clause_elements), params class PatternLookup(BuiltinLookup): param_pattern = '%%%s%%' prepare_rhs = False def get_rhs_op(self, connection, rhs): # Assume we are in startswith. We need to produce SQL like: # col LIKE %s, ['thevalue%'] # For python values we can (and should) do that directly in Python, # but if the value is for example reference to other column, then # we need to add the % pattern match to the lookup by something like # col LIKE othercol || '%%' # So, for Python values we don't need any special pattern, but for # SQL reference values or SQL transformations we need the correct # pattern added. if hasattr(self.rhs, 'as_sql') or self.bilateral_transforms: pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc) return pattern.format(rhs) else: return super().get_rhs_op(connection, rhs) def process_rhs(self, qn, connection): rhs, params = super().process_rhs(qn, connection) if self.rhs_is_direct_value() and params and not self.bilateral_transforms: params[0] = self.param_pattern % connection.ops.prep_for_like_query(params[0]) return rhs, params @Field.register_lookup class Contains(PatternLookup): lookup_name = 'contains' @Field.register_lookup class IContains(Contains): lookup_name = 'icontains' @Field.register_lookup class StartsWith(PatternLookup): lookup_name = 'startswith' param_pattern = '%s%%' @Field.register_lookup class IStartsWith(StartsWith): lookup_name = 'istartswith' @Field.register_lookup class EndsWith(PatternLookup): lookup_name = 'endswith' param_pattern = '%%%s' @Field.register_lookup class IEndsWith(EndsWith): lookup_name = 'iendswith' @Field.register_lookup class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup): lookup_name = 'range' def get_rhs_op(self, connection, rhs): return "BETWEEN %s AND %s" % (rhs[0], rhs[1]) @Field.register_lookup class IsNull(BuiltinLookup): lookup_name = 'isnull' prepare_rhs = False def as_sql(self, compiler, connection): sql, params = compiler.compile(self.lhs) if self.rhs: return "%s IS NULL" % sql, params else: return "%s IS NOT NULL" % sql, params @Field.register_lookup class Regex(BuiltinLookup): lookup_name = 'regex' prepare_rhs = False def as_sql(self, compiler, connection): if self.lookup_name in connection.operators: return super().as_sql(compiler, connection) else: lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) sql_template = connection.ops.regex_lookup(self.lookup_name) return sql_template % (lhs, rhs), lhs_params + rhs_params @Field.register_lookup class IRegex(Regex): lookup_name = 'iregex' class YearLookup(Lookup): def year_lookup_bounds(self, connection, year): output_field = self.lhs.lhs.output_field if isinstance(output_field, DateTimeField): bounds = connection.ops.year_lookup_bounds_for_datetime_field(year) else: bounds = connection.ops.year_lookup_bounds_for_date_field(year) return bounds class YearComparisonLookup(YearLookup): def as_sql(self, compiler, connection): # We will need to skip the extract part and instead go # directly with the originating field, that is self.lhs.lhs. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) rhs_sql, rhs_params = self.process_rhs(compiler, connection) rhs_sql = self.get_rhs_op(connection, rhs_sql) start, finish = self.year_lookup_bounds(connection, rhs_params[0]) params.append(self.get_bound(start, finish)) return '%s %s' % (lhs_sql, rhs_sql), params def get_rhs_op(self, connection, rhs): return connection.operators[self.lookup_name] % rhs def get_bound(self, start, finish): raise NotImplementedError( 'subclasses of YearComparisonLookup must provide a get_bound() method' ) class YearExact(YearLookup, Exact): lookup_name = 'exact' def as_sql(self, compiler, connection): # We will need to skip the extract part and instead go # directly with the originating field, that is self.lhs.lhs. lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) rhs_sql, rhs_params = self.process_rhs(compiler, connection) try: # Check that rhs_params[0] exists (IndexError), # it isn't None (TypeError), and is a number (ValueError) int(rhs_params[0]) except (IndexError, TypeError, ValueError): # Can't determine the bounds before executing the query, so skip # optimizations by falling back to a standard exact comparison. return super().as_sql(compiler, connection) bounds = self.year_lookup_bounds(connection, rhs_params[0]) params.extend(bounds) return '%s BETWEEN %%s AND %%s' % lhs_sql, params class YearGt(YearComparisonLookup): lookup_name = 'gt' def get_bound(self, start, finish): return finish class YearGte(YearComparisonLookup): lookup_name = 'gte' def get_bound(self, start, finish): return start class YearLt(YearComparisonLookup): lookup_name = 'lt' def get_bound(self, start, finish): return start class YearLte(YearComparisonLookup): lookup_name = 'lte' def get_bound(self, start, finish): return finish