Fixed #31702 -- Added support for PostgreSQL opclasses in UniqueConstraint.

This commit is contained in:
Hannes Ljungberg 2020-06-11 21:37:12 +02:00 committed by Mariusz Felisiak
parent 69e0d9c553
commit 7edc6e53a7
6 changed files with 210 additions and 14 deletions

View File

@ -1092,13 +1092,16 @@ class BaseDatabaseSchemaEditor:
if deferrable == Deferrable.IMMEDIATE:
return ' DEFERRABLE INITIALLY IMMEDIATE'
def _unique_sql(self, model, fields, name, condition=None, deferrable=None, include=None):
def _unique_sql(
self, model, fields, name, condition=None, deferrable=None,
include=None, opclasses=None,
):
if (
deferrable and
not self.connection.features.supports_deferrable_unique_constraints
):
return None
if condition or include:
if condition or include or opclasses:
# Databases support conditional and covering unique constraints via
# a unique index.
sql = self._create_unique_sql(
@ -1107,6 +1110,7 @@ class BaseDatabaseSchemaEditor:
name=name,
condition=condition,
include=include,
opclasses=opclasses,
)
if sql:
self.deferred_sql.append(sql)
@ -1120,7 +1124,10 @@ class BaseDatabaseSchemaEditor:
'constraint': constraint,
}
def _create_unique_sql(self, model, columns, name=None, condition=None, deferrable=None, include=None):
def _create_unique_sql(
self, model, columns, name=None, condition=None, deferrable=None,
include=None, opclasses=None,
):
if (
(
deferrable and
@ -1139,8 +1146,8 @@ class BaseDatabaseSchemaEditor:
name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name)
else:
name = self.quote_name(name)
columns = Columns(table, columns, self.quote_name)
if condition or include:
columns = self._index_columns(table, columns, col_suffixes=(), opclasses=opclasses)
if condition or include or opclasses:
sql = self.sql_create_unique_index
else:
sql = self.sql_create_unique
@ -1154,7 +1161,10 @@ class BaseDatabaseSchemaEditor:
include=self._index_include_sql(model, include),
)
def _delete_unique_sql(self, model, name, condition=None, deferrable=None, include=None):
def _delete_unique_sql(
self, model, name, condition=None, deferrable=None, include=None,
opclasses=None,
):
if (
(
deferrable and
@ -1164,7 +1174,7 @@ class BaseDatabaseSchemaEditor:
(include and not self.connection.features.supports_covering_indexes)
):
return None
if condition or include:
if condition or include or opclasses:
sql = self.sql_delete_index
else:
sql = self.sql_delete_unique

View File

@ -77,7 +77,16 @@ class Deferrable(Enum):
class UniqueConstraint(BaseConstraint):
def __init__(self, *, fields, name, condition=None, deferrable=None, include=None):
def __init__(
self,
*,
fields,
name,
condition=None,
deferrable=None,
include=None,
opclasses=(),
):
if not fields:
raise ValueError('At least one field is required to define a unique constraint.')
if not isinstance(condition, (type(None), Q)):
@ -92,10 +101,18 @@ class UniqueConstraint(BaseConstraint):
)
if not isinstance(include, (type(None), list, tuple)):
raise ValueError('UniqueConstraint.include must be a list or tuple.')
if not isinstance(opclasses, (list, tuple)):
raise ValueError('UniqueConstraint.opclasses must be a list or tuple.')
if opclasses and len(fields) != len(opclasses):
raise ValueError(
'UniqueConstraint.fields and UniqueConstraint.opclasses must '
'have the same number of elements.'
)
self.fields = tuple(fields)
self.condition = condition
self.deferrable = deferrable
self.include = tuple(include) if include else ()
self.opclasses = opclasses
super().__init__(name)
def _get_condition_sql(self, model, schema_editor):
@ -114,6 +131,7 @@ class UniqueConstraint(BaseConstraint):
return schema_editor._unique_sql(
model, fields, self.name, condition=condition,
deferrable=self.deferrable, include=include,
opclasses=self.opclasses,
)
def create_sql(self, model, schema_editor):
@ -123,6 +141,7 @@ class UniqueConstraint(BaseConstraint):
return schema_editor._create_unique_sql(
model, fields, self.name, condition=condition,
deferrable=self.deferrable, include=include,
opclasses=self.opclasses,
)
def remove_sql(self, model, schema_editor):
@ -130,15 +149,16 @@ class UniqueConstraint(BaseConstraint):
include = [model._meta.get_field(field_name).column for field_name in self.include]
return schema_editor._delete_unique_sql(
model, self.name, condition=condition, deferrable=self.deferrable,
include=include,
include=include, opclasses=self.opclasses,
)
def __repr__(self):
return '<%s: fields=%r name=%r%s%s%s>' % (
return '<%s: fields=%r name=%r%s%s%s%s>' % (
self.__class__.__name__, self.fields, self.name,
'' if self.condition is None else ' condition=%s' % self.condition,
'' if self.deferrable is None else ' deferrable=%s' % self.deferrable,
'' if not self.include else ' include=%s' % repr(self.include),
'' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses),
)
def __eq__(self, other):
@ -148,7 +168,8 @@ class UniqueConstraint(BaseConstraint):
self.fields == other.fields and
self.condition == other.condition and
self.deferrable == other.deferrable and
self.include == other.include
self.include == other.include and
self.opclasses == other.opclasses
)
return super().__eq__(other)
@ -161,4 +182,6 @@ class UniqueConstraint(BaseConstraint):
kwargs['deferrable'] = self.deferrable
if self.include:
kwargs['include'] = self.include
if self.opclasses:
kwargs['opclasses'] = self.opclasses
return path, args, kwargs

