mirror of https://github.com/django/django.git
Fixed #29824 -- Added support for database exclusion constraints on PostgreSQL.
Thanks to Nick Pope and Mariusz Felisiak for review. Co-Authored-By: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
parent
7174cf0b00
commit
a3417282ac
|
@ -0,0 +1,106 @@
|
|||
from django.db.backends.ddl_references import Statement, Table
|
||||
from django.db.models import F, Q
|
||||
from django.db.models.constraints import BaseConstraint
|
||||
from django.db.models.sql import Query
|
||||
|
||||
__all__ = ['ExclusionConstraint']
|
||||
|
||||
|
||||
class ExclusionConstraint(BaseConstraint):
|
||||
template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(where)s'
|
||||
|
||||
def __init__(self, *, name, expressions, index_type=None, condition=None):
|
||||
if index_type and index_type.lower() not in {'gist', 'spgist'}:
|
||||
raise ValueError(
|
||||
'Exclusion constraints only support GiST or SP-GiST indexes.'
|
||||
)
|
||||
if not expressions:
|
||||
raise ValueError(
|
||||
'At least one expression is required to define an exclusion '
|
||||
'constraint.'
|
||||
)
|
||||
if not all(
|
||||
isinstance(expr, (list, tuple)) and len(expr) == 2
|
||||
for expr in expressions
|
||||
):
|
||||
raise ValueError('The expressions must be a list of 2-tuples.')
|
||||
if not isinstance(condition, (type(None), Q)):
|
||||
raise ValueError(
|
||||
'ExclusionConstraint.condition must be a Q instance.'
|
||||
)
|
||||
self.expressions = expressions
|
||||
self.index_type = index_type or 'GIST'
|
||||
self.condition = condition
|
||||
super().__init__(name=name)
|
||||
|
||||
def _get_expression_sql(self, compiler, connection, query):
|
||||
expressions = []
|
||||
for expression, operator in self.expressions:
|
||||
if isinstance(expression, str):
|
||||
expression = F(expression)
|
||||
if isinstance(expression, F):
|
||||
expression = expression.resolve_expression(query=query, simple_col=True)
|
||||
else:
|
||||
expression = expression.resolve_expression(query=query)
|
||||
sql, params = expression.as_sql(compiler, connection)
|
||||
expressions.append('%s WITH %s' % (sql % params, operator))
|
||||
return expressions
|
||||
|
||||
def _get_condition_sql(self, compiler, schema_editor, query):
|
||||
if self.condition is None:
|
||||
return None
|
||||
where = query.build_where(self.condition)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
query = Query(model)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
expressions = self._get_expression_sql(compiler, schema_editor.connection, query)
|
||||
condition = self._get_condition_sql(compiler, schema_editor, query)
|
||||
return self.template % {
|
||||
'name': schema_editor.quote_name(self.name),
|
||||
'index_type': self.index_type,
|
||||
'expressions': ', '.join(expressions),
|
||||
'where': ' WHERE (%s)' % condition if condition else '',
|
||||
}
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
return Statement(
|
||||
'ALTER TABLE %(table)s ADD %(constraint)s',
|
||||
table=Table(model._meta.db_table, schema_editor.quote_name),
|
||||
constraint=self.constraint_sql(model, schema_editor),
|
||||
)
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
return schema_editor._delete_constraint_sql(
|
||||
schema_editor.sql_delete_check,
|
||||
model,
|
||||
schema_editor.quote_name(self.name),
|
||||
)
|
||||
|
||||
def deconstruct(self):
|
||||
path, args, kwargs = super().deconstruct()
|
||||
kwargs['expressions'] = self.expressions
|
||||
if self.condition is not None:
|
||||
kwargs['condition'] = self.condition
|
||||
if self.index_type.lower() != 'gist':
|
||||
kwargs['index_type'] = self.index_type
|
||||
return path, args, kwargs
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, self.__class__) and
|
||||
self.name == other.name and
|
||||
self.index_type == other.index_type and
|
||||
self.expressions == other.expressions and
|
||||
self.condition == other.condition
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s: index_type=%s, expressions=%s%s>' % (
|
||||
self.__class__.__qualname__,
|
||||
self.index_type,
|
||||
self.expressions,
|
||||
'' if self.condition is None else ', condition=%s' % self.condition,
|
||||
)
|
|
@ -12,10 +12,20 @@ __all__ = [
|
|||
'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
|
||||
'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField',
|
||||
'FloatRangeField',
|
||||
'RangeOperators',
|
||||
'RangeBoundary', 'RangeOperators',
|
||||
]
|
||||
|
||||
|
||||
class RangeBoundary(models.Expression):
|
||||
"""A class that represents range boundaries."""
|
||||
def __init__(self, inclusive_lower=True, inclusive_upper=False):
|
||||
self.lower = '[' if inclusive_lower else '('
|
||||
self.upper = ']' if inclusive_upper else ')'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return "'%s%s'" % (self.lower, self.upper), []
|
||||
|
||||
|
||||
class RangeOperators:
|
||||
# https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
|
||||
EQUAL = '='
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
========================================
|
||||
PostgreSQL specific database constraints
|
||||
========================================
|
||||
|
||||
.. module:: django.contrib.postgres.constraints
|
||||
:synopsis: PostgreSQL specific database constraint
|
||||
|
||||
PostgreSQL supports additional data integrity constraints available from the
|
||||
``django.contrib.postgres.constraints`` module. They are added in the model
|
||||
:attr:`Meta.constraints <django.db.models.Options.constraints>` option.
|
||||
|
||||
``ExclusionConstraint``
|
||||
=======================
|
||||
|
||||
.. versionadded:: 3.0
|
||||
|
||||
.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None)
|
||||
|
||||
Creates an exclusion constraint in the database. Internally, PostgreSQL
|
||||
implements exclusion constraints using indexes. The default index type is
|
||||
`GiST <https://www.postgresql.org/docs/current/gist.html>`_. To use them,
|
||||
you need to activate the `btree_gist extension
|
||||
<https://www.postgresql.org/docs/current/btree-gist.html>`_ on PostgreSQL.
|
||||
You can install it using the
|
||||
:class:`~django.contrib.postgres.operations.BtreeGistExtension` migration
|
||||
operation.
|
||||
|
||||
If you attempt to insert a new row that conflicts with an existing row, an
|
||||
:exc:`~django.db.IntegrityError` is raised. Similarly, when update
|
||||
conflicts with an existing row.
|
||||
|
||||
``name``
|
||||
--------
|
||||
|
||||
.. attribute:: ExclusionConstraint.name
|
||||
|
||||
The name of the constraint.
|
||||
|
||||
``expressions``
|
||||
---------------
|
||||
|
||||
.. attribute:: ExclusionConstraint.expressions
|
||||
|
||||
An iterable of 2-tuples. The first element is an expression or string. The
|
||||
second element is a SQL operator represented as a string. To avoid typos, you
|
||||
may use :class:`~django.contrib.postgres.fields.RangeOperators` which maps the
|
||||
operators with strings. For example::
|
||||
|
||||
expressions=[
|
||||
('timespan', RangeOperators.ADJACENT_TO),
|
||||
(F('room'), RangeOperators.EQUAL),
|
||||
]
|
||||
|
||||
.. admonition:: Restrictions on operators.
|
||||
|
||||
Only commutative operators can be used in exclusion constraints.
|
||||
|
||||
``index_type``
|
||||
--------------
|
||||
|
||||
.. attribute:: ExclusionConstraint.index_type
|
||||
|
||||
The index type of the constraint. Accepted values are ``GIST`` or ``SPGIST``.
|
||||
Matching is case insensitive. If not provided, the default index type is
|
||||
``GIST``.
|
||||
|
||||
``condition``
|
||||
-------------
|
||||
|
||||
.. attribute:: ExclusionConstraint.condition
|
||||
|
||||
A :class:`~django.db.models.Q` object that specifies the condition to restrict
|
||||
a constraint to a subset of rows. For example,
|
||||
``condition=Q(cancelled=False)``.
|
||||
|
||||
These conditions have the same database restrictions as
|
||||
:attr:`django.db.models.Index.condition`.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
The following example restricts overlapping reservations in the same room, not
|
||||
taking canceled reservations into account::
|
||||
|
||||
from django.contrib.postgres.constraints import ExclusionConstraint
|
||||
from django.contrib.postgres.fields import DateTimeRangeField, RangeOperators
|
||||
from django.db import models
|
||||
from django.db.models import Q
|
||||
|
||||
class Room(models.Model):
|
||||
number = models.IntegerField()
|
||||
|
||||
|
||||
class Reservation(models.Model):
|
||||
room = models.ForeignKey('Room', on_delete=models.CASCADE)
|
||||
timespan = DateTimeRangeField()
|
||||
cancelled = models.BooleanField(default=False)
|
||||
|
||||
class Meta:
|
||||
constraints = [
|
||||
ExclusionConstraint(
|
||||
name='exclude_overlapping_reservations',
|
||||
expressions=[
|
||||
('timespan', RangeOperators.OVERLAPS),
|
||||
('room', RangeOperators.EQUAL),
|
||||
],
|
||||
condition=Q(cancelled=False),
|
||||
),
|
||||
]
|
||||
|
||||
In case your model defines a range using two fields, instead of the native
|
||||
PostgreSQL range types, you should write an expression that uses the equivalent
|
||||
function (e.g. ``TsTzRange()``), and use the delimiters for the field. Most
|
||||
often, the delimiters will be ``'[)'``, meaning that the lower bound is
|
||||
inclusive and the upper bound is exclusive. You may use the
|
||||
:class:`~django.contrib.postgres.fields.RangeBoundary` that provides an
|
||||
expression mapping for the `range boundaries <https://www.postgresql.org/docs/
|
||||
current/rangetypes.html#RANGETYPES-INCLUSIVITY>`_. For example::
|
||||
|
||||
from django.contrib.postgres.constraints import ExclusionConstraint
|
||||
from django.contrib.postgres.fields import (
|
||||
DateTimeRangeField,
|
||||
RangeBoundary,
|
||||
RangeOperators,
|
||||
)
|
||||
from django.db import models
|
||||
from django.db.models import Func, Q
|
||||
|
||||
|
||||
class TsTzRange(Func):
|
||||
function = 'TSTZRANGE'
|
||||
output_field = DateTimeRangeField()
|
||||
|
||||
|
||||
class Reservation(models.Model):
|
||||
room = models.ForeignKey('Room', on_delete=models.CASCADE)
|
||||
start = models.DateTimeField()
|
||||
end = models.DateTimeField()
|
||||
cancelled = models.BooleanField(default=False)
|
||||
|
||||
class Meta:
|
||||
constraints = [
|
||||
ExclusionConstraint(
|
||||
name='exclude_overlapping_reservations',
|
||||
expressions=(
|
||||
(TsTzRange('start', 'end', RangeBoundary()), RangeOperators.OVERLAPS),
|
||||
('room', RangeOperators.EQUAL),
|
||||
),
|
||||
condition=Q(cancelled=False),
|
||||
),
|
||||
]
|
|
@ -944,3 +944,26 @@ corresponding lookups.
|
|||
NOT_LT = '&>'
|
||||
NOT_GT = '&<'
|
||||
ADJACENT_TO = '-|-'
|
||||
|
||||
RangeBoundary() expressions
|
||||
---------------------------
|
||||
|
||||
.. versionadded:: 3.0
|
||||
|
||||
.. class:: RangeBoundary(inclusive_lower=True, inclusive_upper=False)
|
||||
|
||||
.. attribute:: inclusive_lower
|
||||
|
||||
If ``True`` (default), the lower bound is inclusive ``'['``, otherwise
|
||||
it's exclusive ``'('``.
|
||||
|
||||
.. attribute:: inclusive_upper
|
||||
|
||||
If ``False`` (default), the upper bound is exclusive ``')'``, otherwise
|
||||
it's inclusive ``']'``.
|
||||
|
||||
A ``RangeBoundary()`` expression represents the range boundaries. It can be
|
||||
used with a custom range functions that expected boundaries, for example to
|
||||
define :class:`~django.contrib.postgres.constraints.ExclusionConstraint`. See
|
||||
`the PostgreSQL documentation for the full details <https://www.postgresql.org/
|
||||
docs/current/rangetypes.html#RANGETYPES-INCLUSIVITY>`_.
|
||||
|
|
|
@ -29,6 +29,7 @@ a number of PostgreSQL specific data types.
|
|||
:maxdepth: 2
|
||||
|
||||
aggregates
|
||||
constraints
|
||||
fields
|
||||
forms
|
||||
functions
|
||||
|
|
|
@ -66,6 +66,14 @@ async code before, this may trigger if you were doing it incorrectly. If you
|
|||
see a ``SynchronousOnlyOperation`` error, then closely examine your code and
|
||||
move any database operations to be in a synchronous child thread.
|
||||
|
||||
Exclusion constraints on PostgreSQL
|
||||
-----------------------------------
|
||||
|
||||
The new :class:`~django.contrib.postgres.constraints.ExclusionConstraint` class
|
||||
enable adding exclusion constraints on PostgreSQL. Constraints are added to
|
||||
models using the
|
||||
:attr:`Meta.constraints <django.db.models.Options.constraints>` option.
|
||||
|
||||
Minor features
|
||||
--------------
|
||||
|
||||
|
@ -137,6 +145,9 @@ Minor features
|
|||
avoid typos in SQL operators that can be used together with
|
||||
:class:`~django.contrib.postgres.fields.RangeField`.
|
||||
|
||||
* The new :class:`~django.contrib.postgres.fields.RangeBoundary` expression
|
||||
represents the range boundaries.
|
||||
|
||||
:mod:`django.contrib.redirects`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -262,4 +262,25 @@ class Migration(migrations.Migration):
|
|||
},
|
||||
bases=(models.Model,),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Room',
|
||||
fields=[
|
||||
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
|
||||
('number', models.IntegerField(unique=True)),
|
||||
],
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='HotelReservation',
|
||||
fields=[
|
||||
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
|
||||
('room', models.ForeignKey('postgres_tests.Room', models.CASCADE)),
|
||||
('datespan', DateRangeField()),
|
||||
('start', models.DateTimeField()),
|
||||
('end', models.DateTimeField()),
|
||||
('cancelled', models.BooleanField(default=False)),
|
||||
],
|
||||
options={
|
||||
'required_db_vendor': 'postgresql',
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -183,3 +183,15 @@ class NowTestModel(models.Model):
|
|||
|
||||
class UUIDTestModel(models.Model):
|
||||
uuid = models.UUIDField(default=None, null=True)
|
||||
|
||||
|
||||
class Room(models.Model):
|
||||
number = models.IntegerField(unique=True)
|
||||
|
||||
|
||||
class HotelReservation(PostgreSQLModel):
|
||||
room = models.ForeignKey('Room', on_delete=models.CASCADE)
|
||||
datespan = DateRangeField()
|
||||
start = models.DateTimeField()
|
||||
end = models.DateTimeField()
|
||||
cancelled = models.BooleanField(default=False)
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
import datetime
|
||||
|
||||
from django.db import connection, transaction
|
||||
from django.db.models import F, Q
|
||||
from django.db.models import F, Func, Q
|
||||
from django.db.models.constraints import CheckConstraint
|
||||
from django.db.utils import IntegrityError
|
||||
from django.utils import timezone
|
||||
|
||||
from . import PostgreSQLTestCase
|
||||
from .models import RangesModel
|
||||
from .models import HotelReservation, RangesModel, Room
|
||||
|
||||
try:
|
||||
from psycopg2.extras import NumericRange
|
||||
from django.contrib.postgres.constraints import ExclusionConstraint
|
||||
from django.contrib.postgres.fields import DateTimeRangeField, RangeBoundary, RangeOperators
|
||||
|
||||
from psycopg2.extras import DateRange, NumericRange
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
@ -77,3 +81,249 @@ class SchemaTests(PostgreSQLTestCase):
|
|||
timestamps=(datetime_1, datetime_2),
|
||||
timestamps_inner=(datetime_1, datetime_2),
|
||||
)
|
||||
|
||||
|
||||
class ExclusionConstraintTests(PostgreSQLTestCase):
|
||||
def get_constraints(self, table):
|
||||
"""Get the constraints on the table using a new cursor."""
|
||||
with connection.cursor() as cursor:
|
||||
return connection.introspection.get_constraints(cursor, table)
|
||||
|
||||
def test_invalid_condition(self):
|
||||
msg = 'ExclusionConstraint.condition must be a Q instance.'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
ExclusionConstraint(
|
||||
index_type='GIST',
|
||||
name='exclude_invalid_condition',
|
||||
expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
|
||||
condition=F('invalid'),
|
||||
)
|
||||
|
||||
def test_invalid_index_type(self):
|
||||
msg = 'Exclusion constraints only support GiST or SP-GiST indexes.'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
ExclusionConstraint(
|
||||
index_type='gin',
|
||||
name='exclude_invalid_index_type',
|
||||
expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
|
||||
)
|
||||
|
||||
def test_invalid_expressions(self):
|
||||
msg = 'The expressions must be a list of 2-tuples.'
|
||||
for expressions in (['foo'], [('foo')], [('foo_1', 'foo_2', 'foo_3')]):
|
||||
with self.subTest(expressions), self.assertRaisesMessage(ValueError, msg):
|
||||
ExclusionConstraint(
|
||||
index_type='GIST',
|
||||
name='exclude_invalid_expressions',
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
def test_empty_expressions(self):
|
||||
msg = 'At least one expression is required to define an exclusion constraint.'
|
||||
for empty_expressions in (None, []):
|
||||
with self.subTest(empty_expressions), self.assertRaisesMessage(ValueError, msg):
|
||||
ExclusionConstraint(
|
||||
index_type='GIST',
|
||||
name='exclude_empty_expressions',
|
||||
expressions=empty_expressions,
|
||||
)
|
||||
|
||||
def test_repr(self):
|
||||
constraint = ExclusionConstraint(
|
||||
name='exclude_overlapping',
|
||||
expressions=[
|
||||
(F('datespan'), RangeOperators.OVERLAPS),
|
||||
(F('room'), RangeOperators.EQUAL),
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
repr(constraint),
|
||||
"<ExclusionConstraint: index_type=GIST, expressions=["
|
||||
"(F(datespan), '&&'), (F(room), '=')]>",
|
||||
)
|
||||
constraint = ExclusionConstraint(
|
||||
name='exclude_overlapping',
|
||||
expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)],
|
||||
condition=Q(cancelled=False),
|
||||
index_type='SPGiST',
|
||||
)
|
||||
self.assertEqual(
|
||||
repr(constraint),
|
||||
"<ExclusionConstraint: index_type=SPGiST, expressions=["
|
||||
"(F(datespan), '-|-')], condition=(AND: ('cancelled', False))>",
|
||||
)
|
||||
|
||||
def test_eq(self):
|
||||
constraint_1 = ExclusionConstraint(
|
||||
name='exclude_overlapping',
|
||||
expressions=[
|
||||
(F('datespan'), RangeOperators.OVERLAPS),
|
||||
(F('room'), RangeOperators.EQUAL),
|
||||
],
|
||||
condition=Q(cancelled=False),
|
||||
)
|
||||
constraint_2 = ExclusionConstraint(
|
||||
name='exclude_overlapping',
|
||||
expressions=[
|
||||
('datespan', RangeOperators.OVERLAPS),
|
||||
('room', RangeOperators.EQUAL),
|
||||
],
|
||||
)
|
||||
constraint_3 = ExclusionConstraint(
|
||||
name='exclude_overlapping',
|
||||
expressions=[('datespan', RangeOperators.OVERLAPS)],
|
||||
condition=Q(cancelled=False),
|
||||
)
|
||||
self.assertEqual(constraint_1, constraint_1)
|
||||
self.assertNotEqual(constraint_1, constraint_2)
|
||||
self.assertNotEqual(constraint_1, constraint_3)
|
||||
self.assertNotEqual(constraint_2, constraint_3)
|
||||
self.assertNotEqual(constraint_1, object())
|
||||
|
||||
def test_deconstruct(self):
|
||||
constraint = ExclusionConstraint(
|
||||
name='exclude_overlapping',
|
||||
expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
|
||||
)
|
||||
path, args, kwargs = constraint.deconstruct()
|
||||
self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
|
||||
self.assertEqual(args, ())
|
||||
self.assertEqual(kwargs, {
|
||||
'name': 'exclude_overlapping',
|
||||
'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
|
||||
})
|
||||
|
||||
def test_deconstruct_index_type(self):
|
||||
constraint = ExclusionConstraint(
|
||||
name='exclude_overlapping',
|
||||
index_type='SPGIST',
|
||||
expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
|
||||
)
|
||||
path, args, kwargs = constraint.deconstruct()
|
||||
self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
|
||||
self.assertEqual(args, ())
|
||||
self.assertEqual(kwargs, {
|
||||
'name': 'exclude_overlapping',
|
||||
'index_type': 'SPGIST',
|
||||
'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
|
||||
})
|
||||
|
||||
def test_deconstruct_condition(self):
|
||||
constraint = ExclusionConstraint(
|
||||
name='exclude_overlapping',
|
||||
expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
|
||||
condition=Q(cancelled=False),
|
||||
)
|
||||
path, args, kwargs = constraint.deconstruct()
|
||||
self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
|
||||
self.assertEqual(args, ())
|
||||
self.assertEqual(kwargs, {
|
||||
'name': 'exclude_overlapping',
|
||||
'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
|
||||
'condition': Q(cancelled=False),
|
||||
})
|
||||
|
||||
def _test_range_overlaps(self, constraint):
|
||||
# Create exclusion constraint.
|
||||
self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table))
|
||||
with connection.schema_editor() as editor:
|
||||
editor.add_constraint(HotelReservation, constraint)
|
||||
self.assertIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table))
|
||||
# Add initial reservations.
|
||||
room101 = Room.objects.create(number=101)
|
||||
room102 = Room.objects.create(number=102)
|
||||
datetimes = [
|
||||
timezone.datetime(2018, 6, 20),
|
||||
timezone.datetime(2018, 6, 24),
|
||||
timezone.datetime(2018, 6, 26),
|
||||
timezone.datetime(2018, 6, 28),
|
||||
timezone.datetime(2018, 6, 29),
|
||||
]
|
||||
HotelReservation.objects.create(
|
||||
datespan=DateRange(datetimes[0].date(), datetimes[1].date()),
|
||||
start=datetimes[0],
|
||||
end=datetimes[1],
|
||||
room=room102,
|
||||
)
|
||||
HotelReservation.objects.create(
|
||||
datespan=DateRange(datetimes[1].date(), datetimes[3].date()),
|
||||
start=datetimes[1],
|
||||
end=datetimes[3],
|
||||
room=room102,
|
||||
)
|
||||
# Overlap dates.
|
||||
with self.assertRaises(IntegrityError), transaction.atomic():
|
||||
reservation = HotelReservation(
|
||||
datespan=(datetimes[1].date(), datetimes[2].date()),
|
||||
start=datetimes[1],
|
||||
end=datetimes[2],
|
||||
room=room102,
|
||||
)
|
||||
reservation.save()
|
||||
# Valid range.
|
||||
HotelReservation.objects.bulk_create([
|
||||
# Other room.
|
||||
HotelReservation(
|
||||
datespan=(datetimes[1].date(), datetimes[2].date()),
|
||||
start=datetimes[1],
|
||||
end=datetimes[2],
|
||||
room=room101,
|
||||
),
|
||||
# Cancelled reservation.
|
||||
HotelReservation(
|
||||
datespan=(datetimes[1].date(), datetimes[1].date()),
|
||||
start=datetimes[1],
|
||||
end=datetimes[2],
|
||||
room=room102,
|
||||
cancelled=True,
|
||||
),
|
||||
# Other adjacent dates.
|
||||
HotelReservation(
|
||||
datespan=(datetimes[3].date(), datetimes[4].date()),
|
||||
start=datetimes[3],
|
||||
end=datetimes[4],
|
||||
room=room102,
|
||||
),
|
||||
])
|
||||
|
||||
def test_range_overlaps_custom(self):
|
||||
class TsTzRange(Func):
|
||||
function = 'TSTZRANGE'
|
||||
output_field = DateTimeRangeField()
|
||||
|
||||
constraint = ExclusionConstraint(
|
||||
name='exclude_overlapping_reservations_custom',
|
||||
expressions=[
|
||||
(TsTzRange('start', 'end', RangeBoundary()), RangeOperators.OVERLAPS),
|
||||
('room', RangeOperators.EQUAL)
|
||||
],
|
||||
condition=Q(cancelled=False),
|
||||
)
|
||||
self._test_range_overlaps(constraint)
|
||||
|
||||
def test_range_overlaps(self):
|
||||
constraint = ExclusionConstraint(
|
||||
name='exclude_overlapping_reservations',
|
||||
expressions=[
|
||||
(F('datespan'), RangeOperators.OVERLAPS),
|
||||
('room', RangeOperators.EQUAL)
|
||||
],
|
||||
condition=Q(cancelled=False),
|
||||
)
|
||||
self._test_range_overlaps(constraint)
|
||||
|
||||
def test_range_adjacent(self):
|
||||
constraint_name = 'ints_adjacent'
|
||||
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
|
||||
constraint = ExclusionConstraint(
|
||||
name=constraint_name,
|
||||
expressions=[('ints', RangeOperators.ADJACENT_TO)],
|
||||
)
|
||||
with connection.schema_editor() as editor:
|
||||
editor.add_constraint(RangesModel, constraint)
|
||||
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
|
||||
RangesModel.objects.create(ints=(20, 50))
|
||||
with self.assertRaises(IntegrityError), transaction.atomic():
|
||||
RangesModel.objects.create(ints=(10, 20))
|
||||
RangesModel.objects.create(ints=(10, 19))
|
||||
RangesModel.objects.create(ints=(51, 60))
|
||||
|
|
Loading…
Reference in New Issue