diff --git a/django/contrib/postgres/fields/hstore.py b/django/contrib/postgres/fields/hstore.py index 366c40052e..1fb9138e9a 100644 --- a/django/contrib/postgres/fields/hstore.py +++ b/django/contrib/postgres/fields/hstore.py @@ -81,14 +81,14 @@ class KeyTransformFactory(object): @HStoreField.register_lookup -class KeysTransform(lookups.FunctionTransform): +class KeysTransform(Transform): lookup_name = 'keys' function = 'akeys' output_field = ArrayField(TextField()) @HStoreField.register_lookup -class ValuesTransform(lookups.FunctionTransform): +class ValuesTransform(Transform): lookup_name = 'values' function = 'avals' output_field = ArrayField(TextField()) diff --git a/django/contrib/postgres/fields/ranges.py b/django/contrib/postgres/fields/ranges.py index 5f9e13508e..bf3f8cb9a4 100644 --- a/django/contrib/postgres/fields/ranges.py +++ b/django/contrib/postgres/fields/ranges.py @@ -173,7 +173,7 @@ class AdjacentToLookup(lookups.PostgresSimpleLookup): @RangeField.register_lookup -class RangeStartsWith(lookups.FunctionTransform): +class RangeStartsWith(models.Transform): lookup_name = 'startswith' function = 'lower' @@ -183,7 +183,7 @@ class RangeStartsWith(lookups.FunctionTransform): @RangeField.register_lookup -class RangeEndsWith(lookups.FunctionTransform): +class RangeEndsWith(models.Transform): lookup_name = 'endswith' function = 'upper' @@ -193,7 +193,7 @@ class RangeEndsWith(lookups.FunctionTransform): @RangeField.register_lookup -class IsEmpty(lookups.FunctionTransform): +class IsEmpty(models.Transform): lookup_name = 'isempty' function = 'isempty' output_field = models.BooleanField() diff --git a/django/contrib/postgres/lookups.py b/django/contrib/postgres/lookups.py index 887337861a..cdecd8d6ba 100644 --- a/django/contrib/postgres/lookups.py +++ b/django/contrib/postgres/lookups.py @@ -9,12 +9,6 @@ class PostgresSimpleLookup(Lookup): return '%s %s %s' % (lhs, self.operator, rhs), params -class FunctionTransform(Transform): - def as_sql(self, qn, connection): - lhs, params = qn.compile(self.lhs) - return "%s(%s)" % (self.function, lhs), params - - class DataContains(PostgresSimpleLookup): lookup_name = 'contains' operator = '@>' @@ -45,7 +39,7 @@ class HasAnyKeys(PostgresSimpleLookup): operator = '?|' -class Unaccent(FunctionTransform): +class Unaccent(Transform): bilateral = True lookup_name = 'unaccent' function = 'UNACCENT' diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index a210ffc0b7..7aa18ee509 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -20,10 +20,7 @@ from django.core import checks, exceptions, validators # purposes. from django.core.exceptions import FieldDoesNotExist # NOQA from django.db import connection, connections, router -from django.db.models.lookups import ( - Lookup, RegisterLookupMixin, Transform, default_lookups, -) -from django.db.models.query_utils import QueryWrapper +from django.db.models.query_utils import QueryWrapper, RegisterLookupMixin from django.utils import six, timezone from django.utils.datastructures import DictWrapper from django.utils.dateparse import ( @@ -120,7 +117,6 @@ class Field(RegisterLookupMixin): 'unique_for_date': _("%(field_label)s must be unique for " "%(date_field_label)s %(lookup_type)s."), } - class_lookups = default_lookups.copy() system_check_deprecated_details = None system_check_removed_details = None @@ -1492,22 +1488,6 @@ class DateTimeField(DateField): return super(DateTimeField, self).formfield(**defaults) -@DateTimeField.register_lookup -class DateTimeDateTransform(Transform): - lookup_name = 'date' - - @cached_property - def output_field(self): - return DateField() - - def as_sql(self, compiler, connection): - lhs, lhs_params = compiler.compile(self.lhs) - tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None - sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname) - lhs_params.extend(tz_params) - return sql, lhs_params - - class DecimalField(Field): empty_strings_allowed = False default_error_messages = { @@ -2450,146 +2430,3 @@ class UUIDField(Field): } defaults.update(kwargs) return super(UUIDField, self).formfield(**defaults) - - -class DateTransform(Transform): - def as_sql(self, compiler, connection): - sql, params = compiler.compile(self.lhs) - lhs_output_field = self.lhs.output_field - if isinstance(lhs_output_field, DateTimeField): - tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None - sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname) - params.extend(tz_params) - elif isinstance(lhs_output_field, DateField): - sql = connection.ops.date_extract_sql(self.lookup_name, sql) - elif isinstance(lhs_output_field, TimeField): - sql = connection.ops.time_extract_sql(self.lookup_name, sql) - else: - raise ValueError('DateTransform only valid on Date/Time/DateTimeFields') - return sql, params - - @cached_property - def output_field(self): - return IntegerField() - - -class YearTransform(DateTransform): - lookup_name = 'year' - - -class YearLookup(Lookup): - def year_lookup_bounds(self, connection, year): - output_field = self.lhs.lhs.output_field - if isinstance(output_field, DateTimeField): - bounds = connection.ops.year_lookup_bounds_for_datetime_field(year) - else: - bounds = connection.ops.year_lookup_bounds_for_date_field(year) - return bounds - - -@YearTransform.register_lookup -class YearExact(YearLookup): - lookup_name = 'exact' - - def as_sql(self, compiler, connection): - # We will need to skip the extract part and instead go - # directly with the originating field, that is self.lhs.lhs. - lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) - rhs_sql, rhs_params = self.process_rhs(compiler, connection) - bounds = self.year_lookup_bounds(connection, rhs_params[0]) - params.extend(bounds) - return '%s BETWEEN %%s AND %%s' % lhs_sql, params - - -class YearComparisonLookup(YearLookup): - def as_sql(self, compiler, connection): - # We will need to skip the extract part and instead go - # directly with the originating field, that is self.lhs.lhs. - lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) - rhs_sql, rhs_params = self.process_rhs(compiler, connection) - rhs_sql = self.get_rhs_op(connection, rhs_sql) - start, finish = self.year_lookup_bounds(connection, rhs_params[0]) - params.append(self.get_bound(start, finish)) - return '%s %s' % (lhs_sql, rhs_sql), params - - def get_rhs_op(self, connection, rhs): - return connection.operators[self.lookup_name] % rhs - - def get_bound(self): - raise NotImplementedError( - 'subclasses of YearComparisonLookup must provide a get_bound() method' - ) - - -@YearTransform.register_lookup -class YearGt(YearComparisonLookup): - lookup_name = 'gt' - - def get_bound(self, start, finish): - return finish - - -@YearTransform.register_lookup -class YearGte(YearComparisonLookup): - lookup_name = 'gte' - - def get_bound(self, start, finish): - return start - - -@YearTransform.register_lookup -class YearLt(YearComparisonLookup): - lookup_name = 'lt' - - def get_bound(self, start, finish): - return start - - -@YearTransform.register_lookup -class YearLte(YearComparisonLookup): - lookup_name = 'lte' - - def get_bound(self, start, finish): - return finish - - -class MonthTransform(DateTransform): - lookup_name = 'month' - - -class DayTransform(DateTransform): - lookup_name = 'day' - - -class WeekDayTransform(DateTransform): - lookup_name = 'week_day' - - -class HourTransform(DateTransform): - lookup_name = 'hour' - - -class MinuteTransform(DateTransform): - lookup_name = 'minute' - - -class SecondTransform(DateTransform): - lookup_name = 'second' - - -DateField.register_lookup(YearTransform) -DateField.register_lookup(MonthTransform) -DateField.register_lookup(DayTransform) -DateField.register_lookup(WeekDayTransform) - -TimeField.register_lookup(HourTransform) -TimeField.register_lookup(MinuteTransform) -TimeField.register_lookup(SecondTransform) - -DateTimeField.register_lookup(YearTransform) -DateTimeField.register_lookup(MonthTransform) -DateTimeField.register_lookup(DayTransform) -DateTimeField.register_lookup(WeekDayTransform) -DateTimeField.register_lookup(HourTransform) -DateTimeField.register_lookup(MinuteTransform) -DateTimeField.register_lookup(SecondTransform) diff --git a/django/db/models/functions.py b/django/db/models/functions.py index dcc54e6fa5..ac24aa7bdc 100644 --- a/django/db/models/functions.py +++ b/django/db/models/functions.py @@ -1,8 +1,9 @@ """ Classes that represent database functions. """ -from django.db.models import DateTimeField, IntegerField -from django.db.models.expressions import Func, Value +from django.db.models import ( + DateTimeField, Func, IntegerField, Transform, Value, +) class Coalesce(Func): @@ -123,9 +124,10 @@ class Least(Func): return super(Least, self).as_sql(compiler, connection, function='MIN') -class Length(Func): +class Length(Transform): """Returns the number of characters in the expression""" function = 'LENGTH' + lookup_name = 'length' def __init__(self, expression, **extra): output_field = extra.pop('output_field', IntegerField()) @@ -136,8 +138,9 @@ class Length(Func): return super(Length, self).as_sql(compiler, connection) -class Lower(Func): +class Lower(Transform): function = 'LOWER' + lookup_name = 'lower' def __init__(self, expression, **extra): super(Lower, self).__init__(expression, **extra) @@ -188,8 +191,9 @@ class Substr(Func): return super(Substr, self).as_sql(compiler, connection) -class Upper(Func): +class Upper(Transform): function = 'UPPER' + lookup_name = 'upper' def __init__(self, expression, **extra): super(Upper, self).__init__(expression, **extra) diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 9bcf8d6303..93053c8160 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -1,101 +1,17 @@ -import inspect from copy import copy +from django.conf import settings +from django.db.models.expressions import Func, Value +from django.db.models.fields import ( + DateField, DateTimeField, Field, IntegerField, TimeField, +) +from django.db.models.query_utils import RegisterLookupMixin +from django.utils import timezone from django.utils.functional import cached_property from django.utils.six.moves import range -from .query_utils import QueryWrapper - -class RegisterLookupMixin(object): - def _get_lookup(self, lookup_name): - try: - return self.class_lookups[lookup_name] - except KeyError: - # To allow for inheritance, check parent class' class_lookups. - for parent in inspect.getmro(self.__class__): - if 'class_lookups' not in parent.__dict__: - continue - if lookup_name in parent.class_lookups: - return parent.class_lookups[lookup_name] - except AttributeError: - # This class didn't have any class_lookups - pass - return None - - def get_lookup(self, lookup_name): - found = self._get_lookup(lookup_name) - if found is None and hasattr(self, 'output_field'): - return self.output_field.get_lookup(lookup_name) - if found is not None and not issubclass(found, Lookup): - return None - return found - - def get_transform(self, lookup_name): - found = self._get_lookup(lookup_name) - if found is None and hasattr(self, 'output_field'): - return self.output_field.get_transform(lookup_name) - if found is not None and not issubclass(found, Transform): - return None - return found - - @classmethod - def register_lookup(cls, lookup): - if 'class_lookups' not in cls.__dict__: - cls.class_lookups = {} - cls.class_lookups[lookup.lookup_name] = lookup - return lookup - - @classmethod - def _unregister_lookup(cls, lookup): - """ - Removes given lookup from cls lookups. Meant to be used in - tests only. - """ - del cls.class_lookups[lookup.lookup_name] - - -class Transform(RegisterLookupMixin): - - bilateral = False - - def __init__(self, lhs, lookups): - self.lhs = lhs - self.init_lookups = lookups[:] - - def as_sql(self, compiler, connection): - raise NotImplementedError - - @cached_property - def output_field(self): - return self.lhs.output_field - - def copy(self): - return copy(self) - - def relabeled_clone(self, relabels): - copy = self.copy() - copy.lhs = self.lhs.relabeled_clone(relabels) - return copy - - def get_group_by_cols(self): - return self.lhs.get_group_by_cols() - - def get_bilateral_transforms(self): - if hasattr(self.lhs, 'get_bilateral_transforms'): - bilateral_transforms = self.lhs.get_bilateral_transforms() - else: - bilateral_transforms = [] - if self.bilateral: - bilateral_transforms.append((self.__class__, self.init_lookups)) - return bilateral_transforms - - @cached_property - def contains_aggregate(self): - return self.lhs.contains_aggregate - - -class Lookup(RegisterLookupMixin): +class Lookup(object): lookup_name = None def __init__(self, lhs, rhs): @@ -115,8 +31,8 @@ class Lookup(RegisterLookupMixin): self.bilateral_transforms = bilateral_transforms def apply_bilateral_transforms(self, value): - for transform, lookups in self.bilateral_transforms: - value = transform(value, lookups) + for transform in self.bilateral_transforms: + value = transform(value) return value def batch_process_rhs(self, compiler, connection, rhs=None): @@ -125,9 +41,9 @@ class Lookup(RegisterLookupMixin): if self.bilateral_transforms: sqls, sqls_params = [], [] for p in rhs: - value = QueryWrapper('%s', - [self.lhs.output_field.get_db_prep_value(p, connection)]) + value = Value(p, output_field=self.lhs.output_field) value = self.apply_bilateral_transforms(value) + value = value.resolve_expression(compiler.query) sql, sql_params = compiler.compile(value) sqls.append(sql) sqls_params.extend(sql_params) @@ -155,9 +71,9 @@ class Lookup(RegisterLookupMixin): if self.rhs_is_direct_value(): # Do not call get_db_prep_lookup here as the value will be # transformed before being used for lookup - value = QueryWrapper("%s", - [self.lhs.output_field.get_db_prep_value(value, connection)]) + value = Value(value, output_field=self.lhs.output_field) value = self.apply_bilateral_transforms(value) + value = value.resolve_expression(compiler.query) # Due to historical reasons there are a couple of different # ways to produce sql here. get_compiler is likely a Query # instance, _as_sql QuerySet and as_sql just something with @@ -201,6 +117,31 @@ class Lookup(RegisterLookupMixin): return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) +class Transform(RegisterLookupMixin, Func): + """ + RegisterLookupMixin() is first so that get_lookup() and get_transform() + first examine self and then check output_field. + """ + bilateral = False + + def __init__(self, expression, **extra): + # Restrict Transform to allow only a single expression. + super(Transform, self).__init__(expression, **extra) + + @property + def lhs(self): + return self.get_source_expressions()[0] + + def get_bilateral_transforms(self): + if hasattr(self.lhs, 'get_bilateral_transforms'): + bilateral_transforms = self.lhs.get_bilateral_transforms() + else: + bilateral_transforms = [] + if self.bilateral: + bilateral_transforms.append(self.__class__) + return bilateral_transforms + + class BuiltinLookup(Lookup): def process_lhs(self, compiler, connection, lhs=None): lhs_sql, params = super(BuiltinLookup, self).process_lhs( @@ -223,12 +164,9 @@ class BuiltinLookup(Lookup): return connection.operators[self.lookup_name] % rhs -default_lookups = {} - - class Exact(BuiltinLookup): lookup_name = 'exact' -default_lookups['exact'] = Exact +Field.register_lookup(Exact) class IExact(BuiltinLookup): @@ -241,27 +179,27 @@ class IExact(BuiltinLookup): return rhs, params -default_lookups['iexact'] = IExact +Field.register_lookup(IExact) class GreaterThan(BuiltinLookup): lookup_name = 'gt' -default_lookups['gt'] = GreaterThan +Field.register_lookup(GreaterThan) class GreaterThanOrEqual(BuiltinLookup): lookup_name = 'gte' -default_lookups['gte'] = GreaterThanOrEqual +Field.register_lookup(GreaterThanOrEqual) class LessThan(BuiltinLookup): lookup_name = 'lt' -default_lookups['lt'] = LessThan +Field.register_lookup(LessThan) class LessThanOrEqual(BuiltinLookup): lookup_name = 'lte' -default_lookups['lte'] = LessThanOrEqual +Field.register_lookup(LessThanOrEqual) class In(BuiltinLookup): @@ -286,32 +224,32 @@ class In(BuiltinLookup): def as_sql(self, compiler, connection): max_in_list_size = connection.ops.max_in_list_size() - if self.rhs_is_direct_value() and (max_in_list_size and - len(self.rhs) > max_in_list_size): - # This is a special case for Oracle which limits the number of elements - # which can appear in an 'IN' clause. - lhs, lhs_params = self.process_lhs(compiler, connection) - rhs, rhs_params = self.batch_process_rhs(compiler, connection) - in_clause_elements = ['('] - params = [] - for offset in range(0, len(rhs_params), max_in_list_size): - if offset > 0: - in_clause_elements.append(' OR ') - in_clause_elements.append('%s IN (' % lhs) - params.extend(lhs_params) - sqls = rhs[offset: offset + max_in_list_size] - sqls_params = rhs_params[offset: offset + max_in_list_size] - param_group = ', '.join(sqls) - in_clause_elements.append(param_group) - in_clause_elements.append(')') - params.extend(sqls_params) + if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size: + return self.split_parameter_list_as_sql(compiler, connection) + return super(In, self).as_sql(compiler, connection) + + def split_parameter_list_as_sql(self, compiler, connection): + # This is a special case for databases which limit the number of + # elements which can appear in an 'IN' clause. + max_in_list_size = connection.ops.max_in_list_size() + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.batch_process_rhs(compiler, connection) + in_clause_elements = ['('] + params = [] + for offset in range(0, len(rhs_params), max_in_list_size): + if offset > 0: + in_clause_elements.append(' OR ') + in_clause_elements.append('%s IN (' % lhs) + params.extend(lhs_params) + sqls = rhs[offset: offset + max_in_list_size] + sqls_params = rhs_params[offset: offset + max_in_list_size] + param_group = ', '.join(sqls) + in_clause_elements.append(param_group) in_clause_elements.append(')') - return ''.join(in_clause_elements), params - else: - return super(In, self).as_sql(compiler, connection) - - -default_lookups['in'] = In + params.extend(sqls_params) + in_clause_elements.append(')') + return ''.join(in_clause_elements), params +Field.register_lookup(In) class PatternLookup(BuiltinLookup): @@ -342,16 +280,12 @@ class Contains(PatternLookup): if params and not self.bilateral_transforms: params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0]) return rhs, params - - -default_lookups['contains'] = Contains +Field.register_lookup(Contains) class IContains(Contains): lookup_name = 'icontains' - - -default_lookups['icontains'] = IContains +Field.register_lookup(IContains) class StartsWith(PatternLookup): @@ -362,9 +296,7 @@ class StartsWith(PatternLookup): if params and not self.bilateral_transforms: params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0]) return rhs, params - - -default_lookups['startswith'] = StartsWith +Field.register_lookup(StartsWith) class IStartsWith(PatternLookup): @@ -375,9 +307,7 @@ class IStartsWith(PatternLookup): if params and not self.bilateral_transforms: params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0]) return rhs, params - - -default_lookups['istartswith'] = IStartsWith +Field.register_lookup(IStartsWith) class EndsWith(PatternLookup): @@ -388,9 +318,7 @@ class EndsWith(PatternLookup): if params and not self.bilateral_transforms: params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0]) return rhs, params - - -default_lookups['endswith'] = EndsWith +Field.register_lookup(EndsWith) class IEndsWith(PatternLookup): @@ -401,9 +329,7 @@ class IEndsWith(PatternLookup): if params and not self.bilateral_transforms: params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0]) return rhs, params - - -default_lookups['iendswith'] = IEndsWith +Field.register_lookup(IEndsWith) class Between(BuiltinLookup): @@ -424,8 +350,7 @@ class Range(BuiltinLookup): return self.batch_process_rhs(compiler, connection) else: return super(Range, self).process_rhs(compiler, connection) - -default_lookups['range'] = Range +Field.register_lookup(Range) class IsNull(BuiltinLookup): @@ -437,7 +362,7 @@ class IsNull(BuiltinLookup): return "%s IS NULL" % sql, params else: return "%s IS NOT NULL" % sql, params -default_lookups['isnull'] = IsNull +Field.register_lookup(IsNull) class Search(BuiltinLookup): @@ -448,8 +373,7 @@ class Search(BuiltinLookup): rhs, rhs_params = self.process_rhs(compiler, connection) sql_template = connection.ops.fulltext_search_sql(field_name=lhs) return sql_template, lhs_params + rhs_params - -default_lookups['search'] = Search +Field.register_lookup(Search) class Regex(BuiltinLookup): @@ -463,9 +387,168 @@ class Regex(BuiltinLookup): rhs, rhs_params = self.process_rhs(compiler, connection) sql_template = connection.ops.regex_lookup(self.lookup_name) return sql_template % (lhs, rhs), lhs_params + rhs_params -default_lookups['regex'] = Regex +Field.register_lookup(Regex) class IRegex(Regex): lookup_name = 'iregex' -default_lookups['iregex'] = IRegex +Field.register_lookup(IRegex) + + +class DateTimeDateTransform(Transform): + lookup_name = 'date' + + @cached_property + def output_field(self): + return DateField() + + def as_sql(self, compiler, connection): + lhs, lhs_params = compiler.compile(self.lhs) + tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None + sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname) + lhs_params.extend(tz_params) + return sql, lhs_params + + +class DateTransform(Transform): + def as_sql(self, compiler, connection): + sql, params = compiler.compile(self.lhs) + lhs_output_field = self.lhs.output_field + if isinstance(lhs_output_field, DateTimeField): + tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None + sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname) + params.extend(tz_params) + elif isinstance(lhs_output_field, DateField): + sql = connection.ops.date_extract_sql(self.lookup_name, sql) + elif isinstance(lhs_output_field, TimeField): + sql = connection.ops.time_extract_sql(self.lookup_name, sql) + else: + raise ValueError('DateTransform only valid on Date/Time/DateTimeFields') + return sql, params + + @cached_property + def output_field(self): + return IntegerField() + + +class YearTransform(DateTransform): + lookup_name = 'year' + + +class YearLookup(Lookup): + def year_lookup_bounds(self, connection, year): + output_field = self.lhs.lhs.output_field + if isinstance(output_field, DateTimeField): + bounds = connection.ops.year_lookup_bounds_for_datetime_field(year) + else: + bounds = connection.ops.year_lookup_bounds_for_date_field(year) + return bounds + + +@YearTransform.register_lookup +class YearExact(YearLookup): + lookup_name = 'exact' + + def as_sql(self, compiler, connection): + # We will need to skip the extract part and instead go + # directly with the originating field, that is self.lhs.lhs. + lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(compiler, connection) + bounds = self.year_lookup_bounds(connection, rhs_params[0]) + params.extend(bounds) + return '%s BETWEEN %%s AND %%s' % lhs_sql, params + + +class YearComparisonLookup(YearLookup): + def as_sql(self, compiler, connection): + # We will need to skip the extract part and instead go + # directly with the originating field, that is self.lhs.lhs. + lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(compiler, connection) + rhs_sql = self.get_rhs_op(connection, rhs_sql) + start, finish = self.year_lookup_bounds(connection, rhs_params[0]) + params.append(self.get_bound(start, finish)) + return '%s %s' % (lhs_sql, rhs_sql), params + + def get_rhs_op(self, connection, rhs): + return connection.operators[self.lookup_name] % rhs + + def get_bound(self): + raise NotImplementedError( + 'subclasses of YearComparisonLookup must provide a get_bound() method' + ) + + +@YearTransform.register_lookup +class YearGt(YearComparisonLookup): + lookup_name = 'gt' + + def get_bound(self, start, finish): + return finish + + +@YearTransform.register_lookup +class YearGte(YearComparisonLookup): + lookup_name = 'gte' + + def get_bound(self, start, finish): + return start + + +@YearTransform.register_lookup +class YearLt(YearComparisonLookup): + lookup_name = 'lt' + + def get_bound(self, start, finish): + return start + + +@YearTransform.register_lookup +class YearLte(YearComparisonLookup): + lookup_name = 'lte' + + def get_bound(self, start, finish): + return finish + + +class MonthTransform(DateTransform): + lookup_name = 'month' + + +class DayTransform(DateTransform): + lookup_name = 'day' + + +class WeekDayTransform(DateTransform): + lookup_name = 'week_day' + + +class HourTransform(DateTransform): + lookup_name = 'hour' + + +class MinuteTransform(DateTransform): + lookup_name = 'minute' + + +class SecondTransform(DateTransform): + lookup_name = 'second' + + +DateField.register_lookup(YearTransform) +DateField.register_lookup(MonthTransform) +DateField.register_lookup(DayTransform) +DateField.register_lookup(WeekDayTransform) + +TimeField.register_lookup(HourTransform) +TimeField.register_lookup(MinuteTransform) +TimeField.register_lookup(SecondTransform) + +DateTimeField.register_lookup(DateTimeDateTransform) +DateTimeField.register_lookup(YearTransform) +DateTimeField.register_lookup(MonthTransform) +DateTimeField.register_lookup(DayTransform) +DateTimeField.register_lookup(WeekDayTransform) +DateTimeField.register_lookup(HourTransform) +DateTimeField.register_lookup(MinuteTransform) +DateTimeField.register_lookup(SecondTransform) diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 4c2a4fe80b..efc995b35f 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -7,6 +7,7 @@ circular import difficulties. """ from __future__ import unicode_literals +import inspect from collections import namedtuple from django.apps import apps @@ -169,6 +170,60 @@ class DeferredAttribute(object): return None +class RegisterLookupMixin(object): + def _get_lookup(self, lookup_name): + try: + return self.class_lookups[lookup_name] + except KeyError: + # To allow for inheritance, check parent class' class_lookups. + for parent in inspect.getmro(self.__class__): + if 'class_lookups' not in parent.__dict__: + continue + if lookup_name in parent.class_lookups: + return parent.class_lookups[lookup_name] + except AttributeError: + # This class didn't have any class_lookups + pass + return None + + def get_lookup(self, lookup_name): + from django.db.models.lookups import Lookup + found = self._get_lookup(lookup_name) + if found is None and hasattr(self, 'output_field'): + return self.output_field.get_lookup(lookup_name) + if found is not None and not issubclass(found, Lookup): + return None + return found + + def get_transform(self, lookup_name): + from django.db.models.lookups import Transform + found = self._get_lookup(lookup_name) + if found is None and hasattr(self, 'output_field'): + return self.output_field.get_transform(lookup_name) + if found is not None and not issubclass(found, Transform): + return None + return found + + @classmethod + def register_lookup(cls, lookup, lookup_name=None): + if lookup_name is None: + lookup_name = lookup.lookup_name + if 'class_lookups' not in cls.__dict__: + cls.class_lookups = {} + cls.class_lookups[lookup_name] = lookup + return lookup + + @classmethod + def _unregister_lookup(cls, lookup, lookup_name=None): + """ + Remove given lookup from cls lookups. For use in tests only as it's + not thread-safe. + """ + if lookup_name is None: + lookup_name = lookup.lookup_name + del cls.class_lookups[lookup_name] + + def select_related_descend(field, restricted, requested, load_fields, reverse=False): """ Returns True if this field should be used to descend deeper for diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 43d7ff1064..4f74a3b0c0 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -5,7 +5,7 @@ import copy import warnings from django.db.models.fields import FloatField, IntegerField -from django.db.models.lookups import RegisterLookupMixin +from django.db.models.query_utils import RegisterLookupMixin from django.utils.deprecation import RemovedInDjango110Warning from django.utils.functional import cached_property diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index f4f544843b..9dd666040f 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1105,9 +1105,9 @@ class Query(object): Helper method for build_lookup. Tries to fetch and initialize a transform for name parameter from lhs. """ - next = lhs.get_transform(name) - if next: - return next(lhs, rest_of_lookups) + transform_class = lhs.get_transform(name) + if transform_class: + return transform_class(lhs) else: raise FieldError( "Unsupported lookup '%s' for %s or join on the field not " diff --git a/docs/howto/custom-lookups.txt b/docs/howto/custom-lookups.txt index 58b0215019..f398618e39 100644 --- a/docs/howto/custom-lookups.txt +++ b/docs/howto/custom-lookups.txt @@ -120,10 +120,7 @@ function ``ABS()`` to transform the value before comparison:: class AbsoluteValue(Transform): lookup_name = 'abs' - - def as_sql(self, compiler, connection): - lhs, params = compiler.compile(self.lhs) - return "ABS(%s)" % lhs, params + function = 'ABS' Next, let's register it for ``IntegerField``:: @@ -157,10 +154,7 @@ be done by adding an ``output_field`` attribute to the transform:: class AbsoluteValue(Transform): lookup_name = 'abs' - - def as_sql(self, compiler, connection): - lhs, params = compiler.compile(self.lhs) - return "ABS(%s)" % lhs, params + function = 'ABS' @property def output_field(self): @@ -243,12 +237,9 @@ this transformation should apply to both ``lhs`` and ``rhs``:: class UpperCase(Transform): lookup_name = 'upper' + function = 'UPPER' bilateral = True - def as_sql(self, compiler, connection): - lhs, params = compiler.compile(self.lhs) - return "UPPER(%s)" % lhs, params - Next, let's register it:: from django.db.models import CharField, TextField diff --git a/docs/ref/models/database-functions.txt b/docs/ref/models/database-functions.txt index 51a2e4b998..dd0bd59379 100644 --- a/docs/ref/models/database-functions.txt +++ b/docs/ref/models/database-functions.txt @@ -180,6 +180,18 @@ Usage example:: >>> print(author.name_length, author.goes_by_length) (14, None) +It can also be registered as a transform. For example:: + + >>> from django.db.models import CharField + >>> from django.db.models.functions import Length + >>> CharField.register_lookup(Length, 'length') + >>> # Get authors whose name is longer than 7 characters + >>> authors = Author.objects.filter(name__length__gt=7) + +.. versionchanged:: 1.9 + + The ability to register the function as a transform was added. + Lower ------ @@ -188,6 +200,8 @@ Lower Accepts a single text field or expression and returns the lowercase representation. +It can also be registered as a transform as described in :class:`Length`. + Usage example:: >>> from django.db.models.functions import Lower @@ -196,6 +210,10 @@ Usage example:: >>> print(author.name_lower) margaret smith +.. versionchanged:: 1.9 + + The ability to register the function as a transform was added. + Now --- @@ -246,6 +264,8 @@ Upper Accepts a single text field or expression and returns the uppercase representation. +It can also be registered as a transform as described in :class:`Length`. + Usage example:: >>> from django.db.models.functions import Upper @@ -253,3 +273,7 @@ Usage example:: >>> author = Author.objects.annotate(name_upper=Upper('name')).get() >>> print(author.name_upper) MARGARET SMITH + +.. versionchanged:: 1.9 + + The ability to register the function as a transform was added. diff --git a/docs/ref/models/lookups.txt b/docs/ref/models/lookups.txt index c2304209d7..58e6e35bbf 100644 --- a/docs/ref/models/lookups.txt +++ b/docs/ref/models/lookups.txt @@ -42,12 +42,17 @@ register lookups on itself. The two prominent examples are A mixin that implements the lookup API on a class. - .. classmethod:: register_lookup(lookup) + .. classmethod:: register_lookup(lookup, lookup_name=None) Registers a new lookup in the class. For example ``DateField.register_lookup(YearExact)`` will register ``YearExact`` lookup on ``DateField``. It overrides a lookup that already exists with - the same name. + the same name. ``lookup_name`` will be used for this lookup if + provided, otherwise ``lookup.lookup_name`` will be used. + + .. versionchanged:: 1.9 + + The ``lookup_name`` parameter was added. .. method:: get_lookup(lookup_name) @@ -125,7 +130,14 @@ Transform reference ``__`` (e.g. ``date__year``). This class follows the :ref:`Query Expression API `, which - implies that you can use ``____``. + implies that you can use ``____``. It's + a specialized :ref:`Func() expression ` that only accepts + one argument. It can also be used on the right hand side of a filter or + directly as an annotation. + + .. versionchanged:: 1.9 + + ``Transform`` is now a subclass of ``Func``. .. attribute:: bilateral @@ -152,18 +164,6 @@ Transform reference :class:`~django.db.models.Field` instance. By default is the same as its ``lhs.output_field``. - .. method:: as_sql - - To be overridden; raises :exc:`NotImplementedError`. - - .. method:: get_lookup(lookup_name) - - Same as :meth:`~lookups.RegisterLookupMixin.get_lookup()`. - - .. method:: get_transform(transform_name) - - Same as :meth:`~lookups.RegisterLookupMixin.get_transform()`. - Lookup reference ~~~~~~~~~~~~~~~~ diff --git a/docs/releases/1.9.txt b/docs/releases/1.9.txt index 2d87d298a3..18fa8ef6bf 100644 --- a/docs/releases/1.9.txt +++ b/docs/releases/1.9.txt @@ -520,6 +520,14 @@ Models * Added the :class:`~django.db.models.functions.Now` database function, which returns the current date and time. +* :class:`~django.db.models.Transform` is now a subclass of + :ref:`Func() ` which allows ``Transform``\s to be used on + the right hand side of an expression, just like regular ``Func``\s. This + allows registering some database functions like + :class:`~django.db.models.functions.Length`, + :class:`~django.db.models.functions.Lower`, and + :class:`~django.db.models.functions.Upper` as transforms. + * :class:`~django.db.models.SlugField` now accepts an :attr:`~django.db.models.SlugField.allow_unicode` argument to allow Unicode characters in slugs. diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index d8852e1e3c..c538d23b76 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -126,11 +126,17 @@ class YearLte(models.lookups.LessThanOrEqual): return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params -class SQLFunc(models.Lookup): - def __init__(self, name, *args, **kwargs): - super(SQLFunc, self).__init__(*args, **kwargs) - self.name = name +class Exactly(models.lookups.Exact): + """ + This lookup is used to test lookup registration. + """ + lookup_name = 'exactly' + def get_rhs_op(self, connection, rhs): + return connection.operators['exact'] % rhs + + +class SQLFuncMixin(object): def as_sql(self, compiler, connection): return '%s()', [self.name] @@ -139,13 +145,28 @@ class SQLFunc(models.Lookup): return CustomField() +class SQLFuncLookup(SQLFuncMixin, models.Lookup): + def __init__(self, name, *args, **kwargs): + super(SQLFuncLookup, self).__init__(*args, **kwargs) + self.name = name + + +class SQLFuncTransform(SQLFuncMixin, models.Transform): + def __init__(self, name, *args, **kwargs): + super(SQLFuncTransform, self).__init__(*args, **kwargs) + self.name = name + + class SQLFuncFactory(object): - def __init__(self, name): + def __init__(self, key, name): + self.key = key self.name = name def __call__(self, *args, **kwargs): - return SQLFunc(self.name, *args, **kwargs) + if self.key == 'lookupfunc': + return SQLFuncLookup(self.name, *args, **kwargs) + return SQLFuncTransform(self.name, *args, **kwargs) class CustomField(models.TextField): @@ -153,13 +174,13 @@ class CustomField(models.TextField): def get_lookup(self, lookup_name): if lookup_name.startswith('lookupfunc_'): key, name = lookup_name.split('_', 1) - return SQLFuncFactory(name) + return SQLFuncFactory(key, name) return super(CustomField, self).get_lookup(lookup_name) def get_transform(self, lookup_name): if lookup_name.startswith('transformfunc_'): key, name = lookup_name.split('_', 1) - return SQLFuncFactory(name) + return SQLFuncFactory(key, name) return super(CustomField, self).get_transform(lookup_name) @@ -200,6 +221,27 @@ class DateTimeTransform(models.Transform): class LookupTests(TestCase): + + def test_custom_name_lookup(self): + a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16)) + Author.objects.create(name='a2', birthdate=date(2012, 2, 29)) + custom_lookup_name = 'isactually' + custom_transform_name = 'justtheyear' + try: + models.DateField.register_lookup(YearTransform) + models.DateField.register_lookup(YearTransform, custom_transform_name) + YearTransform.register_lookup(Exactly) + YearTransform.register_lookup(Exactly, custom_lookup_name) + qs1 = Author.objects.filter(birthdate__testyear__exactly=1981) + qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981) + self.assertQuerysetEqual(qs1, [a1], lambda x: x) + self.assertQuerysetEqual(qs2, [a1], lambda x: x) + finally: + YearTransform._unregister_lookup(Exactly) + YearTransform._unregister_lookup(Exactly, custom_lookup_name) + models.DateField._unregister_lookup(YearTransform) + models.DateField._unregister_lookup(YearTransform, custom_transform_name) + def test_basic_lookup(self): a1 = Author.objects.create(name='a1', age=1) a2 = Author.objects.create(name='a2', age=2) @@ -299,6 +341,19 @@ class BilateralTransformTests(TestCase): with self.assertRaises(NotImplementedError): Author.objects.filter(name__upper__in=Author.objects.values_list('name')) + def test_bilateral_multi_value(self): + with register_lookup(models.CharField, UpperBilateralTransform): + Author.objects.bulk_create([ + Author(name='Foo'), + Author(name='Bar'), + Author(name='Ray'), + ]) + self.assertQuerysetEqual( + Author.objects.filter(name__upper__in=['foo', 'bar', 'doe']).order_by('name'), + ['Bar', 'Foo'], + lambda a: a.name + ) + def test_div3_bilateral_extract(self): with register_lookup(models.IntegerField, Div3BilateralTransform): a1 = Author.objects.create(name='a1', age=1) diff --git a/tests/db_functions/tests.py b/tests/db_functions/tests.py index a401c550fb..00aac82a7b 100644 --- a/tests/db_functions/tests.py +++ b/tests/db_functions/tests.py @@ -547,3 +547,97 @@ class FunctionTests(TestCase): ['How to Time Travel'], lambda a: a.title ) + + def test_length_transform(self): + try: + CharField.register_lookup(Length, 'length') + Author.objects.create(name='John Smith', alias='smithj') + Author.objects.create(name='Rhonda') + authors = Author.objects.filter(name__length__gt=7) + self.assertQuerysetEqual( + authors.order_by('name'), [ + 'John Smith', + ], + lambda a: a.name + ) + finally: + CharField._unregister_lookup(Length, 'length') + + def test_lower_transform(self): + try: + CharField.register_lookup(Lower, 'lower') + Author.objects.create(name='John Smith', alias='smithj') + Author.objects.create(name='Rhonda') + authors = Author.objects.filter(name__lower__exact='john smith') + self.assertQuerysetEqual( + authors.order_by('name'), [ + 'John Smith', + ], + lambda a: a.name + ) + finally: + CharField._unregister_lookup(Lower, 'lower') + + def test_upper_transform(self): + try: + CharField.register_lookup(Upper, 'upper') + Author.objects.create(name='John Smith', alias='smithj') + Author.objects.create(name='Rhonda') + authors = Author.objects.filter(name__upper__exact='JOHN SMITH') + self.assertQuerysetEqual( + authors.order_by('name'), [ + 'John Smith', + ], + lambda a: a.name + ) + finally: + CharField._unregister_lookup(Upper, 'upper') + + def test_func_transform_bilateral(self): + class UpperBilateral(Upper): + bilateral = True + + try: + CharField.register_lookup(UpperBilateral, 'upper') + Author.objects.create(name='John Smith', alias='smithj') + Author.objects.create(name='Rhonda') + authors = Author.objects.filter(name__upper__exact='john smith') + self.assertQuerysetEqual( + authors.order_by('name'), [ + 'John Smith', + ], + lambda a: a.name + ) + finally: + CharField._unregister_lookup(UpperBilateral, 'upper') + + def test_func_transform_bilateral_multivalue(self): + class UpperBilateral(Upper): + bilateral = True + + try: + CharField.register_lookup(UpperBilateral, 'upper') + Author.objects.create(name='John Smith', alias='smithj') + Author.objects.create(name='Rhonda') + authors = Author.objects.filter(name__upper__in=['john smith', 'rhonda']) + self.assertQuerysetEqual( + authors.order_by('name'), [ + 'John Smith', + 'Rhonda', + ], + lambda a: a.name + ) + finally: + CharField._unregister_lookup(UpperBilateral, 'upper') + + def test_function_as_filter(self): + Author.objects.create(name='John Smith', alias='SMITHJ') + Author.objects.create(name='Rhonda') + self.assertQuerysetEqual( + Author.objects.filter(alias=Upper(V('smithj'))), + ['John Smith'], lambda x: x.name + ) + self.assertQuerysetEqual( + Author.objects.exclude(alias=Upper(V('smithj'))), + ['Rhonda'], lambda x: x.name + )