mirror of https://github.com/django/django.git
Fixed #26903 -- Fixed __contains lookup for Date/DateTimeRangeField.
Thanks Mariusz Felisiak and Tim Graham for polishing the patch.
This commit is contained in:
parent
0034e9af18
commit
6b048b364c
|
@ -1,3 +1,4 @@
|
||||||
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
|
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
|
||||||
|
@ -131,6 +132,37 @@ RangeField.register_lookup(lookups.ContainedBy)
|
||||||
RangeField.register_lookup(lookups.Overlap)
|
RangeField.register_lookup(lookups.Overlap)
|
||||||
|
|
||||||
|
|
||||||
|
class DateTimeRangeContains(models.Lookup):
|
||||||
|
"""
|
||||||
|
Lookup for Date/DateTimeRange containment to cast the rhs to the correct
|
||||||
|
type.
|
||||||
|
"""
|
||||||
|
lookup_name = 'contains'
|
||||||
|
|
||||||
|
def process_rhs(self, compiler, connection):
|
||||||
|
# Transform rhs value for db lookup.
|
||||||
|
if isinstance(self.rhs, datetime.date):
|
||||||
|
output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
|
||||||
|
value = models.Value(self.rhs, output_field=output_field)
|
||||||
|
self.rhs = value.resolve_expression(compiler.query)
|
||||||
|
return super().process_rhs(compiler, connection)
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||||
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||||
|
params = lhs_params + rhs_params
|
||||||
|
# Cast the rhs if needed.
|
||||||
|
cast_sql = ''
|
||||||
|
if isinstance(self.rhs, models.Expression) and self.rhs._output_field_or_none:
|
||||||
|
cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
|
||||||
|
cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
|
||||||
|
return '%s @> %s%s' % (lhs, rhs, cast_sql), params
|
||||||
|
|
||||||
|
|
||||||
|
DateRangeField.register_lookup(DateTimeRangeContains)
|
||||||
|
DateTimeRangeField.register_lookup(DateTimeRangeContains)
|
||||||
|
|
||||||
|
|
||||||
class RangeContainedBy(models.Lookup):
|
class RangeContainedBy(models.Lookup):
|
||||||
lookup_name = 'contained_by'
|
lookup_name = 'contained_by'
|
||||||
type_mapping = {
|
type_mapping = {
|
||||||
|
|
|
@ -3,7 +3,7 @@ import json
|
||||||
|
|
||||||
from django import forms
|
from django import forms
|
||||||
from django.core import exceptions, serializers
|
from django.core import exceptions, serializers
|
||||||
from django.db.models import F
|
from django.db.models import DateField, DateTimeField, F, Func, Value
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
|
@ -87,6 +87,80 @@ class TestSaveLoad(PostgreSQLTestCase):
|
||||||
self.assertEqual(field.base_field.model, RangesModel)
|
self.assertEqual(field.base_field.model, RangesModel)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRangeContainsLookup(PostgreSQLTestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpTestData(cls):
|
||||||
|
cls.timestamps = [
|
||||||
|
datetime.datetime(year=2016, month=1, day=1),
|
||||||
|
datetime.datetime(year=2016, month=1, day=2, hour=1),
|
||||||
|
datetime.datetime(year=2016, month=1, day=2, hour=12),
|
||||||
|
datetime.datetime(year=2016, month=1, day=3),
|
||||||
|
datetime.datetime(year=2016, month=1, day=3, hour=1),
|
||||||
|
datetime.datetime(year=2016, month=2, day=2),
|
||||||
|
]
|
||||||
|
cls.aware_timestamps = [
|
||||||
|
timezone.make_aware(timestamp, timezone.get_current_timezone())
|
||||||
|
for timestamp in cls.timestamps
|
||||||
|
]
|
||||||
|
cls.dates = [
|
||||||
|
datetime.date(year=2016, month=1, day=1),
|
||||||
|
datetime.date(year=2016, month=1, day=2),
|
||||||
|
datetime.date(year=2016, month=1, day=3),
|
||||||
|
datetime.date(year=2016, month=1, day=4),
|
||||||
|
datetime.date(year=2016, month=2, day=2),
|
||||||
|
datetime.date(year=2016, month=2, day=3),
|
||||||
|
]
|
||||||
|
cls.obj = RangesModel.objects.create(
|
||||||
|
dates=(cls.dates[0], cls.dates[3]),
|
||||||
|
timestamps=(cls.timestamps[0], cls.timestamps[3]),
|
||||||
|
)
|
||||||
|
cls.aware_obj = RangesModel.objects.create(
|
||||||
|
dates=(cls.dates[0], cls.dates[3]),
|
||||||
|
timestamps=(cls.aware_timestamps[0], cls.aware_timestamps[3]),
|
||||||
|
)
|
||||||
|
# Objects that don't match any queries.
|
||||||
|
for i in range(3, 4):
|
||||||
|
RangesModel.objects.create(
|
||||||
|
dates=(cls.dates[i], cls.dates[i + 1]),
|
||||||
|
timestamps=(cls.timestamps[i], cls.timestamps[i + 1]),
|
||||||
|
)
|
||||||
|
RangesModel.objects.create(
|
||||||
|
dates=(cls.dates[i], cls.dates[i + 1]),
|
||||||
|
timestamps=(cls.aware_timestamps[i], cls.aware_timestamps[i + 1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_datetime_range_contains(self):
|
||||||
|
filter_args = (
|
||||||
|
self.timestamps[1],
|
||||||
|
self.aware_timestamps[1],
|
||||||
|
(self.timestamps[1], self.timestamps[2]),
|
||||||
|
(self.aware_timestamps[1], self.aware_timestamps[2]),
|
||||||
|
Value(self.dates[0], output_field=DateTimeField()),
|
||||||
|
Func(F('dates'), function='lower', output_field=DateTimeField()),
|
||||||
|
)
|
||||||
|
for filter_arg in filter_args:
|
||||||
|
with self.subTest(filter_arg=filter_arg):
|
||||||
|
self.assertCountEqual(
|
||||||
|
RangesModel.objects.filter(**{'timestamps__contains': filter_arg}),
|
||||||
|
[self.obj, self.aware_obj],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_date_range_contains(self):
|
||||||
|
filter_args = (
|
||||||
|
self.timestamps[1],
|
||||||
|
(self.dates[1], self.dates[2]),
|
||||||
|
Value(self.dates[0], output_field=DateField()),
|
||||||
|
Func(F('timestamps'), function='lower', output_field=DateField()),
|
||||||
|
)
|
||||||
|
for filter_arg in filter_args:
|
||||||
|
with self.subTest(filter_arg=filter_arg):
|
||||||
|
self.assertCountEqual(
|
||||||
|
RangesModel.objects.filter(**{'dates__contains': filter_arg}),
|
||||||
|
[self.obj, self.aware_obj],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestQuerying(PostgreSQLTestCase):
|
class TestQuerying(PostgreSQLTestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue