diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py index cdb238c405..2ae72c402e 100644 --- a/django/contrib/postgres/constraints.py +++ b/django/contrib/postgres/constraints.py @@ -12,7 +12,7 @@ class ExclusionConstraint(BaseConstraint): def __init__( self, *, name, expressions, index_type=None, condition=None, - deferrable=None, include=None, + deferrable=None, include=None, opclasses=(), ): if index_type and index_type.lower() not in {'gist', 'spgist'}: raise ValueError( @@ -48,20 +48,37 @@ class ExclusionConstraint(BaseConstraint): raise ValueError( 'Covering exclusion constraints only support GiST indexes.' ) + if not isinstance(opclasses, (list, tuple)): + raise ValueError( + 'ExclusionConstraint.opclasses must be a list or tuple.' + ) + if opclasses and len(expressions) != len(opclasses): + raise ValueError( + 'ExclusionConstraint.expressions and ' + 'ExclusionConstraint.opclasses must have the same number of ' + 'elements.' + ) self.expressions = expressions self.index_type = index_type or 'GIST' self.condition = condition self.deferrable = deferrable self.include = tuple(include) if include else () + self.opclasses = opclasses super().__init__(name=name) def _get_expression_sql(self, compiler, connection, query): expressions = [] - for expression, operator in self.expressions: + for idx, (expression, operator) in enumerate(self.expressions): if isinstance(expression, str): expression = F(expression) expression = expression.resolve_expression(query=query) sql, params = expression.as_sql(compiler, connection) + try: + opclass = self.opclasses[idx] + if opclass: + sql = '%s %s' % (sql, opclass) + except IndexError: + pass expressions.append('%s WITH %s' % (sql % params, operator)) return expressions @@ -119,6 +136,8 @@ class ExclusionConstraint(BaseConstraint): kwargs['deferrable'] = self.deferrable if self.include: kwargs['include'] = self.include + if self.opclasses: + kwargs['opclasses'] = self.opclasses return path, args, kwargs def __eq__(self, other): @@ -129,16 +148,18 @@ class ExclusionConstraint(BaseConstraint): self.expressions == other.expressions and self.condition == other.condition and self.deferrable == other.deferrable and - self.include == other.include + self.include == other.include and + self.opclasses == other.opclasses ) return super().__eq__(other) def __repr__(self): - return '<%s: index_type=%s, expressions=%s%s%s%s>' % ( + return '<%s: index_type=%s, expressions=%s%s%s%s%s>' % ( self.__class__.__qualname__, self.index_type, self.expressions, '' if self.condition is None else ', condition=%s' % self.condition, '' if self.deferrable is None else ', deferrable=%s' % self.deferrable, '' if not self.include else ', include=%s' % repr(self.include), + '' if not self.opclasses else ', opclasses=%s' % repr(self.opclasses), ) diff --git a/docs/ref/contrib/postgres/constraints.txt b/docs/ref/contrib/postgres/constraints.txt index fdc547265e..25e6ae5ae0 100644 --- a/docs/ref/contrib/postgres/constraints.txt +++ b/docs/ref/contrib/postgres/constraints.txt @@ -12,7 +12,7 @@ PostgreSQL supports additional data integrity constraints available from the ``ExclusionConstraint`` ======================= -.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None) +.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None, opclasses=()) Creates an exclusion constraint in the database. Internally, PostgreSQL implements exclusion constraints using indexes. The default index type is @@ -121,6 +121,28 @@ used for queries that select only included fields ``include`` is supported only for GiST indexes on PostgreSQL 12+. +``opclasses`` +------------- + +.. attribute:: ExclusionConstraint.opclasses + +.. versionadded:: 3.2 + +The names of the `PostgreSQL operator classes +`_ to use for +this constraint. If you require a custom operator class, you must provide one +for each expression in the constraint. + +For example:: + + ExclusionConstraint( + name='exclude_overlapping_opclasses', + expressions=[('circle', RangeOperators.OVERLAPS)], + opclasses=['circle_ops'], + ) + +creates an exclusion constraint on ``circle`` using ``circle_ops``. + Examples -------- diff --git a/docs/releases/3.2.txt b/docs/releases/3.2.txt index 416530400e..6b2fb41144 100644 --- a/docs/releases/3.2.txt +++ b/docs/releases/3.2.txt @@ -73,6 +73,9 @@ Minor features * The new :attr:`.ExclusionConstraint.include` attribute allows creating covering exclusion constraints on PostgreSQL 12+. +* The new :attr:`.ExclusionConstraint.opclasses` attribute allows setting + PostgreSQL operator classes. + * The new :attr:`.JSONBAgg.ordering` attribute determines the ordering of the aggregated elements. diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py index cfe0981d3c..bdefa0c76b 100644 --- a/tests/postgres_tests/test_constraints.py +++ b/tests/postgres_tests/test_constraints.py @@ -246,6 +246,28 @@ class ExclusionConstraintTests(PostgreSQLTestCase): index_type='spgist', ) + def test_invalid_opclasses_type(self): + msg = 'ExclusionConstraint.opclasses must be a list or tuple.' + with self.assertRaisesMessage(ValueError, msg): + ExclusionConstraint( + name='exclude_invalid_opclasses', + expressions=[(F('datespan'), RangeOperators.OVERLAPS)], + opclasses='invalid', + ) + + def test_opclasses_and_expressions_same_length(self): + msg = ( + 'ExclusionConstraint.expressions and ' + 'ExclusionConstraint.opclasses must have the same number of ' + 'elements.' + ) + with self.assertRaisesMessage(ValueError, msg): + ExclusionConstraint( + name='exclude_invalid_expressions_opclasses_length', + expressions=[(F('datespan'), RangeOperators.OVERLAPS)], + opclasses=['foo', 'bar'], + ) + def test_repr(self): constraint = ExclusionConstraint( name='exclude_overlapping', @@ -290,6 +312,16 @@ class ExclusionConstraintTests(PostgreSQLTestCase): "", ) + constraint = ExclusionConstraint( + name='exclude_overlapping', + expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)], + opclasses=['range_ops'], + ) + self.assertEqual( + repr(constraint), + "", + ) def test_eq(self): constraint_1 = ExclusionConstraint( @@ -345,6 +377,23 @@ class ExclusionConstraintTests(PostgreSQLTestCase): ], include=['cancelled'], ) + constraint_8 = ExclusionConstraint( + name='exclude_overlapping', + expressions=[ + ('datespan', RangeOperators.OVERLAPS), + ('room', RangeOperators.EQUAL), + ], + include=['cancelled'], + opclasses=['range_ops', 'range_ops'] + ) + constraint_9 = ExclusionConstraint( + name='exclude_overlapping', + expressions=[ + ('datespan', RangeOperators.OVERLAPS), + ('room', RangeOperators.EQUAL), + ], + opclasses=['range_ops', 'range_ops'] + ) self.assertEqual(constraint_1, constraint_1) self.assertEqual(constraint_1, mock.ANY) self.assertNotEqual(constraint_1, constraint_2) @@ -353,8 +402,10 @@ class ExclusionConstraintTests(PostgreSQLTestCase): self.assertNotEqual(constraint_2, constraint_3) self.assertNotEqual(constraint_2, constraint_4) self.assertNotEqual(constraint_2, constraint_7) + self.assertNotEqual(constraint_2, constraint_9) self.assertNotEqual(constraint_4, constraint_5) self.assertNotEqual(constraint_5, constraint_6) + self.assertNotEqual(constraint_7, constraint_8) self.assertNotEqual(constraint_1, object()) def test_deconstruct(self): @@ -430,6 +481,21 @@ class ExclusionConstraintTests(PostgreSQLTestCase): 'include': ('cancelled', 'room'), }) + def test_deconstruct_opclasses(self): + constraint = ExclusionConstraint( + name='exclude_overlapping', + expressions=[('datespan', RangeOperators.OVERLAPS)], + opclasses=['range_ops'], + ) + path, args, kwargs = constraint.deconstruct() + self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint') + self.assertEqual(args, ()) + self.assertEqual(kwargs, { + 'name': 'exclude_overlapping', + 'expressions': [('datespan', RangeOperators.OVERLAPS)], + 'opclasses': ['range_ops'], + }) + def _test_range_overlaps(self, constraint): # Create exclusion constraint. self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table)) @@ -505,6 +571,7 @@ class ExclusionConstraintTests(PostgreSQLTestCase): ('room', RangeOperators.EQUAL) ], condition=Q(cancelled=False), + opclasses=['range_ops', 'gist_int4_ops'], ) self._test_range_overlaps(constraint) @@ -624,3 +691,64 @@ class ExclusionConstraintTests(PostgreSQLTestCase): ): with self.assertRaisesMessage(NotSupportedError, msg): editor.add_constraint(RangesModel, constraint) + + def test_range_adjacent_opclasses(self): + constraint_name = 'ints_adjacent_opclasses' + self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint = ExclusionConstraint( + name=constraint_name, + expressions=[('ints', RangeOperators.ADJACENT_TO)], + opclasses=['range_ops'], + ) + with connection.schema_editor() as editor: + editor.add_constraint(RangesModel, constraint) + self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + RangesModel.objects.create(ints=(20, 50)) + with self.assertRaises(IntegrityError), transaction.atomic(): + RangesModel.objects.create(ints=(10, 20)) + RangesModel.objects.create(ints=(10, 19)) + RangesModel.objects.create(ints=(51, 60)) + # Drop the constraint. + with connection.schema_editor() as editor: + editor.remove_constraint(RangesModel, constraint) + self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + + def test_range_adjacent_opclasses_condition(self): + constraint_name = 'ints_adjacent_opclasses_condition' + self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint = ExclusionConstraint( + name=constraint_name, + expressions=[('ints', RangeOperators.ADJACENT_TO)], + opclasses=['range_ops'], + condition=Q(id__gte=100), + ) + with connection.schema_editor() as editor: + editor.add_constraint(RangesModel, constraint) + self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + + def test_range_adjacent_opclasses_deferrable(self): + constraint_name = 'ints_adjacent_opclasses_deferrable' + self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint = ExclusionConstraint( + name=constraint_name, + expressions=[('ints', RangeOperators.ADJACENT_TO)], + opclasses=['range_ops'], + deferrable=Deferrable.DEFERRED, + ) + with connection.schema_editor() as editor: + editor.add_constraint(RangesModel, constraint) + self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + + @skipUnlessDBFeature('supports_covering_gist_indexes') + def test_range_adjacent_opclasses_include(self): + constraint_name = 'ints_adjacent_opclasses_include' + self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) + constraint = ExclusionConstraint( + name=constraint_name, + expressions=[('ints', RangeOperators.ADJACENT_TO)], + opclasses=['range_ops'], + include=['decimals'], + ) + with connection.schema_editor() as editor: + editor.add_constraint(RangesModel, constraint) + self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))