Simplified DateTimeRangeContains by making it subclass PostgresSimpleLookup.

This commit is contained in:
Mariusz Felisiak 2019-07-12 17:27:49 +02:00 committed by GitHub
parent 402e6d292f
commit 70c2b90d95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 5 deletions

View File

@ -149,12 +149,13 @@ RangeField.register_lookup(lookups.ContainedBy)
RangeField.register_lookup(lookups.Overlap) RangeField.register_lookup(lookups.Overlap)
class DateTimeRangeContains(models.Lookup): class DateTimeRangeContains(lookups.PostgresSimpleLookup):
""" """
Lookup for Date/DateTimeRange containment to cast the rhs to the correct Lookup for Date/DateTimeRange containment to cast the rhs to the correct
type. type.
""" """
lookup_name = 'contains' lookup_name = 'contains'
operator = '@>'
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
# Transform rhs value for db lookup. # Transform rhs value for db lookup.
@ -165,9 +166,7 @@ class DateTimeRangeContains(models.Lookup):
return super().process_rhs(compiler, connection) return super().process_rhs(compiler, connection)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection) sql, params = super().as_sql(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
# Cast the rhs if needed. # Cast the rhs if needed.
cast_sql = '' cast_sql = ''
if ( if (
@ -178,7 +177,7 @@ class DateTimeRangeContains(models.Lookup):
): ):
cast_internal_type = self.lhs.output_field.base_field.get_internal_type() cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type)) cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
return '%s @> %s%s' % (lhs, rhs, cast_sql), params return '%s%s' % (sql, cast_sql), params
DateRangeField.register_lookup(DateTimeRangeContains) DateRangeField.register_lookup(DateTimeRangeContains)