diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py new file mode 100644 index 0000000000..2fcb076ecf --- /dev/null +++ b/django/contrib/postgres/constraints.py @@ -0,0 +1,106 @@ +from django.db.backends.ddl_references import Statement, Table +from django.db.models import F, Q +from django.db.models.constraints import BaseConstraint +from django.db.models.sql import Query + +__all__ = ['ExclusionConstraint'] + + +class ExclusionConstraint(BaseConstraint): + template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(where)s' + + def __init__(self, *, name, expressions, index_type=None, condition=None): + if index_type and index_type.lower() not in {'gist', 'spgist'}: + raise ValueError( + 'Exclusion constraints only support GiST or SP-GiST indexes.' + ) + if not expressions: + raise ValueError( + 'At least one expression is required to define an exclusion ' + 'constraint.' + ) + if not all( + isinstance(expr, (list, tuple)) and len(expr) == 2 + for expr in expressions + ): + raise ValueError('The expressions must be a list of 2-tuples.') + if not isinstance(condition, (type(None), Q)): + raise ValueError( + 'ExclusionConstraint.condition must be a Q instance.' + ) + self.expressions = expressions + self.index_type = index_type or 'GIST' + self.condition = condition + super().__init__(name=name) + + def _get_expression_sql(self, compiler, connection, query): + expressions = [] + for expression, operator in self.expressions: + if isinstance(expression, str): + expression = F(expression) + if isinstance(expression, F): + expression = expression.resolve_expression(query=query, simple_col=True) + else: + expression = expression.resolve_expression(query=query) + sql, params = expression.as_sql(compiler, connection) + expressions.append('%s WITH %s' % (sql % params, operator)) + return expressions + + def _get_condition_sql(self, compiler, schema_editor, query): + if self.condition is None: + return None + where = query.build_where(self.condition) + sql, params = where.as_sql(compiler, schema_editor.connection) + return sql % tuple(schema_editor.quote_value(p) for p in params) + + def constraint_sql(self, model, schema_editor): + query = Query(model) + compiler = query.get_compiler(connection=schema_editor.connection) + expressions = self._get_expression_sql(compiler, schema_editor.connection, query) + condition = self._get_condition_sql(compiler, schema_editor, query) + return self.template % { + 'name': schema_editor.quote_name(self.name), + 'index_type': self.index_type, + 'expressions': ', '.join(expressions), + 'where': ' WHERE (%s)' % condition if condition else '', + } + + def create_sql(self, model, schema_editor): + return Statement( + 'ALTER TABLE %(table)s ADD %(constraint)s', + table=Table(model._meta.db_table, schema_editor.quote_name), + constraint=self.constraint_sql(model, schema_editor), + ) + + def remove_sql(self, model, schema_editor): + return schema_editor._delete_constraint_sql( + schema_editor.sql_delete_check, + model, + schema_editor.quote_name(self.name), + ) + + def deconstruct(self): + path, args, kwargs = super().deconstruct() + kwargs['expressions'] = self.expressions + if self.condition is not None: + kwargs['condition'] = self.condition + if self.index_type.lower() != 'gist': + kwargs['index_type'] = self.index_type + return path, args, kwargs + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) and + self.name == other.name and + self.index_type == other.index_type and + self.expressions == other.expressions and + self.condition == other.condition + ) + + def __repr__(self): + return '<%s: index_type=%s, expressions=%s%s>' % ( + self.__class__.__qualname__, + self.index_type, + self.expressions, + '' if self.condition is None else ', condition=%s' % self.condition, + ) diff --git a/django/contrib/postgres/fields/ranges.py b/django/contrib/postgres/fields/ranges.py index 21e982cedd..a4fa20adf3 100644 --- a/django/contrib/postgres/fields/ranges.py +++ b/django/contrib/postgres/fields/ranges.py @@ -12,10 +12,20 @@ __all__ = [ 'RangeField', 'IntegerRangeField', 'BigIntegerRangeField', 'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField', 'FloatRangeField', - 'RangeOperators', + 'RangeBoundary', 'RangeOperators', ] +class RangeBoundary(models.Expression): + """A class that represents range boundaries.""" + def __init__(self, inclusive_lower=True, inclusive_upper=False): + self.lower = '[' if inclusive_lower else '(' + self.upper = ']' if inclusive_upper else ')' + + def as_sql(self, compiler, connection): + return "'%s%s'" % (self.lower, self.upper), [] + + class RangeOperators: # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE EQUAL = '=' diff --git a/docs/ref/contrib/postgres/constraints.txt b/docs/ref/contrib/postgres/constraints.txt new file mode 100644 index 0000000000..fe9e72e605 --- /dev/null +++ b/docs/ref/contrib/postgres/constraints.txt @@ -0,0 +1,151 @@ +======================================== +PostgreSQL specific database constraints +======================================== + +.. module:: django.contrib.postgres.constraints + :synopsis: PostgreSQL specific database constraint + +PostgreSQL supports additional data integrity constraints available from the +``django.contrib.postgres.constraints`` module. They are added in the model +:attr:`Meta.constraints ` option. + +``ExclusionConstraint`` +======================= + +.. versionadded:: 3.0 + +.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None) + + Creates an exclusion constraint in the database. Internally, PostgreSQL + implements exclusion constraints using indexes. The default index type is + `GiST `_. To use them, + you need to activate the `btree_gist extension + `_ on PostgreSQL. + You can install it using the + :class:`~django.contrib.postgres.operations.BtreeGistExtension` migration + operation. + + If you attempt to insert a new row that conflicts with an existing row, an + :exc:`~django.db.IntegrityError` is raised. Similarly, when update + conflicts with an existing row. + +``name`` +-------- + +.. attribute:: ExclusionConstraint.name + +The name of the constraint. + +``expressions`` +--------------- + +.. attribute:: ExclusionConstraint.expressions + +An iterable of 2-tuples. The first element is an expression or string. The +second element is a SQL operator represented as a string. To avoid typos, you +may use :class:`~django.contrib.postgres.fields.RangeOperators` which maps the +operators with strings. For example:: + + expressions=[ + ('timespan', RangeOperators.ADJACENT_TO), + (F('room'), RangeOperators.EQUAL), + ] + +.. admonition:: Restrictions on operators. + + Only commutative operators can be used in exclusion constraints. + +``index_type`` +-------------- + +.. attribute:: ExclusionConstraint.index_type + +The index type of the constraint. Accepted values are ``GIST`` or ``SPGIST``. +Matching is case insensitive. If not provided, the default index type is +``GIST``. + +``condition`` +------------- + +.. attribute:: ExclusionConstraint.condition + +A :class:`~django.db.models.Q` object that specifies the condition to restrict +a constraint to a subset of rows. For example, +``condition=Q(cancelled=False)``. + +These conditions have the same database restrictions as +:attr:`django.db.models.Index.condition`. + +Examples +-------- + +The following example restricts overlapping reservations in the same room, not +taking canceled reservations into account:: + + from django.contrib.postgres.constraints import ExclusionConstraint + from django.contrib.postgres.fields import DateTimeRangeField, RangeOperators + from django.db import models + from django.db.models import Q + + class Room(models.Model): + number = models.IntegerField() + + + class Reservation(models.Model): + room = models.ForeignKey('Room', on_delete=models.CASCADE) + timespan = DateTimeRangeField() + cancelled = models.BooleanField(default=False) + + class Meta: + constraints = [ + ExclusionConstraint( + name='exclude_overlapping_reservations', + expressions=[ + ('timespan', RangeOperators.OVERLAPS), + ('room', RangeOperators.EQUAL), + ], + condition=Q(cancelled=False), + ), + ] + +In case your model defines a range using two fields, instead of the native +PostgreSQL range types, you should write an expression that uses the equivalent +function (e.g. ``TsTzRange()``), and use the delimiters for the field. Most +often, the delimiters will be ``'[)'``, meaning that the lower bound is +inclusive and the upper bound is exclusive. You may use the +:class:`~django.contrib.postgres.fields.RangeBoundary` that provides an +expression mapping for the `range boundaries `_. For example:: + + from django.contrib.postgres.constraints import ExclusionConstraint + from django.contrib.postgres.fields import ( + DateTimeRangeField, + RangeBoundary, + RangeOperators, + ) + from django.db import models + from django.db.models import Func, Q + + + class TsTzRange(Func): + function = 'TSTZRANGE' + output_field = DateTimeRangeField() + + + class Reservation(models.Model): + room = models.ForeignKey('Room', on_delete=models.CASCADE) + start = models.DateTimeField() + end = models.DateTimeField() + cancelled = models.BooleanField(default=False) + + class Meta: + constraints = [ + ExclusionConstraint( + name='exclude_overlapping_reservations', + expressions=( + (TsTzRange('start', 'end', RangeBoundary()), RangeOperators.OVERLAPS), + ('room', RangeOperators.EQUAL), + ), + condition=Q(cancelled=False), + ), + ] diff --git a/docs/ref/contrib/postgres/fields.txt b/docs/ref/contrib/postgres/fields.txt index 387eb0f02e..14cb1d00cb 100644 --- a/docs/ref/contrib/postgres/fields.txt +++ b/docs/ref/contrib/postgres/fields.txt @@ -944,3 +944,26 @@ corresponding lookups. NOT_LT = '&>' NOT_GT = '&<' ADJACENT_TO = '-|-' + +RangeBoundary() expressions +--------------------------- + +.. versionadded:: 3.0 + +.. class:: RangeBoundary(inclusive_lower=True, inclusive_upper=False) + + .. attribute:: inclusive_lower + + If ``True`` (default), the lower bound is inclusive ``'['``, otherwise + it's exclusive ``'('``. + + .. attribute:: inclusive_upper + + If ``False`` (default), the upper bound is exclusive ``')'``, otherwise + it's inclusive ``']'``. + +A ``RangeBoundary()`` expression represents the range boundaries. It can be +used with a custom range functions that expected boundaries, for example to +define :class:`~django.contrib.postgres.constraints.ExclusionConstraint`. See +`the PostgreSQL documentation for the full details `_. diff --git a/docs/ref/contrib/postgres/index.txt b/docs/ref/contrib/postgres/index.txt index 9485f56409..03ff6da1e0 100644 --- a/docs/ref/contrib/postgres/index.txt +++ b/docs/ref/contrib/postgres/index.txt @@ -29,6 +29,7 @@ a number of PostgreSQL specific data types. :maxdepth: 2 aggregates + constraints fields forms functions diff --git a/docs/releases/3.0.txt b/docs/releases/3.0.txt index 6aa21975bf..34152573f0 100644 --- a/docs/releases/3.0.txt +++ b/docs/releases/3.0.txt @@ -66,6 +66,14 @@ async code before, this may trigger if you were doing it incorrectly. If you see a ``SynchronousOnlyOperation`` error, then closely examine your code and move any database operations to be in a synchronous child thread. +Exclusion constraints on PostgreSQL +----------------------------------- + +The new :class:`~django.contrib.postgres.constraints.ExclusionConstraint` class +enable adding exclusion constraints on PostgreSQL. Constraints are added to +models using the +:attr:`Meta.constraints ` option. + Minor features -------------- @@ -137,6 +145,9 @@ Minor features avoid typos in SQL operators that can be used together with :class:`~django.contrib.postgres.fields.RangeField`. +* The new :class:`~django.contrib.postgres.fields.RangeBoundary` expression + represents the range boundaries. + :mod:`django.contrib.redirects` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index b9f9cee6bf..1b9c45881f 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -262,4 +262,25 @@ class Migration(migrations.Migration): }, bases=(models.Model,), ), + migrations.CreateModel( + name='Room', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('number', models.IntegerField(unique=True)), + ], + ), + migrations.CreateModel( + name='HotelReservation', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('room', models.ForeignKey('postgres_tests.Room', models.CASCADE)), + ('datespan', DateRangeField()), + ('start', models.DateTimeField()), + ('end', models.DateTimeField()), + ('cancelled', models.BooleanField(default=False)), + ], + options={ + 'required_db_vendor': 'postgresql', + }, + ), ] diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index 3d170a9a1a..5b2f41160a 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -183,3 +183,15 @@ class NowTestModel(models.Model): class UUIDTestModel(models.Model): uuid = models.UUIDField(default=None, null=True) + + +class Room(models.Model): + number = models.IntegerField(unique=True) + + +class HotelReservation(PostgreSQLModel): + room = models.ForeignKey('Room', on_delete=models.CASCADE) + datespan = DateRangeField() + start = models.DateTimeField() + end = models.DateTimeField() + cancelled = models.BooleanField(default=False) diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py index 3ecabcbd49..d8665f59f6 100644 --- a/tests/postgres_tests/test_constraints.py +++ b/tests/postgres_tests/test_constraints.py @@ -1,15 +1,19 @@ import datetime from django.db import connection, transaction -from django.db.models import F, Q +from django.db.models import F, Func, Q from django.db.models.constraints import CheckConstraint from django.db.utils import IntegrityError +from django.utils import timezone from . import PostgreSQLTestCase -from .models import RangesModel +from .models import HotelReservation, RangesModel, Room try: - from psycopg2.extras import NumericRange + from django.contrib.postgres.constraints import ExclusionConstraint + from django.contrib.postgres.fields import DateTimeRangeField, RangeBoundary, RangeOperators + + from psycopg2.extras import DateRange, NumericRange except ImportError: pass @@ -77,3 +81,249 @@ class SchemaTests(PostgreSQLTestCase): timestamps=(datetime_1, datetime_2), timestamps_inner=(datetime_1, datetime_2), ) + + +class ExclusionConstraintTests(PostgreSQLTestCase): + def get_constraints(self, table): + """Get the constraints on the table using a new cursor.""" + with connection.cursor() as cursor: + return connection.introspection.get_constraints(cursor, table) + + def test_invalid_condition(self): + msg = 'ExclusionConstraint.condition must be a Q instance.' + with self.assertRaisesMessage(ValueError, msg): + ExclusionConstraint( + index_type='GIST', + name='exclude_invalid_condition', + expressions=[(F('datespan'), RangeOperators.OVERLAPS)], + condition=F('invalid'), + ) + + def test_invalid_index_type(self): + msg = 'Exclusion constraints only support GiST or SP-GiST indexes.' + with self.assertRaisesMessage(ValueError, msg): + ExclusionConstraint( + index_type='gin', + name='exclude_invalid_index_type', + expressions=[(F('datespan'), RangeOperators.OVERLAPS)], + ) + + def test_invalid_expressions(self): + msg = 'The expressions must be a list of 2-tuples.' + for expressions in (['foo'], [('foo')], [('foo_1', 'foo_2', 'foo_3')]): + with self.subTest(expressions), self.assertRaisesMessage(ValueError, msg): + ExclusionConstraint( + index_type='GIST', + name='exclude_invalid_expressions', + expressions=expressions, + ) + + def test_empty_expressions(self): + msg = 'At least one expression is required to define an exclusion constraint.' + for empty_expressions in (None, []): + with self.subTest(empty_expressions), self.assertRaisesMessage(ValueError, msg): + ExclusionConstraint( + index_type='GIST', + name='exclude_empty_expressions', + expressions=empty_expressions, + ) + + def test_repr(self): + constraint = ExclusionConstraint( + name='exclude_overlapping', + expressions=[ + (F('datespan'), RangeOperators.OVERLAPS), + (F('room'), RangeOperators.EQUAL), + ], + ) + self.assertEqual( + repr(constraint), + "", + ) + + def test_eq(self): + constraint_1 = ExclusionConstraint( + name='exclude_overlapping', + expressions=[ + (F('datespan'), RangeOperators.OVERLAPS), + (F('room'), RangeOperators.EQUAL), + ], + condition=Q(cancelled=False), + ) + constraint_2 = ExclusionConstraint( + name='exclude_overlapping', + expressions=[ + ('datespan', RangeOperators.OVERLAPS), + ('room', RangeOperators.EQUAL), + ], + ) + constraint_3 = ExclusionConstraint( + name='exclude_overlapping', + expressions=[('datespan', RangeOperators.OVERLAPS)], + condition=Q(cancelled=False), + ) + self.assertEqual(constraint_1, constraint_1) + self.assertNotEqual(constraint_1, constraint_2) + self.assertNotEqual(constraint_1, constraint_3) + self.assertNotEqual(constraint_2, constraint_3) + self.assertNotEqual(constraint_1, object()) + + def test_deconstruct(self): + constraint = ExclusionConstraint( + name='exclude_overlapping', + expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], + ) + path, args, kwargs = constraint.deconstruct() + self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint') + self.assertEqual(args, ()) + self.assertEqual(kwargs, { + 'name': 'exclude_overlapping', + 'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], + }) + + def test_deconstruct_index_type(self): + constraint = ExclusionConstraint( + name='exclude_overlapping', + index_type='SPGIST', + expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], + ) + path, args, kwargs = constraint.deconstruct() + self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint') + self.assertEqual(args, ()) + self.assertEqual(kwargs, { + 'name': 'exclude_overlapping', + 'index_type': 'SPGIST', + 'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], + }) + + def test_deconstruct_condition(self): + constraint = ExclusionConstraint( + name='exclude_overlapping', + expressions=[('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], + condition=Q(cancelled=False), + ) + path, args, kwargs = constraint.deconstruct() + self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint') + self.assertEqual(args, ()) + self.assertEqual(kwargs, { + 'name': 'exclude_overlapping', + 'expressions': [('datespan', RangeOperators.OVERLAPS), ('room', RangeOperators.EQUAL)], + 'condition': Q(cancelled=False), + }) + + def _test_range_overlaps(self, constraint): + # Create exclusion constraint. + self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table)) + with connection.schema_editor() as editor: + editor.add_constraint(HotelReservation, constraint) + self.assertIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table)) + # Add initial reservations. + room101 = Room.objects.create(number=101) + room102 = Room.objects.create(number=102) + datetimes = [ + timezone.datetime(2018, 6, 20), + timezone.datetime(2018, 6, 24), + timezone.datetime(2018, 6, 26), + timezone.datetime(2018, 6, 28), + timezone.datetime(2018, 6, 29), + ] + HotelReservation.objects.create( + datespan=DateRange(datetimes[0].date(), datetimes[1].date()), + start=datetimes[0], + end=datetimes[1], + room=room102, + ) + HotelReservation.objects.create( + datespan=DateRange(datetimes[1].date(), datetimes[3].date()), + start=datetimes[1], + end=datetimes[3], + room=room102, + ) + # Overlap dates. + with self.assertRaises(IntegrityError), transaction.atomic(): + reservation = HotelReservation( + datespan=(datetimes[1].date(), datetimes[2].date()), + start=datetimes[1], + end=datetimes[2], + room=room102, + ) + reservation.save() + # Valid range. + HotelReservation.objects.bulk_create([ + # Other room. + HotelReservation( + datespan=(datetimes[1].date(), datetimes[2].date()), + start=datetimes[1], + end=datetimes[2], + room=room101, + ), + # Cancelled reservation. + HotelReservation( + datespan=(datetimes[1].date(), datetimes[1].date()), + start=datetimes[1], + end=datetimes[2], + room=room102, + cancelled=True, + ), + # Other adjacent dates. + HotelReservation( + datespan=(datetimes[3].date(), datetimes[4].date()), + start=datetimes[3], + end=datetimes[4], + room=room102, + ), + ]) + + def test_range_overlaps_custom(self): + class TsTzRange(Func): + function = 'TSTZRANGE' + output_field = DateTimeRangeField() + + constraint = ExclusionConstraint( + name='exclude_overlapping_reservations_custom', + expressions=[ + (TsTzRange('start', 'end', RangeBoundary()), RangeOperators.OVERLAPS), + ('room', RangeOperators.EQUAL) + ], + condition=Q(cancelled=False), + ) + self._test_range_overlaps(constraint) + + def test_range_overlaps(self): + constraint = ExclusionConstraint( + name='exclude_overlapping_reservations', + expressions=[ + (F('datespan'), RangeOperators.OVERLAPS), + ('room', RangeOperators.EQUAL) + ], + condition=Q(cancelled=False), + ) + self._test_range_overlaps(constraint) + + def test_range_adjacent(self): + constraint_name = 'ints_adjacent' + self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint = ExclusionConstraint( + name=constraint_name, + expressions=[('ints', RangeOperators.ADJACENT_TO)], + ) + with connection.schema_editor() as editor: + editor.add_constraint(RangesModel, constraint) + self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + RangesModel.objects.create(ints=(20, 50)) + with self.assertRaises(IntegrityError), transaction.atomic(): + RangesModel.objects.create(ints=(10, 20)) + RangesModel.objects.create(ints=(10, 19)) + RangesModel.objects.create(ints=(51, 60))