Fixed #24837 -- field__contained_by=Range

Provide `contained_by` lookups for the equivalent single valued fields
related to the range field types. This acts as the opposite direction to
rangefield__contains.

With thanks to schinckel for the idea and initial tests.
This commit is contained in:
Marc Tamlyn 2015-05-21 20:55:50 +09:30
parent 5987b3c46d
commit 7bda2d8ebc
6 changed files with 175 additions and 3 deletions

View File

@ -98,6 +98,38 @@ RangeField.register_lookup(lookups.ContainedBy)
RangeField.register_lookup(lookups.Overlap) RangeField.register_lookup(lookups.Overlap)
class RangeContainedBy(models.Lookup):
lookup_name = 'contained_by'
type_mapping = {
'integer': 'int4range',
'bigint': 'int8range',
'double precision': 'numrange',
'date': 'daterange',
'timestamp with time zone': 'tstzrange',
}
def as_sql(self, qn, connection):
field = self.lhs.output_field
if isinstance(field, models.FloatField):
sql = '%s::numeric <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
else:
sql = '%s <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params
return sql % (lhs, rhs), params
def get_prep_lookup(self):
return RangeField().get_prep_lookup(self.lookup_name, self.rhs)
models.DateField.register_lookup(RangeContainedBy)
models.DateTimeField.register_lookup(RangeContainedBy)
models.IntegerField.register_lookup(RangeContainedBy)
models.BigIntegerField.register_lookup(RangeContainedBy)
models.FloatField.register_lookup(RangeContainedBy)
@RangeField.register_lookup @RangeField.register_lookup
class FullyLessThan(lookups.PostgresSimpleLookup): class FullyLessThan(lookups.PostgresSimpleLookup):
lookup_name = 'fully_lt' lookup_name = 'fully_lt'

View File

