Move % addition to lookups, refactor postgres lookups.

These refactorings making overriding some text based lookup names on
other fields (specifically `contains`) much cleaner. It also removes a
bunch of duplication in the contrib.postgres lookups.
This commit is contained in:
Marc Tamlyn 2015-01-10 16:11:15 +00:00
parent 74f02557e0
commit 916e38802f
5 changed files with 108 additions and 108 deletions

View File

@ -1,9 +1,10 @@
import json import json
from django.contrib.postgres import lookups
from django.contrib.postgres.forms import SimpleArrayField from django.contrib.postgres.forms import SimpleArrayField
from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions from django.core import checks, exceptions
from django.db.models import Field, Lookup, Transform, IntegerField from django.db.models import Field, Transform, IntegerField
from django.utils import six from django.utils import six
from django.utils.translation import string_concat, ugettext_lazy as _ from django.utils.translation import string_concat, ugettext_lazy as _
@ -74,12 +75,6 @@ class ArrayField(Field):
return [self.base_field.get_prep_value(i) for i in value] return [self.base_field.get_prep_value(i) for i in value]
return value return value
def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
if lookup_type == 'contains':
return [self.get_prep_value(value)]
return super(ArrayField, self).get_db_prep_lookup(lookup_type, value,
connection, prepared=False)
def deconstruct(self): def deconstruct(self):
name, path, args, kwargs = super(ArrayField, self).deconstruct() name, path, args, kwargs = super(ArrayField, self).deconstruct()
if path == 'django.contrib.postgres.fields.array.ArrayField': if path == 'django.contrib.postgres.fields.array.ArrayField':
@ -156,46 +151,21 @@ class ArrayField(Field):
@ArrayField.register_lookup @ArrayField.register_lookup
class ArrayContainsLookup(Lookup): class ArrayContains(lookups.DataContains):
lookup_name = 'contains' def as_sql(self, qn, connection):
sql, params = super(ArrayContains, self).as_sql(qn, connection)
def as_sql(self, compiler, connection): sql += '::%s' % self.lhs.output_field.db_type(connection)
lhs, lhs_params = self.process_lhs(compiler, connection) return sql, params
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
type_cast = self.lhs.output_field.db_type(connection)
return '%s @> %s::%s' % (lhs, rhs, type_cast), params
@ArrayField.register_lookup ArrayField.register_lookup(lookups.ContainedBy)
class ArrayContainedByLookup(Lookup): ArrayField.register_lookup(lookups.Overlap)
lookup_name = 'contained_by'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s <@ %s' % (lhs, rhs), params
@ArrayField.register_lookup
class ArrayOverlapLookup(Lookup):
lookup_name = 'overlap'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s && %s' % (lhs, rhs), params
@ArrayField.register_lookup @ArrayField.register_lookup
class ArrayLenTransform(Transform): class ArrayLenTransform(Transform):
lookup_name = 'len' lookup_name = 'len'
output_field = IntegerField()
@property
def output_field(self):
return IntegerField()
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs) lhs, params = compiler.compile(self.lhs)

View File

@ -1,9 +1,9 @@
import json import json
from django.contrib.postgres import forms from django.contrib.postgres import forms, lookups
from django.contrib.postgres.fields.array import ArrayField from django.contrib.postgres.fields.array import ArrayField
from django.core import exceptions from django.core import exceptions
from django.db.models import Field, Lookup, Transform, TextField from django.db.models import Field, Transform, TextField
from django.utils import six from django.utils import six
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -21,12 +21,6 @@ class HStoreField(Field):
def db_type(self, connection): def db_type(self, connection):
return 'hstore' return 'hstore'
def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
if lookup_type == 'contains':
return [self.get_prep_value(value)]
return super(HStoreField, self).get_db_prep_lookup(lookup_type, value,
connection, prepared=False)
def get_transform(self, name): def get_transform(self, name):
transform = super(HStoreField, self).get_transform(name) transform = super(HStoreField, self).get_transform(name)
if transform: if transform:
@ -60,48 +54,20 @@ class HStoreField(Field):
return super(HStoreField, self).formfield(**defaults) return super(HStoreField, self).formfield(**defaults)
@HStoreField.register_lookup HStoreField.register_lookup(lookups.DataContains)
class HStoreContainsLookup(Lookup): HStoreField.register_lookup(lookups.ContainedBy)
lookup_name = 'contains'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s @> %s' % (lhs, rhs), params
@HStoreField.register_lookup @HStoreField.register_lookup
class HStoreContainedByLookup(Lookup): class HasKeyLookup(lookups.PostgresSimpleLookup):
lookup_name = 'contained_by'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s <@ %s' % (lhs, rhs), params
@HStoreField.register_lookup
class HasKeyLookup(Lookup):
lookup_name = 'has_key' lookup_name = 'has_key'
operator = '?'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s ? %s' % (lhs, rhs), params
@HStoreField.register_lookup @HStoreField.register_lookup
class HasKeysLookup(Lookup): class HasKeysLookup(lookups.PostgresSimpleLookup):
lookup_name = 'has_keys' lookup_name = 'has_keys'
operator = '?&'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s ?& %s' % (lhs, rhs), params
class KeyTransform(Transform): class KeyTransform(Transform):
@ -126,20 +92,14 @@ class KeyTransformFactory(object):
@HStoreField.register_lookup @HStoreField.register_lookup
class KeysTransform(Transform): class KeysTransform(lookups.FunctionTransform):
lookup_name = 'keys' lookup_name = 'keys'
function = 'akeys'
output_field = ArrayField(TextField()) output_field = ArrayField(TextField())
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
return 'akeys(%s)' % lhs, params
@HStoreField.register_lookup @HStoreField.register_lookup
class ValuesTransform(Transform): class ValuesTransform(lookups.FunctionTransform):
lookup_name = 'values' lookup_name = 'values'
function = 'avals'
output_field = ArrayField(TextField()) output_field = ArrayField(TextField())
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
return 'avals(%s)' % lhs, params

