Fixed #24629 -- Unified Transform and Expression APIs

This commit is contained in:
Josh Smeaton 2015-08-03 12:30:06 +10:00
parent 8dc3ba5ceb
commit 534aaf56f4
15 changed files with 522 additions and 377 deletions

View File

@ -81,14 +81,14 @@ class KeyTransformFactory(object):
@HStoreField.register_lookup @HStoreField.register_lookup
class KeysTransform(lookups.FunctionTransform): class KeysTransform(Transform):
lookup_name = 'keys' lookup_name = 'keys'
function = 'akeys' function = 'akeys'
output_field = ArrayField(TextField()) output_field = ArrayField(TextField())
@HStoreField.register_lookup @HStoreField.register_lookup
class ValuesTransform(lookups.FunctionTransform): class ValuesTransform(Transform):
lookup_name = 'values' lookup_name = 'values'
function = 'avals' function = 'avals'
output_field = ArrayField(TextField()) output_field = ArrayField(TextField())

View File

@ -173,7 +173,7 @@ class AdjacentToLookup(lookups.PostgresSimpleLookup):
@RangeField.register_lookup @RangeField.register_lookup
class RangeStartsWith(lookups.FunctionTransform): class RangeStartsWith(models.Transform):
lookup_name = 'startswith' lookup_name = 'startswith'
function = 'lower' function = 'lower'
@ -183,7 +183,7 @@ class RangeStartsWith(lookups.FunctionTransform):
@RangeField.register_lookup @RangeField.register_lookup
class RangeEndsWith(lookups.FunctionTransform): class RangeEndsWith(models.Transform):
lookup_name = 'endswith' lookup_name = 'endswith'
function = 'upper' function = 'upper'
@ -193,7 +193,7 @@ class RangeEndsWith(lookups.FunctionTransform):
@RangeField.register_lookup @RangeField.register_lookup
class IsEmpty(lookups.FunctionTransform): class IsEmpty(models.Transform):
lookup_name = 'isempty' lookup_name = 'isempty'
function = 'isempty' function = 'isempty'
output_field = models.BooleanField() output_field = models.BooleanField()

View File

@ -9,12 +9,6 @@ class PostgresSimpleLookup(Lookup):
return '%s %s %s' % (lhs, self.operator, rhs), params return '%s %s %s' % (lhs, self.operator, rhs), params
class FunctionTransform(Transform):
def as_sql(self, qn, connection):
lhs, params = qn.compile(self.lhs)
return "%s(%s)" % (self.function, lhs), params
class DataContains(PostgresSimpleLookup): class DataContains(PostgresSimpleLookup):
lookup_name = 'contains' lookup_name = 'contains'
operator = '@>' operator = '@>'
@ -45,7 +39,7 @@ class HasAnyKeys(PostgresSimpleLookup):
operator = '?|' operator = '?|'
class Unaccent(FunctionTransform): class Unaccent(Transform):
bilateral = True bilateral = True
lookup_name = 'unaccent' lookup_name = 'unaccent'
function = 'UNACCENT' function = 'UNACCENT'

View File

