Fixed #27147 -- Allowed specifying bounds of tuple inputs for non-discrete range fields.

This commit is contained in:
Guilherme Martins Crocetti 2021-06-17 18:13:49 -03:00 committed by Mariusz Felisiak
parent 52f6927d7f
commit fc565cb539
8 changed files with 181 additions and 13 deletions

View File

@ -44,6 +44,10 @@ class RangeField(models.Field):
empty_strings_allowed = False
def __init__(self, *args, **kwargs):
if 'default_bounds' in kwargs:
raise TypeError(
f"Cannot use 'default_bounds' with {self.__class__.__name__}."
)
# Initializing base_field here ensures that its model matches the model for self.
if hasattr(self, 'base_field'):
self.base_field = self.base_field()
@ -112,6 +116,37 @@ class RangeField(models.Field):
return super().formfield(**kwargs)
CANONICAL_RANGE_BOUNDS = '[)'
class ContinuousRangeField(RangeField):
"""
Continuous range field. It allows specifying default bounds for list and
tuple inputs.
"""
def __init__(self, *args, default_bounds=CANONICAL_RANGE_BOUNDS, **kwargs):
if default_bounds not in ('[)', '(]', '()', '[]'):
raise ValueError("default_bounds must be one of '[)', '(]', '()', or '[]'.")
self.default_bounds = default_bounds
super().__init__(*args, **kwargs)
def get_prep_value(self, value):
if isinstance(value, (list, tuple)):
return self.range_type(value[0], value[1], self.default_bounds)
return super().get_prep_value(value)
def formfield(self, **kwargs):
kwargs.setdefault('default_bounds', self.default_bounds)
return super().formfield(**kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.default_bounds and self.default_bounds != CANONICAL_RANGE_BOUNDS:
kwargs['default_bounds'] = self.default_bounds
return name, path, args, kwargs
class IntegerRangeField(RangeField):
base_field = models.IntegerField
range_type = NumericRange
@ -130,7 +165,7 @@ class BigIntegerRangeField(RangeField):
return 'int8range'
class DecimalRangeField(RangeField):
class DecimalRangeField(ContinuousRangeField):
base_field = models.DecimalField
range_type = NumericRange
form_field = forms.DecimalRangeField
@ -139,7 +174,7 @@ class DecimalRangeField(RangeField):
return 'numrange'
class DateTimeRangeField(RangeField):
class DateTimeRangeField(ContinuousRangeField):
base_field = models.DateTimeField
range_type = DateTimeTZRange
form_field = forms.DateTimeRangeField

View File

@ -42,6 +42,9 @@ class BaseRangeField(forms.MultiValueField):
kwargs['fields'] = [self.base_field(required=False), self.base_field(required=False)]
kwargs.setdefault('required', False)
kwargs.setdefault('require_all_fields', False)
self.range_kwargs = {}
if default_bounds := kwargs.pop('default_bounds', None):
self.range_kwargs = {'bounds': default_bounds}
super().__init__(**kwargs)
def prepare_value(self, value):
@ -68,7 +71,7 @@ class BaseRangeField(forms.MultiValueField):
code='bound_ordering',
)
try:
range_value = self.range_type(lower, upper)
range_value = self.range_type(lower, upper, **self.range_kwargs)
except TypeError:
raise exceptions.ValidationError(
self.error_messages['invalid'],

View File

@ -503,9 +503,9 @@ All of the range fields translate to :ref:`psycopg2 Range objects
<psycopg2:adapt-range>` in Python, but also accept tuples as input if no bounds
information is necessary. The default is lower bound included, upper bound
excluded, that is ``[)`` (see the PostgreSQL documentation for details about
`different bounds`_).
.. _different bounds: https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-IO
`different bounds`_). The default bounds can be changed for non-discrete range
fields (:class:`.DateTimeRangeField` and :class:`.DecimalRangeField`) by using
the ``default_bounds`` argument.
``IntegerRangeField``
---------------------
@ -538,23 +538,43 @@ excluded, that is ``[)`` (see the PostgreSQL documentation for details about
``DecimalRangeField``
---------------------
.. class:: DecimalRangeField(**options)
.. class:: DecimalRangeField(default_bounds='[)', **options)
Stores a range of floating point values. Based on a
:class:`~django.db.models.DecimalField`. Represented by a ``numrange`` in
the database and a :class:`~psycopg2:psycopg2.extras.NumericRange` in
Python.
.. attribute:: DecimalRangeField.default_bounds
.. versionadded:: 4.1
Optional. The value of ``bounds`` for list and tuple inputs. The
default is lower bound included, upper bound excluded, that is ``[)``
(see the PostgreSQL documentation for details about
`different bounds`_). ``default_bounds`` is not used for
:class:`~psycopg2:psycopg2.extras.NumericRange` inputs.
``DateTimeRangeField``
----------------------
.. class:: DateTimeRangeField(**options)
.. class:: DateTimeRangeField(default_bounds='[)', **options)
Stores a range of timestamps. Based on a
:class:`~django.db.models.DateTimeField`. Represented by a ``tstzrange`` in
the database and a :class:`~psycopg2:psycopg2.extras.DateTimeTZRange` in
Python.
.. attribute:: DateTimeRangeField.default_bounds
.. versionadded:: 4.1
Optional. The value of ``bounds`` for list and tuple inputs. The
default is lower bound included, upper bound excluded, that is ``[)``
(see the PostgreSQL documentation for details about
`different bounds`_). ``default_bounds`` is not used for
:class:`~psycopg2:psycopg2.extras.DateTimeTZRange` inputs.
``DateRangeField``
------------------
@ -884,3 +904,5 @@ used with a custom range functions that expected boundaries, for example to
define :class:`~django.contrib.postgres.constraints.ExclusionConstraint`. See
`the PostgreSQL documentation for the full details <https://www.postgresql.org/
docs/current/rangetypes.html#RANGETYPES-INCLUSIVITY>`_.
.. _different bounds: https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-IO

View File

@ -76,6 +76,12 @@ Minor features
supports covering exclusion constraints using SP-GiST indexes on PostgreSQL
14+.
* The new ``default_bounds`` attribute of :attr:`DateTimeRangeField
<django.contrib.postgres.fields.DateTimeRangeField.default_bounds>` and
:attr:`DecimalRangeField
<django.contrib.postgres.fields.DecimalRangeField.default_bounds>` allows
specifying bounds for list and tuple inputs.
:mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -26,14 +26,23 @@ except ImportError:
})
return name, path, args, kwargs
class DummyContinuousRangeField(models.Field):
def __init__(self, *args, default_bounds='[)', **kwargs):
super().__init__(**kwargs)
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
kwargs['default_bounds'] = '[)'
return name, path, args, kwargs
ArrayField = DummyArrayField
BigIntegerRangeField = models.Field
CICharField = models.Field
CIEmailField = models.Field
CITextField = models.Field
DateRangeField = models.Field
DateTimeRangeField = models.Field
DecimalRangeField = models.Field
DateTimeRangeField = DummyContinuousRangeField
DecimalRangeField = DummyContinuousRangeField
HStoreField = models.Field
IntegerRangeField = models.Field
SearchVector = models.Expression

