Fixed #29824 -- Added support for database exclusion constraints on PostgreSQL.

Thanks to Nick Pope and Mariusz Felisiak for review.

Co-Authored-By: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
Mads Jensen 2019-07-12 13:08:00 +02:00 committed by Mariusz Felisiak
parent 7174cf0b00
commit a3417282ac
9 changed files with 589 additions and 4 deletions

View File

@ -0,0 +1,106 @@
from django.db.backends.ddl_references import Statement, Table
from django.db.models import F, Q
from django.db.models.constraints import BaseConstraint
from django.db.models.sql import Query
__all__ = ['ExclusionConstraint']
class ExclusionConstraint(BaseConstraint):
template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(where)s'
def __init__(self, *, name, expressions, index_type=None, condition=None):
if index_type and index_type.lower() not in {'gist', 'spgist'}:
raise ValueError(
'Exclusion constraints only support GiST or SP-GiST indexes.'
)
if not expressions:
raise ValueError(
'At least one expression is required to define an exclusion '
'constraint.'
)
if not all(
isinstance(expr, (list, tuple)) and len(expr) == 2
for expr in expressions
):
raise ValueError('The expressions must be a list of 2-tuples.')
if not isinstance(condition, (type(None), Q)):
raise ValueError(
'ExclusionConstraint.condition must be a Q instance.'
)
self.expressions = expressions
self.index_type = index_type or 'GIST'
self.condition = condition
super().__init__(name=name)
def _get_expression_sql(self, compiler, connection, query):
expressions = []
for expression, operator in self.expressions:
if isinstance(expression, str):
expression = F(expression)
if isinstance(expression, F):
expression = expression.resolve_expression(query=query, simple_col=True)
else:
expression = expression.resolve_expression(query=query)
sql, params = expression.as_sql(compiler, connection)
expressions.append('%s WITH %s' % (sql % params, operator))
return expressions
def _get_condition_sql(self, compiler, schema_editor, query):
if self.condition is None:
return None
where = query.build_where(self.condition)
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def constraint_sql(self, model, schema_editor):
query = Query(model)
compiler = query.get_compiler(connection=schema_editor.connection)
expressions = self._get_expression_sql(compiler, schema_editor.connection, query)
condition = self._get_condition_sql(compiler, schema_editor, query)
return self.template % {
'name': schema_editor.quote_name(self.name),
'index_type': self.index_type,
'expressions': ', '.join(expressions),
'where': ' WHERE (%s)' % condition if condition else '',
}
def create_sql(self, model, schema_editor):
return Statement(
'ALTER TABLE %(table)s ADD %(constraint)s',
table=Table(model._meta.db_table, schema_editor.quote_name),
constraint=self.constraint_sql(model, schema_editor),
)
def remove_sql(self, model, schema_editor):
return schema_editor._delete_constraint_sql(
schema_editor.sql_delete_check,
model,
schema_editor.quote_name(self.name),
)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
kwargs['expressions'] = self.expressions
if self.condition is not None:
kwargs['condition'] = self.condition
if self.index_type.lower() != 'gist':
kwargs['index_type'] = self.index_type
return path, args, kwargs
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
self.name == other.name and
self.index_type == other.index_type and
self.expressions == other.expressions and
self.condition == other.condition
)
def __repr__(self):
return '<%s: index_type=%s, expressions=%s%s>' % (
self.__class__.__qualname__,
self.index_type,
self.expressions,
'' if self.condition is None else ', condition=%s' % self.condition,
)

View File

@ -12,10 +12,20 @@ __all__ = [
'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField',
'FloatRangeField',
'RangeOperators',
'RangeBoundary', 'RangeOperators',
]
class RangeBoundary(models.Expression):
"""A class that represents range boundaries."""
def __init__(self, inclusive_lower=True, inclusive_upper=False):
self.lower = '[' if inclusive_lower else '('
self.upper = ']' if inclusive_upper else ')'
def as_sql(self, compiler, connection):
return "'%s%s'" % (self.lower, self.upper), []
class RangeOperators:
# https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
EQUAL = '='

View File

