Move % addition to lookups, refactor postgres lookups.

These refactorings making overriding some text based lookup names on
other fields (specifically `contains`) much cleaner. It also removes a
bunch of duplication in the contrib.postgres lookups.
This commit is contained in:
Marc Tamlyn 2015-01-10 16:11:15 +00:00
parent 74f02557e0
commit 916e38802f
5 changed files with 108 additions and 108 deletions

View File

@ -1,9 +1,10 @@
import json
from django.contrib.postgres import lookups
from django.contrib.postgres.forms import SimpleArrayField
from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions
from django.db.models import Field, Lookup, Transform, IntegerField
from django.db.models import Field, Transform, IntegerField
from django.utils import six
from django.utils.translation import string_concat, ugettext_lazy as _
@ -74,12 +75,6 @@ class ArrayField(Field):
return [self.base_field.get_prep_value(i) for i in value]
return value
def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
if lookup_type == 'contains':
return [self.get_prep_value(value)]
return super(ArrayField, self).get_db_prep_lookup(lookup_type, value,
connection, prepared=False)
def deconstruct(self):
name, path, args, kwargs = super(ArrayField, self).deconstruct()
if path == 'django.contrib.postgres.fields.array.ArrayField':
@ -156,46 +151,21 @@ class ArrayField(Field):
@ArrayField.register_lookup
class ArrayContainsLookup(Lookup):
lookup_name = 'contains'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
type_cast = self.lhs.output_field.db_type(connection)
return '%s @> %s::%s' % (lhs, rhs, type_cast), params
class ArrayContains(lookups.DataContains):
def as_sql(self, qn, connection):
sql, params = super(ArrayContains, self).as_sql(qn, connection)
sql += '::%s' % self.lhs.output_field.db_type(connection)
return sql, params
@ArrayField.register_lookup
class ArrayContainedByLookup(Lookup):
lookup_name = 'contained_by'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s <@ %s' % (lhs, rhs), params
@ArrayField.register_lookup
class ArrayOverlapLookup(Lookup):
lookup_name = 'overlap'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s && %s' % (lhs, rhs), params
ArrayField.register_lookup(lookups.ContainedBy)
ArrayField.register_lookup(lookups.Overlap)
@ArrayField.register_lookup
class ArrayLenTransform(Transform):
lookup_name = 'len'
@property
def output_field(self):
return IntegerField()
output_field = IntegerField()
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)

View File

@ -1,9 +1,9 @@
import json
from django.contrib.postgres import forms
from django.contrib.postgres import forms, lookups
from django.contrib.postgres.fields.array import ArrayField
from django.core import exceptions
from django.db.models import Field, Lookup, Transform, TextField
from django.db.models import Field, Transform, TextField
from django.utils import six
from django.utils.translation import ugettext_lazy as _
@ -21,12 +21,6 @@ class HStoreField(Field):
def db_type(self, connection):
return 'hstore'
def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
if lookup_type == 'contains':
return [self.get_prep_value(value)]
return super(HStoreField, self).get_db_prep_lookup(lookup_type, value,
connection, prepared=False)
def get_transform(self, name):
transform = super(HStoreField, self).get_transform(name)
if transform:
@ -60,48 +54,20 @@ class HStoreField(Field):
return super(HStoreField, self).formfield(**defaults)
@HStoreField.register_lookup
class HStoreContainsLookup(Lookup):
lookup_name = 'contains'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s @> %s' % (lhs, rhs), params
HStoreField.register_lookup(lookups.DataContains)
HStoreField.register_lookup(lookups.ContainedBy)
@HStoreField.register_lookup
class HStoreContainedByLookup(Lookup):
lookup_name = 'contained_by'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s <@ %s' % (lhs, rhs), params
@HStoreField.register_lookup
class HasKeyLookup(Lookup):
class HasKeyLookup(lookups.PostgresSimpleLookup):
lookup_name = 'has_key'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s ? %s' % (lhs, rhs), params
operator = '?'
@HStoreField.register_lookup
class HasKeysLookup(Lookup):
class HasKeysLookup(lookups.PostgresSimpleLookup):
lookup_name = 'has_keys'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s ?& %s' % (lhs, rhs), params
operator = '?&'
class KeyTransform(Transform):
@ -126,20 +92,14 @@ class KeyTransformFactory(object):
@HStoreField.register_lookup
class KeysTransform(Transform):
class KeysTransform(lookups.FunctionTransform):
lookup_name = 'keys'
function = 'akeys'
output_field = ArrayField(TextField())
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
return 'akeys(%s)' % lhs, params
@HStoreField.register_lookup
class ValuesTransform(Transform):
class ValuesTransform(lookups.FunctionTransform):
lookup_name = 'values'
function = 'avals'
output_field = ArrayField(TextField())
def as_sql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
return 'avals(%s)' % lhs, params