View File

@ -73,7 +73,7 @@ constraint.
``UniqueConstraint``
====================
.. class:: UniqueConstraint(*, fields, name, condition=None, deferrable=None, include=None)
.. class:: UniqueConstraint(*, fields, name, condition=None, deferrable=None, include=None, opclasses=())
Creates a unique constraint in the database.
@ -168,3 +168,24 @@ while fetching data only from the index.
``include`` is supported only on PostgreSQL.
Non-key columns have the same database restrictions as :attr:`Index.include`.
``opclasses``
-------------
.. attribute:: UniqueConstraint.opclasses
.. versionadded:: 3.2
The names of the `PostgreSQL operator classes
<https://www.postgresql.org/docs/current/indexes-opclass.html>`_ to use for
this unique index. If you require a custom operator class, you must provide one
for each field in the index.
For example::
UniqueConstraint(name='unique_username', fields=['username'], opclasses=['varchar_pattern_ops'])
creates a unique index on ``username`` using ``varchar_pattern_ops``.
``opclasses`` are ignored for databases besides PostgreSQL.

View File

@ -196,6 +196,9 @@ Models
attributes allow creating covering indexes and covering unique constraints on
PostgreSQL 11+.
* The new :attr:`.UniqueConstraint.opclasses` attribute allows setting
PostgreSQL operator classes.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -196,6 +196,20 @@ class UniqueConstraintTests(TestCase):
self.assertEqual(constraint_1, constraint_1)
self.assertNotEqual(constraint_1, constraint_2)
def test_eq_with_opclasses(self):
constraint_1 = models.UniqueConstraint(
fields=['foo', 'bar'],
name='opclasses',
opclasses=['text_pattern_ops', 'varchar_pattern_ops'],
)
constraint_2 = models.UniqueConstraint(
fields=['foo', 'bar'],
name='opclasses',
opclasses=['varchar_pattern_ops', 'text_pattern_ops'],
)
self.assertEqual(constraint_1, constraint_1)
self.assertNotEqual(constraint_1, constraint_2)
def test_repr(self):
fields = ['foo', 'bar']
name = 'unique_fields'
@ -241,6 +255,18 @@ class UniqueConstraintTests(TestCase):
"include=('baz_1', 'baz_2')>",
)
def test_repr_with_opclasses(self):
constraint = models.UniqueConstraint(
fields=['foo', 'bar'],
name='opclasses_fields',
opclasses=['text_pattern_ops', 'varchar_pattern_ops'],
)
self.assertEqual(
repr(constraint),
"<UniqueConstraint: fields=('foo', 'bar') name='opclasses_fields' "
"opclasses=['text_pattern_ops', 'varchar_pattern_ops']>",
)
def test_deconstruction(self):
fields = ['foo', 'bar']
name = 'unique_fields'
@ -291,6 +317,20 @@ class UniqueConstraintTests(TestCase):
'include': tuple(include),
})
def test_deconstruction_with_opclasses(self):
fields = ['foo', 'bar']
name = 'unique_fields'
opclasses = ['varchar_pattern_ops', 'text_pattern_ops']
constraint = models.UniqueConstraint(fields=fields, name=name, opclasses=opclasses)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, 'django.db.models.UniqueConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {
'fields': tuple(fields),
'name': name,
'opclasses': opclasses,
})
def test_database_constraint(self):
with self.assertRaises(IntegrityError):
UniqueConstraintProduct.objects.create(name=self.p1.name, color=self.p1.color)
@ -392,3 +432,24 @@ class UniqueConstraintTests(TestCase):
fields=['field'],
include='other',
)
def test_invalid_opclasses_argument(self):
msg = 'UniqueConstraint.opclasses must be a list or tuple.'
with self.assertRaisesMessage(ValueError, msg):
models.UniqueConstraint(
name='uniq_opclasses',
fields=['field'],
opclasses='jsonb_path_ops',
)
def test_opclasses_and_fields_same_length(self):
msg = (
'UniqueConstraint.fields and UniqueConstraint.opclasses must have '
'the same number of elements.'
)
with self.assertRaisesMessage(ValueError, msg):
models.UniqueConstraint(
name='uniq_opclasses',
fields=['field'],
opclasses=['foo', 'bar'],
)

