diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py index 3b12098fe9..a9ea4761b3 100644 --- a/django/contrib/postgres/constraints.py +++ b/django/contrib/postgres/constraints.py @@ -1,13 +1,19 @@ +from django.contrib.postgres.indexes import OpClass from django.db import NotSupportedError -from django.db.backends.ddl_references import Statement, Table +from django.db.backends.ddl_references import Expressions, Statement, Table from django.db.models import Deferrable, F, Q from django.db.models.constraints import BaseConstraint -from django.db.models.expressions import Col +from django.db.models.expressions import ExpressionList +from django.db.models.indexes import IndexExpression from django.db.models.sql import Query __all__ = ['ExclusionConstraint'] +class ExclusionConstraintExpression(IndexExpression): + template = '%(expressions)s WITH %(operator)s' + + class ExclusionConstraint(BaseConstraint): template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s' @@ -63,24 +69,19 @@ class ExclusionConstraint(BaseConstraint): self.opclasses = opclasses super().__init__(name=name) - def _get_expression_sql(self, compiler, schema_editor, query): + def _get_expressions(self, schema_editor, query): expressions = [] for idx, (expression, operator) in enumerate(self.expressions): if isinstance(expression, str): expression = F(expression) - expression = expression.resolve_expression(query=query) - sql, params = compiler.compile(expression) - if not isinstance(expression, Col): - sql = f'({sql})' try: - opclass = self.opclasses[idx] - if opclass: - sql = '%s %s' % (sql, opclass) + expression = OpClass(expression, self.opclasses[idx]) except IndexError: pass - sql = sql % tuple(schema_editor.quote_value(p) for p in params) - expressions.append('%s WITH %s' % (sql, operator)) - return expressions + expression = ExclusionConstraintExpression(expression, operator=operator) + expression.set_wrapper_classes(schema_editor.connection) + expressions.append(expression) + return ExpressionList(*expressions).resolve_expression(query) def _get_condition_sql(self, compiler, schema_editor, query): if self.condition is None: @@ -92,17 +93,20 @@ class ExclusionConstraint(BaseConstraint): def constraint_sql(self, model, schema_editor): query = Query(model, alias_cols=False) compiler = query.get_compiler(connection=schema_editor.connection) - expressions = self._get_expression_sql(compiler, schema_editor, query) + expressions = self._get_expressions(schema_editor, query) + table = model._meta.db_table condition = self._get_condition_sql(compiler, schema_editor, query) include = [model._meta.get_field(field_name).column for field_name in self.include] - return self.template % { - 'name': schema_editor.quote_name(self.name), - 'index_type': self.index_type, - 'expressions': ', '.join(expressions), - 'include': schema_editor._index_include_sql(model, include), - 'where': ' WHERE (%s)' % condition if condition else '', - 'deferrable': schema_editor._deferrable_constraint_sql(self.deferrable), - } + return Statement( + self.template, + table=Table(table, schema_editor.quote_name), + name=schema_editor.quote_name(self.name), + index_type=self.index_type, + expressions=Expressions(table, expressions, compiler, schema_editor.quote_value), + where=' WHERE (%s)' % condition if condition else '', + include=schema_editor._index_include_sql(model, include), + deferrable=schema_editor._deferrable_constraint_sql(self.deferrable), + ) def create_sql(self, model, schema_editor): self.check_supported(schema_editor) diff --git a/docs/ref/contrib/postgres/constraints.txt b/docs/ref/contrib/postgres/constraints.txt index be06b907ff..fe7d53bb79 100644 --- a/docs/ref/contrib/postgres/constraints.txt +++ b/docs/ref/contrib/postgres/constraints.txt @@ -53,6 +53,10 @@ operators with strings. For example:: Only commutative operators can be used in exclusion constraints. +.. versionchanged:: 4.1 + + Support for the ``OpClass()`` expression was added. + ``index_type`` -------------- @@ -143,6 +147,20 @@ For example:: creates an exclusion constraint on ``circle`` using ``circle_ops``. +Alternatively, you can use +:class:`OpClass() ` in +:attr:`~ExclusionConstraint.expressions`:: + + ExclusionConstraint( + name='exclude_overlapping_opclasses', + expressions=[(OpClass('circle', 'circle_ops'), RangeOperators.OVERLAPS)], + ) + +.. versionchanged:: 4.1 + + Support for specifying operator classes with the ``OpClass()`` expression + was added. + Examples -------- diff --git a/docs/ref/contrib/postgres/indexes.txt b/docs/ref/contrib/postgres/indexes.txt index 35de2bf31a..ba4706da3c 100644 --- a/docs/ref/contrib/postgres/indexes.txt +++ b/docs/ref/contrib/postgres/indexes.txt @@ -150,10 +150,10 @@ available from the ``django.contrib.postgres.indexes`` module. .. class:: OpClass(expression, name) An ``OpClass()`` expression represents the ``expression`` with a custom - `operator class`_ that can be used to define functional indexes or unique - constraints. To use it, you need to add ``'django.contrib.postgres'`` in - your :setting:`INSTALLED_APPS`. Set the ``name`` parameter to the name of - the `operator class`_. + `operator class`_ that can be used to define functional indexes, functional + unique constraints, or exclusion constraints. To use it, you need to add + ``'django.contrib.postgres'`` in your :setting:`INSTALLED_APPS`. Set the + ``name`` parameter to the name of the `operator class`_. For example:: @@ -163,8 +163,7 @@ available from the ``django.contrib.postgres.indexes`` module. ) creates an index on ``Lower('username')`` using ``varchar_pattern_ops``. - - Another example:: + :: UniqueConstraint( OpClass(Upper('description'), name='text_pattern_ops'), @@ -173,9 +172,23 @@ available from the ``django.contrib.postgres.indexes`` module. creates a unique constraint on ``Upper('description')`` using ``text_pattern_ops``. + :: + + ExclusionConstraint( + name='exclude_overlapping_ops', + expressions=[ + (OpClass('circle', name='circle_ops'), RangeOperators.OVERLAPS), + ], + ) + + creates an exclusion constraint on ``circle`` using ``circle_ops``. .. versionchanged:: 4.0 Support for functional unique constraints was added. + .. versionchanged:: 4.1 + + Support for exclusion constraints was added. + .. _operator class: https://www.postgresql.org/docs/current/indexes-opclass.html diff --git a/docs/releases/4.1.txt b/docs/releases/4.1.txt index f2a168db40..d9a05f3bda 100644 --- a/docs/releases/4.1.txt +++ b/docs/releases/4.1.txt @@ -108,6 +108,10 @@ Minor features ` allows specifying bounds for list and tuple inputs. +* :class:`~django.contrib.postgres.constraints.ExclusionConstraint` now allows + specifying operator classes with the + :class:`OpClass() ` expression. + :mod:`django.contrib.redirects` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py index bf0b57488a..f67210097f 100644 --- a/tests/postgres_tests/test_constraints.py +++ b/tests/postgres_tests/test_constraints.py @@ -198,6 +198,7 @@ class SchemaTests(PostgreSQLTestCase): Scene.objects.create(scene='ScEnE 10', setting="Sir Bedemir's Castle") +@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) class ExclusionConstraintTests(PostgreSQLTestCase): def get_constraints(self, table): """Get the constraints on the table using a new cursor.""" @@ -604,6 +605,24 @@ class ExclusionConstraintTests(PostgreSQLTestCase): ) self._test_range_overlaps(constraint) + def test_range_overlaps_custom_opclass_expression(self): + class TsTzRange(Func): + function = 'TSTZRANGE' + output_field = DateTimeRangeField() + + constraint = ExclusionConstraint( + name='exclude_overlapping_reservations_custom_opclass', + expressions=[ + ( + OpClass(TsTzRange('start', 'end', RangeBoundary()), 'range_ops'), + RangeOperators.OVERLAPS, + ), + (OpClass('room', 'gist_int4_ops'), RangeOperators.EQUAL), + ], + condition=Q(cancelled=False), + ) + self._test_range_overlaps(constraint) + def test_range_overlaps(self): constraint = ExclusionConstraint( name='exclude_overlapping_reservations', @@ -914,6 +933,34 @@ class ExclusionConstraintTests(PostgreSQLTestCase): editor.add_constraint(RangesModel, constraint) self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + def test_opclass_expression(self): + constraint_name = 'ints_adjacent_opclass_expression' + self.assertNotIn( + constraint_name, + self.get_constraints(RangesModel._meta.db_table), + ) + constraint = ExclusionConstraint( + name=constraint_name, + expressions=[(OpClass('ints', 'range_ops'), RangeOperators.ADJACENT_TO)], + ) + with connection.schema_editor() as editor: + editor.add_constraint(RangesModel, constraint) + constraints = self.get_constraints(RangesModel._meta.db_table) + self.assertIn(constraint_name, constraints) + with editor.connection.cursor() as cursor: + cursor.execute(SchemaTests.get_opclass_query, [constraint_name]) + self.assertEqual( + cursor.fetchall(), + [('range_ops', constraint_name)], + ) + # Drop the constraint. + with connection.schema_editor() as editor: + editor.remove_constraint(RangesModel, constraint) + self.assertNotIn( + constraint_name, + self.get_constraints(RangesModel._meta.db_table), + ) + def test_range_equal_cast(self): constraint_name = 'exclusion_equal_room_cast' self.assertNotIn(constraint_name, self.get_constraints(Room._meta.db_table))