Refs #12990 -- Moved PostgresSimpleLookup to the django.db.models.lookups.PostgresOperatorLookup.

This commit is contained in:
Mariusz Felisiak 2020-04-01 10:55:53 +02:00 committed by GitHub
parent a7e4ff370c
commit 5c24c16e68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 39 deletions

View File

@ -5,6 +5,7 @@ from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
from django.contrib.postgres import forms, lookups from django.contrib.postgres import forms, lookups
from django.db import models from django.db import models
from django.db.models.lookups import PostgresOperatorLookup
from .utils import AttributeSetter from .utils import AttributeSetter
@ -161,13 +162,13 @@ RangeField.register_lookup(lookups.ContainedBy)
RangeField.register_lookup(lookups.Overlap) RangeField.register_lookup(lookups.Overlap)
class DateTimeRangeContains(lookups.PostgresSimpleLookup): class DateTimeRangeContains(PostgresOperatorLookup):
""" """
Lookup for Date/DateTimeRange containment to cast the rhs to the correct Lookup for Date/DateTimeRange containment to cast the rhs to the correct
type. type.
""" """
lookup_name = 'contains' lookup_name = 'contains'
operator = RangeOperators.CONTAINS postgres_operator = RangeOperators.CONTAINS
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
# Transform rhs value for db lookup. # Transform rhs value for db lookup.
@ -177,8 +178,8 @@ class DateTimeRangeContains(lookups.PostgresSimpleLookup):
self.rhs = value.resolve_expression(compiler.query) self.rhs = value.resolve_expression(compiler.query)
return super().process_rhs(compiler, connection) return super().process_rhs(compiler, connection)
def as_sql(self, compiler, connection): def as_postgresql(self, compiler, connection):
sql, params = super().as_sql(compiler, connection) sql, params = super().as_postgresql(compiler, connection)
# Cast the rhs if needed. # Cast the rhs if needed.
cast_sql = '' cast_sql = ''
if ( if (
@ -196,7 +197,7 @@ DateRangeField.register_lookup(DateTimeRangeContains)
DateTimeRangeField.register_lookup(DateTimeRangeContains) DateTimeRangeField.register_lookup(DateTimeRangeContains)
class RangeContainedBy(lookups.PostgresSimpleLookup): class RangeContainedBy(PostgresOperatorLookup):
lookup_name = 'contained_by' lookup_name = 'contained_by'
type_mapping = { type_mapping = {
'smallint': 'int4range', 'smallint': 'int4range',
@ -207,7 +208,7 @@ class RangeContainedBy(lookups.PostgresSimpleLookup):
'date': 'daterange', 'date': 'daterange',
'timestamp with time zone': 'tstzrange', 'timestamp with time zone': 'tstzrange',
} }
operator = RangeOperators.CONTAINED_BY postgres_operator = RangeOperators.CONTAINED_BY
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection) rhs, rhs_params = super().process_rhs(compiler, connection)
@ -236,33 +237,33 @@ models.DecimalField.register_lookup(RangeContainedBy)
@RangeField.register_lookup @RangeField.register_lookup
class FullyLessThan(lookups.PostgresSimpleLookup): class FullyLessThan(PostgresOperatorLookup):
lookup_name = 'fully_lt' lookup_name = 'fully_lt'
operator = RangeOperators.FULLY_LT postgres_operator = RangeOperators.FULLY_LT
@RangeField.register_lookup @RangeField.register_lookup
class FullGreaterThan(lookups.PostgresSimpleLookup): class FullGreaterThan(PostgresOperatorLookup):
lookup_name = 'fully_gt' lookup_name = 'fully_gt'
operator = RangeOperators.FULLY_GT postgres_operator = RangeOperators.FULLY_GT
@RangeField.register_lookup @RangeField.register_lookup
class NotLessThan(lookups.PostgresSimpleLookup): class NotLessThan(PostgresOperatorLookup):
lookup_name = 'not_lt' lookup_name = 'not_lt'
operator = RangeOperators.NOT_LT postgres_operator = RangeOperators.NOT_LT
@RangeField.register_lookup @RangeField.register_lookup
class NotGreaterThan(lookups.PostgresSimpleLookup): class NotGreaterThan(PostgresOperatorLookup):
lookup_name = 'not_gt' lookup_name = 'not_gt'
operator = RangeOperators.NOT_GT postgres_operator = RangeOperators.NOT_GT
@RangeField.register_lookup @RangeField.register_lookup
class AdjacentToLookup(lookups.PostgresSimpleLookup): class AdjacentToLookup(PostgresOperatorLookup):
lookup_name = 'adjacent_to' lookup_name = 'adjacent_to'
operator = RangeOperators.ADJACENT_TO postgres_operator = RangeOperators.ADJACENT_TO
@RangeField.register_lookup @RangeField.register_lookup

View File

@ -1,41 +1,33 @@
from django.db.models import Lookup, Transform from django.db.models import Transform
from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin from django.db.models.lookups import Exact, PostgresOperatorLookup
from .search import SearchVector, SearchVectorExact, SearchVectorField from .search import SearchVector, SearchVectorExact, SearchVectorField
class PostgresSimpleLookup(FieldGetDbPrepValueMixin, Lookup): class DataContains(PostgresOperatorLookup):
def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = tuple(lhs_params) + tuple(rhs_params)
return '%s %s %s' % (lhs, self.operator, rhs), params
class DataContains(PostgresSimpleLookup):
lookup_name = 'contains' lookup_name = 'contains'
operator = '@>' postgres_operator = '@>'
class ContainedBy(PostgresSimpleLookup): class ContainedBy(PostgresOperatorLookup):
lookup_name = 'contained_by' lookup_name = 'contained_by'
operator = '<@' postgres_operator = '<@'
class Overlap(PostgresSimpleLookup): class Overlap(PostgresOperatorLookup):
lookup_name = 'overlap' lookup_name = 'overlap'
operator = '&&' postgres_operator = '&&'
class HasKey(PostgresSimpleLookup): class HasKey(PostgresOperatorLookup):
lookup_name = 'has_key' lookup_name = 'has_key'
operator = '?' postgres_operator = '?'
prepare_rhs = False prepare_rhs = False
class HasKeys(PostgresSimpleLookup): class HasKeys(PostgresOperatorLookup):
lookup_name = 'has_keys' lookup_name = 'has_keys'
operator = '?&' postgres_operator = '?&'
def get_prep_lookup(self): def get_prep_lookup(self):
return [str(item) for item in self.rhs] return [str(item) for item in self.rhs]
@ -43,7 +35,7 @@ class HasKeys(PostgresSimpleLookup):
class HasAnyKeys(HasKeys): class HasAnyKeys(HasKeys):
lookup_name = 'has_any_keys' lookup_name = 'has_any_keys'
operator = '?|' postgres_operator = '?|'
class Unaccent(Transform): class Unaccent(Transform):
@ -63,9 +55,9 @@ class SearchLookup(SearchVectorExact):
return lhs, lhs_params return lhs, lhs_params
class TrigramSimilar(PostgresSimpleLookup): class TrigramSimilar(PostgresOperatorLookup):
lookup_name = 'trigram_similar' lookup_name = 'trigram_similar'
operator = '%%' postgres_operator = '%%'
class JSONExact(Exact): class JSONExact(Exact):

View File

@ -256,6 +256,17 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
return sql, tuple(params) return sql, tuple(params)
class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup):
"""Lookup defined by operators on PostgreSQL."""
postgres_operator = None
def as_postgresql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
return '%s %s %s' % (lhs, self.postgres_operator, rhs), params
@Field.register_lookup @Field.register_lookup
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'exact' lookup_name = 'exact'