@ -0,0 +1,151 @@
========================================
PostgreSQL specific database constraints
========================================
.. module:: django.contrib.postgres.constraints
:synopsis: PostgreSQL specific database constraint
PostgreSQL supports additional data integrity constraints available from the
``django.contrib.postgres.constraints`` module. They are added in the model
:attr:`Meta.constraints <django.db.models.Options.constraints>` option.
``ExclusionConstraint``
=======================
.. versionadded:: 3.0
.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None)
Creates an exclusion constraint in the database. Internally, PostgreSQL
implements exclusion constraints using indexes. The default index type is
`GiST <https://www.postgresql.org/docs/current/gist.html>`_. To use them,
you need to activate the `btree_gist extension
<https://www.postgresql.org/docs/current/btree-gist.html>`_ on PostgreSQL.
You can install it using the
:class:`~django.contrib.postgres.operations.BtreeGistExtension` migration
operation.
If you attempt to insert a new row that conflicts with an existing row, an
:exc:`~django.db.IntegrityError` is raised. Similarly, when update
conflicts with an existing row.
``name``
--------
.. attribute:: ExclusionConstraint.name
The name of the constraint.
``expressions``
---------------
.. attribute:: ExclusionConstraint.expressions
An iterable of 2-tuples. The first element is an expression or string. The
second element is a SQL operator represented as a string. To avoid typos, you
may use :class:`~django.contrib.postgres.fields.RangeOperators` which maps the
operators with strings. For example::
expressions=[
('timespan', RangeOperators.ADJACENT_TO),
(F('room'), RangeOperators.EQUAL),
]
.. admonition:: Restrictions on operators.
Only commutative operators can be used in exclusion constraints.
``index_type``
--------------
.. attribute:: ExclusionConstraint.index_type
The index type of the constraint. Accepted values are ``GIST`` or ``SPGIST``.
Matching is case insensitive. If not provided, the default index type is
``GIST``.
``condition``
-------------
.. attribute:: ExclusionConstraint.condition
A :class:`~django.db.models.Q` object that specifies the condition to restrict
a constraint to a subset of rows. For example,
``condition=Q(cancelled=False)``.
These conditions have the same database restrictions as
:attr:`django.db.models.Index.condition`.
Examples
--------
The following example restricts overlapping reservations in the same room, not
taking canceled reservations into account::
from django.contrib.postgres.constraints import ExclusionConstraint
from django.contrib.postgres.fields import DateTimeRangeField, RangeOperators
from django.db import models
from django.db.models import Q
class Room(models.Model):
number = models.IntegerField()
class Reservation(models.Model):
room = models.ForeignKey('Room', on_delete=models.CASCADE)
timespan = DateTimeRangeField()
cancelled = models.BooleanField(default=False)
class Meta:
constraints = [
ExclusionConstraint(
name='exclude_overlapping_reservations',
expressions=[
('timespan', RangeOperators.OVERLAPS),
('room', RangeOperators.EQUAL),
],
condition=Q(cancelled=False),
),
]
In case your model defines a range using two fields, instead of the native
PostgreSQL range types, you should write an expression that uses the equivalent
function (e.g. ``TsTzRange()``), and use the delimiters for the field. Most
often, the delimiters will be ``'[)'``, meaning that the lower bound is
inclusive and the upper bound is exclusive. You may use the
:class:`~django.contrib.postgres.fields.RangeBoundary` that provides an
expression mapping for the `range boundaries <https://www.postgresql.org/docs/
current/rangetypes.html#RANGETYPES-INCLUSIVITY>`_. For example::
from django.contrib.postgres.constraints import ExclusionConstraint
from django.contrib.postgres.fields import (
DateTimeRangeField,
RangeBoundary,
RangeOperators,
)
from django.db import models
from django.db.models import Func, Q
class TsTzRange(Func):
function = 'TSTZRANGE'
output_field = DateTimeRangeField()
class Reservation(models.Model):
room = models.ForeignKey('Room', on_delete=models.CASCADE)
start = models.DateTimeField()
end = models.DateTimeField()
cancelled = models.BooleanField(default=False)
class Meta:
constraints = [
ExclusionConstraint(
name='exclude_overlapping_reservations',
expressions=(
(TsTzRange('start', 'end', RangeBoundary()), RangeOperators.OVERLAPS),
('room', RangeOperators.EQUAL),
),
condition=Q(cancelled=False),
),
]

View File

@ -944,3 +944,26 @@ corresponding lookups.
NOT_LT = '&>'
NOT_GT = '&<'
ADJACENT_TO = '-|-'
RangeBoundary() expressions
---------------------------
.. versionadded:: 3.0
.. class:: RangeBoundary(inclusive_lower=True, inclusive_upper=False)
.. attribute:: inclusive_lower
If ``True`` (default), the lower bound is inclusive ``'['``, otherwise
it's exclusive ``'('``.
.. attribute:: inclusive_upper
If ``False`` (default), the upper bound is exclusive ``')'``, otherwise
it's inclusive ``']'``.
A ``RangeBoundary()`` expression represents the range boundaries. It can be
used with a custom range functions that expected boundaries, for example to
define :class:`~django.contrib.postgres.constraints.ExclusionConstraint`. See
`the PostgreSQL documentation for the full details <https://www.postgresql.org/
docs/current/rangetypes.html#RANGETYPES-INCLUSIVITY>`_.