@ -20,10 +20,7 @@ from django.core import checks, exceptions, validators
# purposes. # purposes.
from django.core.exceptions import FieldDoesNotExist # NOQA from django.core.exceptions import FieldDoesNotExist # NOQA
from django.db import connection, connections, router from django.db import connection, connections, router
from django.db.models.lookups import ( from django.db.models.query_utils import QueryWrapper, RegisterLookupMixin
Lookup, RegisterLookupMixin, Transform, default_lookups,
)
from django.db.models.query_utils import QueryWrapper
from django.utils import six, timezone from django.utils import six, timezone
from django.utils.datastructures import DictWrapper from django.utils.datastructures import DictWrapper
from django.utils.dateparse import ( from django.utils.dateparse import (
@ -120,7 +117,6 @@ class Field(RegisterLookupMixin):
'unique_for_date': _("%(field_label)s must be unique for " 'unique_for_date': _("%(field_label)s must be unique for "
"%(date_field_label)s %(lookup_type)s."), "%(date_field_label)s %(lookup_type)s."),
} }
class_lookups = default_lookups.copy()
system_check_deprecated_details = None system_check_deprecated_details = None
system_check_removed_details = None system_check_removed_details = None
@ -1492,22 +1488,6 @@ class DateTimeField(DateField):
return super(DateTimeField, self).formfield(**defaults) return super(DateTimeField, self).formfield(**defaults)
@DateTimeField.register_lookup
class DateTimeDateTransform(Transform):
lookup_name = 'date'
@cached_property
def output_field(self):
return DateField()
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
lhs_params.extend(tz_params)
return sql, lhs_params
class DecimalField(Field): class DecimalField(Field):
empty_strings_allowed = False empty_strings_allowed = False
default_error_messages = { default_error_messages = {
@ -2450,146 +2430,3 @@ class UUIDField(Field):
} }
defaults.update(kwargs) defaults.update(kwargs)
return super(UUIDField, self).formfield(**defaults) return super(UUIDField, self).formfield(**defaults)
class DateTransform(Transform):
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField):
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
params.extend(tz_params)
elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, TimeField):
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
else:
raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
return sql, params
@cached_property
def output_field(self):
return IntegerField()
class YearTransform(DateTransform):
lookup_name = 'year'
class YearLookup(Lookup):
def year_lookup_bounds(self, connection, year):
output_field = self.lhs.lhs.output_field
if isinstance(output_field, DateTimeField):
bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
else:
bounds = connection.ops.year_lookup_bounds_for_date_field(year)
return bounds
@YearTransform.register_lookup
class YearExact(YearLookup):
lookup_name = 'exact'
def as_sql(self, compiler, connection):
# We will need to skip the extract part and instead go
# directly with the originating field, that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
bounds = self.year_lookup_bounds(connection, rhs_params[0])
params.extend(bounds)
return '%s BETWEEN %%s AND %%s' % lhs_sql, params
class YearComparisonLookup(YearLookup):
def as_sql(self, compiler, connection):
# We will need to skip the extract part and instead go
# directly with the originating field, that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
rhs_sql = self.get_rhs_op(connection, rhs_sql)
start, finish = self.year_lookup_bounds(connection, rhs_params[0])
params.append(self.get_bound(start, finish))
return '%s %s' % (lhs_sql, rhs_sql), params
def get_rhs_op(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
def get_bound(self):
raise NotImplementedError(
'subclasses of YearComparisonLookup must provide a get_bound() method'
)
@YearTransform.register_lookup
class YearGt(YearComparisonLookup):
lookup_name = 'gt'
def get_bound(self, start, finish):
return finish
@YearTransform.register_lookup
class YearGte(YearComparisonLookup):
lookup_name = 'gte'
def get_bound(self, start, finish):
return start
@YearTransform.register_lookup
class YearLt(YearComparisonLookup):
lookup_name = 'lt'
def get_bound(self, start, finish):
return start
@YearTransform.register_lookup
class YearLte(YearComparisonLookup):
lookup_name = 'lte'
def get_bound(self, start, finish):
return finish
class MonthTransform(DateTransform):
lookup_name = 'month'
class DayTransform(DateTransform):
lookup_name = 'day'
class WeekDayTransform(DateTransform):
lookup_name = 'week_day'
class HourTransform(DateTransform):
lookup_name = 'hour'
class MinuteTransform(DateTransform):
lookup_name = 'minute'
class SecondTransform(DateTransform):
lookup_name = 'second'
DateField.register_lookup(YearTransform)
DateField.register_lookup(MonthTransform)
DateField.register_lookup(DayTransform)
DateField.register_lookup(WeekDayTransform)
TimeField.register_lookup(HourTransform)
TimeField.register_lookup(MinuteTransform)
TimeField.register_lookup(SecondTransform)
DateTimeField.register_lookup(YearTransform)
DateTimeField.register_lookup(MonthTransform)
DateTimeField.register_lookup(DayTransform)
DateTimeField.register_lookup(WeekDayTransform)
DateTimeField.register_lookup(HourTransform)
DateTimeField.register_lookup(MinuteTransform)
DateTimeField.register_lookup(SecondTransform)

View File

@ -1,8 +1,9 @@
""" """
Classes that represent database functions. Classes that represent database functions.
""" """
from django.db.models import DateTimeField, IntegerField from django.db.models import (
from django.db.models.expressions import Func, Value DateTimeField, Func, IntegerField, Transform, Value,
)
class Coalesce(Func): class Coalesce(Func):
@ -123,9 +124,10 @@ class Least(Func):
return super(Least, self).as_sql(compiler, connection, function='MIN') return super(Least, self).as_sql(compiler, connection, function='MIN')
class Length(Func): class Length(Transform):
"""Returns the number of characters in the expression""" """Returns the number of characters in the expression"""
function = 'LENGTH' function = 'LENGTH'
lookup_name = 'length'
def __init__(self, expression, **extra): def __init__(self, expression, **extra):
output_field = extra.pop('output_field', IntegerField()) output_field = extra.pop('output_field', IntegerField())
@ -136,8 +138,9 @@ class Length(Func):
return super(Length, self).as_sql(compiler, connection) return super(Length, self).as_sql(compiler, connection)
class Lower(Func): class Lower(Transform):
function = 'LOWER' function = 'LOWER'
lookup_name = 'lower'
def __init__(self, expression, **extra): def __init__(self, expression, **extra):
super(Lower, self).__init__(expression, **extra) super(Lower, self).__init__(expression, **extra)
@ -188,8 +191,9 @@ class Substr(Func):
return super(Substr, self).as_sql(compiler, connection) return super(Substr, self).as_sql(compiler, connection)
class Upper(Func): class Upper(Transform):
function = 'UPPER' function = 'UPPER'
lookup_name = 'upper'
def __init__(self, expression, **extra): def __init__(self, expression, **extra):
super(Upper, self).__init__(expression, **extra) super(Upper, self).__init__(expression, **extra)

View File

@ -1,101 +1,17 @@
import inspect
from copy import copy from copy import copy
from django.conf import settings
from django.db.models.expressions import Func, Value
from django.db.models.fields import (
DateField, DateTimeField, Field, IntegerField, TimeField,
)
from django.db.models.query_utils import RegisterLookupMixin
from django.utils import timezone
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.six.moves import range from django.utils.six.moves import range
from .query_utils import QueryWrapper
class Lookup(object):
class RegisterLookupMixin(object):
def _get_lookup(self, lookup_name):
try:
return self.class_lookups[lookup_name]
except KeyError:
# To allow for inheritance, check parent class' class_lookups.
for parent in inspect.getmro(self.__class__):
if 'class_lookups' not in parent.__dict__:
continue
if lookup_name in parent.class_lookups:
return parent.class_lookups[lookup_name]
except AttributeError:
# This class didn't have any class_lookups
pass
return None
def get_lookup(self, lookup_name):
found = self._get_lookup(lookup_name)
if found is None and hasattr(self, 'output_field'):
return self.output_field.get_lookup(lookup_name)
if found is not None and not issubclass(found, Lookup):
return None
return found
def get_transform(self, lookup_name):
found = self._get_lookup(lookup_name)
if found is None and hasattr(self, 'output_field'):
return self.output_field.get_transform(lookup_name)
if found is not None and not issubclass(found, Transform):
return None
return found
@classmethod
def register_lookup(cls, lookup):
if 'class_lookups' not in cls.__dict__:
cls.class_lookups = {}
cls.class_lookups[lookup.lookup_name] = lookup
return lookup
@classmethod
def _unregister_lookup(cls, lookup):
"""
Removes given lookup from cls lookups. Meant to be used in
tests only.
"""
del cls.class_lookups[lookup.lookup_name]
class Transform(RegisterLookupMixin):
bilateral = False
def __init__(self, lhs, lookups):
self.lhs = lhs
self.init_lookups = lookups[:]
def as_sql(self, compiler, connection):
raise NotImplementedError
@cached_property
def output_field(self):
return self.lhs.output_field
def copy(self):
return copy(self)
def relabeled_clone(self, relabels):
copy = self.copy()
copy.lhs = self.lhs.relabeled_clone(relabels)
return copy
def get_group_by_cols(self):
return self.lhs.get_group_by_cols()
def get_bilateral_transforms(self):
if hasattr(self.lhs, 'get_bilateral_transforms'):
bilateral_transforms = self.lhs.get_bilateral_transforms()
else:
bilateral_transforms = []
if self.bilateral:
bilateral_transforms.append((self.__class__, self.init_lookups))
return bilateral_transforms
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate
class Lookup(RegisterLookupMixin):
lookup_name = None lookup_name = None
def __init__(self, lhs, rhs): def __init__(self, lhs, rhs):
@ -115,8 +31,8 @@ class Lookup(RegisterLookupMixin):
self.bilateral_transforms = bilateral_transforms self.bilateral_transforms = bilateral_transforms
def apply_bilateral_transforms(self, value): def apply_bilateral_transforms(self, value):
for transform, lookups in self.bilateral_transforms: for transform in self.bilateral_transforms:
value = transform(value, lookups) value = transform(value)
return value return value
def batch_process_rhs(self, compiler, connection, rhs=None): def batch_process_rhs(self, compiler, connection, rhs=None):
@ -125,9 +41,9 @@ class Lookup(RegisterLookupMixin):
if self.bilateral_transforms: if self.bilateral_transforms:
sqls, sqls_params = [], [] sqls, sqls_params = [], []
for p in rhs: for p in rhs:
value = QueryWrapper('%s', value = Value(p, output_field=self.lhs.output_field)
[self.lhs.output_field.get_db_prep_value(p, connection)])
value = self.apply_bilateral_transforms(value) value = self.apply_bilateral_transforms(value)
value = value.resolve_expression(compiler.query)
sql, sql_params = compiler.compile(value) sql, sql_params = compiler.compile(value)
sqls.append(sql) sqls.append(sql)
sqls_params.extend(sql_params) sqls_params.extend(sql_params)
@ -155,9 +71,9 @@ class Lookup(RegisterLookupMixin):
if self.rhs_is_direct_value(): if self.rhs_is_direct_value():
# Do not call get_db_prep_lookup here as the value will be # Do not call get_db_prep_lookup here as the value will be
# transformed before being used for lookup # transformed before being used for lookup
value = QueryWrapper("%s", value = Value(value, output_field=self.lhs.output_field)
[self.lhs.output_field.get_db_prep_value(value, connection)])
value = self.apply_bilateral_transforms(value) value = self.apply_bilateral_transforms(value)
value = value.resolve_expression(compiler.query)
# Due to historical reasons there are a couple of different # Due to historical reasons there are a couple of different
# ways to produce sql here. get_compiler is likely a Query # ways to produce sql here. get_compiler is likely a Query
# instance, _as_sql QuerySet and as_sql just something with # instance, _as_sql QuerySet and as_sql just something with
@ -201,6 +117,31 @@ class Lookup(RegisterLookupMixin):
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
class Transform(RegisterLookupMixin, Func):
"""
RegisterLookupMixin() is first so that get_lookup() and get_transform()
first examine self and then check output_field.
"""
bilateral = False
def __init__(self, expression, **extra):
# Restrict Transform to allow only a single expression.
super(Transform, self).__init__(expression, **extra)
@property
def lhs(self):
return self.get_source_expressions()[0]
def get_bilateral_transforms(self):
if hasattr(self.lhs, 'get_bilateral_transforms'):
bilateral_transforms = self.lhs.get_bilateral_transforms()
else:
bilateral_transforms = []
if self.bilateral:
bilateral_transforms.append(self.__class__)
return bilateral_transforms
class BuiltinLookup(Lookup): class BuiltinLookup(Lookup):
def process_lhs(self, compiler, connection, lhs=None): def process_lhs(self, compiler, connection, lhs=None):
lhs_sql, params = super(BuiltinLookup, self).process_lhs( lhs_sql, params = super(BuiltinLookup, self).process_lhs(
@ -223,12 +164,9 @@ class BuiltinLookup(Lookup):
return connection.operators[self.lookup_name] % rhs return connection.operators[self.lookup_name] % rhs
default_lookups = {}
class Exact(BuiltinLookup): class Exact(BuiltinLookup):
lookup_name = 'exact' lookup_name = 'exact'
default_lookups['exact'] = Exact Field.register_lookup(Exact)
class IExact(BuiltinLookup): class IExact(BuiltinLookup):
@ -241,27 +179,27 @@ class IExact(BuiltinLookup):
return rhs, params return rhs, params
default_lookups['iexact'] = IExact Field.register_lookup(IExact)
class GreaterThan(BuiltinLookup): class GreaterThan(BuiltinLookup):
lookup_name = 'gt' lookup_name = 'gt'
default_lookups['gt'] = GreaterThan Field.register_lookup(GreaterThan)
class GreaterThanOrEqual(BuiltinLookup): class GreaterThanOrEqual(BuiltinLookup):
lookup_name = 'gte' lookup_name = 'gte'
default_lookups['gte'] = GreaterThanOrEqual Field.register_lookup(GreaterThanOrEqual)
class LessThan(BuiltinLookup): class LessThan(BuiltinLookup):
lookup_name = 'lt' lookup_name = 'lt'
default_lookups['lt'] = LessThan Field.register_lookup(LessThan)
class LessThanOrEqual(BuiltinLookup): class LessThanOrEqual(BuiltinLookup):
lookup_name = 'lte' lookup_name = 'lte'
default_lookups['lte'] = LessThanOrEqual Field.register_lookup(LessThanOrEqual)
class In(BuiltinLookup): class In(BuiltinLookup):
@ -286,10 +224,14 @@ class In(BuiltinLookup):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
max_in_list_size = connection.ops.max_in_list_size() max_in_list_size = connection.ops.max_in_list_size()
if self.rhs_is_direct_value() and (max_in_list_size and if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
len(self.rhs) > max_in_list_size): return self.split_parameter_list_as_sql(compiler, connection)
# This is a special case for Oracle which limits the number of elements return super(In, self).as_sql(compiler, connection)
# which can appear in an 'IN' clause.
def split_parameter_list_as_sql(self, compiler, connection):
# This is a special case for databases which limit the number of
# elements which can appear in an 'IN' clause.
max_in_list_size = connection.ops.max_in_list_size()
lhs, lhs_params = self.process_lhs(compiler, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.batch_process_rhs(compiler, connection) rhs, rhs_params = self.batch_process_rhs(compiler, connection)
in_clause_elements = ['('] in_clause_elements = ['(']
@ -307,11 +249,7 @@ class In(BuiltinLookup):
params.extend(sqls_params) params.extend(sqls_params)
in_clause_elements.append(')') in_clause_elements.append(')')
return ''.join(in_clause_elements), params return ''.join(in_clause_elements), params
else: Field.register_lookup(In)
return super(In, self).as_sql(compiler, connection)
default_lookups['in'] = In
class PatternLookup(BuiltinLookup): class PatternLookup(BuiltinLookup):
@ -342,16 +280,12 @@ class Contains(PatternLookup):
if params and not self.bilateral_transforms: if params and not self.bilateral_transforms:
params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0]) params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0])
return rhs, params return rhs, params
Field.register_lookup(Contains)
default_lookups['contains'] = Contains
class IContains(Contains): class IContains(Contains):
lookup_name = 'icontains' lookup_name = 'icontains'
Field.register_lookup(IContains)
default_lookups['icontains'] = IContains
class StartsWith(PatternLookup): class StartsWith(PatternLookup):
@ -362,9 +296,7 @@ class StartsWith(PatternLookup):
if params and not self.bilateral_transforms: if params and not self.bilateral_transforms:
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0]) params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
return rhs, params return rhs, params
Field.register_lookup(StartsWith)
default_lookups['startswith'] = StartsWith
class IStartsWith(PatternLookup): class IStartsWith(PatternLookup):
@ -375,9 +307,7 @@ class IStartsWith(PatternLookup):
if params and not self.bilateral_transforms: if params and not self.bilateral_transforms:
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0]) params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
return rhs, params return rhs, params
Field.register_lookup(IStartsWith)
default_lookups['istartswith'] = IStartsWith
class EndsWith(PatternLookup): class EndsWith(PatternLookup):
@ -388,9 +318,7 @@ class EndsWith(PatternLookup):
if params and not self.bilateral_transforms: if params and not self.bilateral_transforms:
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0]) params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
return rhs, params return rhs, params
Field.register_lookup(EndsWith)
default_lookups['endswith'] = EndsWith
class IEndsWith(PatternLookup): class IEndsWith(PatternLookup):
@ -401,9 +329,7 @@ class IEndsWith(PatternLookup):
if params and not self.bilateral_transforms: if params and not self.bilateral_transforms:
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0]) params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
return rhs, params return rhs, params
Field.register_lookup(IEndsWith)
default_lookups['iendswith'] = IEndsWith
class Between(BuiltinLookup): class Between(BuiltinLookup):
@ -424,8 +350,7 @@ class Range(BuiltinLookup):
return self.batch_process_rhs(compiler, connection) return self.batch_process_rhs(compiler, connection)
else: else:
return super(Range, self).process_rhs(compiler, connection) return super(Range, self).process_rhs(compiler, connection)
Field.register_lookup(Range)
default_lookups['range'] = Range
class IsNull(BuiltinLookup): class IsNull(BuiltinLookup):
@ -437,7 +362,7 @@ class IsNull(BuiltinLookup):
return "%s IS NULL" % sql, params return "%s IS NULL" % sql, params
else: else:
return "%s IS NOT NULL" % sql, params return "%s IS NOT NULL" % sql, params
default_lookups['isnull'] = IsNull Field.register_lookup(IsNull)
class Search(BuiltinLookup): class Search(BuiltinLookup):
@ -448,8 +373,7 @@ class Search(BuiltinLookup):
rhs, rhs_params = self.process_rhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
sql_template = connection.ops.fulltext_search_sql(field_name=lhs) sql_template = connection.ops.fulltext_search_sql(field_name=lhs)
return sql_template, lhs_params + rhs_params return sql_template, lhs_params + rhs_params
Field.register_lookup(Search)
default_lookups['search'] = Search
class Regex(BuiltinLookup): class Regex(BuiltinLookup):
@ -463,9 +387,168 @@ class Regex(BuiltinLookup):
rhs, rhs_params = self.process_rhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
sql_template = connection.ops.regex_lookup(self.lookup_name) sql_template = connection.ops.regex_lookup(self.lookup_name)
return sql_template % (lhs, rhs), lhs_params + rhs_params return sql_template % (lhs, rhs), lhs_params + rhs_params
default_lookups['regex'] = Regex Field.register_lookup(Regex)
class IRegex(Regex): class IRegex(Regex):
lookup_name = 'iregex' lookup_name = 'iregex'
default_lookups['iregex'] = IRegex Field.register_lookup(IRegex)
class DateTimeDateTransform(Transform):
lookup_name = 'date'
@cached_property
def output_field(self):
return DateField()
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
lhs_params.extend(tz_params)
return sql, lhs_params
class DateTransform(Transform):
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField):
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
params.extend(tz_params)
elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, TimeField):
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
else:
raise ValueError('DateTransform only valid on Date/Time/DateTimeFields')
return sql, params
@cached_property
def output_field(self):
return IntegerField()
class YearTransform(DateTransform):
lookup_name = 'year'
class YearLookup(Lookup):
def year_lookup_bounds(self, connection, year):
output_field = self.lhs.lhs.output_field
if isinstance(output_field, DateTimeField):
bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
else:
bounds = connection.ops.year_lookup_bounds_for_date_field(year)
return bounds
@YearTransform.register_lookup
class YearExact(YearLookup):
lookup_name = 'exact'
def as_sql(self, compiler, connection):
# We will need to skip the extract part and instead go
# directly with the originating field, that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
bounds = self.year_lookup_bounds(connection, rhs_params[0])
params.extend(bounds)
return '%s BETWEEN %%s AND %%s' % lhs_sql, params
class YearComparisonLookup(YearLookup):
def as_sql(self, compiler, connection):
# We will need to skip the extract part and instead go
# directly with the originating field, that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
rhs_sql = self.get_rhs_op(connection, rhs_sql)
start, finish = self.year_lookup_bounds(connection, rhs_params[0])
params.append(self.get_bound(start, finish))
return '%s %s' % (lhs_sql, rhs_sql), params
def get_rhs_op(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
def get_bound(self):
raise NotImplementedError(
'subclasses of YearComparisonLookup must provide a get_bound() method'
)
@YearTransform.register_lookup
class YearGt(YearComparisonLookup):
lookup_name = 'gt'
def get_bound(self, start, finish):
return finish
@YearTransform.register_lookup
class YearGte(YearComparisonLookup):
lookup_name = 'gte'
def get_bound(self, start, finish):
return start
@YearTransform.register_lookup
class YearLt(YearComparisonLookup):
lookup_name = 'lt'
def get_bound(self, start, finish):
return start
@YearTransform.register_lookup
class YearLte(YearComparisonLookup):
lookup_name = 'lte'
def get_bound(self, start, finish):
return finish
class MonthTransform(DateTransform):
lookup_name = 'month'
class DayTransform(DateTransform):
lookup_name = 'day'
class WeekDayTransform(DateTransform):
lookup_name = 'week_day'
class HourTransform(DateTransform):
lookup_name = 'hour'
class MinuteTransform(DateTransform):
lookup_name = 'minute'
class SecondTransform(DateTransform):
lookup_name = 'second'
DateField.register_lookup(YearTransform)
DateField.register_lookup(MonthTransform)
DateField.register_lookup(DayTransform)
DateField.register_lookup(WeekDayTransform)
TimeField.register_lookup(HourTransform)
TimeField.register_lookup(MinuteTransform)
TimeField.register_lookup(SecondTransform)
DateTimeField.register_lookup(DateTimeDateTransform)
DateTimeField.register_lookup(YearTransform)
DateTimeField.register_lookup(MonthTransform)
DateTimeField.register_lookup(DayTransform)
DateTimeField.register_lookup(WeekDayTransform)
DateTimeField.register_lookup(HourTransform)
DateTimeField.register_lookup(MinuteTransform)
DateTimeField.register_lookup(SecondTransform)

View File

@ -7,6 +7,7 @@ circular import difficulties.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
import inspect
from collections import namedtuple from collections import namedtuple
from django.apps import apps from django.apps import apps
@ -169,6 +170,60 @@ class DeferredAttribute(object):
return None return None
class RegisterLookupMixin(object):
def _get_lookup(self, lookup_name):
try:
return self.class_lookups[lookup_name]
except KeyError:
# To allow for inheritance, check parent class' class_lookups.
for parent in inspect.getmro(self.__class__):
if 'class_lookups' not in parent.__dict__:
continue
if lookup_name in parent.class_lookups:
return parent.class_lookups[lookup_name]
except AttributeError:
# This class didn't have any class_lookups
pass
return None
def get_lookup(self, lookup_name):
from django.db.models.lookups import Lookup
found = self._get_lookup(lookup_name)
if found is None and hasattr(self, 'output_field'):
return self.output_field.get_lookup(lookup_name)
if found is not None and not issubclass(found, Lookup):
return None
return found
def get_transform(self, lookup_name):
from django.db.models.lookups import Transform
found = self._get_lookup(lookup_name)
if found is None and hasattr(self, 'output_field'):
return self.output_field.get_transform(lookup_name)
if found is not None and not issubclass(found, Transform):
return None
return found
@classmethod
def register_lookup(cls, lookup, lookup_name=None):
if lookup_name is None:
lookup_name = lookup.lookup_name
if 'class_lookups' not in cls.__dict__:
cls.class_lookups = {}
cls.class_lookups[lookup_name] = lookup
return lookup
@classmethod
def _unregister_lookup(cls, lookup, lookup_name=None):
"""
Remove given lookup from cls lookups. For use in tests only as it's
not thread-safe.
"""
if lookup_name is None:
lookup_name = lookup.lookup_name
del cls.class_lookups[lookup_name]
def select_related_descend(field, restricted, requested, load_fields, reverse=False): def select_related_descend(field, restricted, requested, load_fields, reverse=False):
""" """
Returns True if this field should be used to descend deeper for Returns True if this field should be used to descend deeper for

View File

@ -5,7 +5,7 @@ import copy
import warnings import warnings
from django.db.models.fields import FloatField, IntegerField from django.db.models.fields import FloatField, IntegerField
from django.db.models.lookups import RegisterLookupMixin from django.db.models.query_utils import RegisterLookupMixin
from django.utils.deprecation import RemovedInDjango110Warning from django.utils.deprecation import RemovedInDjango110Warning
from django.utils.functional import cached_property from django.utils.functional import cached_property

View File

@ -1105,9 +1105,9 @@ class Query(object):
Helper method for build_lookup. Tries to fetch and initialize Helper method for build_lookup. Tries to fetch and initialize
a transform for name parameter from lhs. a transform for name parameter from lhs.
""" """
next = lhs.get_transform(name) transform_class = lhs.get_transform(name)
if next: if transform_class:
return next(lhs, rest_of_lookups) return transform_class(lhs)
else: else:
raise FieldError( raise FieldError(
"Unsupported lookup '%s' for %s or join on the field not " "Unsupported lookup '%s' for %s or join on the field not "

View File

@ -120,10 +120,7 @@ function ``ABS()`` to transform the value before comparison::
class AbsoluteValue(Transform): class AbsoluteValue(Transform):
lookup_name = 'abs' lookup_name = 'abs'
function = 'ABS'
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
return "ABS(%s)" % lhs, params
Next, let's register it for ``IntegerField``:: Next, let's register it for ``IntegerField``::
@ -157,10 +154,7 @@ be done by adding an ``output_field`` attribute to the transform::
class AbsoluteValue(Transform): class AbsoluteValue(Transform):
lookup_name = 'abs' lookup_name = 'abs'
function = 'ABS'
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
return "ABS(%s)" % lhs, params
@property @property
def output_field(self): def output_field(self):
@ -243,12 +237,9 @@ this transformation should apply to both ``lhs`` and ``rhs``::
class UpperCase(Transform): class UpperCase(Transform):
lookup_name = 'upper' lookup_name = 'upper'
function = 'UPPER'
bilateral = True bilateral = True
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
return "UPPER(%s)" % lhs, params
Next, let's register it:: Next, let's register it::
from django.db.models import CharField, TextField from django.db.models import CharField, TextField

View File

@ -180,6 +180,18 @@ Usage example::
>>> print(author.name_length, author.goes_by_length) >>> print(author.name_length, author.goes_by_length)
(14, None) (14, None)
It can also be registered as a transform. For example::
>>> from django.db.models import CharField
>>> from django.db.models.functions import Length
>>> CharField.register_lookup(Length, 'length')
>>> # Get authors whose name is longer than 7 characters
>>> authors = Author.objects.filter(name__length__gt=7)
.. versionchanged:: 1.9
The ability to register the function as a transform was added.
Lower Lower
------ ------
@ -188,6 +200,8 @@ Lower
Accepts a single text field or expression and returns the lowercase Accepts a single text field or expression and returns the lowercase
representation. representation.
It can also be registered as a transform as described in :class:`Length`.
Usage example:: Usage example::
>>> from django.db.models.functions import Lower >>> from django.db.models.functions import Lower
@ -196,6 +210,10 @@ Usage example::
>>> print(author.name_lower) >>> print(author.name_lower)
margaret smith margaret smith
.. versionchanged:: 1.9
The ability to register the function as a transform was added.
Now Now
--- ---
@ -246,6 +264,8 @@ Upper
Accepts a single text field or expression and returns the uppercase Accepts a single text field or expression and returns the uppercase
representation. representation.
It can also be registered as a transform as described in :class:`Length`.
Usage example:: Usage example::
>>> from django.db.models.functions import Upper >>> from django.db.models.functions import Upper
@ -253,3 +273,7 @@ Usage example::
>>> author = Author.objects.annotate(name_upper=Upper('name')).get() >>> author = Author.objects.annotate(name_upper=Upper('name')).get()
>>> print(author.name_upper) >>> print(author.name_upper)
MARGARET SMITH MARGARET SMITH
.. versionchanged:: 1.9
The ability to register the function as a transform was added.

View File

@ -42,12 +42,17 @@ register lookups on itself. The two prominent examples are
A mixin that implements the lookup API on a class. A mixin that implements the lookup API on a class.
.. classmethod:: register_lookup(lookup) .. classmethod:: register_lookup(lookup, lookup_name=None)
Registers a new lookup in the class. For example Registers a new lookup in the class. For example
``DateField.register_lookup(YearExact)`` will register ``YearExact`` ``DateField.register_lookup(YearExact)`` will register ``YearExact``
lookup on ``DateField``. It overrides a lookup that already exists with lookup on ``DateField``. It overrides a lookup that already exists with
the same name. the same name. ``lookup_name`` will be used for this lookup if
provided, otherwise ``lookup.lookup_name`` will be used.
.. versionchanged:: 1.9
The ``lookup_name`` parameter was added.
.. method:: get_lookup(lookup_name) .. method:: get_lookup(lookup_name)
@ -125,7 +130,14 @@ Transform reference
``<expression>__<transformation>`` (e.g. ``date__year``). ``<expression>__<transformation>`` (e.g. ``date__year``).
This class follows the :ref:`Query Expression API <query-expression>`, which This class follows the :ref:`Query Expression API <query-expression>`, which
implies that you can use ``<expression>__<transform1>__<transform2>``. implies that you can use ``<expression>__<transform1>__<transform2>``. It's
a specialized :ref:`Func() expression <func-expressions>` that only accepts
one argument. It can also be used on the right hand side of a filter or
directly as an annotation.
.. versionchanged:: 1.9
``Transform`` is now a subclass of ``Func``.
.. attribute:: bilateral .. attribute:: bilateral
@ -152,18 +164,6 @@ Transform reference
:class:`~django.db.models.Field` instance. By default is the same as :class:`~django.db.models.Field` instance. By default is the same as
its ``lhs.output_field``. its ``lhs.output_field``.
.. method:: as_sql
To be overridden; raises :exc:`NotImplementedError`.
.. method:: get_lookup(lookup_name)
Same as :meth:`~lookups.RegisterLookupMixin.get_lookup()`.
.. method:: get_transform(transform_name)
Same as :meth:`~lookups.RegisterLookupMixin.get_transform()`.
Lookup reference Lookup reference
~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~

View File

@ -520,6 +520,14 @@ Models
* Added the :class:`~django.db.models.functions.Now` database function, which * Added the :class:`~django.db.models.functions.Now` database function, which
returns the current date and time. returns the current date and time.
* :class:`~django.db.models.Transform` is now a subclass of
:ref:`Func() <func-expressions>` which allows ``Transform``\s to be used on
the right hand side of an expression, just like regular ``Func``\s. This
allows registering some database functions like
:class:`~django.db.models.functions.Length`,
:class:`~django.db.models.functions.Lower`, and
:class:`~django.db.models.functions.Upper` as transforms.
* :class:`~django.db.models.SlugField` now accepts an * :class:`~django.db.models.SlugField` now accepts an
:attr:`~django.db.models.SlugField.allow_unicode` argument to allow Unicode :attr:`~django.db.models.SlugField.allow_unicode` argument to allow Unicode
characters in slugs. characters in slugs.

View File

@ -126,11 +126,17 @@ class YearLte(models.lookups.LessThanOrEqual):
return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
class SQLFunc(models.Lookup): class Exactly(models.lookups.Exact):
def __init__(self, name, *args, **kwargs): """
super(SQLFunc, self).__init__(*args, **kwargs) This lookup is used to test lookup registration.
self.name = name """
lookup_name = 'exactly'
def get_rhs_op(self, connection, rhs):
return connection.operators['exact'] % rhs
class SQLFuncMixin(object):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return '%s()', [self.name] return '%s()', [self.name]
@ -139,13 +145,28 @@ class SQLFunc(models.Lookup):
return CustomField() return CustomField()
class SQLFuncLookup(SQLFuncMixin, models.Lookup):
def __init__(self, name, *args, **kwargs):
super(SQLFuncLookup, self).__init__(*args, **kwargs)
self.name = name
class SQLFuncTransform(SQLFuncMixin, models.Transform):
def __init__(self, name, *args, **kwargs):
super(SQLFuncTransform, self).__init__(*args, **kwargs)
self.name = name
class SQLFuncFactory(object): class SQLFuncFactory(object):
def __init__(self, name): def __init__(self, key, name):
self.key = key
self.name = name self.name = name
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return SQLFunc(self.name, *args, **kwargs) if self.key == 'lookupfunc':
return SQLFuncLookup(self.name, *args, **kwargs)
return SQLFuncTransform(self.name, *args, **kwargs)
class CustomField(models.TextField): class CustomField(models.TextField):
@ -153,13 +174,13 @@ class CustomField(models.TextField):
def get_lookup(self, lookup_name): def get_lookup(self, lookup_name):
if lookup_name.startswith('lookupfunc_'): if lookup_name.startswith('lookupfunc_'):
key, name = lookup_name.split('_', 1) key, name = lookup_name.split('_', 1)
return SQLFuncFactory(name) return SQLFuncFactory(key, name)
return super(CustomField, self).get_lookup(lookup_name) return super(CustomField, self).get_lookup(lookup_name)
def get_transform(self, lookup_name): def get_transform(self, lookup_name):
if lookup_name.startswith('transformfunc_'): if lookup_name.startswith('transformfunc_'):
key, name = lookup_name.split('_', 1) key, name = lookup_name.split('_', 1)
return SQLFuncFactory(name) return SQLFuncFactory(key, name)
return super(CustomField, self).get_transform(lookup_name) return super(CustomField, self).get_transform(lookup_name)
@ -200,6 +221,27 @@ class DateTimeTransform(models.Transform):
class LookupTests(TestCase): class LookupTests(TestCase):
def test_custom_name_lookup(self):
a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
custom_lookup_name = 'isactually'
custom_transform_name = 'justtheyear'
try:
models.DateField.register_lookup(YearTransform)
models.DateField.register_lookup(YearTransform, custom_transform_name)
YearTransform.register_lookup(Exactly)
YearTransform.register_lookup(Exactly, custom_lookup_name)
qs1 = Author.objects.filter(birthdate__testyear__exactly=1981)
qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981)
self.assertQuerysetEqual(qs1, [a1], lambda x: x)
self.assertQuerysetEqual(qs2, [a1], lambda x: x)
finally:
YearTransform._unregister_lookup(Exactly)
YearTransform._unregister_lookup(Exactly, custom_lookup_name)
models.DateField._unregister_lookup(YearTransform)
models.DateField._unregister_lookup(YearTransform, custom_transform_name)
def test_basic_lookup(self): def test_basic_lookup(self):
a1 = Author.objects.create(name='a1', age=1) a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2) a2 = Author.objects.create(name='a2', age=2)
@ -299,6 +341,19 @@ class BilateralTransformTests(TestCase):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
Author.objects.filter(name__upper__in=Author.objects.values_list('name')) Author.objects.filter(name__upper__in=Author.objects.values_list('name'))
def test_bilateral_multi_value(self):
with register_lookup(models.CharField, UpperBilateralTransform):
Author.objects.bulk_create([
Author(name='Foo'),
Author(name='Bar'),
Author(name='Ray'),
])
self.assertQuerysetEqual(
Author.objects.filter(name__upper__in=['foo', 'bar', 'doe']).order_by('name'),
['Bar', 'Foo'],
lambda a: a.name
)
def test_div3_bilateral_extract(self): def test_div3_bilateral_extract(self):
with register_lookup(models.IntegerField, Div3BilateralTransform): with register_lookup(models.IntegerField, Div3BilateralTransform):
a1 = Author.objects.create(name='a1', age=1) a1 = Author.objects.create(name='a1', age=1)

View File

@ -547,3 +547,97 @@ class FunctionTests(TestCase):
['How to Time Travel'], ['How to Time Travel'],
lambda a: a.title lambda a: a.title
) )
def test_length_transform(self):
try:
CharField.register_lookup(Length, 'length')
Author.objects.create(name='John Smith', alias='smithj')
Author.objects.create(name='Rhonda')
authors = Author.objects.filter(name__length__gt=7)
self.assertQuerysetEqual(
authors.order_by('name'), [
'John Smith',
],
lambda a: a.name
)
finally:
CharField._unregister_lookup(Length, 'length')
def test_lower_transform(self):
try:
CharField.register_lookup(Lower, 'lower')
Author.objects.create(name='John Smith', alias='smithj')
Author.objects.create(name='Rhonda')
authors = Author.objects.filter(name__lower__exact='john smith')
self.assertQuerysetEqual(
authors.order_by('name'), [
'John Smith',
],
lambda a: a.name
)
finally:
CharField._unregister_lookup(Lower, 'lower')
def test_upper_transform(self):
try:
CharField.register_lookup(Upper, 'upper')
Author.objects.create(name='John Smith', alias='smithj')
Author.objects.create(name='Rhonda')
authors = Author.objects.filter(name__upper__exact='JOHN SMITH')
self.assertQuerysetEqual(
authors.order_by('name'), [
'John Smith',
],
lambda a: a.name
)
finally:
CharField._unregister_lookup(Upper, 'upper')
def test_func_transform_bilateral(self):
class UpperBilateral(Upper):
bilateral = True
try:
CharField.register_lookup(UpperBilateral, 'upper')
Author.objects.create(name='John Smith', alias='smithj')
Author.objects.create(name='Rhonda')
authors = Author.objects.filter(name__upper__exact='john smith')
self.assertQuerysetEqual(
authors.order_by('name'), [
'John Smith',
],
lambda a: a.name
)
finally:
CharField._unregister_lookup(UpperBilateral, 'upper')
def test_func_transform_bilateral_multivalue(self):
class UpperBilateral(Upper):
bilateral = True
try:
CharField.register_lookup(UpperBilateral, 'upper')
Author.objects.create(name='John Smith', alias='smithj')
Author.objects.create(name='Rhonda')
authors = Author.objects.filter(name__upper__in=['john smith', 'rhonda'])
self.assertQuerysetEqual(
authors.order_by('name'), [
'John Smith',
'Rhonda',
],
lambda a: a.name
)
finally:
CharField._unregister_lookup(UpperBilateral, 'upper')
def test_function_as_filter(self):
Author.objects.create(name='John Smith', alias='SMITHJ')
Author.objects.create(name='Rhonda')
self.assertQuerysetEqual(
Author.objects.filter(alias=Upper(V('smithj'))),
['John Smith'], lambda x: x.name
)
self.assertQuerysetEqual(
Author.objects.exclude(alias=Upper(V('smithj'))),
['Rhonda'], lambda x: x.name
)