Fixed #31649 -- Added support for covering exclusion constraints on PostgreSQL 12+.

This commit is contained in:
Hannes Ljungberg 2020-06-11 21:05:38 +02:00 committed by Mariusz Felisiak
parent db8268bce6
commit e0cdd0fcf5
4 changed files with 175 additions and 7 deletions

View File

@ -1,3 +1,4 @@
from django.db import NotSupportedError
from django.db.backends.ddl_references import Statement, Table from django.db.backends.ddl_references import Statement, Table
from django.db.models import Deferrable, F, Q from django.db.models import Deferrable, F, Q
from django.db.models.constraints import BaseConstraint from django.db.models.constraints import BaseConstraint
@ -7,11 +8,11 @@ __all__ = ['ExclusionConstraint']
class ExclusionConstraint(BaseConstraint): class ExclusionConstraint(BaseConstraint):
template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(where)s%(deferrable)s' template = 'CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(include)s%(where)s%(deferrable)s'
def __init__( def __init__(
self, *, name, expressions, index_type=None, condition=None, self, *, name, expressions, index_type=None, condition=None,
deferrable=None, deferrable=None, include=None,
): ):
if index_type and index_type.lower() not in {'gist', 'spgist'}: if index_type and index_type.lower() not in {'gist', 'spgist'}:
raise ValueError( raise ValueError(
@ -39,10 +40,19 @@ class ExclusionConstraint(BaseConstraint):
raise ValueError( raise ValueError(
'ExclusionConstraint.deferrable must be a Deferrable instance.' 'ExclusionConstraint.deferrable must be a Deferrable instance.'
) )
if not isinstance(include, (type(None), list, tuple)):
raise ValueError(
'ExclusionConstraint.include must be a list or tuple.'
)
if include and index_type and index_type.lower() != 'gist':
raise ValueError(
'Covering exclusion constraints only support GiST indexes.'
)
self.expressions = expressions self.expressions = expressions
self.index_type = index_type or 'GIST' self.index_type = index_type or 'GIST'
self.condition = condition self.condition = condition
self.deferrable = deferrable self.deferrable = deferrable
self.include = tuple(include) if include else ()
super().__init__(name=name) super().__init__(name=name)
def _get_expression_sql(self, compiler, connection, query): def _get_expression_sql(self, compiler, connection, query):
@ -67,15 +77,18 @@ class ExclusionConstraint(BaseConstraint):
compiler = query.get_compiler(connection=schema_editor.connection) compiler = query.get_compiler(connection=schema_editor.connection)
expressions = self._get_expression_sql(compiler, schema_editor.connection, query) expressions = self._get_expression_sql(compiler, schema_editor.connection, query)
condition = self._get_condition_sql(compiler, schema_editor, query) condition = self._get_condition_sql(compiler, schema_editor, query)
include = [model._meta.get_field(field_name).column for field_name in self.include]
return self.template % { return self.template % {
'name': schema_editor.quote_name(self.name), 'name': schema_editor.quote_name(self.name),
'index_type': self.index_type, 'index_type': self.index_type,
'expressions': ', '.join(expressions), 'expressions': ', '.join(expressions),
'include': schema_editor._index_include_sql(model, include),
'where': ' WHERE (%s)' % condition if condition else '', 'where': ' WHERE (%s)' % condition if condition else '',
'deferrable': schema_editor._deferrable_constraint_sql(self.deferrable), 'deferrable': schema_editor._deferrable_constraint_sql(self.deferrable),
} }
def create_sql(self, model, schema_editor): def create_sql(self, model, schema_editor):
self.check_supported(schema_editor)
return Statement( return Statement(
'ALTER TABLE %(table)s ADD %(constraint)s', 'ALTER TABLE %(table)s ADD %(constraint)s',
table=Table(model._meta.db_table, schema_editor.quote_name), table=Table(model._meta.db_table, schema_editor.quote_name),
@ -89,6 +102,12 @@ class ExclusionConstraint(BaseConstraint):
schema_editor.quote_name(self.name), schema_editor.quote_name(self.name),
) )
def check_supported(self, schema_editor):
if self.include and not schema_editor.connection.features.supports_covering_gist_indexes:
raise NotSupportedError(
'Covering exclusion constraints requires PostgreSQL 12+.'
)
def deconstruct(self): def deconstruct(self):
path, args, kwargs = super().deconstruct() path, args, kwargs = super().deconstruct()
kwargs['expressions'] = self.expressions kwargs['expressions'] = self.expressions
@ -98,6 +117,8 @@ class ExclusionConstraint(BaseConstraint):
kwargs['index_type'] = self.index_type kwargs['index_type'] = self.index_type
if self.deferrable: if self.deferrable:
kwargs['deferrable'] = self.deferrable kwargs['deferrable'] = self.deferrable
if self.include:
kwargs['include'] = self.include
return path, args, kwargs return path, args, kwargs
def __eq__(self, other): def __eq__(self, other):
@ -107,15 +128,17 @@ class ExclusionConstraint(BaseConstraint):
self.index_type == other.index_type and self.index_type == other.index_type and
self.expressions == other.expressions and self.expressions == other.expressions and
self.condition == other.condition and self.condition == other.condition and
self.deferrable == other.deferrable self.deferrable == other.deferrable and
self.include == other.include
) )
return super().__eq__(other) return super().__eq__(other)
def __repr__(self): def __repr__(self):
return '<%s: index_type=%s, expressions=%s%s%s>' % ( return '<%s: index_type=%s, expressions=%s%s%s%s>' % (
self.__class__.__qualname__, self.__class__.__qualname__,
self.index_type, self.index_type,
self.expressions, self.expressions,
'' if self.condition is None else ', condition=%s' % self.condition, '' if self.condition is None else ', condition=%s' % self.condition,
'' if self.deferrable is None else ', deferrable=%s' % self.deferrable, '' if self.deferrable is None else ', deferrable=%s' % self.deferrable,
'' if not self.include else ', include=%s' % repr(self.include),
) )

