Fixed #31709 -- Added support for opclasses in ExclusionConstraint.

This commit is contained in:
Hannes Ljungberg 2020-06-14 20:50:39 +02:00 committed by Mariusz Felisiak
parent dcb4d79ef7
commit 0d6d4e78b1
4 changed files with 179 additions and 5 deletions

View File

@ -12,7 +12,7 @@ class ExclusionConstraint(BaseConstraint):
def __init__( def __init__(
self, *, name, expressions, index_type=None, condition=None, self, *, name, expressions, index_type=None, condition=None,
deferrable=None, include=None, deferrable=None, include=None, opclasses=(),
): ):
if index_type and index_type.lower() not in {'gist', 'spgist'}: if index_type and index_type.lower() not in {'gist', 'spgist'}:
raise ValueError( raise ValueError(
@ -48,20 +48,37 @@ class ExclusionConstraint(BaseConstraint):
raise ValueError( raise ValueError(
'Covering exclusion constraints only support GiST indexes.' 'Covering exclusion constraints only support GiST indexes.'
) )
if not isinstance(opclasses, (list, tuple)):
raise ValueError(
'ExclusionConstraint.opclasses must be a list or tuple.'
)
if opclasses and len(expressions) != len(opclasses):
raise ValueError(
'ExclusionConstraint.expressions and '
'ExclusionConstraint.opclasses must have the same number of '
'elements.'
)
self.expressions = expressions self.expressions = expressions
self.index_type = index_type or 'GIST' self.index_type = index_type or 'GIST'
self.condition = condition self.condition = condition
self.deferrable = deferrable self.deferrable = deferrable
self.include = tuple(include) if include else () self.include = tuple(include) if include else ()
self.opclasses = opclasses
super().__init__(name=name) super().__init__(name=name)
def _get_expression_sql(self, compiler, connection, query): def _get_expression_sql(self, compiler, connection, query):
expressions = [] expressions = []
for expression, operator in self.expressions: for idx, (expression, operator) in enumerate(self.expressions):
if isinstance(expression, str): if isinstance(expression, str):
expression = F(expression) expression = F(expression)
expression = expression.resolve_expression(query=query) expression = expression.resolve_expression(query=query)
sql, params = expression.as_sql(compiler, connection) sql, params = expression.as_sql(compiler, connection)
try:
opclass = self.opclasses[idx]
if opclass:
sql = '%s %s' % (sql, opclass)
except IndexError:
pass
expressions.append('%s WITH %s' % (sql % params, operator)) expressions.append('%s WITH %s' % (sql % params, operator))
return expressions return expressions
@ -119,6 +136,8 @@ class ExclusionConstraint(BaseConstraint):
kwargs['deferrable'] = self.deferrable kwargs['deferrable'] = self.deferrable
if self.include: if self.include:
kwargs['include'] = self.include kwargs['include'] = self.include
if self.opclasses:
kwargs['opclasses'] = self.opclasses
return path, args, kwargs return path, args, kwargs
def __eq__(self, other): def __eq__(self, other):
@ -129,16 +148,18 @@ class ExclusionConstraint(BaseConstraint):
self.expressions == other.expressions and self.expressions == other.expressions and
self.condition == other.condition and self.condition == other.condition and
self.deferrable == other.deferrable and self.deferrable == other.deferrable and
self.include == other.include self.include == other.include and
self.opclasses == other.opclasses
) )
return super().__eq__(other) return super().__eq__(other)
def __repr__(self): def __repr__(self):
return '<%s: index_type=%s, expressions=%s%s%s%s>' % ( return '<%s: index_type=%s, expressions=%s%s%s%s%s>' % (
self.__class__.__qualname__, self.__class__.__qualname__,
self.index_type, self.index_type,
self.expressions, self.expressions,
'' if self.condition is None else ', condition=%s' % self.condition, '' if self.condition is None else ', condition=%s' % self.condition,
'' if self.deferrable is None else ', deferrable=%s' % self.deferrable, '' if self.deferrable is None else ', deferrable=%s' % self.deferrable,
'' if not self.include else ', include=%s' % repr(self.include), '' if not self.include else ', include=%s' % repr(self.include),
'' if not self.opclasses else ', opclasses=%s' % repr(self.opclasses),
) )

