Fixed #26903 -- Fixed __contains lookup for Date/DateTimeRangeField.
Thanks Mariusz Felisiak and Tim Graham for polishing the patch.
This commit is contained in:
parent
0034e9af18
commit
6b048b364c
|
@ -1,3 +1,4 @@
|
|||
import datetime
|
||||
import json
|
||||
|
||||
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
|
||||
|
@ -131,6 +132,37 @@ RangeField.register_lookup(lookups.ContainedBy)
|
|||
RangeField.register_lookup(lookups.Overlap)
|
||||
|
||||
|
||||
class DateTimeRangeContains(models.Lookup):
|
||||
"""
|
||||
Lookup for Date/DateTimeRange containment to cast the rhs to the correct
|
||||
type.
|
||||
"""
|
||||
lookup_name = 'contains'
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
# Transform rhs value for db lookup.
|
||||
if isinstance(self.rhs, datetime.date):
|
||||
output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
|
||||
value = models.Value(self.rhs, output_field=output_field)
|
||||
self.rhs = value.resolve_expression(compiler.query)
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = lhs_params + rhs_params
|
||||
# Cast the rhs if needed.
|
||||
cast_sql = ''
|
||||
if isinstance(self.rhs, models.Expression) and self.rhs._output_field_or_none:
|
||||
cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
|
||||
cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
|
||||
return '%s @> %s%s' % (lhs, rhs, cast_sql), params
|
||||
|
||||
|
||||
DateRangeField.register_lookup(DateTimeRangeContains)
|
||||
DateTimeRangeField.register_lookup(DateTimeRangeContains)
|
||||
|
||||
|
||||
class RangeContainedBy(models.Lookup):
|
||||
lookup_name = 'contained_by'
|
||||
type_mapping = {
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
|
||||
from django import forms
|
||||
from django.core import exceptions, serializers
|
||||
from django.db.models import F
|
||||
from django.db.models import DateField, DateTimeField, F, Func, Value
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
|
||||
|
@ -87,6 +87,80 @@ class TestSaveLoad(PostgreSQLTestCase):
|
|||
self.assertEqual(field.base_field.model, RangesModel)
|
||||
|
||||
|
||||
class TestRangeContainsLookup(PostgreSQLTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
cls.timestamps = [
|
||||
datetime.datetime(year=2016, month=1, day=1),
|
||||
datetime.datetime(year=2016, month=1, day=2, hour=1),
|
||||
datetime.datetime(year=2016, month=1, day=2, hour=12),
|
||||
datetime.datetime(year=2016, month=1, day=3),
|
||||
datetime.datetime(year=2016, month=1, day=3, hour=1),
|
||||
datetime.datetime(year=2016, month=2, day=2),
|
||||
]
|
||||
cls.aware_timestamps = [
|
||||
timezone.make_aware(timestamp, timezone.get_current_timezone())
|
||||
for timestamp in cls.timestamps
|
||||
]
|
||||
cls.dates = [
|
||||
datetime.date(year=2016, month=1, day=1),
|
||||
datetime.date(year=2016, month=1, day=2),
|
||||
datetime.date(year=2016, month=1, day=3),
|
||||
datetime.date(year=2016, month=1, day=4),
|
||||
datetime.date(year=2016, month=2, day=2),
|
||||
datetime.date(year=2016, month=2, day=3),
|
||||
]
|
||||
cls.obj = RangesModel.objects.create(
|
||||
dates=(cls.dates[0], cls.dates[3]),
|
||||
timestamps=(cls.timestamps[0], cls.timestamps[3]),
|
||||
)
|
||||
cls.aware_obj = RangesModel.objects.create(
|
||||
dates=(cls.dates[0], cls.dates[3]),
|
||||
timestamps=(cls.aware_timestamps[0], cls.aware_timestamps[3]),
|
||||
)
|
||||
# Objects that don't match any queries.
|
||||
for i in range(3, 4):
|
||||
RangesModel.objects.create(
|
||||
dates=(cls.dates[i], cls.dates[i + 1]),
|
||||
timestamps=(cls.timestamps[i], cls.timestamps[i + 1]),
|
||||
)
|
||||
RangesModel.objects.create(
|
||||
dates=(cls.dates[i], cls.dates[i + 1]),
|
||||
timestamps=(cls.aware_timestamps[i], cls.aware_timestamps[i + 1]),
|
||||
)
|
||||
|
||||
def test_datetime_range_contains(self):
|
||||
filter_args = (
|
||||
self.timestamps[1],
|
||||
self.aware_timestamps[1],
|
||||
(self.timestamps[1], self.timestamps[2]),
|
||||
(self.aware_timestamps[1], self.aware_timestamps[2]),
|
||||
Value(self.dates[0], output_field=DateTimeField()),
|
||||
Func(F('dates'), function='lower', output_field=DateTimeField()),
|
||||
)
|
||||
for filter_arg in filter_args:
|
||||
with self.subTest(filter_arg=filter_arg):
|
||||
self.assertCountEqual(
|
||||
RangesModel.objects.filter(**{'timestamps__contains': filter_arg}),
|
||||
[self.obj, self.aware_obj],
|
||||
)
|
||||
|
||||
def test_date_range_contains(self):
|
||||
filter_args = (
|
||||
self.timestamps[1],
|
||||
(self.dates[1], self.dates[2]),
|
||||
Value(self.dates[0], output_field=DateField()),
|
||||
Func(F('timestamps'), function='lower', output_field=DateField()),
|
||||
)
|
||||
for filter_arg in filter_args:
|
||||
with self.subTest(filter_arg=filter_arg):
|
||||
self.assertCountEqual(
|
||||
RangesModel.objects.filter(**{'dates__contains': filter_arg}),
|
||||
[self.obj, self.aware_obj],
|
||||
)
|
||||
|
||||
|
||||
class TestQuerying(PostgreSQLTestCase):
|
||||
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue