Fixed #33342 -- Added support for using OpClass() in exclusion constraints.

This commit is contained in:
Hannes Ljungberg 2021-11-03 22:21:50 +01:00 committed by Mariusz Felisiak
parent a0d43a7a6e
commit 0e656c02fe
5 changed files with 114 additions and 28 deletions

View File

@ -1,13 +1,19 @@
from django.contrib.postgres.indexes import OpClass
from django.db import NotSupportedError from django.db import NotSupportedError
from django.db.backends.ddl_references import Statement, Table from django.db.backends.ddl_references import Expressions, Statement, Table
from django.db.models import Deferrable, F, Q from django.db.models import Deferrable, F, Q
from django.db.models.constraints import BaseConstraint from django.db.models.constraints import BaseConstraint
from django.db.models.expressions import Col from django.db.models.expressions import ExpressionList
from django.db.models.indexes import IndexExpression
from django.db.models.sql import Query from django.db.models.sql import Query
__all__ = ['ExclusionConstraint'] __all__ = ['ExclusionConstraint']
class ExclusionConstraintExpression(IndexExpression):
template = '%(expressions)s WITH %(operator)s'
class ExclusionConstraint(BaseConstraint): class ExclusionConstraint(BaseConstraint):
template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s' template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s'
@ -63,24 +69,19 @@ class ExclusionConstraint(BaseConstraint):
self.opclasses = opclasses self.opclasses = opclasses
super().__init__(name=name) super().__init__(name=name)
def _get_expression_sql(self, compiler, schema_editor, query): def _get_expressions(self, schema_editor, query):
expressions = [] expressions = []
for idx, (expression, operator) in enumerate(self.expressions): for idx, (expression, operator) in enumerate(self.expressions):
if isinstance(expression, str): if isinstance(expression, str):
expression = F(expression) expression = F(expression)
expression = expression.resolve_expression(query=query)
sql, params = compiler.compile(expression)
if not isinstance(expression, Col):
sql = f'({sql})'
try: try:
opclass = self.opclasses[idx] expression = OpClass(expression, self.opclasses[idx])
if opclass:
sql = '%s %s' % (sql, opclass)
except IndexError: except IndexError:
pass pass
sql = sql % tuple(schema_editor.quote_value(p) for p in params) expression = ExclusionConstraintExpression(expression, operator=operator)
expressions.append('%s WITH %s' % (sql, operator)) expression.set_wrapper_classes(schema_editor.connection)
return expressions expressions.append(expression)
return ExpressionList(*expressions).resolve_expression(query)
def _get_condition_sql(self, compiler, schema_editor, query): def _get_condition_sql(self, compiler, schema_editor, query):
if self.condition is None: if self.condition is None:
@ -92,17 +93,20 @@ class ExclusionConstraint(BaseConstraint):
def constraint_sql(self, model, schema_editor): def constraint_sql(self, model, schema_editor):
query = Query(model, alias_cols=False) query = Query(model, alias_cols=False)
compiler = query.get_compiler(connection=schema_editor.connection) compiler = query.get_compiler(connection=schema_editor.connection)
expressions = self._get_expression_sql(compiler, schema_editor, query) expressions = self._get_expressions(schema_editor, query)
table = model._meta.db_table
condition = self._get_condition_sql(compiler, schema_editor, query) condition = self._get_condition_sql(compiler, schema_editor, query)
include = [model._meta.get_field(field_name).column for field_name in self.include] include = [model._meta.get_field(field_name).column for field_name in self.include]
return self.template % { return Statement(
'name': schema_editor.quote_name(self.name), self.template,
'index_type': self.index_type, table=Table(table, schema_editor.quote_name),
'expressions': ', '.join(expressions), name=schema_editor.quote_name(self.name),
'include': schema_editor._index_include_sql(model, include), index_type=self.index_type,
'where': ' WHERE (%s)' % condition if condition else '', expressions=Expressions(table, expressions, compiler, schema_editor.quote_value),
'deferrable': schema_editor._deferrable_constraint_sql(self.deferrable), where=' WHERE (%s)' % condition if condition else '',
} include=schema_editor._index_include_sql(model, include),
deferrable=schema_editor._deferrable_constraint_sql(self.deferrable),
)
def create_sql(self, model, schema_editor): def create_sql(self, model, schema_editor):
self.check_supported(schema_editor) self.check_supported(schema_editor)

View File

@ -53,6 +53,10 @@ operators with strings. For example::
Only commutative operators can be used in exclusion constraints. Only commutative operators can be used in exclusion constraints.
.. versionchanged:: 4.1
Support for the ``OpClass()`` expression was added.
``index_type`` ``index_type``
-------------- --------------
@ -143,6 +147,20 @@ For example::
creates an exclusion constraint on ``circle`` using ``circle_ops``. creates an exclusion constraint on ``circle`` using ``circle_ops``.
Alternatively, you can use
:class:`OpClass() <django.contrib.postgres.indexes.OpClass>` in
:attr:`~ExclusionConstraint.expressions`::
ExclusionConstraint(
name='exclude_overlapping_opclasses',
expressions=[(OpClass('circle', 'circle_ops'), RangeOperators.OVERLAPS)],
)
.. versionchanged:: 4.1
Support for specifying operator classes with the ``OpClass()`` expression
was added.
Examples Examples
-------- --------

View File