@ -631,14 +631,18 @@ model::
class Event(models.Model): class Event(models.Model):
name = models.CharField(max_length=200) name = models.CharField(max_length=200)
ages = IntegerRangeField() ages = IntegerRangeField()
start = models.DateTimeField()
def __str__(self): # __unicode__ on Python 2 def __str__(self): # __unicode__ on Python 2
return self.name return self.name
We will also use the following example objects:: We will also use the following example objects::
>>> Event.objects.create(name='Soft play', ages=(0, 10)) >>> import datetime
>>> Event.objects.create(name='Pub trip', ages=(21, None)) >>> from django.utils import timezone
>>> now = timezone.now()
>>> Event.objects.create(name='Soft play', ages=(0, 10), start=now)
>>> Event.objects.create(name='Pub trip', ages=(21, None), start=now - datetime.timedelta(days=1))
and ``NumericRange``: and ``NumericRange``:
@ -667,6 +671,22 @@ contained_by
>>> Event.objects.filter(ages__contained_by=NumericRange(0, 15)) >>> Event.objects.filter(ages__contained_by=NumericRange(0, 15))
[<Event: Soft play>] [<Event: Soft play>]
.. versionadded 1.9
The `contained_by` lookup is also available on the non-range field types:
:class:`~django.db.models.fields.IntegerField`,
:class:`~django.db.models.fields.BigIntegerField`,
:class:`~django.db.models.fields.FloatField`,
:class:`~django.db.models.fields.DateField`, and
:class:`~django.db.models.fields.DateTimeField`. For example::
>>> from psycopg2.extras import DateTimeTZRange
>>> Event.objects.filter(start__contained_by=DateTimeTZRange(
... timezone.now() - datetime.timedelta(hours=1),
... timezone.now() + datetime.timedelta(hours=1),
... )
[<Event: Soft play>]
.. fieldlookup:: rangefield.overlap .. fieldlookup:: rangefield.overlap
overlap overlap

View File

@ -91,6 +91,8 @@ Minor features
:mod:`django.contrib.postgres` :mod:`django.contrib.postgres`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
* Added support for the :lookup:`rangefield.contained_by` lookup for some built
in fields which correspond to the range fields.
* Added :class:`~django.contrib.postgres.fields.JSONField`. * Added :class:`~django.contrib.postgres.fields.JSONField`.
* Added :doc:`/ref/contrib/postgres/aggregates`. * Added :doc:`/ref/contrib/postgres/aggregates`.

View File

@ -143,6 +143,21 @@ class Migration(migrations.Migration):
('timestamps', DateTimeRangeField(null=True, blank=True)), ('timestamps', DateTimeRangeField(null=True, blank=True)),
('dates', DateRangeField(null=True, blank=True)), ('dates', DateRangeField(null=True, blank=True)),
], ],
options={
'required_db_vendor': 'postgresql'
},
bases=(models.Model,)
),
migrations.CreateModel(
name='RangeLookupsModel',
fields=[
('parent', models.ForeignKey('postgres_tests.RangesModel', blank=True, null=True)),
('integer', models.IntegerField(blank=True, null=True)),
('big_integer', models.BigIntegerField(blank=True, null=True)),
('float', models.FloatField(blank=True, null=True)),
('timestamp', models.DateTimeField(blank=True, null=True)),
('date', models.DateField(blank=True, null=True)),
],
options={ options={
'required_db_vendor': 'postgresql', 'required_db_vendor': 'postgresql',
}, },

View File

@ -60,11 +60,23 @@ if connection.vendor == 'postgresql' and connection.pg_version >= 90200:
floats = FloatRangeField(blank=True, null=True) floats = FloatRangeField(blank=True, null=True)
timestamps = DateTimeRangeField(blank=True, null=True) timestamps = DateTimeRangeField(blank=True, null=True)
dates = DateRangeField(blank=True, null=True) dates = DateRangeField(blank=True, null=True)
class RangeLookupsModel(PostgreSQLModel):
parent = models.ForeignKey(RangesModel, blank=True, null=True)
integer = models.IntegerField(blank=True, null=True)
big_integer = models.BigIntegerField(blank=True, null=True)
float = models.FloatField(blank=True, null=True)
timestamp = models.DateTimeField(blank=True, null=True)
date = models.DateField(blank=True, null=True)
else: else:
# create an object with this name so we don't have failing imports # create an object with this name so we don't have failing imports
class RangesModel(object): class RangesModel(object):
pass pass
class RangeLookupsModel(object):
pass
# Only create this model for postgres >= 9.4 # Only create this model for postgres >= 9.4
if connection.vendor == 'postgresql' and connection.pg_version >= 90400: if connection.vendor == 'postgresql' and connection.pg_version >= 90400:

View File

@ -5,11 +5,12 @@ import unittest
from django import forms from django import forms
from django.core import exceptions, serializers from django.core import exceptions, serializers
from django.db import connection from django.db import connection
from django.db.models import F
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.utils import timezone from django.utils import timezone
from . import PostgreSQLTestCase from . import PostgreSQLTestCase
from .models import RangesModel from .models import RangeLookupsModel, RangesModel
try: try:
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange
@ -197,6 +198,96 @@ class TestQuerying(TestCase):
) )
@skipUnlessPG92
class TestQueringWithRanges(TestCase):
def test_date_range(self):
objs = [
RangeLookupsModel.objects.create(date='2015-01-01'),
RangeLookupsModel.objects.create(date='2015-05-05'),
]
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(date__contained_by=DateRange('2015-01-01', '2015-05-04')),
[objs[0]],
)
def test_date_range_datetime_field(self):
objs = [
RangeLookupsModel.objects.create(timestamp='2015-01-01'),
RangeLookupsModel.objects.create(timestamp='2015-05-05'),
]
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(timestamp__date__contained_by=DateRange('2015-01-01', '2015-05-04')),
[objs[0]],
)
def test_datetime_range(self):
objs = [
RangeLookupsModel.objects.create(timestamp='2015-01-01T09:00:00'),
RangeLookupsModel.objects.create(timestamp='2015-05-05T17:00:00'),
]
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(
timestamp__contained_by=DateTimeTZRange('2015-01-01T09:00', '2015-05-04T23:55')
),
[objs[0]],
)
def test_integer_range(self):
objs = [
RangeLookupsModel.objects.create(integer=5),
RangeLookupsModel.objects.create(integer=99),
RangeLookupsModel.objects.create(integer=-1),
]
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(integer__contained_by=NumericRange(1, 98)),
[objs[0]]
)
def test_biginteger_range(self):
objs = [
RangeLookupsModel.objects.create(big_integer=5),
RangeLookupsModel.objects.create(big_integer=99),
RangeLookupsModel.objects.create(big_integer=-1),
]
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(big_integer__contained_by=NumericRange(1, 98)),
[objs[0]]
)
def test_float_range(self):
objs = [
RangeLookupsModel.objects.create(float=5),
RangeLookupsModel.objects.create(float=99),
RangeLookupsModel.objects.create(float=-1),
]
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(float__contained_by=NumericRange(1, 98)),
[objs[0]]
)
def test_f_ranges(self):
parent = RangesModel.objects.create(floats=NumericRange(0, 10))
objs = [
RangeLookupsModel.objects.create(float=5, parent=parent),
RangeLookupsModel.objects.create(float=99, parent=parent),
]
self.assertSequenceEqual(
RangeLookupsModel.objects.filter(float__contained_by=F('parent__floats')),
[objs[0]]
)
def test_exclude(self):
objs = [
RangeLookupsModel.objects.create(float=5),
RangeLookupsModel.objects.create(float=99),
RangeLookupsModel.objects.create(float=-1),
]
self.assertSequenceEqual(
RangeLookupsModel.objects.exclude(float__contained_by=NumericRange(0, 100)),
[objs[2]]
)
@skipUnlessPG92 @skipUnlessPG92
class TestSerialization(TestCase): class TestSerialization(TestCase):
test_data = ( test_data = (