View File

@ -249,6 +249,7 @@ class Migration(migrations.Migration):
('decimals', DecimalRangeField(null=True, blank=True)),
('timestamps', DateTimeRangeField(null=True, blank=True)),
('timestamps_inner', DateTimeRangeField(null=True, blank=True)),
('timestamps_closed_bounds', DateTimeRangeField(null=True, blank=True, default_bounds='[]')),
('dates', DateRangeField(null=True, blank=True)),
('dates_inner', DateRangeField(null=True, blank=True)),
],

View File

@ -135,6 +135,9 @@ class RangesModel(PostgreSQLModel):
decimals = DecimalRangeField(blank=True, null=True)
timestamps = DateTimeRangeField(blank=True, null=True)
timestamps_inner = DateTimeRangeField(blank=True, null=True)
timestamps_closed_bounds = DateTimeRangeField(
blank=True, null=True, default_bounds='[]',
)
dates = DateRangeField(blank=True, null=True)
dates_inner = DateRangeField(blank=True, null=True)

View File

@ -50,6 +50,41 @@ class BasicTests(PostgreSQLSimpleTestCase):
instance = Model(field=value)
self.assertEqual(instance.get_field_display(), display)
def test_discrete_range_fields_unsupported_default_bounds(self):
discrete_range_types = [
pg_fields.BigIntegerRangeField,
pg_fields.IntegerRangeField,
pg_fields.DateRangeField,
]
for field_type in discrete_range_types:
msg = f"Cannot use 'default_bounds' with {field_type.__name__}."
with self.assertRaisesMessage(TypeError, msg):
field_type(choices=[((51, 100), '51-100')], default_bounds='[]')
def test_continuous_range_fields_default_bounds(self):
continuous_range_types = [
pg_fields.DecimalRangeField,
pg_fields.DateTimeRangeField,
]
for field_type in continuous_range_types:
field = field_type(choices=[((51, 100), '51-100')], default_bounds='[]')
self.assertEqual(field.default_bounds, '[]')
def test_invalid_default_bounds(self):
tests = [')]', ')[', '](', '])', '([', '[(', 'x', '', None]
msg = "default_bounds must be one of '[)', '(]', '()', or '[]'."
for invalid_bounds in tests:
with self.assertRaisesMessage(ValueError, msg):
pg_fields.DecimalRangeField(default_bounds=invalid_bounds)
def test_deconstruct(self):
field = pg_fields.DecimalRangeField()
*_, kwargs = field.deconstruct()
self.assertEqual(kwargs, {})
field = pg_fields.DecimalRangeField(default_bounds='[]')
*_, kwargs = field.deconstruct()
self.assertEqual(kwargs, {'default_bounds': '[]'})
class TestSaveLoad(PostgreSQLTestCase):
@ -83,6 +118,19 @@ class TestSaveLoad(PostgreSQLTestCase):
loaded = RangesModel.objects.get()
self.assertEqual(NumericRange(0, 10), loaded.ints)
def test_tuple_range_with_default_bounds(self):
range_ = (timezone.now(), timezone.now() + datetime.timedelta(hours=1))
RangesModel.objects.create(timestamps_closed_bounds=range_, timestamps=range_)
loaded = RangesModel.objects.get()
self.assertEqual(
loaded.timestamps_closed_bounds,
DateTimeTZRange(range_[0], range_[1], '[]'),
)
self.assertEqual(
loaded.timestamps,
DateTimeTZRange(range_[0], range_[1], '[)'),
)
def test_range_object_boundaries(self):
r = NumericRange(0, 10, '[]')
instance = RangesModel(decimals=r)
@ -91,6 +139,16 @@ class TestSaveLoad(PostgreSQLTestCase):
self.assertEqual(r, loaded.decimals)
self.assertIn(10, loaded.decimals)
def test_range_object_boundaries_range_with_default_bounds(self):
range_ = DateTimeTZRange(
timezone.now(),
timezone.now() + datetime.timedelta(hours=1),
bounds='()',
)
RangesModel.objects.create(timestamps_closed_bounds=range_)
loaded = RangesModel.objects.get()
self.assertEqual(loaded.timestamps_closed_bounds, range_)
def test_unbounded(self):
r = NumericRange(None, None, '()')
instance = RangesModel(decimals=r)
@ -478,6 +536,8 @@ class TestSerialization(PostgreSQLSimpleTestCase):
'"bigints": null, "timestamps": "{\\"upper\\": \\"2014-02-02T12:12:12+00:00\\", '
'\\"lower\\": \\"2014-01-01T00:00:00+00:00\\", \\"bounds\\": \\"[)\\"}", '
'"timestamps_inner": null, '
'"timestamps_closed_bounds": "{\\"upper\\": \\"2014-02-02T12:12:12+00:00\\", '
'\\"lower\\": \\"2014-01-01T00:00:00+00:00\\", \\"bounds\\": \\"()\\"}", '
'"dates": "{\\"upper\\": \\"2014-02-02\\", \\"lower\\": \\"2014-01-01\\", \\"bounds\\": \\"[)\\"}", '
'"dates_inner": null }, '
'"model": "postgres_tests.rangesmodel", "pk": null}]'
@ -492,15 +552,19 @@ class TestSerialization(PostgreSQLSimpleTestCase):
instance = RangesModel(
ints=NumericRange(0, 10), decimals=NumericRange(empty=True),
timestamps=DateTimeTZRange(self.lower_dt, self.upper_dt),
timestamps_closed_bounds=DateTimeTZRange(
self.lower_dt, self.upper_dt, bounds='()',
),
dates=DateRange(self.lower_date, self.upper_date),
)
data = serializers.serialize('json', [instance])
dumped = json.loads(data)
for field in ('ints', 'dates', 'timestamps'):
for field in ('ints', 'dates', 'timestamps', 'timestamps_closed_bounds'):
dumped[0]['fields'][field] = json.loads(dumped[0]['fields'][field])
check = json.loads(self.test_data)
for field in ('ints', 'dates', 'timestamps'):
for field in ('ints', 'dates', 'timestamps', 'timestamps_closed_bounds'):
check[0]['fields'][field] = json.loads(check[0]['fields'][field])
self.assertEqual(dumped, check)
def test_loading(self):
@ -510,6 +574,10 @@ class TestSerialization(PostgreSQLSimpleTestCase):
self.assertIsNone(instance.bigints)
self.assertEqual(instance.dates, DateRange(self.lower_date, self.upper_date))
self.assertEqual(instance.timestamps, DateTimeTZRange(self.lower_dt, self.upper_dt))
self.assertEqual(
instance.timestamps_closed_bounds,
DateTimeTZRange(self.lower_dt, self.upper_dt, bounds='()'),
)
def test_serialize_range_with_null(self):
instance = RangesModel(ints=NumericRange(None, 10))
@ -886,26 +954,47 @@ class TestFormField(PostgreSQLSimpleTestCase):
model_field = pg_fields.IntegerRangeField()
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
self.assertEqual(form_field.range_kwargs, {})
def test_model_field_formfield_biginteger(self):
model_field = pg_fields.BigIntegerRangeField()
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
self.assertEqual(form_field.range_kwargs, {})
def test_model_field_formfield_float(self):
model_field = pg_fields.DecimalRangeField()
model_field = pg_fields.DecimalRangeField(default_bounds='()')
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DecimalRangeField)
self.assertEqual(form_field.range_kwargs, {'bounds': '()'})
def test_model_field_formfield_date(self):
model_field = pg_fields.DateRangeField()
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DateRangeField)
self.assertEqual(form_field.range_kwargs, {})
def test_model_field_formfield_datetime(self):
model_field = pg_fields.DateTimeRangeField()
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)
self.assertEqual(
form_field.range_kwargs,
{'bounds': pg_fields.ranges.CANONICAL_RANGE_BOUNDS},
)
def test_model_field_formfield_datetime_default_bounds(self):
model_field = pg_fields.DateTimeRangeField(default_bounds='[]')
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)
self.assertEqual(form_field.range_kwargs, {'bounds': '[]'})
def test_model_field_with_default_bounds(self):
field = pg_forms.DateTimeRangeField(default_bounds='[]')
value = field.clean(['2014-01-01 00:00:00', '2014-02-03 12:13:14'])
lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
upper = datetime.datetime(2014, 2, 3, 12, 13, 14)
self.assertEqual(value, DateTimeTZRange(lower, upper, '[]'))
def test_has_changed(self):
for field, value in (