@ -150,10 +150,10 @@ available from the ``django.contrib.postgres.indexes`` module.
.. class:: OpClass(expression, name) .. class:: OpClass(expression, name)
An ``OpClass()`` expression represents the ``expression`` with a custom An ``OpClass()`` expression represents the ``expression`` with a custom
`operator class`_ that can be used to define functional indexes or unique `operator class`_ that can be used to define functional indexes, functional
constraints. To use it, you need to add ``'django.contrib.postgres'`` in unique constraints, or exclusion constraints. To use it, you need to add
your :setting:`INSTALLED_APPS`. Set the ``name`` parameter to the name of ``'django.contrib.postgres'`` in your :setting:`INSTALLED_APPS`. Set the
the `operator class`_. ``name`` parameter to the name of the `operator class`_.
For example:: For example::
@ -163,8 +163,7 @@ available from the ``django.contrib.postgres.indexes`` module.
) )
creates an index on ``Lower('username')`` using ``varchar_pattern_ops``. creates an index on ``Lower('username')`` using ``varchar_pattern_ops``.
::
Another example::
UniqueConstraint( UniqueConstraint(
OpClass(Upper('description'), name='text_pattern_ops'), OpClass(Upper('description'), name='text_pattern_ops'),
@ -173,9 +172,23 @@ available from the ``django.contrib.postgres.indexes`` module.
creates a unique constraint on ``Upper('description')`` using creates a unique constraint on ``Upper('description')`` using
``text_pattern_ops``. ``text_pattern_ops``.
::
ExclusionConstraint(
name='exclude_overlapping_ops',
expressions=[
(OpClass('circle', name='circle_ops'), RangeOperators.OVERLAPS),
],
)
creates an exclusion constraint on ``circle`` using ``circle_ops``.
.. versionchanged:: 4.0 .. versionchanged:: 4.0
Support for functional unique constraints was added. Support for functional unique constraints was added.
.. versionchanged:: 4.1
Support for exclusion constraints was added.
.. _operator class: https://www.postgresql.org/docs/current/indexes-opclass.html .. _operator class: https://www.postgresql.org/docs/current/indexes-opclass.html

View File

@ -108,6 +108,10 @@ Minor features
<django.contrib.postgres.fields.DecimalRangeField.default_bounds>` allows <django.contrib.postgres.fields.DecimalRangeField.default_bounds>` allows
specifying bounds for list and tuple inputs. specifying bounds for list and tuple inputs.
* :class:`~django.contrib.postgres.constraints.ExclusionConstraint` now allows
specifying operator classes with the
:class:`OpClass() <django.contrib.postgres.indexes.OpClass>` expression.
:mod:`django.contrib.redirects` :mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -198,6 +198,7 @@ class SchemaTests(PostgreSQLTestCase):
Scene.objects.create(scene='ScEnE 10', setting="Sir Bedemir's Castle") Scene.objects.create(scene='ScEnE 10', setting="Sir Bedemir's Castle")
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
class ExclusionConstraintTests(PostgreSQLTestCase): class ExclusionConstraintTests(PostgreSQLTestCase):
def get_constraints(self, table): def get_constraints(self, table):
"""Get the constraints on the table using a new cursor.""" """Get the constraints on the table using a new cursor."""
@ -604,6 +605,24 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
) )
self._test_range_overlaps(constraint) self._test_range_overlaps(constraint)
def test_range_overlaps_custom_opclass_expression(self):
class TsTzRange(Func):
function = 'TSTZRANGE'
output_field = DateTimeRangeField()
constraint = ExclusionConstraint(
name='exclude_overlapping_reservations_custom_opclass',
expressions=[
(
OpClass(TsTzRange('start', 'end', RangeBoundary()), 'range_ops'),
RangeOperators.OVERLAPS,
),
(OpClass('room', 'gist_int4_ops'), RangeOperators.EQUAL),
],
condition=Q(cancelled=False),
)
self._test_range_overlaps(constraint)
def test_range_overlaps(self): def test_range_overlaps(self):
constraint = ExclusionConstraint( constraint = ExclusionConstraint(
name='exclude_overlapping_reservations', name='exclude_overlapping_reservations',
@ -914,6 +933,34 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
editor.add_constraint(RangesModel, constraint) editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table)) self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
def test_opclass_expression(self):
constraint_name = 'ints_adjacent_opclass_expression'
self.assertNotIn(
constraint_name,
self.get_constraints(RangesModel._meta.db_table),
)
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[(OpClass('ints', 'range_ops'), RangeOperators.ADJACENT_TO)],
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
constraints = self.get_constraints(RangesModel._meta.db_table)
self.assertIn(constraint_name, constraints)
with editor.connection.cursor() as cursor:
cursor.execute(SchemaTests.get_opclass_query, [constraint_name])
self.assertEqual(
cursor.fetchall(),
[('range_ops', constraint_name)],
)
# Drop the constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(RangesModel, constraint)
self.assertNotIn(
constraint_name,
self.get_constraints(RangesModel._meta.db_table),
)
def test_range_equal_cast(self): def test_range_equal_cast(self):
constraint_name = 'exclusion_equal_room_cast' constraint_name = 'exclusion_equal_room_cast'
self.assertNotIn(constraint_name, self.get_constraints(Room._meta.db_table)) self.assertNotIn(constraint_name, self.get_constraints(Room._meta.db_table))