Move % addition to lookups, refactor postgres lookups.
These refactorings making overriding some text based lookup names on other fields (specifically `contains`) much cleaner. It also removes a bunch of duplication in the contrib.postgres lookups.
This commit is contained in:
parent
74f02557e0
commit
916e38802f
|
@ -1,9 +1,10 @@
|
|||
import json
|
||||
|
||||
from django.contrib.postgres import lookups
|
||||
from django.contrib.postgres.forms import SimpleArrayField
|
||||
from django.contrib.postgres.validators import ArrayMaxLengthValidator
|
||||
from django.core import checks, exceptions
|
||||
from django.db.models import Field, Lookup, Transform, IntegerField
|
||||
from django.db.models import Field, Transform, IntegerField
|
||||
from django.utils import six
|
||||
from django.utils.translation import string_concat, ugettext_lazy as _
|
||||
|
||||
|
@ -74,12 +75,6 @@ class ArrayField(Field):
|
|||
return [self.base_field.get_prep_value(i) for i in value]
|
||||
return value
|
||||
|
||||
def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
|
||||
if lookup_type == 'contains':
|
||||
return [self.get_prep_value(value)]
|
||||
return super(ArrayField, self).get_db_prep_lookup(lookup_type, value,
|
||||
connection, prepared=False)
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super(ArrayField, self).deconstruct()
|
||||
if path == 'django.contrib.postgres.fields.array.ArrayField':
|
||||
|
@ -156,46 +151,21 @@ class ArrayField(Field):
|
|||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayContainsLookup(Lookup):
|
||||
lookup_name = 'contains'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = lhs_params + rhs_params
|
||||
type_cast = self.lhs.output_field.db_type(connection)
|
||||
return '%s @> %s::%s' % (lhs, rhs, type_cast), params
|
||||
class ArrayContains(lookups.DataContains):
|
||||
def as_sql(self, qn, connection):
|
||||
sql, params = super(ArrayContains, self).as_sql(qn, connection)
|
||||
sql += '::%s' % self.lhs.output_field.db_type(connection)
|
||||
return sql, params
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayContainedByLookup(Lookup):
|
||||
lookup_name = 'contained_by'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = lhs_params + rhs_params
|
||||
return '%s <@ %s' % (lhs, rhs), params
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayOverlapLookup(Lookup):
|
||||
lookup_name = 'overlap'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = lhs_params + rhs_params
|
||||
return '%s && %s' % (lhs, rhs), params
|
||||
ArrayField.register_lookup(lookups.ContainedBy)
|
||||
ArrayField.register_lookup(lookups.Overlap)
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayLenTransform(Transform):
|
||||
lookup_name = 'len'
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
return IntegerField()
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import json
|
||||
|
||||
from django.contrib.postgres import forms
|
||||
from django.contrib.postgres import forms, lookups
|
||||
from django.contrib.postgres.fields.array import ArrayField
|
||||
from django.core import exceptions
|
||||
from django.db.models import Field, Lookup, Transform, TextField
|
||||
from django.db.models import Field, Transform, TextField
|
||||
from django.utils import six
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
|
@ -21,12 +21,6 @@ class HStoreField(Field):
|
|||
def db_type(self, connection):
|
||||
return 'hstore'
|
||||
|
||||
def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
|
||||
if lookup_type == 'contains':
|
||||
return [self.get_prep_value(value)]
|
||||
return super(HStoreField, self).get_db_prep_lookup(lookup_type, value,
|
||||
connection, prepared=False)
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super(HStoreField, self).get_transform(name)
|
||||
if transform:
|
||||
|
@ -60,48 +54,20 @@ class HStoreField(Field):
|
|||
return super(HStoreField, self).formfield(**defaults)
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class HStoreContainsLookup(Lookup):
|
||||
lookup_name = 'contains'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = lhs_params + rhs_params
|
||||
return '%s @> %s' % (lhs, rhs), params
|
||||
HStoreField.register_lookup(lookups.DataContains)
|
||||
HStoreField.register_lookup(lookups.ContainedBy)
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class HStoreContainedByLookup(Lookup):
|
||||
lookup_name = 'contained_by'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = lhs_params + rhs_params
|
||||
return '%s <@ %s' % (lhs, rhs), params
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class HasKeyLookup(Lookup):
|
||||
class HasKeyLookup(lookups.PostgresSimpleLookup):
|
||||
lookup_name = 'has_key'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = lhs_params + rhs_params
|
||||
return '%s ? %s' % (lhs, rhs), params
|
||||
operator = '?'
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class HasKeysLookup(Lookup):
|
||||
class HasKeysLookup(lookups.PostgresSimpleLookup):
|
||||
lookup_name = 'has_keys'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = lhs_params + rhs_params
|
||||
return '%s ?& %s' % (lhs, rhs), params
|
||||
operator = '?&'
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
|
@ -126,20 +92,14 @@ class KeyTransformFactory(object):
|
|||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class KeysTransform(Transform):
|
||||
class KeysTransform(lookups.FunctionTransform):
|
||||
lookup_name = 'keys'
|
||||
function = 'akeys'
|
||||
output_field = ArrayField(TextField())
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return 'akeys(%s)' % lhs, params
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class ValuesTransform(Transform):
|
||||
class ValuesTransform(lookups.FunctionTransform):
|
||||
lookup_name = 'values'
|
||||
function = 'avals'
|
||||
output_field = ArrayField(TextField())
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return 'avals(%s)' % lhs, params
|
||||
|
|
|
@ -1,10 +1,36 @@
|
|||
from django.db.models import Transform
|
||||
from django.db.models import Lookup, Transform
|
||||
|
||||
|
||||
class Unaccent(Transform):
|
||||
class PostgresSimpleLookup(Lookup):
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, lhs_params = self.process_lhs(qn, connection)
|
||||
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||
params = lhs_params + rhs_params
|
||||
return '%s %s %s' % (lhs, self.operator, rhs), params
|
||||
|
||||
|
||||
class FunctionTransform(Transform):
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, params = qn.compile(self.lhs)
|
||||
return "%s(%s)" % (self.function, lhs), params
|
||||
|
||||
|
||||
class DataContains(PostgresSimpleLookup):
|
||||
lookup_name = 'contains'
|
||||
operator = '@>'
|
||||
|
||||
|
||||
class ContainedBy(PostgresSimpleLookup):
|
||||
lookup_name = 'contained_by'
|
||||
operator = '<@'
|
||||
|
||||
|
||||
class Overlap(PostgresSimpleLookup):
|
||||
lookup_name = 'overlap'
|
||||
operator = '&&'
|
||||
|
||||
|
||||
class Unaccent(FunctionTransform):
|
||||
bilateral = True
|
||||
lookup_name = 'unaccent'
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return "UNACCENT(%s)" % lhs, params
|
||||
function = 'UNACCENT'
|
||||
|
|
|
@ -746,7 +746,9 @@ class Field(RegisterLookupMixin):
|
|||
return QueryWrapper(('(%s)' % sql), params)
|
||||
|
||||
if lookup_type in ('month', 'day', 'week_day', 'hour', 'minute',
|
||||
'second', 'search', 'regex', 'iregex'):
|
||||
'second', 'search', 'regex', 'iregex', 'contains',
|
||||
'icontains', 'iexact', 'startswith', 'endswith',
|
||||
'istartswith', 'iendswith'):
|
||||
return [value]
|
||||
elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'):
|
||||
return [self.get_db_prep_value(value, connection=connection,
|
||||
|
@ -754,14 +756,6 @@ class Field(RegisterLookupMixin):
|
|||
elif lookup_type in ('range', 'in'):
|
||||
return [self.get_db_prep_value(v, connection=connection,
|
||||
prepared=prepared) for v in value]
|
||||
elif lookup_type in ('contains', 'icontains'):
|
||||
return ["%%%s%%" % connection.ops.prep_for_like_query(value)]
|
||||
elif lookup_type == 'iexact':
|
||||
return [connection.ops.prep_for_iexact_query(value)]
|
||||
elif lookup_type in ('startswith', 'istartswith'):
|
||||
return ["%s%%" % connection.ops.prep_for_like_query(value)]
|
||||
elif lookup_type in ('endswith', 'iendswith'):
|
||||
return ["%%%s" % connection.ops.prep_for_like_query(value)]
|
||||
elif lookup_type == 'isnull':
|
||||
return []
|
||||
elif lookup_type == 'year':
|
||||
|
|
|
@ -222,6 +222,14 @@ default_lookups['exact'] = Exact
|
|||
|
||||
class IExact(BuiltinLookup):
|
||||
lookup_name = 'iexact'
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super(IExact, self).process_rhs(qn, connection)
|
||||
if params:
|
||||
params[0] = connection.ops.prep_for_iexact_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['iexact'] = IExact
|
||||
|
||||
|
||||
|
@ -317,31 +325,73 @@ class PatternLookup(BuiltinLookup):
|
|||
|
||||
class Contains(PatternLookup):
|
||||
lookup_name = 'contains'
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super(Contains, self).process_rhs(qn, connection)
|
||||
if params and not self.bilateral_transforms:
|
||||
params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['contains'] = Contains
|
||||
|
||||
|
||||
class IContains(PatternLookup):
|
||||
class IContains(Contains):
|
||||
lookup_name = 'icontains'
|
||||
|
||||
|
||||
default_lookups['icontains'] = IContains
|
||||
|
||||
|
||||
class StartsWith(PatternLookup):
|
||||
lookup_name = 'startswith'
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super(StartsWith, self).process_rhs(qn, connection)
|
||||
if params and not self.bilateral_transforms:
|
||||
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['startswith'] = StartsWith
|
||||
|
||||
|
||||
class IStartsWith(PatternLookup):
|
||||
lookup_name = 'istartswith'
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super(IStartsWith, self).process_rhs(qn, connection)
|
||||
if params and not self.bilateral_transforms:
|
||||
params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['istartswith'] = IStartsWith
|
||||
|
||||
|
||||
class EndsWith(PatternLookup):
|
||||
lookup_name = 'endswith'
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super(EndsWith, self).process_rhs(qn, connection)
|
||||
if params and not self.bilateral_transforms:
|
||||
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['endswith'] = EndsWith
|
||||
|
||||
|
||||
class IEndsWith(PatternLookup):
|
||||
lookup_name = 'iendswith'
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super(IEndsWith, self).process_rhs(qn, connection)
|
||||
if params and not self.bilateral_transforms:
|
||||
params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
default_lookups['iendswith'] = IEndsWith
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue