diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index b5a75538bd..76dea84cff 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -1092,13 +1092,16 @@ class BaseDatabaseSchemaEditor: if deferrable == Deferrable.IMMEDIATE: return ' DEFERRABLE INITIALLY IMMEDIATE' - def _unique_sql(self, model, fields, name, condition=None, deferrable=None, include=None): + def _unique_sql( + self, model, fields, name, condition=None, deferrable=None, + include=None, opclasses=None, + ): if ( deferrable and not self.connection.features.supports_deferrable_unique_constraints ): return None - if condition or include: + if condition or include or opclasses: # Databases support conditional and covering unique constraints via # a unique index. sql = self._create_unique_sql( @@ -1107,6 +1110,7 @@ class BaseDatabaseSchemaEditor: name=name, condition=condition, include=include, + opclasses=opclasses, ) if sql: self.deferred_sql.append(sql) @@ -1120,7 +1124,10 @@ class BaseDatabaseSchemaEditor: 'constraint': constraint, } - def _create_unique_sql(self, model, columns, name=None, condition=None, deferrable=None, include=None): + def _create_unique_sql( + self, model, columns, name=None, condition=None, deferrable=None, + include=None, opclasses=None, + ): if ( ( deferrable and @@ -1139,8 +1146,8 @@ class BaseDatabaseSchemaEditor: name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name) else: name = self.quote_name(name) - columns = Columns(table, columns, self.quote_name) - if condition or include: + columns = self._index_columns(table, columns, col_suffixes=(), opclasses=opclasses) + if condition or include or opclasses: sql = self.sql_create_unique_index else: sql = self.sql_create_unique @@ -1154,7 +1161,10 @@ class BaseDatabaseSchemaEditor: include=self._index_include_sql(model, include), ) - def _delete_unique_sql(self, model, name, condition=None, deferrable=None, include=None): + def _delete_unique_sql( + self, model, name, condition=None, deferrable=None, include=None, + opclasses=None, + ): if ( ( deferrable and @@ -1164,7 +1174,7 @@ class BaseDatabaseSchemaEditor: (include and not self.connection.features.supports_covering_indexes) ): return None - if condition or include: + if condition or include or opclasses: sql = self.sql_delete_index else: sql = self.sql_delete_unique diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index 17a8226915..c6dd39e762 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -77,7 +77,16 @@ class Deferrable(Enum): class UniqueConstraint(BaseConstraint): - def __init__(self, *, fields, name, condition=None, deferrable=None, include=None): + def __init__( + self, + *, + fields, + name, + condition=None, + deferrable=None, + include=None, + opclasses=(), + ): if not fields: raise ValueError('At least one field is required to define a unique constraint.') if not isinstance(condition, (type(None), Q)): @@ -92,10 +101,18 @@ class UniqueConstraint(BaseConstraint): ) if not isinstance(include, (type(None), list, tuple)): raise ValueError('UniqueConstraint.include must be a list or tuple.') + if not isinstance(opclasses, (list, tuple)): + raise ValueError('UniqueConstraint.opclasses must be a list or tuple.') + if opclasses and len(fields) != len(opclasses): + raise ValueError( + 'UniqueConstraint.fields and UniqueConstraint.opclasses must ' + 'have the same number of elements.' + ) self.fields = tuple(fields) self.condition = condition self.deferrable = deferrable self.include = tuple(include) if include else () + self.opclasses = opclasses super().__init__(name) def _get_condition_sql(self, model, schema_editor): @@ -114,6 +131,7 @@ class UniqueConstraint(BaseConstraint): return schema_editor._unique_sql( model, fields, self.name, condition=condition, deferrable=self.deferrable, include=include, + opclasses=self.opclasses, ) def create_sql(self, model, schema_editor): @@ -123,6 +141,7 @@ class UniqueConstraint(BaseConstraint): return schema_editor._create_unique_sql( model, fields, self.name, condition=condition, deferrable=self.deferrable, include=include, + opclasses=self.opclasses, ) def remove_sql(self, model, schema_editor): @@ -130,15 +149,16 @@ class UniqueConstraint(BaseConstraint): include = [model._meta.get_field(field_name).column for field_name in self.include] return schema_editor._delete_unique_sql( model, self.name, condition=condition, deferrable=self.deferrable, - include=include, + include=include, opclasses=self.opclasses, ) def __repr__(self): - return '<%s: fields=%r name=%r%s%s%s>' % ( + return '<%s: fields=%r name=%r%s%s%s%s>' % ( self.__class__.__name__, self.fields, self.name, '' if self.condition is None else ' condition=%s' % self.condition, '' if self.deferrable is None else ' deferrable=%s' % self.deferrable, '' if not self.include else ' include=%s' % repr(self.include), + '' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses), ) def __eq__(self, other): @@ -148,7 +168,8 @@ class UniqueConstraint(BaseConstraint): self.fields == other.fields and self.condition == other.condition and self.deferrable == other.deferrable and - self.include == other.include + self.include == other.include and + self.opclasses == other.opclasses ) return super().__eq__(other) @@ -161,4 +182,6 @@ class UniqueConstraint(BaseConstraint): kwargs['deferrable'] = self.deferrable if self.include: kwargs['include'] = self.include + if self.opclasses: + kwargs['opclasses'] = self.opclasses return path, args, kwargs diff --git a/docs/ref/models/constraints.txt b/docs/ref/models/constraints.txt index 819bb3a20b..1536a8692a 100644 --- a/docs/ref/models/constraints.txt +++ b/docs/ref/models/constraints.txt @@ -73,7 +73,7 @@ constraint. ``UniqueConstraint`` ==================== -.. class:: UniqueConstraint(*, fields, name, condition=None, deferrable=None, include=None) +.. class:: UniqueConstraint(*, fields, name, condition=None, deferrable=None, include=None, opclasses=()) Creates a unique constraint in the database. @@ -168,3 +168,24 @@ while fetching data only from the index. ``include`` is supported only on PostgreSQL. Non-key columns have the same database restrictions as :attr:`Index.include`. + + +``opclasses`` +------------- + +.. attribute:: UniqueConstraint.opclasses + +.. versionadded:: 3.2 + +The names of the `PostgreSQL operator classes +`_ to use for +this unique index. If you require a custom operator class, you must provide one +for each field in the index. + +For example:: + + UniqueConstraint(name='unique_username', fields=['username'], opclasses=['varchar_pattern_ops']) + +creates a unique index on ``username`` using ``varchar_pattern_ops``. + +``opclasses`` are ignored for databases besides PostgreSQL. diff --git a/docs/releases/3.2.txt b/docs/releases/3.2.txt index 3b831ee8ec..416530400e 100644 --- a/docs/releases/3.2.txt +++ b/docs/releases/3.2.txt @@ -196,6 +196,9 @@ Models attributes allow creating covering indexes and covering unique constraints on PostgreSQL 11+. +* The new :attr:`.UniqueConstraint.opclasses` attribute allows setting + PostgreSQL operator classes. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py index 02320e30b1..4ec1f2a8e8 100644 --- a/tests/constraints/tests.py +++ b/tests/constraints/tests.py @@ -196,6 +196,20 @@ class UniqueConstraintTests(TestCase): self.assertEqual(constraint_1, constraint_1) self.assertNotEqual(constraint_1, constraint_2) + def test_eq_with_opclasses(self): + constraint_1 = models.UniqueConstraint( + fields=['foo', 'bar'], + name='opclasses', + opclasses=['text_pattern_ops', 'varchar_pattern_ops'], + ) + constraint_2 = models.UniqueConstraint( + fields=['foo', 'bar'], + name='opclasses', + opclasses=['varchar_pattern_ops', 'text_pattern_ops'], + ) + self.assertEqual(constraint_1, constraint_1) + self.assertNotEqual(constraint_1, constraint_2) + def test_repr(self): fields = ['foo', 'bar'] name = 'unique_fields' @@ -241,6 +255,18 @@ class UniqueConstraintTests(TestCase): "include=('baz_1', 'baz_2')>", ) + def test_repr_with_opclasses(self): + constraint = models.UniqueConstraint( + fields=['foo', 'bar'], + name='opclasses_fields', + opclasses=['text_pattern_ops', 'varchar_pattern_ops'], + ) + self.assertEqual( + repr(constraint), + "", + ) + def test_deconstruction(self): fields = ['foo', 'bar'] name = 'unique_fields' @@ -291,6 +317,20 @@ class UniqueConstraintTests(TestCase): 'include': tuple(include), }) + def test_deconstruction_with_opclasses(self): + fields = ['foo', 'bar'] + name = 'unique_fields' + opclasses = ['varchar_pattern_ops', 'text_pattern_ops'] + constraint = models.UniqueConstraint(fields=fields, name=name, opclasses=opclasses) + path, args, kwargs = constraint.deconstruct() + self.assertEqual(path, 'django.db.models.UniqueConstraint') + self.assertEqual(args, ()) + self.assertEqual(kwargs, { + 'fields': tuple(fields), + 'name': name, + 'opclasses': opclasses, + }) + def test_database_constraint(self): with self.assertRaises(IntegrityError): UniqueConstraintProduct.objects.create(name=self.p1.name, color=self.p1.color) @@ -392,3 +432,24 @@ class UniqueConstraintTests(TestCase): fields=['field'], include='other', ) + + def test_invalid_opclasses_argument(self): + msg = 'UniqueConstraint.opclasses must be a list or tuple.' + with self.assertRaisesMessage(ValueError, msg): + models.UniqueConstraint( + name='uniq_opclasses', + fields=['field'], + opclasses='jsonb_path_ops', + ) + + def test_opclasses_and_fields_same_length(self): + msg = ( + 'UniqueConstraint.fields and UniqueConstraint.opclasses must have ' + 'the same number of elements.' + ) + with self.assertRaisesMessage(ValueError, msg): + models.UniqueConstraint( + name='uniq_opclasses', + fields=['field'], + opclasses=['foo', 'bar'], + ) diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py index b7a866a2b5..cfe0981d3c 100644 --- a/tests/postgres_tests/test_constraints.py +++ b/tests/postgres_tests/test_constraints.py @@ -4,12 +4,14 @@ from unittest import mock from django.db import ( IntegrityError, NotSupportedError, connection, transaction, ) -from django.db.models import CheckConstraint, Deferrable, F, Func, Q +from django.db.models import ( + CheckConstraint, Deferrable, F, Func, Q, UniqueConstraint, +) from django.test import skipUnlessDBFeature from django.utils import timezone from . import PostgreSQLTestCase -from .models import HotelReservation, RangesModel, Room +from .models import HotelReservation, RangesModel, Room, Scene try: from django.contrib.postgres.constraints import ExclusionConstraint @@ -21,6 +23,13 @@ except ImportError: class SchemaTests(PostgreSQLTestCase): + get_opclass_query = ''' + SELECT opcname, c.relname FROM pg_opclass AS oc + JOIN pg_index as i on oc.oid = ANY(i.indclass) + JOIN pg_class as c on c.oid = i.indexrelid + WHERE c.relname = %s + ''' + def get_constraints(self, table): """Get the constraints on the table using a new cursor.""" with connection.cursor() as cursor: @@ -84,6 +93,75 @@ class SchemaTests(PostgreSQLTestCase): timestamps_inner=(datetime_1, datetime_2), ) + def test_opclass(self): + constraint = UniqueConstraint( + name='test_opclass', + fields=['scene'], + opclasses=['varchar_pattern_ops'], + ) + with connection.schema_editor() as editor: + editor.add_constraint(Scene, constraint) + self.assertIn(constraint.name, self.get_constraints(Scene._meta.db_table)) + with editor.connection.cursor() as cursor: + cursor.execute(self.get_opclass_query, [constraint.name]) + self.assertEqual( + cursor.fetchall(), + [('varchar_pattern_ops', constraint.name)], + ) + # Drop the constraint. + with connection.schema_editor() as editor: + editor.remove_constraint(Scene, constraint) + self.assertNotIn(constraint.name, self.get_constraints(Scene._meta.db_table)) + + def test_opclass_multiple_columns(self): + constraint = UniqueConstraint( + name='test_opclass_multiple', + fields=['scene', 'setting'], + opclasses=['varchar_pattern_ops', 'text_pattern_ops'], + ) + with connection.schema_editor() as editor: + editor.add_constraint(Scene, constraint) + with editor.connection.cursor() as cursor: + cursor.execute(self.get_opclass_query, [constraint.name]) + expected_opclasses = ( + ('varchar_pattern_ops', constraint.name), + ('text_pattern_ops', constraint.name), + ) + self.assertCountEqual(cursor.fetchall(), expected_opclasses) + + def test_opclass_partial(self): + constraint = UniqueConstraint( + name='test_opclass_partial', + fields=['scene'], + opclasses=['varchar_pattern_ops'], + condition=Q(setting__contains="Sir Bedemir's Castle"), + ) + with connection.schema_editor() as editor: + editor.add_constraint(Scene, constraint) + with editor.connection.cursor() as cursor: + cursor.execute(self.get_opclass_query, [constraint.name]) + self.assertCountEqual( + cursor.fetchall(), + [('varchar_pattern_ops', constraint.name)], + ) + + @skipUnlessDBFeature('supports_covering_indexes') + def test_opclass_include(self): + constraint = UniqueConstraint( + name='test_opclass_include', + fields=['scene'], + opclasses=['varchar_pattern_ops'], + include=['setting'], + ) + with connection.schema_editor() as editor: + editor.add_constraint(Scene, constraint) + with editor.connection.cursor() as cursor: + cursor.execute(self.get_opclass_query, [constraint.name]) + self.assertCountEqual( + cursor.fetchall(), + [('varchar_pattern_ops', constraint.name)], + ) + class ExclusionConstraintTests(PostgreSQLTestCase): def get_constraints(self, table):