View File

@ -1,10 +1,36 @@
from django.db.models import Transform from django.db.models import Lookup, Transform
class Unaccent(Transform): class PostgresSimpleLookup(Lookup):
def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params
return '%s %s %s' % (lhs, self.operator, rhs), params
class FunctionTransform(Transform):
def as_sql(self, qn, connection):
lhs, params = qn.compile(self.lhs)
return "%s(%s)" % (self.function, lhs), params
class DataContains(PostgresSimpleLookup):
lookup_name = 'contains'
operator = '@>'
class ContainedBy(PostgresSimpleLookup):
lookup_name = 'contained_by'
operator = '<@'
class Overlap(PostgresSimpleLookup):
lookup_name = 'overlap'
operator = '&&'
class Unaccent(FunctionTransform):
bilateral = True bilateral = True
lookup_name = 'unaccent' lookup_name = 'unaccent'
function = 'UNACCENT'
def as_postgresql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
return "UNACCENT(%s)" % lhs, params

View File

@ -746,7 +746,9 @@ class Field(RegisterLookupMixin):
return QueryWrapper(('(%s)' % sql), params) return QueryWrapper(('(%s)' % sql), params)
if lookup_type in ('month', 'day', 'week_day', 'hour', 'minute', if lookup_type in ('month', 'day', 'week_day', 'hour', 'minute',
'second', 'search', 'regex', 'iregex'): 'second', 'search', 'regex', 'iregex', 'contains',
'icontains', 'iexact', 'startswith', 'endswith',
'istartswith', 'iendswith'):
return [value] return [value]
elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'): elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'):
return [self.get_db_prep_value(value, connection=connection, return [self.get_db_prep_value(value, connection=connection,
@ -754,14 +756,6 @@ class Field(RegisterLookupMixin):
elif lookup_type in ('range', 'in'): elif lookup_type in ('range', 'in'):
return [self.get_db_prep_value(v, connection=connection, return [self.get_db_prep_value(v, connection=connection,
prepared=prepared) for v in value] prepared=prepared) for v in value]
elif lookup_type in ('contains', 'icontains'):
return ["%%%s%%" % connection.ops.prep_for_like_query(value)]
elif lookup_type == 'iexact':
return [connection.ops.prep_for_iexact_query(value)]
elif lookup_type in ('startswith', 'istartswith'):
return ["%s%%" % connection.ops.prep_for_like_query(value)]
elif lookup_type in ('endswith', 'iendswith'):
return ["%%%s" % connection.ops.prep_for_like_query(value)]
elif lookup_type == 'isnull': elif lookup_type == 'isnull':
return [] return []
elif lookup_type == 'year': elif lookup_type == 'year':

View File

@ -222,6 +222,14 @@ default_lookups['exact'] = Exact
class IExact(BuiltinLookup): class IExact(BuiltinLookup):
lookup_name = 'iexact' lookup_name = 'iexact'
def process_rhs(self, qn, connection):
rhs, params = super(IExact, self).process_rhs(qn, connection)
if params:
params[0] = connection.ops.prep_for_iexact_query(params[0])
return rhs, params
default_lookups['iexact'] = IExact default_lookups['iexact'] = IExact
@ -317,31 +325,73 @@ class PatternLookup(BuiltinLookup):
class Contains(PatternLookup): class Contains(PatternLookup):
lookup_name = 'contains' lookup_name = 'contains'
def process_rhs(self, qn, connection):
rhs, params = super(Contains, self).process_rhs(qn, connection)
if params and not self.bilateral_transforms:
params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0])
return rhs, params
default_lookups['contains'] = Contains default_lookups['contains'] = Contains
class IContains(PatternLookup): class IContains(Contains):
lookup_name = 'icontains' lookup_name = 'icontains'
default_lookups['icontains'] = IContains default_lookups['icontains'] = IContains
class StartsWith(PatternLookup): class StartsWith(PatternLookup):
lookup_name = 'startswith' lookup_name = 'startswith'
def process_rhs(self, qn, connection):
rhs, params = super(StartsWith, self).process_rhs(qn, connection)
if params and not self.bilateral_transforms:
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
return rhs, params
default_lookups['startswith'] = StartsWith default_lookups['startswith'] = StartsWith
class IStartsWith(PatternLookup): class IStartsWith(PatternLookup):
lookup_name = 'istartswith' lookup_name = 'istartswith'
def process_rhs(self, qn, connection):
rhs, params = super(IStartsWith, self).process_rhs(qn, connection)
if params and not self.bilateral_transforms:
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
return rhs, params
default_lookups['istartswith'] = IStartsWith default_lookups['istartswith'] = IStartsWith
class EndsWith(PatternLookup): class EndsWith(PatternLookup):
lookup_name = 'endswith' lookup_name = 'endswith'
def process_rhs(self, qn, connection):
rhs, params = super(EndsWith, self).process_rhs(qn, connection)
if params and not self.bilateral_transforms:
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
return rhs, params
default_lookups['endswith'] = EndsWith default_lookups['endswith'] = EndsWith
class IEndsWith(PatternLookup): class IEndsWith(PatternLookup):
lookup_name = 'iendswith' lookup_name = 'iendswith'
def process_rhs(self, qn, connection):
rhs, params = super(IEndsWith, self).process_rhs(qn, connection)
if params and not self.bilateral_transforms:
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
return rhs, params
default_lookups['iendswith'] = IEndsWith default_lookups['iendswith'] = IEndsWith