View File

@ -29,6 +29,7 @@ a number of PostgreSQL specific data types.
:maxdepth: 2
aggregates
constraints
fields
forms
functions

View File

@ -66,6 +66,14 @@ async code before, this may trigger if you were doing it incorrectly. If you
see a ``SynchronousOnlyOperation`` error, then closely examine your code and
move any database operations to be in a synchronous child thread.
Exclusion constraints on PostgreSQL
-----------------------------------
The new :class:`~django.contrib.postgres.constraints.ExclusionConstraint` class
enable adding exclusion constraints on PostgreSQL. Constraints are added to
models using the
:attr:`Meta.constraints <django.db.models.Options.constraints>` option.
Minor features
--------------
@ -137,6 +145,9 @@ Minor features
avoid typos in SQL operators that can be used together with
:class:`~django.contrib.postgres.fields.RangeField`.
* The new :class:`~django.contrib.postgres.fields.RangeBoundary` expression
represents the range boundaries.
:mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -262,4 +262,25 @@ class Migration(migrations.Migration):
},
bases=(models.Model,),
),
migrations.CreateModel(
name='Room',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('number', models.IntegerField(unique=True)),
],
),
migrations.CreateModel(
name='HotelReservation',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('room', models.ForeignKey('postgres_tests.Room', models.CASCADE)),
('datespan', DateRangeField()),
('start', models.DateTimeField()),
('end', models.DateTimeField()),
('cancelled', models.BooleanField(default=False)),
],
options={
'required_db_vendor': 'postgresql',
},
),
]

View File

@ -183,3 +183,15 @@ class NowTestModel(models.Model):
class UUIDTestModel(models.Model):
uuid = models.UUIDField(default=None, null=True)
class Room(models.Model):
number = models.IntegerField(unique=True)
class HotelReservation(PostgreSQLModel):
room = models.ForeignKey('Room', on_delete=models.CASCADE)
datespan = DateRangeField()
start = models.DateTimeField()
end = models.DateTimeField()
cancelled = models.BooleanField(default=False)

View File