View File

@ -1,10 +1,36 @@
from django.db.models import Transform
from django.db.models import Lookup, Transform
class Unaccent(Transform):
class PostgresSimpleLookup(Lookup):
def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params
return '%s %s %s' % (lhs, self.operator, rhs), params
class FunctionTransform(Transform):
def as_sql(self, qn, connection):
lhs, params = qn.compile(self.lhs)
return "%s(%s)" % (self.function, lhs), params
class DataContains(PostgresSimpleLookup):
lookup_name = 'contains'
operator = '@>'
class ContainedBy(PostgresSimpleLookup):
lookup_name = 'contained_by'
operator = '<@'
class Overlap(PostgresSimpleLookup):
lookup_name = 'overlap'
operator = '&&'
class Unaccent(FunctionTransform):
bilateral = True
lookup_name = 'unaccent'
def as_postgresql(self, compiler, connection):
lhs, params = compiler.compile(self.lhs)
return "UNACCENT(%s)" % lhs, params
function = 'UNACCENT'

View File

@ -746,7 +746,9 @@ class Field(RegisterLookupMixin):
return QueryWrapper(('(%s)' % sql), params)
if lookup_type in ('month', 'day', 'week_day', 'hour', 'minute',
'second', 'search', 'regex', 'iregex'):
'second', 'search', 'regex', 'iregex', 'contains',
'icontains', 'iexact', 'startswith', 'endswith',
'istartswith', 'iendswith'):
return [value]
elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'):
return [self.get_db_prep_value(value, connection=connection,
@ -754,14 +756,6 @@ class Field(RegisterLookupMixin):
elif lookup_type in ('range', 'in'):
return [self.get_db_prep_value(v, connection=connection,
prepared=prepared) for v in value]
elif lookup_type in ('contains', 'icontains'):
return ["%%%s%%" % connection.ops.prep_for_like_query(value)]
elif lookup_type == 'iexact':
return [connection.ops.prep_for_iexact_query(value)]
elif lookup_type in ('startswith', 'istartswith'):
return ["%s%%" % connection.ops.prep_for_like_query(value)]
elif lookup_type in ('endswith', 'iendswith'):
return ["%%%s" % connection.ops.prep_for_like_query(value)]
elif lookup_type == 'isnull':
return []
elif lookup_type == 'year':

View File

@ -222,6 +222,14 @@ default_lookups['exact'] = Exact
class IExact(BuiltinLookup):
lookup_name = 'iexact'
def process_rhs(self, qn, connection):
rhs, params = super(IExact, self).process_rhs(qn, connection)
if params:
params[0] = connection.ops.prep_for_iexact_query(params[0])
return rhs, params
default_lookups['iexact'] = IExact
@ -317,31 +325,73 @@ class PatternLookup(BuiltinLookup):
class Contains(PatternLookup):
lookup_name = 'contains'
def process_rhs(self, qn, connection):
rhs, params = super(Contains, self).process_rhs(qn, connection)
if params and not self.bilateral_transforms:
params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0])
return rhs, params
default_lookups['contains'] = Contains
class IContains(PatternLookup):
class IContains(Contains):
lookup_name = 'icontains'
default_lookups['icontains'] = IContains
class StartsWith(PatternLookup):
lookup_name = 'startswith'
def process_rhs(self, qn, connection):
rhs, params = super(StartsWith, self).process_rhs(qn, connection)
if params and not self.bilateral_transforms:
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
return rhs, params
default_lookups['startswith'] = StartsWith
class IStartsWith(PatternLookup):
lookup_name = 'istartswith'
def process_rhs(self, qn, connection):
rhs, params = super(IStartsWith, self).process_rhs(qn, connection)
if params and not self.bilateral_transforms:
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
return rhs, params
default_lookups['istartswith'] = IStartsWith
class EndsWith(PatternLookup):
lookup_name = 'endswith'
def process_rhs(self, qn, connection):
rhs, params = super(EndsWith, self).process_rhs(qn, connection)
if params and not self.bilateral_transforms:
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
return rhs, params
default_lookups['endswith'] = EndsWith
class IEndsWith(PatternLookup):
lookup_name = 'iendswith'
def process_rhs(self, qn, connection):
rhs, params = super(IEndsWith, self).process_rhs(qn, connection)
if params and not self.bilateral_transforms:
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
return rhs, params
default_lookups['iendswith'] = IEndsWith