From 916e38802f151b34aaca487dc7e928946e81be73 Mon Sep 17 00:00:00 2001 From: Marc Tamlyn Date: Sat, 10 Jan 2015 16:11:15 +0000 Subject: [PATCH] Move % addition to lookups, refactor postgres lookups. These refactorings making overriding some text based lookup names on other fields (specifically `contains`) much cleaner. It also removes a bunch of duplication in the contrib.postgres lookups. --- django/contrib/postgres/fields/array.py | 50 ++++-------------- django/contrib/postgres/fields/hstore.py | 64 +++++------------------- django/contrib/postgres/lookups.py | 38 +++++++++++--- django/db/models/fields/__init__.py | 12 ++--- django/db/models/lookups.py | 52 ++++++++++++++++++- 5 files changed, 108 insertions(+), 108 deletions(-) diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 65f3dc6f6a..318afabd2c 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -1,9 +1,10 @@ import json +from django.contrib.postgres import lookups from django.contrib.postgres.forms import SimpleArrayField from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.core import checks, exceptions -from django.db.models import Field, Lookup, Transform, IntegerField +from django.db.models import Field, Transform, IntegerField from django.utils import six from django.utils.translation import string_concat, ugettext_lazy as _ @@ -74,12 +75,6 @@ class ArrayField(Field): return [self.base_field.get_prep_value(i) for i in value] return value - def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): - if lookup_type == 'contains': - return [self.get_prep_value(value)] - return super(ArrayField, self).get_db_prep_lookup(lookup_type, value, - connection, prepared=False) - def deconstruct(self): name, path, args, kwargs = super(ArrayField, self).deconstruct() if path == 'django.contrib.postgres.fields.array.ArrayField': @@ -156,46 +151,21 @@ class ArrayField(Field): @ArrayField.register_lookup -class ArrayContainsLookup(Lookup): - lookup_name = 'contains' - - def as_sql(self, compiler, connection): - lhs, lhs_params = self.process_lhs(compiler, connection) - rhs, rhs_params = self.process_rhs(compiler, connection) - params = lhs_params + rhs_params - type_cast = self.lhs.output_field.db_type(connection) - return '%s @> %s::%s' % (lhs, rhs, type_cast), params +class ArrayContains(lookups.DataContains): + def as_sql(self, qn, connection): + sql, params = super(ArrayContains, self).as_sql(qn, connection) + sql += '::%s' % self.lhs.output_field.db_type(connection) + return sql, params -@ArrayField.register_lookup -class ArrayContainedByLookup(Lookup): - lookup_name = 'contained_by' - - def as_sql(self, compiler, connection): - lhs, lhs_params = self.process_lhs(compiler, connection) - rhs, rhs_params = self.process_rhs(compiler, connection) - params = lhs_params + rhs_params - return '%s <@ %s' % (lhs, rhs), params - - -@ArrayField.register_lookup -class ArrayOverlapLookup(Lookup): - lookup_name = 'overlap' - - def as_sql(self, compiler, connection): - lhs, lhs_params = self.process_lhs(compiler, connection) - rhs, rhs_params = self.process_rhs(compiler, connection) - params = lhs_params + rhs_params - return '%s && %s' % (lhs, rhs), params +ArrayField.register_lookup(lookups.ContainedBy) +ArrayField.register_lookup(lookups.Overlap) @ArrayField.register_lookup class ArrayLenTransform(Transform): lookup_name = 'len' - - @property - def output_field(self): - return IntegerField() + output_field = IntegerField() def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) diff --git a/django/contrib/postgres/fields/hstore.py b/django/contrib/postgres/fields/hstore.py index be998488fb..1524368ecf 100644 --- a/django/contrib/postgres/fields/hstore.py +++ b/django/contrib/postgres/fields/hstore.py @@ -1,9 +1,9 @@ import json -from django.contrib.postgres import forms +from django.contrib.postgres import forms, lookups from django.contrib.postgres.fields.array import ArrayField from django.core import exceptions -from django.db.models import Field, Lookup, Transform, TextField +from django.db.models import Field, Transform, TextField from django.utils import six from django.utils.translation import ugettext_lazy as _ @@ -21,12 +21,6 @@ class HStoreField(Field): def db_type(self, connection): return 'hstore' - def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): - if lookup_type == 'contains': - return [self.get_prep_value(value)] - return super(HStoreField, self).get_db_prep_lookup(lookup_type, value, - connection, prepared=False) - def get_transform(self, name): transform = super(HStoreField, self).get_transform(name) if transform: @@ -60,48 +54,20 @@ class HStoreField(Field): return super(HStoreField, self).formfield(**defaults) -@HStoreField.register_lookup -class HStoreContainsLookup(Lookup): - lookup_name = 'contains' - - def as_sql(self, compiler, connection): - lhs, lhs_params = self.process_lhs(compiler, connection) - rhs, rhs_params = self.process_rhs(compiler, connection) - params = lhs_params + rhs_params - return '%s @> %s' % (lhs, rhs), params +HStoreField.register_lookup(lookups.DataContains) +HStoreField.register_lookup(lookups.ContainedBy) @HStoreField.register_lookup -class HStoreContainedByLookup(Lookup): - lookup_name = 'contained_by' - - def as_sql(self, compiler, connection): - lhs, lhs_params = self.process_lhs(compiler, connection) - rhs, rhs_params = self.process_rhs(compiler, connection) - params = lhs_params + rhs_params - return '%s <@ %s' % (lhs, rhs), params - - -@HStoreField.register_lookup -class HasKeyLookup(Lookup): +class HasKeyLookup(lookups.PostgresSimpleLookup): lookup_name = 'has_key' - - def as_sql(self, compiler, connection): - lhs, lhs_params = self.process_lhs(compiler, connection) - rhs, rhs_params = self.process_rhs(compiler, connection) - params = lhs_params + rhs_params - return '%s ? %s' % (lhs, rhs), params + operator = '?' @HStoreField.register_lookup -class HasKeysLookup(Lookup): +class HasKeysLookup(lookups.PostgresSimpleLookup): lookup_name = 'has_keys' - - def as_sql(self, compiler, connection): - lhs, lhs_params = self.process_lhs(compiler, connection) - rhs, rhs_params = self.process_rhs(compiler, connection) - params = lhs_params + rhs_params - return '%s ?& %s' % (lhs, rhs), params + operator = '?&' class KeyTransform(Transform): @@ -126,20 +92,14 @@ class KeyTransformFactory(object): @HStoreField.register_lookup -class KeysTransform(Transform): +class KeysTransform(lookups.FunctionTransform): lookup_name = 'keys' + function = 'akeys' output_field = ArrayField(TextField()) - def as_sql(self, compiler, connection): - lhs, params = compiler.compile(self.lhs) - return 'akeys(%s)' % lhs, params - @HStoreField.register_lookup -class ValuesTransform(Transform): +class ValuesTransform(lookups.FunctionTransform): lookup_name = 'values' + function = 'avals' output_field = ArrayField(TextField()) - - def as_sql(self, compiler, connection): - lhs, params = compiler.compile(self.lhs) - return 'avals(%s)' % lhs, params diff --git a/django/contrib/postgres/lookups.py b/django/contrib/postgres/lookups.py index 4cf51dbd9c..eb7cfd8359 100644 --- a/django/contrib/postgres/lookups.py +++ b/django/contrib/postgres/lookups.py @@ -1,10 +1,36 @@ -from django.db.models import Transform +from django.db.models import Lookup, Transform -class Unaccent(Transform): +class PostgresSimpleLookup(Lookup): + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s %s %s' % (lhs, self.operator, rhs), params + + +class FunctionTransform(Transform): + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return "%s(%s)" % (self.function, lhs), params + + +class DataContains(PostgresSimpleLookup): + lookup_name = 'contains' + operator = '@>' + + +class ContainedBy(PostgresSimpleLookup): + lookup_name = 'contained_by' + operator = '<@' + + +class Overlap(PostgresSimpleLookup): + lookup_name = 'overlap' + operator = '&&' + + +class Unaccent(FunctionTransform): bilateral = True lookup_name = 'unaccent' - - def as_postgresql(self, compiler, connection): - lhs, params = compiler.compile(self.lhs) - return "UNACCENT(%s)" % lhs, params + function = 'UNACCENT' diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index adf975cfa0..03c7eafac6 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -746,7 +746,9 @@ class Field(RegisterLookupMixin): return QueryWrapper(('(%s)' % sql), params) if lookup_type in ('month', 'day', 'week_day', 'hour', 'minute', - 'second', 'search', 'regex', 'iregex'): + 'second', 'search', 'regex', 'iregex', 'contains', + 'icontains', 'iexact', 'startswith', 'endswith', + 'istartswith', 'iendswith'): return [value] elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'): return [self.get_db_prep_value(value, connection=connection, @@ -754,14 +756,6 @@ class Field(RegisterLookupMixin): elif lookup_type in ('range', 'in'): return [self.get_db_prep_value(v, connection=connection, prepared=prepared) for v in value] - elif lookup_type in ('contains', 'icontains'): - return ["%%%s%%" % connection.ops.prep_for_like_query(value)] - elif lookup_type == 'iexact': - return [connection.ops.prep_for_iexact_query(value)] - elif lookup_type in ('startswith', 'istartswith'): - return ["%s%%" % connection.ops.prep_for_like_query(value)] - elif lookup_type in ('endswith', 'iendswith'): - return ["%%%s" % connection.ops.prep_for_like_query(value)] elif lookup_type == 'isnull': return [] elif lookup_type == 'year': diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index aea2820a10..d7423762f3 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -222,6 +222,14 @@ default_lookups['exact'] = Exact class IExact(BuiltinLookup): lookup_name = 'iexact' + + def process_rhs(self, qn, connection): + rhs, params = super(IExact, self).process_rhs(qn, connection) + if params: + params[0] = connection.ops.prep_for_iexact_query(params[0]) + return rhs, params + + default_lookups['iexact'] = IExact @@ -317,31 +325,73 @@ class PatternLookup(BuiltinLookup): class Contains(PatternLookup): lookup_name = 'contains' + + def process_rhs(self, qn, connection): + rhs, params = super(Contains, self).process_rhs(qn, connection) + if params and not self.bilateral_transforms: + params[0] = "%%%s%%" % connection.ops.prep_for_like_query(params[0]) + return rhs, params + + default_lookups['contains'] = Contains -class IContains(PatternLookup): +class IContains(Contains): lookup_name = 'icontains' + + default_lookups['icontains'] = IContains class StartsWith(PatternLookup): lookup_name = 'startswith' + + def process_rhs(self, qn, connection): + rhs, params = super(StartsWith, self).process_rhs(qn, connection) + if params and not self.bilateral_transforms: + params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0]) + return rhs, params + + default_lookups['startswith'] = StartsWith class IStartsWith(PatternLookup): lookup_name = 'istartswith' + + def process_rhs(self, qn, connection): + rhs, params = super(IStartsWith, self).process_rhs(qn, connection) + if params and not self.bilateral_transforms: + params[0] = "%s%%" % connection.ops.prep_for_like_query(params[0]) + return rhs, params + + default_lookups['istartswith'] = IStartsWith class EndsWith(PatternLookup): lookup_name = 'endswith' + + def process_rhs(self, qn, connection): + rhs, params = super(EndsWith, self).process_rhs(qn, connection) + if params and not self.bilateral_transforms: + params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0]) + return rhs, params + + default_lookups['endswith'] = EndsWith class IEndsWith(PatternLookup): lookup_name = 'iendswith' + + def process_rhs(self, qn, connection): + rhs, params = super(IEndsWith, self).process_rhs(qn, connection) + if params and not self.bilateral_transforms: + params[0] = "%%%s" % connection.ops.prep_for_like_query(params[0]) + return rhs, params + + default_lookups['iendswith'] = IEndsWith