@ -1,15 +1,19 @@
import datetime
from django.db import connection, transaction
from django.db.models import F, Q
from django.db.models import F, Func, Q
from django.db.models.constraints import CheckConstraint
from django.db.utils import IntegrityError
from django.utils import timezone
from . import PostgreSQLTestCase
from .models import RangesModel
from .models import HotelReservation, RangesModel, Room
try:
from psycopg2.extras import NumericRange
from django.contrib.postgres.constraints import ExclusionConstraint
from django.contrib.postgres.fields import DateTimeRangeField, RangeBoundary, RangeOperators
from psycopg2.extras import DateRange, NumericRange
except ImportError:
pass
@ -77,3 +81,249 @@ class SchemaTests(PostgreSQLTestCase):
timestamps=(datetime_1, datetime_2),
timestamps_inner=(datetime_1, datetime_2),
)
class ExclusionConstraintTests(PostgreSQLTestCase):
def get_constraints(self, table):
"""Get the constraints on the table using a new cursor."""
with connection.cursor() as cursor:
return connection.introspection.get_constraints(cursor, table)
def test_invalid_condition(self):
msg = 'ExclusionConstraint.condition must be a Q instance.'
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
index_type='GIST',
name='exclude_invalid_condition',
expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
condition=F('invalid'),
)
def test_invalid_index_type(self):
msg = 'Exclusion constraints only support GiST or SP-GiST indexes.'
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
index_type='gin',
name='exclude_invalid_index_type',
expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
)
def test_invalid_expressions(self):
msg = 'The expressions must be a list of 2-tuples.'
for expressions in (['foo'], [('foo')], [('foo_1', 'foo_2', 'foo_3')]):
with self.subTest(expressions), self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
index_type='GIST',
name='exclude_invalid_expressions',
expressions=expressions,
)
def test_empty_expressions(self):
msg = 'At least one expression is required to define an exclusion constraint.'
for empty_expressions in (None, []):
with self.subTest(empty_expressions), self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
index_type='GIST',
name='exclude_empty_expressions',
expressions=empty_expressions,
)
def test_repr(self):
constraint = ExclusionConstraint(
name='exclude_overlapping',
expressions=[
(F('datespan'), RangeOperators.OVERLAPS),
(F('room'), RangeOperators.EQUAL),
],
)
self.assertEqual(
repr(constraint),
"<ExclusionConstraint: index_type=GIST, expressions=["
"(F(datespan), '&&'), (F(room), '=')]>",
)
constraint = ExclusionConstraint(
name='exclude_overlapping',
expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)],
condition=Q(cancelled=False),
index_type='SPGiST',
)
self.assertEqual(
repr(constraint),
"<ExclusionConstraint: index_type=SPGiST, expressions=["
"(F(datespan), '-|-')], condition=(AND: ('cancelled', False))>",
)
def test_eq(self):
constraint_1 = ExclusionConstraint(
name='exclude_overlapping',
expressions=[
(F('datespan'), RangeOperators.OVERLAPS),
(F('room'), RangeOperators.EQUAL),
],
condition=Q(cancelled=False),
)
constraint_2 = ExclusionConstraint(
name='exclude_overlapping',
expressions=[
('datespan', RangeOperators.OVERLAPS),
('room', RangeOperators.EQUAL),
],
)
constraint_3 = ExclusionConstraint(
name='exclude_overlapping',
expressions=[('datespan', RangeOperators.OVERLAPS)],
condition=Q(cancelled=False),
)
self.assertEqual(constraint_1, constraint_1)
self.assertNotEqual(constraint_1, constraint_2)
self.assertNotEqual(constraint_1, constraint_3)
self.assertNotEqual(constraint_2, constraint_3)
self.assertNotEqual(constraint_1, object())
def test_deconstruct(self):
constraint = ExclusionConstraint(
name='exclude_overlapping',
expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {
'name': 'exclude_overlapping',
'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
})
def test_deconstruct_index_type(self):
constraint = ExclusionConstraint(
name='exclude_overlapping',
index_type='SPGIST',
expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {
'name': 'exclude_overlapping',
'index_type': 'SPGIST',
'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
})
def test_deconstruct_condition(self):
constraint = ExclusionConstraint(
name='exclude_overlapping',
expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
condition=Q(cancelled=False),
)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {
'name': 'exclude_overlapping',
'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)],
'condition': Q(cancelled=False),
})
def _test_range_overlaps(self, constraint):
# Create exclusion constraint.
self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table))
with connection.schema_editor() as editor:
editor.add_constraint(HotelReservation, constraint)
self.assertIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table))
# Add initial reservations.
room101 = Room.objects.create(number=101)
room102 = Room.objects.create(number=102)
datetimes = [
timezone.datetime(2018, 6, 20),
timezone.datetime(2018, 6, 24),
timezone.datetime(2018, 6, 26),
timezone.datetime(2018, 6, 28),
timezone.datetime(2018, 6, 29),
]
HotelReservation.objects.create(
datespan=DateRange(datetimes[0].date(), datetimes[1].date()),
start=datetimes[0],
end=datetimes[1],
room=room102,
)
HotelReservation.objects.create(
datespan=DateRange(datetimes[1].date(), datetimes[3].date()),
start=datetimes[1],
end=datetimes[3],
room=room102,
)
# Overlap dates.
with self.assertRaises(IntegrityError), transaction.atomic():
reservation = HotelReservation(
datespan=(datetimes[1].date(), datetimes[2].date()),
start=datetimes[1],
end=datetimes[2],
room=room102,
)
reservation.save()
# Valid range.
HotelReservation.objects.bulk_create([
# Other room.
HotelReservation(
datespan=(datetimes[1].date(), datetimes[2].date()),
start=datetimes[1],
end=datetimes[2],
room=room101,
),
# Cancelled reservation.
HotelReservation(
datespan=(datetimes[1].date(), datetimes[1].date()),
start=datetimes[1],
end=datetimes[2],
room=room102,
cancelled=True,
),
# Other adjacent dates.
HotelReservation(
datespan=(datetimes[3].date(), datetimes[4].date()),
start=datetimes[3],
end=datetimes[4],
room=room102,
),
])
def test_range_overlaps_custom(self):
class TsTzRange(Func):
function = 'TSTZRANGE'
output_field = DateTimeRangeField()
constraint = ExclusionConstraint(
name='exclude_overlapping_reservations_custom',
expressions=[
(TsTzRange('start', 'end', RangeBoundary()), RangeOperators.OVERLAPS),
('room', RangeOperators.EQUAL)
],
condition=Q(cancelled=False),
)
self._test_range_overlaps(constraint)
def test_range_overlaps(self):
constraint = ExclusionConstraint(
name='exclude_overlapping_reservations',
expressions=[
(F('datespan'), RangeOperators.OVERLAPS),
('room', RangeOperators.EQUAL)
],
condition=Q(cancelled=False),
)
self._test_range_overlaps(constraint)
def test_range_adjacent(self):
constraint_name = 'ints_adjacent'
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[('ints', RangeOperators.ADJACENT_TO)],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
RangesModel.objects.create(ints=(20, 50))
with self.assertRaises(IntegrityError), transaction.atomic():
RangesModel.objects.create(ints=(10, 20))
RangesModel.objects.create(ints=(10, 19))
RangesModel.objects.create(ints=(51, 60))