django/tests/postgres_tests/test_ranges.py

377 lines
13 KiB
Python

import datetime
import json
import unittest
from django import forms
from django.contrib.postgres import forms as pg_forms, fields as pg_fields
from django.contrib.postgres.validators import RangeMaxValueValidator, RangeMinValueValidator
from django.core import exceptions, serializers
from django.db import connection
from django.test import TestCase
from django.utils import timezone
from psycopg2.extras import NumericRange, DateTimeTZRange, DateRange
from .models import RangesModel
def skipUnlessPG92(test):
if not connection.vendor == 'postgresql':
return unittest.skip('PostgreSQL required')(test)
PG_VERSION = connection.pg_version
if PG_VERSION < 90200:
return unittest.skip('PostgreSQL >= 9.2 required')(test)
return test
@skipUnlessPG92
class TestSaveLoad(TestCase):
def test_all_fields(self):
now = timezone.now()
instance = RangesModel(
ints=NumericRange(0, 10),
bigints=NumericRange(10, 20),
floats=NumericRange(20, 30),
timestamps=DateTimeTZRange(now - datetime.timedelta(hours=1), now),
dates=DateRange(now.date() - datetime.timedelta(days=1), now.date()),
)
instance.save()
loaded = RangesModel.objects.get()
self.assertEqual(instance.ints, loaded.ints)
self.assertEqual(instance.bigints, loaded.bigints)
self.assertEqual(instance.floats, loaded.floats)
self.assertEqual(instance.timestamps, loaded.timestamps)
self.assertEqual(instance.dates, loaded.dates)
def test_range_object(self):
r = NumericRange(0, 10)
instance = RangesModel(ints=r)
instance.save()
loaded = RangesModel.objects.get()
self.assertEqual(r, loaded.ints)
def test_tuple(self):
instance = RangesModel(ints=(0, 10))
instance.save()
loaded = RangesModel.objects.get()
self.assertEqual(NumericRange(0, 10), loaded.ints)
def test_range_object_boundaries(self):
r = NumericRange(0, 10, '[]')
instance = RangesModel(floats=r)
instance.save()
loaded = RangesModel.objects.get()
self.assertEqual(r, loaded.floats)
self.assertTrue(10 in loaded.floats)
def test_unbounded(self):
r = NumericRange(None, None, '()')
instance = RangesModel(floats=r)
instance.save()
loaded = RangesModel.objects.get()
self.assertEqual(r, loaded.floats)
def test_empty(self):
r = NumericRange(empty=True)
instance = RangesModel(ints=r)
instance.save()
loaded = RangesModel.objects.get()
self.assertEqual(r, loaded.ints)
def test_null(self):
instance = RangesModel(ints=None)
instance.save()
loaded = RangesModel.objects.get()
self.assertEqual(None, loaded.ints)
@skipUnlessPG92
class TestQuerying(TestCase):
@classmethod
def setUpTestData(cls):
cls.objs = [
RangesModel.objects.create(ints=NumericRange(0, 10)),
RangesModel.objects.create(ints=NumericRange(5, 15)),
RangesModel.objects.create(ints=NumericRange(None, 0)),
RangesModel.objects.create(ints=NumericRange(empty=True)),
RangesModel.objects.create(ints=None),
]
def test_exact(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__exact=NumericRange(0, 10)),
[self.objs[0]],
)
def test_isnull(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__isnull=True),
[self.objs[4]],
)
def test_isempty(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__isempty=True),
[self.objs[3]],
)
def test_contains(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__contains=8),
[self.objs[0], self.objs[1]],
)
def test_contains_range(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__contains=NumericRange(3, 8)),
[self.objs[0]],
)
def test_contained_by(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__contained_by=NumericRange(0, 20)),
[self.objs[0], self.objs[1], self.objs[3]],
)
def test_overlap(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__overlap=NumericRange(3, 8)),
[self.objs[0], self.objs[1]],
)
def test_fully_lt(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__fully_lt=NumericRange(5, 10)),
[self.objs[2]],
)
def test_fully_gt(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__fully_gt=NumericRange(5, 10)),
[],
)
def test_not_lt(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__not_lt=NumericRange(5, 10)),
[self.objs[1]],
)
def test_not_gt(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__not_gt=NumericRange(5, 10)),
[self.objs[0], self.objs[2]],
)
def test_adjacent_to(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__adjacent_to=NumericRange(0, 5)),
[self.objs[1], self.objs[2]],
)
def test_startswith(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__startswith=0),
[self.objs[0]],
)
def test_endswith(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__endswith=0),
[self.objs[2]],
)
def test_startswith_chaining(self):
self.assertSequenceEqual(
RangesModel.objects.filter(ints__startswith__gte=0),
[self.objs[0], self.objs[1]],
)
@skipUnlessPG92
class TestSerialization(TestCase):
test_data = (
'[{"fields": {"ints": "{\\"upper\\": 10, \\"lower\\": 0, '
'\\"bounds\\": \\"[)\\"}", "floats": "{\\"empty\\": true}", '
'"bigints": null, "timestamps": null, "dates": null}, '
'"model": "postgres_tests.rangesmodel", "pk": null}]'
)
def test_dumping(self):
instance = RangesModel(ints=NumericRange(0, 10), floats=NumericRange(empty=True))
data = serializers.serialize('json', [instance])
dumped = json.loads(data)
dumped[0]['fields']['ints'] = json.loads(dumped[0]['fields']['ints'])
check = json.loads(self.test_data)
check[0]['fields']['ints'] = json.loads(check[0]['fields']['ints'])
self.assertEqual(dumped, check)
def test_loading(self):
instance = list(serializers.deserialize('json', self.test_data))[0].object
self.assertEqual(instance.ints, NumericRange(0, 10))
self.assertEqual(instance.floats, NumericRange(empty=True))
self.assertEqual(instance.dates, None)
class TestValidators(TestCase):
def test_max(self):
validator = RangeMaxValueValidator(5)
validator(NumericRange(0, 5))
with self.assertRaises(exceptions.ValidationError) as cm:
validator(NumericRange(0, 10))
self.assertEqual(cm.exception.messages[0], 'Ensure that this range is completely less than or equal to 5.')
self.assertEqual(cm.exception.code, 'max_value')
def test_min(self):
validator = RangeMinValueValidator(5)
validator(NumericRange(10, 15))
with self.assertRaises(exceptions.ValidationError) as cm:
validator(NumericRange(0, 10))
self.assertEqual(cm.exception.messages[0], 'Ensure that this range is completely greater than or equal to 5.')
self.assertEqual(cm.exception.code, 'min_value')
class TestFormField(TestCase):
def test_valid_integer(self):
field = pg_forms.IntegerRangeField()
value = field.clean(['1', '2'])
self.assertEqual(value, NumericRange(1, 2))
def test_valid_floats(self):
field = pg_forms.FloatRangeField()
value = field.clean(['1.12345', '2.001'])
self.assertEqual(value, NumericRange(1.12345, 2.001))
def test_valid_timestamps(self):
field = pg_forms.DateTimeRangeField()
value = field.clean(['01/01/2014 00:00:00', '02/02/2014 12:12:12'])
lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
upper = datetime.datetime(2014, 2, 2, 12, 12, 12)
self.assertEqual(value, DateTimeTZRange(lower, upper))
def test_valid_dates(self):
field = pg_forms.DateRangeField()
value = field.clean(['01/01/2014', '02/02/2014'])
lower = datetime.date(2014, 1, 1)
upper = datetime.date(2014, 2, 2)
self.assertEqual(value, DateRange(lower, upper))
def test_using_split_datetime_widget(self):
class SplitDateTimeRangeField(pg_forms.DateTimeRangeField):
base_field = forms.SplitDateTimeField
class SplitForm(forms.Form):
field = SplitDateTimeRangeField()
form = SplitForm()
self.assertHTMLEqual(str(form), '''
<tr>
<th>
<label for="id_field_0">Field:</label>
</th>
<td>
<input id="id_field_0_0" name="field_0_0" type="text" />
<input id="id_field_0_1" name="field_0_1" type="text" />
<input id="id_field_1_0" name="field_1_0" type="text" />
<input id="id_field_1_1" name="field_1_1" type="text" />
</td>
</tr>
''')
form = SplitForm({
'field_0_0': '01/01/2014',
'field_0_1': '00:00:00',
'field_1_0': '02/02/2014',
'field_1_1': '12:12:12',
})
self.assertTrue(form.is_valid())
lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
upper = datetime.datetime(2014, 2, 2, 12, 12, 12)
self.assertEqual(form.cleaned_data['field'], DateTimeTZRange(lower, upper))
def test_none(self):
field = pg_forms.IntegerRangeField(required=False)
value = field.clean(['', ''])
self.assertEqual(value, None)
def test_rendering(self):
class RangeForm(forms.Form):
ints = pg_forms.IntegerRangeField()
self.assertHTMLEqual(str(RangeForm()), '''
<tr>
<th><label for="id_ints_0">Ints:</label></th>
<td>
<input id="id_ints_0" name="ints_0" type="number" />
<input id="id_ints_1" name="ints_1" type="number" />
</td>
</tr>
''')
def test_lower_bound_higher(self):
field = pg_forms.IntegerRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean(['10', '2'])
self.assertEqual(cm.exception.messages[0], 'The start of the range must not exceed the end of the range.')
self.assertEqual(cm.exception.code, 'bound_ordering')
def test_open(self):
field = pg_forms.IntegerRangeField()
value = field.clean(['', '0'])
self.assertEqual(value, NumericRange(None, 0))
def test_incorrect_data_type(self):
field = pg_forms.IntegerRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean('1')
self.assertEqual(cm.exception.messages[0], 'Enter two valid values.')
self.assertEqual(cm.exception.code, 'invalid')
def test_invalid_lower(self):
field = pg_forms.IntegerRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean(['a', '2'])
self.assertEqual(cm.exception.messages[0], 'Enter a whole number.')
def test_invalid_upper(self):
field = pg_forms.IntegerRangeField()
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean(['1', 'b'])
self.assertEqual(cm.exception.messages[0], 'Enter a whole number.')
def test_required(self):
field = pg_forms.IntegerRangeField(required=True)
with self.assertRaises(exceptions.ValidationError) as cm:
field.clean(['', ''])
self.assertEqual(cm.exception.messages[0], 'This field is required.')
value = field.clean([1, ''])
self.assertEqual(value, NumericRange(1, None))
def test_model_field_formfield_integer(self):
model_field = pg_fields.IntegerRangeField()
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
def test_model_field_formfield_biginteger(self):
model_field = pg_fields.BigIntegerRangeField()
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
def test_model_field_formfield_float(self):
model_field = pg_fields.FloatRangeField()
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.FloatRangeField)
def test_model_field_formfield_date(self):
model_field = pg_fields.DateRangeField()
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DateRangeField)
def test_model_field_formfield_datetime(self):
model_field = pg_fields.DateTimeRangeField()
form_field = model_field.formfield()
self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)