View File

@ -12,7 +12,7 @@ PostgreSQL supports additional data integrity constraints available from the
``ExclusionConstraint`` ``ExclusionConstraint``
======================= =======================
.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None) .. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None, opclasses=())
Creates an exclusion constraint in the database. Internally, PostgreSQL Creates an exclusion constraint in the database. Internally, PostgreSQL
implements exclusion constraints using indexes. The default index type is implements exclusion constraints using indexes. The default index type is
@ -121,6 +121,28 @@ used for queries that select only included fields
``include`` is supported only for GiST indexes on PostgreSQL 12+. ``include`` is supported only for GiST indexes on PostgreSQL 12+.
``opclasses``
-------------
.. attribute:: ExclusionConstraint.opclasses
.. versionadded:: 3.2
The names of the `PostgreSQL operator classes
<https://www.postgresql.org/docs/current/indexes-opclass.html>`_ to use for
this constraint. If you require a custom operator class, you must provide one
for each expression in the constraint.
For example::
ExclusionConstraint(
name='exclude_overlapping_opclasses',
expressions=[('circle', RangeOperators.OVERLAPS)],
opclasses=['circle_ops'],
)
creates an exclusion constraint on ``circle`` using ``circle_ops``.
Examples Examples
-------- --------

View File

@ -73,6 +73,9 @@ Minor features
* The new :attr:`.ExclusionConstraint.include` attribute allows creating * The new :attr:`.ExclusionConstraint.include` attribute allows creating
covering exclusion constraints on PostgreSQL 12+. covering exclusion constraints on PostgreSQL 12+.
* The new :attr:`.ExclusionConstraint.opclasses` attribute allows setting
PostgreSQL operator classes.
* The new :attr:`.JSONBAgg.ordering` attribute determines the ordering of the * The new :attr:`.JSONBAgg.ordering` attribute determines the ordering of the
aggregated elements. aggregated elements.

View File