View File

@ -12,7 +12,7 @@ PostgreSQL supports additional data integrity constraints available from the
``ExclusionConstraint`` ``ExclusionConstraint``
======================= =======================
.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None) .. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None)
Creates an exclusion constraint in the database. Internally, PostgreSQL Creates an exclusion constraint in the database. Internally, PostgreSQL
implements exclusion constraints using indexes. The default index type is implements exclusion constraints using indexes. The default index type is
@ -106,6 +106,21 @@ enforced immediately after every command.
Deferred exclusion constraints may lead to a `performance penalty Deferred exclusion constraints may lead to a `performance penalty
<https://www.postgresql.org/docs/current/sql-createtable.html#id-1.9.3.85.9.4>`_. <https://www.postgresql.org/docs/current/sql-createtable.html#id-1.9.3.85.9.4>`_.
``include``
-----------
.. attribute:: ExclusionConstraint.include
.. versionadded:: 3.2
A list or tuple of the names of the fields to be included in the covering
exclusion constraint as non-key columns. This allows index-only scans to be
used for queries that select only included fields
(:attr:`~ExclusionConstraint.include`) and filter only by indexed fields
(:attr:`~ExclusionConstraint.expressions`).
``include`` is supported only for GiST indexes on PostgreSQL 12+.
Examples Examples
-------- --------

View File