View File

@ -4,12 +4,14 @@ from unittest import mock
from django.db import (
IntegrityError, NotSupportedError, connection, transaction,
)
from django.db.models import CheckConstraint, Deferrable, F, Func, Q
from django.db.models import (
CheckConstraint, Deferrable, F, Func, Q, UniqueConstraint,
)
from django.test import skipUnlessDBFeature
from django.utils import timezone
from . import PostgreSQLTestCase
from .models import HotelReservation, RangesModel, Room
from .models import HotelReservation, RangesModel, Room, Scene
try:
from django.contrib.postgres.constraints import ExclusionConstraint
@ -21,6 +23,13 @@ except ImportError:
class SchemaTests(PostgreSQLTestCase):
get_opclass_query = '''
SELECT opcname, c.relname FROM pg_opclass AS oc
JOIN pg_index as i on oc.oid = ANY(i.indclass)
JOIN pg_class as c on c.oid = i.indexrelid
WHERE c.relname = %s
'''
def get_constraints(self, table):
"""Get the constraints on the table using a new cursor."""
with connection.cursor() as cursor:
@ -84,6 +93,75 @@ class SchemaTests(PostgreSQLTestCase):
timestamps_inner=(datetime_1, datetime_2),
)
def test_opclass(self):
constraint = UniqueConstraint(
name='test_opclass',
fields=['scene'],
opclasses=['varchar_pattern_ops'],
)
with connection.schema_editor() as editor:
editor.add_constraint(Scene, constraint)
self.assertIn(constraint.name, self.get_constraints(Scene._meta.db_table))
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [constraint.name])
self.assertEqual(
cursor.fetchall(),
[('varchar_pattern_ops', constraint.name)],
)
# Drop the constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(Scene, constraint)
self.assertNotIn(constraint.name, self.get_constraints(Scene._meta.db_table))
def test_opclass_multiple_columns(self):
constraint = UniqueConstraint(
name='test_opclass_multiple',
fields=['scene', 'setting'],
opclasses=['varchar_pattern_ops', 'text_pattern_ops'],
)
with connection.schema_editor() as editor:
editor.add_constraint(Scene, constraint)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [constraint.name])
expected_opclasses = (
('varchar_pattern_ops', constraint.name),
('text_pattern_ops', constraint.name),
)
self.assertCountEqual(cursor.fetchall(), expected_opclasses)
def test_opclass_partial(self):
constraint = UniqueConstraint(
name='test_opclass_partial',
fields=['scene'],
opclasses=['varchar_pattern_ops'],
condition=Q(setting__contains="Sir Bedemir's Castle"),
)
with connection.schema_editor() as editor:
editor.add_constraint(Scene, constraint)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [constraint.name])
self.assertCountEqual(
cursor.fetchall(),
[('varchar_pattern_ops', constraint.name)],
)
@skipUnlessDBFeature('supports_covering_indexes')
def test_opclass_include(self):
constraint = UniqueConstraint(
name='test_opclass_include',
fields=['scene'],
opclasses=['varchar_pattern_ops'],
include=['setting'],
)
with connection.schema_editor() as editor:
editor.add_constraint(Scene, constraint)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [constraint.name])
self.assertCountEqual(
cursor.fetchall(),
[('varchar_pattern_ops', constraint.name)],
)
class ExclusionConstraintTests(PostgreSQLTestCase):
def get_constraints(self, table):