Fixed #26903 -- Fixed __contains lookup for Date/DateTimeRangeField.

Thanks Mariusz Felisiak and Tim Graham for polishing the patch.
This commit is contained in:
Mariusz Felisiak 2017-02-07 18:46:18 +01:00 committed by Tim Graham
parent 0034e9af18
commit 6b048b364c
2 changed files with 107 additions and 1 deletions

View File

@ -1,3 +1,4 @@
import datetime
import json import json
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
@ -131,6 +132,37 @@ RangeField.register_lookup(lookups.ContainedBy)
RangeField.register_lookup(lookups.Overlap) RangeField.register_lookup(lookups.Overlap)
class DateTimeRangeContains(models.Lookup):
"""
Lookup for Date/DateTimeRange containment to cast the rhs to the correct
type.
"""
lookup_name = 'contains'
def process_rhs(self, compiler, connection):
# Transform rhs value for db lookup.
if isinstance(self.rhs, datetime.date):
output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
value = models.Value(self.rhs, output_field=output_field)
self.rhs = value.resolve_expression(compiler.query)
return super().process_rhs(compiler, connection)
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
# Cast the rhs if needed.
cast_sql = ''
if isinstance(self.rhs, models.Expression) and self.rhs._output_field_or_none:
cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
return '%s @> %s%s' % (lhs, rhs, cast_sql), params
DateRangeField.register_lookup(DateTimeRangeContains)
DateTimeRangeField.register_lookup(DateTimeRangeContains)
class RangeContainedBy(models.Lookup): class RangeContainedBy(models.Lookup):
lookup_name = 'contained_by' lookup_name = 'contained_by'
type_mapping = { type_mapping = {

View File

@ -3,7 +3,7 @@ import json
from django import forms from django import forms
from django.core import exceptions, serializers from django.core import exceptions, serializers
from django.db.models import F from django.db.models import DateField, DateTimeField, F, Func, Value
from django.test import override_settings from django.test import override_settings
from django.utils import timezone from django.utils import timezone
@ -87,6 +87,80 @@ class TestSaveLoad(PostgreSQLTestCase):
self.assertEqual(field.base_field.model, RangesModel) self.assertEqual(field.base_field.model, RangesModel)
class TestRangeContainsLookup(PostgreSQLTestCase):
@classmethod
def setUpTestData(cls):
cls.timestamps = [
datetime.datetime(year=2016, month=1, day=1),
datetime.datetime(year=2016, month=1, day=2, hour=1),
datetime.datetime(year=2016, month=1, day=2, hour=12),
datetime.datetime(year=2016, month=1, day=3),
datetime.datetime(year=2016, month=1, day=3, hour=1),
datetime.datetime(year=2016, month=2, day=2),
]
cls.aware_timestamps = [
timezone.make_aware(timestamp, timezone.get_current_timezone())
for timestamp in cls.timestamps
]
cls.dates = [
datetime.date(year=2016, month=1, day=1),
datetime.date(year=2016, month=1, day=2),
datetime.date(year=2016, month=1, day=3),
datetime.date(year=2016, month=1, day=4),
datetime.date(year=2016, month=2, day=2),
datetime.date(year=2016, month=2, day=3),
]
cls.obj = RangesModel.objects.create(
dates=(cls.dates[0], cls.dates[3]),
timestamps=(cls.timestamps[0], cls.timestamps[3]),
)
cls.aware_obj = RangesModel.objects.create(
dates=(cls.dates[0], cls.dates[3]),
timestamps=(cls.aware_timestamps[0], cls.aware_timestamps[3]),
)
# Objects that don't match any queries.
for i in range(3, 4):
RangesModel.objects.create(
dates=(cls.dates[i], cls.dates[i + 1]),
timestamps=(cls.timestamps[i], cls.timestamps[i + 1]),
)
RangesModel.objects.create(
dates=(cls.dates[i], cls.dates[i + 1]),
timestamps=(cls.aware_timestamps[i], cls.aware_timestamps[i + 1]),
)
def test_datetime_range_contains(self):
filter_args = (
self.timestamps[1],
self.aware_timestamps[1],
(self.timestamps[1], self.timestamps[2]),
(self.aware_timestamps[1], self.aware_timestamps[2]),
Value(self.dates[0], output_field=DateTimeField()),
Func(F('dates'), function='lower', output_field=DateTimeField()),
)
for filter_arg in filter_args:
with self.subTest(filter_arg=filter_arg):
self.assertCountEqual(
RangesModel.objects.filter(**{'timestamps__contains': filter_arg}),
[self.obj, self.aware_obj],
)
def test_date_range_contains(self):
filter_args = (
self.timestamps[1],
(self.dates[1], self.dates[2]),
Value(self.dates[0], output_field=DateField()),
Func(F('timestamps'), function='lower', output_field=DateField()),
)
for filter_arg in filter_args:
with self.subTest(filter_arg=filter_arg):
self.assertCountEqual(
RangesModel.objects.filter(**{'dates__contains': filter_arg}),
[self.obj, self.aware_obj],
)
class TestQuerying(PostgreSQLTestCase): class TestQuerying(PostgreSQLTestCase):
@classmethod @classmethod