@ -246,6 +246,28 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
index_type='spgist', index_type='spgist',
) )
def test_invalid_opclasses_type(self):
msg = 'ExclusionConstraint.opclasses must be a list or tuple.'
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
name='exclude_invalid_opclasses',
expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
opclasses='invalid',
)
def test_opclasses_and_expressions_same_length(self):
msg = (
'ExclusionConstraint.expressions and '
'ExclusionConstraint.opclasses must have the same number of '
'elements.'
)
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
name='exclude_invalid_expressions_opclasses_length',
expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
opclasses=['foo', 'bar'],
)
def test_repr(self): def test_repr(self):
constraint = ExclusionConstraint( constraint = ExclusionConstraint(
name='exclude_overlapping', name='exclude_overlapping',
@ -290,6 +312,16 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
"<ExclusionConstraint: index_type=GIST, expressions=[" "<ExclusionConstraint: index_type=GIST, expressions=["
"(F(datespan), '-|-')], include=('cancelled', 'room')>", "(F(datespan), '-|-')], include=('cancelled', 'room')>",
) )
constraint = ExclusionConstraint(
name='exclude_overlapping',
expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)],
opclasses=['range_ops'],
)
self.assertEqual(
repr(constraint),
"<ExclusionConstraint: index_type=GIST, expressions=["
"(F(datespan), '-|-')], opclasses=['range_ops']>",
)
def test_eq(self): def test_eq(self):
constraint_1 = ExclusionConstraint( constraint_1 = ExclusionConstraint(
@ -345,6 +377,23 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
], ],
include=['cancelled'], include=['cancelled'],
) )
constraint_8 = ExclusionConstraint(
name='exclude_overlapping',
expressions=[
('datespan', RangeOperators.OVERLAPS),
('room', RangeOperators.EQUAL),
],
include=['cancelled'],
opclasses=['range_ops', 'range_ops']
)
constraint_9 = ExclusionConstraint(
name='exclude_overlapping',
expressions=[
('datespan', RangeOperators.OVERLAPS),
('room', RangeOperators.EQUAL),
],
opclasses=['range_ops', 'range_ops']
)
self.assertEqual(constraint_1, constraint_1) self.assertEqual(constraint_1, constraint_1)
self.assertEqual(constraint_1, mock.ANY) self.assertEqual(constraint_1, mock.ANY)
self.assertNotEqual(constraint_1, constraint_2) self.assertNotEqual(constraint_1, constraint_2)
@ -353,8 +402,10 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
self.assertNotEqual(constraint_2, constraint_3) self.assertNotEqual(constraint_2, constraint_3)
self.assertNotEqual(constraint_2, constraint_4) self.assertNotEqual(constraint_2, constraint_4)
self.assertNotEqual(constraint_2, constraint_7) self.assertNotEqual(constraint_2, constraint_7)
self.assertNotEqual(constraint_2, constraint_9)
self.assertNotEqual(constraint_4, constraint_5) self.assertNotEqual(constraint_4, constraint_5)
self.assertNotEqual(constraint_5, constraint_6) self.assertNotEqual(constraint_5, constraint_6)
self.assertNotEqual(constraint_7, constraint_8)
self.assertNotEqual(constraint_1, object()) self.assertNotEqual(constraint_1, object())
def test_deconstruct(self): def test_deconstruct(self):
@ -430,6 +481,21 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
'include': ('cancelled', 'room'), 'include': ('cancelled', 'room'),
}) })
def test_deconstruct_opclasses(self):
constraint = ExclusionConstraint(
name='exclude_overlapping',
expressions=[('datespan', RangeOperators.OVERLAPS)],
opclasses=['range_ops'],
)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {
'name': 'exclude_overlapping',
'expressions': [('datespan', RangeOperators.OVERLAPS)],
'opclasses': ['range_ops'],
})
def _test_range_overlaps(self, constraint): def _test_range_overlaps(self, constraint):
# Create exclusion constraint. # Create exclusion constraint.
self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table)) self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table))
@ -505,6 +571,7 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
('room', RangeOperators.EQUAL) ('room', RangeOperators.EQUAL)
], ],
condition=Q(cancelled=False), condition=Q(cancelled=False),
opclasses=['range_ops', 'gist_int4_ops'],
) )
self._test_range_overlaps(constraint) self._test_range_overlaps(constraint)
@ -624,3 +691,64 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
): ):
with self.assertRaisesMessage(NotSupportedError, msg): with self.assertRaisesMessage(NotSupportedError, msg):
editor.add_constraint(RangesModel, constraint) editor.add_constraint(RangesModel, constraint)
def test_range_adjacent_opclasses(self):
constraint_name = 'ints_adjacent_opclasses'
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[('ints', RangeOperators.ADJACENT_TO)],
opclasses=['range_ops'],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
RangesModel.objects.create(ints=(20, 50))
with self.assertRaises(IntegrityError), transaction.atomic():
RangesModel.objects.create(ints=(10, 20))
RangesModel.objects.create(ints=(10, 19))
RangesModel.objects.create(ints=(51, 60))
# Drop the constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(RangesModel, constraint)
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
def test_range_adjacent_opclasses_condition(self):
constraint_name = 'ints_adjacent_opclasses_condition'
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[('ints', RangeOperators.ADJACENT_TO)],
opclasses=['range_ops'],
condition=Q(id__gte=100),
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
def test_range_adjacent_opclasses_deferrable(self):
constraint_name = 'ints_adjacent_opclasses_deferrable'
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[('ints', RangeOperators.ADJACENT_TO)],
opclasses=['range_ops'],
deferrable=Deferrable.DEFERRED,
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
@skipUnlessDBFeature('supports_covering_gist_indexes')
def test_range_adjacent_opclasses_include(self):
constraint_name = 'ints_adjacent_opclasses_include'
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[('ints', RangeOperators.ADJACENT_TO)],
opclasses=['range_ops'],
include=['decimals'],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))