diff --git a/django/contrib/postgres/fields/ranges.py b/django/contrib/postgres/fields/ranges.py index 41acc8dcb2..0bb914d383 100644 --- a/django/contrib/postgres/fields/ranges.py +++ b/django/contrib/postgres/fields/ranges.py @@ -1,3 +1,4 @@ +import datetime import json from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range @@ -131,6 +132,37 @@ RangeField.register_lookup(lookups.ContainedBy) RangeField.register_lookup(lookups.Overlap) +class DateTimeRangeContains(models.Lookup): + """ + Lookup for Date/DateTimeRange containment to cast the rhs to the correct + type. + """ + lookup_name = 'contains' + + def process_rhs(self, compiler, connection): + # Transform rhs value for db lookup. + if isinstance(self.rhs, datetime.date): + output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField() + value = models.Value(self.rhs, output_field=output_field) + self.rhs = value.resolve_expression(compiler.query) + return super().process_rhs(compiler, connection) + + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) + params = lhs_params + rhs_params + # Cast the rhs if needed. + cast_sql = '' + if isinstance(self.rhs, models.Expression) and self.rhs._output_field_or_none: + cast_internal_type = self.lhs.output_field.base_field.get_internal_type() + cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type)) + return '%s @> %s%s' % (lhs, rhs, cast_sql), params + + +DateRangeField.register_lookup(DateTimeRangeContains) +DateTimeRangeField.register_lookup(DateTimeRangeContains) + + class RangeContainedBy(models.Lookup): lookup_name = 'contained_by' type_mapping = { diff --git a/tests/postgres_tests/test_ranges.py b/tests/postgres_tests/test_ranges.py index 3dd60dea0d..d87ad36438 100644 --- a/tests/postgres_tests/test_ranges.py +++ b/tests/postgres_tests/test_ranges.py @@ -3,7 +3,7 @@ import json from django import forms from django.core import exceptions, serializers -from django.db.models import F +from django.db.models import DateField, DateTimeField, F, Func, Value from django.test import override_settings from django.utils import timezone @@ -87,6 +87,80 @@ class TestSaveLoad(PostgreSQLTestCase): self.assertEqual(field.base_field.model, RangesModel) +class TestRangeContainsLookup(PostgreSQLTestCase): + + @classmethod + def setUpTestData(cls): + cls.timestamps = [ + datetime.datetime(year=2016, month=1, day=1), + datetime.datetime(year=2016, month=1, day=2, hour=1), + datetime.datetime(year=2016, month=1, day=2, hour=12), + datetime.datetime(year=2016, month=1, day=3), + datetime.datetime(year=2016, month=1, day=3, hour=1), + datetime.datetime(year=2016, month=2, day=2), + ] + cls.aware_timestamps = [ + timezone.make_aware(timestamp, timezone.get_current_timezone()) + for timestamp in cls.timestamps + ] + cls.dates = [ + datetime.date(year=2016, month=1, day=1), + datetime.date(year=2016, month=1, day=2), + datetime.date(year=2016, month=1, day=3), + datetime.date(year=2016, month=1, day=4), + datetime.date(year=2016, month=2, day=2), + datetime.date(year=2016, month=2, day=3), + ] + cls.obj = RangesModel.objects.create( + dates=(cls.dates[0], cls.dates[3]), + timestamps=(cls.timestamps[0], cls.timestamps[3]), + ) + cls.aware_obj = RangesModel.objects.create( + dates=(cls.dates[0], cls.dates[3]), + timestamps=(cls.aware_timestamps[0], cls.aware_timestamps[3]), + ) + # Objects that don't match any queries. + for i in range(3, 4): + RangesModel.objects.create( + dates=(cls.dates[i], cls.dates[i + 1]), + timestamps=(cls.timestamps[i], cls.timestamps[i + 1]), + ) + RangesModel.objects.create( + dates=(cls.dates[i], cls.dates[i + 1]), + timestamps=(cls.aware_timestamps[i], cls.aware_timestamps[i + 1]), + ) + + def test_datetime_range_contains(self): + filter_args = ( + self.timestamps[1], + self.aware_timestamps[1], + (self.timestamps[1], self.timestamps[2]), + (self.aware_timestamps[1], self.aware_timestamps[2]), + Value(self.dates[0], output_field=DateTimeField()), + Func(F('dates'), function='lower', output_field=DateTimeField()), + ) + for filter_arg in filter_args: + with self.subTest(filter_arg=filter_arg): + self.assertCountEqual( + RangesModel.objects.filter(**{'timestamps__contains': filter_arg}), + [self.obj, self.aware_obj], + ) + + def test_date_range_contains(self): + filter_args = ( + self.timestamps[1], + (self.dates[1], self.dates[2]), + Value(self.dates[0], output_field=DateField()), + Func(F('timestamps'), function='lower', output_field=DateField()), + ) + for filter_arg in filter_args: + with self.subTest(filter_arg=filter_arg): + self.assertCountEqual( + RangesModel.objects.filter(**{'dates__contains': filter_arg}), + [self.obj, self.aware_obj], + ) + + class TestQuerying(PostgreSQLTestCase): @classmethod