@ -70,7 +70,8 @@ Minor features
:mod:`django.contrib.postgres` :mod:`django.contrib.postgres`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* ... * The new :attr:`.ExclusionConstraint.include` attribute allows creating
covering exclusion constraints on PostgreSQL 12+.
:mod:`django.contrib.redirects` :mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1,8 +1,11 @@
import datetime import datetime
from unittest import mock from unittest import mock
from django.db import IntegrityError, connection, transaction from django.db import (
IntegrityError, NotSupportedError, connection, transaction,
)
from django.db.models import CheckConstraint, Deferrable, F, Func, Q from django.db.models import CheckConstraint, Deferrable, F, Func, Q
from django.test import skipUnlessDBFeature
from django.utils import timezone from django.utils import timezone
from . import PostgreSQLTestCase from . import PostgreSQLTestCase
@ -146,6 +149,25 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
deferrable=Deferrable.DEFERRED, deferrable=Deferrable.DEFERRED,
) )
def test_invalid_include_type(self):
msg = 'ExclusionConstraint.include must be a list or tuple.'
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
name='exclude_invalid_include',
expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
include='invalid',
)
def test_invalid_include_index_type(self):
msg = 'Covering exclusion constraints only support GiST indexes.'
with self.assertRaisesMessage(ValueError, msg):
ExclusionConstraint(
name='exclude_invalid_index_type',
expressions=[(F('datespan'), RangeOperators.OVERLAPS)],
include=['cancelled'],
index_type='spgist',
)
def test_repr(self): def test_repr(self):
constraint = ExclusionConstraint( constraint = ExclusionConstraint(
name='exclude_overlapping', name='exclude_overlapping',
@ -180,6 +202,16 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
"<ExclusionConstraint: index_type=GIST, expressions=[" "<ExclusionConstraint: index_type=GIST, expressions=["
"(F(datespan), '-|-')], deferrable=Deferrable.IMMEDIATE>", "(F(datespan), '-|-')], deferrable=Deferrable.IMMEDIATE>",
) )
constraint = ExclusionConstraint(
name='exclude_overlapping',
expressions=[(F('datespan'), RangeOperators.ADJACENT_TO)],
include=['cancelled', 'room'],
)
self.assertEqual(
repr(constraint),
"<ExclusionConstraint: index_type=GIST, expressions=["
"(F(datespan), '-|-')], include=('cancelled', 'room')>",
)
def test_eq(self): def test_eq(self):
constraint_1 = ExclusionConstraint( constraint_1 = ExclusionConstraint(
@ -218,6 +250,23 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
], ],
deferrable=Deferrable.IMMEDIATE, deferrable=Deferrable.IMMEDIATE,
) )
constraint_6 = ExclusionConstraint(
name='exclude_overlapping',
expressions=[
('datespan', RangeOperators.OVERLAPS),
('room', RangeOperators.EQUAL),
],
deferrable=Deferrable.IMMEDIATE,
include=['cancelled'],
)
constraint_7 = ExclusionConstraint(
name='exclude_overlapping',
expressions=[
('datespan', RangeOperators.OVERLAPS),
('room', RangeOperators.EQUAL),
],
include=['cancelled'],
)
self.assertEqual(constraint_1, constraint_1) self.assertEqual(constraint_1, constraint_1)
self.assertEqual(constraint_1, mock.ANY) self.assertEqual(constraint_1, mock.ANY)
self.assertNotEqual(constraint_1, constraint_2) self.assertNotEqual(constraint_1, constraint_2)
@ -225,7 +274,9 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
self.assertNotEqual(constraint_1, constraint_4) self.assertNotEqual(constraint_1, constraint_4)
self.assertNotEqual(constraint_2, constraint_3) self.assertNotEqual(constraint_2, constraint_3)
self.assertNotEqual(constraint_2, constraint_4) self.assertNotEqual(constraint_2, constraint_4)
self.assertNotEqual(constraint_2, constraint_7)
self.assertNotEqual(constraint_4, constraint_5) self.assertNotEqual(constraint_4, constraint_5)
self.assertNotEqual(constraint_5, constraint_6)
self.assertNotEqual(constraint_1, object()) self.assertNotEqual(constraint_1, object())
def test_deconstruct(self): def test_deconstruct(self):
@ -286,6 +337,21 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
'deferrable': Deferrable.DEFERRED, 'deferrable': Deferrable.DEFERRED,
}) })
def test_deconstruct_include(self):
constraint = ExclusionConstraint(
name='exclude_overlapping',
expressions=[('datespan', RangeOperators.OVERLAPS)],
include=['cancelled', 'room'],
)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, 'django.contrib.postgres.constraints.ExclusionConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {
'name': 'exclude_overlapping',
'expressions': [('datespan', RangeOperators.OVERLAPS)],
'include': ('cancelled', 'room'),
})
def _test_range_overlaps(self, constraint): def _test_range_overlaps(self, constraint):
# Create exclusion constraint. # Create exclusion constraint.
self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table)) self.assertNotIn(constraint.name, self.get_constraints(HotelReservation._meta.db_table))
@ -417,3 +483,66 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
adjacent_range.delete() adjacent_range.delete()
RangesModel.objects.create(ints=(10, 19)) RangesModel.objects.create(ints=(10, 19))
RangesModel.objects.create(ints=(51, 60)) RangesModel.objects.create(ints=(51, 60))
@skipUnlessDBFeature('supports_covering_gist_indexes')
def test_range_adjacent_include(self):
constraint_name = 'ints_adjacent_include'
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[('ints', RangeOperators.ADJACENT_TO)],
include=['decimals', 'ints'],
index_type='gist',
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
RangesModel.objects.create(ints=(20, 50))
with self.assertRaises(IntegrityError), transaction.atomic():
RangesModel.objects.create(ints=(10, 20))
RangesModel.objects.create(ints=(10, 19))
RangesModel.objects.create(ints=(51, 60))
@skipUnlessDBFeature('supports_covering_gist_indexes')
def test_range_adjacent_include_condition(self):
constraint_name = 'ints_adjacent_include_condition'
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[('ints', RangeOperators.ADJACENT_TO)],
include=['decimals'],
condition=Q(id__gte=100),
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
@skipUnlessDBFeature('supports_covering_gist_indexes')
def test_range_adjacent_include_deferrable(self):
constraint_name = 'ints_adjacent_include_deferrable'
self.assertNotIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[('ints', RangeOperators.ADJACENT_TO)],
include=['decimals'],
deferrable=Deferrable.DEFERRED,
)
with connection.schema_editor() as editor:
editor.add_constraint(RangesModel, constraint)
self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
def test_include_not_supported(self):
constraint_name = 'ints_adjacent_include_not_supported'
constraint = ExclusionConstraint(
name=constraint_name,
expressions=[('ints', RangeOperators.ADJACENT_TO)],
include=['id'],
)
msg = 'Covering exclusion constraints requires PostgreSQL 12+.'
with connection.schema_editor() as editor:
with mock.patch(
'django.db.backends.postgresql.features.DatabaseFeatures.supports_covering_gist_indexes',
False,
):
with self.assertRaisesMessage(NotSupportedError, msg):
editor.add_constraint(RangesModel, constraint)