Fixed #30916 -- Added support for functional unique constraints.

Thanks Ian Foote and Mariusz Felisiak for reviews.
This commit is contained in:
Hannes Ljungberg 2021-02-06 20:45:54 +01:00 committed by Mariusz Felisiak
parent 19ce1d493a
commit 3aa545281e
15 changed files with 779 additions and 38 deletions

View File

@ -1184,16 +1184,16 @@ class BaseDatabaseSchemaEditor:
def _unique_sql( def _unique_sql(
self, model, fields, name, condition=None, deferrable=None, self, model, fields, name, condition=None, deferrable=None,
include=None, opclasses=None, include=None, opclasses=None, expressions=None,
): ):
if ( if (
deferrable and deferrable and
not self.connection.features.supports_deferrable_unique_constraints not self.connection.features.supports_deferrable_unique_constraints
): ):
return None return None
if condition or include or opclasses: if condition or include or opclasses or expressions:
# Databases support conditional and covering unique constraints via # Databases support conditional, covering, and functional unique
# a unique index. # constraints via a unique index.
sql = self._create_unique_sql( sql = self._create_unique_sql(
model, model,
fields, fields,
@ -1201,6 +1201,7 @@ class BaseDatabaseSchemaEditor:
condition=condition, condition=condition,
include=include, include=include,
opclasses=opclasses, opclasses=opclasses,
expressions=expressions,
) )
if sql: if sql:
self.deferred_sql.append(sql) self.deferred_sql.append(sql)
@ -1216,7 +1217,7 @@ class BaseDatabaseSchemaEditor:
def _create_unique_sql( def _create_unique_sql(
self, model, columns, name=None, condition=None, deferrable=None, self, model, columns, name=None, condition=None, deferrable=None,
include=None, opclasses=None, include=None, opclasses=None, expressions=None,
): ):
if ( if (
( (
@ -1224,23 +1225,28 @@ class BaseDatabaseSchemaEditor:
not self.connection.features.supports_deferrable_unique_constraints not self.connection.features.supports_deferrable_unique_constraints
) or ) or
(condition and not self.connection.features.supports_partial_indexes) or (condition and not self.connection.features.supports_partial_indexes) or
(include and not self.connection.features.supports_covering_indexes) (include and not self.connection.features.supports_covering_indexes) or
(expressions and not self.connection.features.supports_expression_indexes)
): ):
return None return None
def create_unique_name(*args, **kwargs): def create_unique_name(*args, **kwargs):
return self.quote_name(self._create_index_name(*args, **kwargs)) return self.quote_name(self._create_index_name(*args, **kwargs))
compiler = Query(model, alias_cols=False).get_compiler(connection=self.connection)
table = Table(model._meta.db_table, self.quote_name) table = Table(model._meta.db_table, self.quote_name)
if name is None: if name is None:
name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name) name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name)
else: else:
name = self.quote_name(name) name = self.quote_name(name)
columns = self._index_columns(table, columns, col_suffixes=(), opclasses=opclasses) if condition or include or opclasses or expressions:
if condition or include or opclasses:
sql = self.sql_create_unique_index sql = self.sql_create_unique_index
else: else:
sql = self.sql_create_unique sql = self.sql_create_unique
if columns:
columns = self._index_columns(table, columns, col_suffixes=(), opclasses=opclasses)
else:
columns = Expressions(model._meta.db_table, expressions, compiler, self.quote_value)
return Statement( return Statement(
sql, sql,
table=table, table=table,
@ -1253,7 +1259,7 @@ class BaseDatabaseSchemaEditor:
def _delete_unique_sql( def _delete_unique_sql(
self, model, name, condition=None, deferrable=None, include=None, self, model, name, condition=None, deferrable=None, include=None,
opclasses=None, opclasses=None, expressions=None,
): ):
if ( if (
( (
@ -1261,10 +1267,12 @@ class BaseDatabaseSchemaEditor:
not self.connection.features.supports_deferrable_unique_constraints not self.connection.features.supports_deferrable_unique_constraints
) or ) or
(condition and not self.connection.features.supports_partial_indexes) or (condition and not self.connection.features.supports_partial_indexes) or
(include and not self.connection.features.supports_covering_indexes) (include and not self.connection.features.supports_covering_indexes) or
(expressions and not self.connection.features.supports_expression_indexes)
): ):
return None return None
if condition or include or opclasses: if condition or include or opclasses or expressions:
sql = self.sql_delete_index sql = self.sql_delete_index
else: else:
sql = self.sql_delete_unique sql = self.sql_delete_unique

View File

@ -289,7 +289,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
constraints[index] = { constraints[index] = {
'columns': OrderedSet(), 'columns': OrderedSet(),
'primary_key': False, 'primary_key': False,
'unique': False, 'unique': not non_unique,
'check': False, 'check': False,
'foreign_key': None, 'foreign_key': None,
} }

View File

@ -307,6 +307,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
SELECT SELECT
ind.index_name, ind.index_name,
LOWER(ind.index_type), LOWER(ind.index_type),
LOWER(ind.uniqueness),
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.column_position), LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.column_position),
LISTAGG(cols.descend, ',') WITHIN GROUP (ORDER BY cols.column_position) LISTAGG(cols.descend, ',') WITHIN GROUP (ORDER BY cols.column_position)
FROM FROM
@ -318,13 +319,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
FROM user_constraints cons FROM user_constraints cons
WHERE ind.index_name = cons.index_name WHERE ind.index_name = cons.index_name
) AND cols.index_name = ind.index_name ) AND cols.index_name = ind.index_name
GROUP BY ind.index_name, ind.index_type GROUP BY ind.index_name, ind.index_type, ind.uniqueness
""", [table_name]) """, [table_name])
for constraint, type_, columns, orders in cursor.fetchall(): for constraint, type_, unique, columns, orders in cursor.fetchall():
constraint = self.identifier_converter(constraint) constraint = self.identifier_converter(constraint)
constraints[constraint] = { constraints[constraint] = {
'primary_key': False, 'primary_key': False,
'unique': False, 'unique': unique == 'unique',
'foreign_key': None, 'foreign_key': None,
'check': False, 'check': False,
'index': True, 'index': True,

View File

@ -419,13 +419,17 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
self.delete_model(old_field.remote_field.through) self.delete_model(old_field.remote_field.through)
def add_constraint(self, model, constraint): def add_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and constraint.condition: if isinstance(constraint, UniqueConstraint) and (
constraint.condition or constraint.contains_expressions
):
super().add_constraint(model, constraint) super().add_constraint(model, constraint)
else: else:
self._remake_table(model) self._remake_table(model)
def remove_constraint(self, model, constraint): def remove_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and constraint.condition: if isinstance(constraint, UniqueConstraint) and (
constraint.condition or constraint.contains_expressions
):
super().remove_constraint(model, constraint) super().remove_constraint(model, constraint)
else: else:
self._remake_table(model) self._remake_table(model)

View File

@ -2039,6 +2039,25 @@ class Model(metaclass=ModelBase):
id='models.W039', id='models.W039',
) )
) )
if not (
connection.features.supports_expression_indexes or
'supports_expression_indexes' in cls._meta.required_db_features
) and any(
isinstance(constraint, UniqueConstraint) and constraint.contains_expressions
for constraint in cls._meta.constraints
):
errors.append(
checks.Warning(
'%s does not support unique constraints on '
'expressions.' % connection.display_name,
hint=(
"A constraint won't be created. Silence this "
"warning if you don't care about it."
),
obj=cls,
id='models.W044',
)
)
fields = set(chain.from_iterable( fields = set(chain.from_iterable(
(*constraint.fields, *constraint.include) (*constraint.fields, *constraint.include)
for constraint in cls._meta.constraints if isinstance(constraint, UniqueConstraint) for constraint in cls._meta.constraints if isinstance(constraint, UniqueConstraint)
@ -2051,6 +2070,12 @@ class Model(metaclass=ModelBase):
'supports_partial_indexes' not in cls._meta.required_db_features 'supports_partial_indexes' not in cls._meta.required_db_features
) and isinstance(constraint.condition, Q): ) and isinstance(constraint.condition, Q):
references.update(cls._get_expr_references(constraint.condition)) references.update(cls._get_expr_references(constraint.condition))
if (
connection.features.supports_expression_indexes or
'supports_expression_indexes' not in cls._meta.required_db_features
) and constraint.contains_expressions:
for expression in constraint.expressions:
references.update(cls._get_expr_references(expression))
elif isinstance(constraint, CheckConstraint): elif isinstance(constraint, CheckConstraint):
if ( if (
connection.features.supports_table_check_constraints or connection.features.supports_table_check_constraints or

View File

@ -1,5 +1,7 @@
from enum import Enum from enum import Enum
from django.db.models.expressions import ExpressionList, F
from django.db.models.indexes import IndexExpression
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
@ -10,6 +12,10 @@ class BaseConstraint:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
@property
def contains_expressions(self):
return False
def constraint_sql(self, model, schema_editor): def constraint_sql(self, model, schema_editor):
raise NotImplementedError('This method must be implemented by a subclass.') raise NotImplementedError('This method must be implemented by a subclass.')
@ -83,16 +89,25 @@ class Deferrable(Enum):
class UniqueConstraint(BaseConstraint): class UniqueConstraint(BaseConstraint):
def __init__( def __init__(
self, self,
*, *expressions,
fields, fields=(),
name, name=None,
condition=None, condition=None,
deferrable=None, deferrable=None,
include=None, include=None,
opclasses=(), opclasses=(),
): ):
if not fields: if not name:
raise ValueError('At least one field is required to define a unique constraint.') raise ValueError('A unique constraint must be named.')
if not expressions and not fields:
raise ValueError(
'At least one field or expression is required to define a '
'unique constraint.'
)
if expressions and fields:
raise ValueError(
'UniqueConstraint.fields and expressions are mutually exclusive.'
)
if not isinstance(condition, (type(None), Q)): if not isinstance(condition, (type(None), Q)):
raise ValueError('UniqueConstraint.condition must be a Q instance.') raise ValueError('UniqueConstraint.condition must be a Q instance.')
if condition and deferrable: if condition and deferrable:
@ -107,6 +122,15 @@ class UniqueConstraint(BaseConstraint):
raise ValueError( raise ValueError(
'UniqueConstraint with opclasses cannot be deferred.' 'UniqueConstraint with opclasses cannot be deferred.'
) )
if expressions and deferrable:
raise ValueError(
'UniqueConstraint with expressions cannot be deferred.'
)
if expressions and opclasses:
raise ValueError(
'UniqueConstraint.opclasses cannot be used with expressions. '
'Use django.contrib.postgres.indexes.OpClass() instead.'
)
if not isinstance(deferrable, (type(None), Deferrable)): if not isinstance(deferrable, (type(None), Deferrable)):
raise ValueError( raise ValueError(
'UniqueConstraint.deferrable must be a Deferrable instance.' 'UniqueConstraint.deferrable must be a Deferrable instance.'
@ -125,8 +149,16 @@ class UniqueConstraint(BaseConstraint):
self.deferrable = deferrable self.deferrable = deferrable
self.include = tuple(include) if include else () self.include = tuple(include) if include else ()
self.opclasses = opclasses self.opclasses = opclasses
self.expressions = tuple(
F(expression) if isinstance(expression, str) else expression
for expression in expressions
)
super().__init__(name) super().__init__(name)
@property
def contains_expressions(self):
return bool(self.expressions)
def _get_condition_sql(self, model, schema_editor): def _get_condition_sql(self, model, schema_editor):
if self.condition is None: if self.condition is None:
return None return None
@ -136,39 +168,55 @@ class UniqueConstraint(BaseConstraint):
sql, params = where.as_sql(compiler, schema_editor.connection) sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params) return sql % tuple(schema_editor.quote_value(p) for p in params)
def _get_index_expressions(self, model, schema_editor):
if not self.expressions:
return None
index_expressions = []
for expression in self.expressions:
index_expression = IndexExpression(expression)
index_expression.set_wrapper_classes(schema_editor.connection)
index_expressions.append(index_expression)
return ExpressionList(*index_expressions).resolve_expression(
Query(model, alias_cols=False),
)
def constraint_sql(self, model, schema_editor): def constraint_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name).column for field_name in self.fields] fields = [model._meta.get_field(field_name).column for field_name in self.fields]
include = [model._meta.get_field(field_name).column for field_name in self.include] include = [model._meta.get_field(field_name).column for field_name in self.include]
condition = self._get_condition_sql(model, schema_editor) condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._unique_sql( return schema_editor._unique_sql(
model, fields, self.name, condition=condition, model, fields, self.name, condition=condition,
deferrable=self.deferrable, include=include, deferrable=self.deferrable, include=include,
opclasses=self.opclasses, opclasses=self.opclasses, expressions=expressions,
) )
def create_sql(self, model, schema_editor): def create_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name).column for field_name in self.fields] fields = [model._meta.get_field(field_name).column for field_name in self.fields]
include = [model._meta.get_field(field_name).column for field_name in self.include] include = [model._meta.get_field(field_name).column for field_name in self.include]
condition = self._get_condition_sql(model, schema_editor) condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._create_unique_sql( return schema_editor._create_unique_sql(
model, fields, self.name, condition=condition, model, fields, self.name, condition=condition,
deferrable=self.deferrable, include=include, deferrable=self.deferrable, include=include,
opclasses=self.opclasses, opclasses=self.opclasses, expressions=expressions,
) )
def remove_sql(self, model, schema_editor): def remove_sql(self, model, schema_editor):
condition = self._get_condition_sql(model, schema_editor) condition = self._get_condition_sql(model, schema_editor)
include = [model._meta.get_field(field_name).column for field_name in self.include] include = [model._meta.get_field(field_name).column for field_name in self.include]
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._delete_unique_sql( return schema_editor._delete_unique_sql(
model, self.name, condition=condition, deferrable=self.deferrable, model, self.name, condition=condition, deferrable=self.deferrable,
include=include, opclasses=self.opclasses, include=include, opclasses=self.opclasses, expressions=expressions,
) )
def __repr__(self): def __repr__(self):
return '<%s: fields=%s name=%s%s%s%s%s>' % ( return '<%s:%s%s%s%s%s%s%s>' % (
self.__class__.__qualname__, self.__class__.__qualname__,
repr(self.fields), '' if not self.fields else ' fields=%s' % repr(self.fields),
repr(self.name), '' if not self.expressions else ' expressions=%s' % repr(self.expressions),
' name=%s' % repr(self.name),
'' if self.condition is None else ' condition=%s' % self.condition, '' if self.condition is None else ' condition=%s' % self.condition,
'' if self.deferrable is None else ' deferrable=%s' % self.deferrable, '' if self.deferrable is None else ' deferrable=%s' % self.deferrable,
'' if not self.include else ' include=%s' % repr(self.include), '' if not self.include else ' include=%s' % repr(self.include),
@ -183,13 +231,15 @@ class UniqueConstraint(BaseConstraint):
self.condition == other.condition and self.condition == other.condition and
self.deferrable == other.deferrable and self.deferrable == other.deferrable and
self.include == other.include and self.include == other.include and
self.opclasses == other.opclasses self.opclasses == other.opclasses and
self.expressions == other.expressions
) )
return super().__eq__(other) return super().__eq__(other)
def deconstruct(self): def deconstruct(self):
path, args, kwargs = super().deconstruct() path, args, kwargs = super().deconstruct()
kwargs['fields'] = self.fields if self.fields:
kwargs['fields'] = self.fields
if self.condition: if self.condition:
kwargs['condition'] = self.condition kwargs['condition'] = self.condition
if self.deferrable: if self.deferrable:
@ -198,4 +248,4 @@ class UniqueConstraint(BaseConstraint):
kwargs['include'] = self.include kwargs['include'] = self.include
if self.opclasses: if self.opclasses:
kwargs['opclasses'] = self.opclasses kwargs['opclasses'] = self.opclasses
return path, args, kwargs return path, self.expressions, kwargs

View File

@ -391,6 +391,8 @@ Models
* **models.W042**: Auto-created primary key used when not defining a primary * **models.W042**: Auto-created primary key used when not defining a primary
key type, by default ``django.db.models.AutoField``. key type, by default ``django.db.models.AutoField``.
* **models.W043**: ``<database>`` does not support indexes on expressions. * **models.W043**: ``<database>`` does not support indexes on expressions.
* **models.W044**: ``<database>`` does not support unique constraints on
expressions.
Security Security
-------- --------

View File

@ -183,10 +183,10 @@ available from the ``django.contrib.postgres.indexes`` module.
.. class:: OpClass(expression, name) .. class:: OpClass(expression, name)
An ``OpClass()`` expression represents the ``expression`` with a custom An ``OpClass()`` expression represents the ``expression`` with a custom
`operator class`_ that can be used to define functional indexes. To use it, `operator class`_ that can be used to define functional indexes or unique
you need to add ``'django.contrib.postgres'`` in your constraints. To use it, you need to add ``'django.contrib.postgres'`` in
:setting:`INSTALLED_APPS`. Set the ``name`` parameter to the name of the your :setting:`INSTALLED_APPS`. Set the ``name`` parameter to the name of
`operator class`_. the `operator class`_.
For example:: For example::
@ -197,4 +197,18 @@ available from the ``django.contrib.postgres.indexes`` module.
creates an index on ``Lower('username')`` using ``varchar_pattern_ops``. creates an index on ``Lower('username')`` using ``varchar_pattern_ops``.
Another example::
UniqueConstraint(
OpClass(Upper('description'), name='text_pattern_ops'),
name='upper_description_unique',
)
creates a unique constraint on ``Upper('description')`` using
``text_pattern_ops``.
.. versionchanged:: 4.0
Support for functional unique constraints was added.
.. _operator class: https://www.postgresql.org/docs/current/indexes-opclass.html .. _operator class: https://www.postgresql.org/docs/current/indexes-opclass.html

View File

@ -69,10 +69,30 @@ constraint.
``UniqueConstraint`` ``UniqueConstraint``
==================== ====================
.. class:: UniqueConstraint(*, fields, name, condition=None, deferrable=None, include=None, opclasses=()) .. class:: UniqueConstraint(*expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=())
Creates a unique constraint in the database. Creates a unique constraint in the database.
``expressions``
---------------
.. attribute:: UniqueConstraint.expressions
.. versionadded:: 4.0
Positional argument ``*expressions`` allows creating functional unique
constraints on expressions and database functions.
For example::
UniqueConstraint(Lower('name').desc(), 'category', name='unique_lower_name_category')
creates a unique constraint on the lowercased value of the ``name`` field in
descending order and the ``category`` field in the default ascending order.
Functional unique constraints have the same database restrictions as
:attr:`Index.expressions`.
``fields`` ``fields``
---------- ----------

View File

@ -28,6 +28,36 @@ The Django 3.2.x series is the last to support Python 3.6 and 3.7.
What's new in Django 4.0 What's new in Django 4.0
======================== ========================
Functional unique constraints
-----------------------------
The new :attr:`*expressions <django.db.models.UniqueConstraint.expressions>`
positional argument of
:class:`UniqueConstraint() <django.db.models.UniqueConstraint>` enables
creating functional unique constraints on expressions and database functions.
For example::
from django.db import models
from django.db.models import UniqueConstraint
from django.db.models.functions import Lower
class MyModel(models.Model):
first_name = models.CharField(max_length=255)
last_name = models.CharField(max_length=255)
class Meta:
indexes = [
UniqueConstraint(
Lower('first_name'),
Lower('last_name').desc(),
name='first_last_name_unique',
),
]
Functional unique constraints are added to models using the
:attr:`Meta.constraints <django.db.models.Options.constraints>` option.
Minor features Minor features
-------------- --------------

View File

@ -2,7 +2,9 @@ from unittest import mock
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import IntegrityError, connection, models from django.db import IntegrityError, connection, models
from django.db.models import F
from django.db.models.constraints import BaseConstraint from django.db.models.constraints import BaseConstraint
from django.db.models.functions import Lower
from django.db.transaction import atomic from django.db.transaction import atomic
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
@ -25,6 +27,10 @@ class BaseConstraintTests(SimpleTestCase):
with self.assertRaisesMessage(NotImplementedError, msg): with self.assertRaisesMessage(NotImplementedError, msg):
c.constraint_sql(None, None) c.constraint_sql(None, None)
def test_contains_expressions(self):
c = BaseConstraint('name')
self.assertIs(c.contains_expressions, False)
def test_create_sql(self): def test_create_sql(self):
c = BaseConstraint('name') c = BaseConstraint('name')
msg = 'This method must be implemented by a subclass.' msg = 'This method must be implemented by a subclass.'
@ -218,6 +224,25 @@ class UniqueConstraintTests(TestCase):
self.assertEqual(constraint_1, constraint_1) self.assertEqual(constraint_1, constraint_1)
self.assertNotEqual(constraint_1, constraint_2) self.assertNotEqual(constraint_1, constraint_2)
def test_eq_with_expressions(self):
constraint = models.UniqueConstraint(
Lower('title'),
F('author'),
name='book_func_uq',
)
same_constraint = models.UniqueConstraint(
Lower('title'),
'author',
name='book_func_uq',
)
another_constraint = models.UniqueConstraint(
Lower('title'),
name='book_func_uq',
)
self.assertEqual(constraint, same_constraint)
self.assertEqual(constraint, mock.ANY)
self.assertNotEqual(constraint, another_constraint)
def test_repr(self): def test_repr(self):
fields = ['foo', 'bar'] fields = ['foo', 'bar']
name = 'unique_fields' name = 'unique_fields'
@ -275,6 +300,18 @@ class UniqueConstraintTests(TestCase):
"opclasses=['text_pattern_ops', 'varchar_pattern_ops']>", "opclasses=['text_pattern_ops', 'varchar_pattern_ops']>",
) )
def test_repr_with_expressions(self):
constraint = models.UniqueConstraint(
Lower('title'),
F('author'),
name='book_func_uq',
)
self.assertEqual(
repr(constraint),
"<UniqueConstraint: expressions=(Lower(F(title)), F(author)) "
"name='book_func_uq'>",
)
def test_deconstruction(self): def test_deconstruction(self):
fields = ['foo', 'bar'] fields = ['foo', 'bar']
name = 'unique_fields' name = 'unique_fields'
@ -339,6 +376,14 @@ class UniqueConstraintTests(TestCase):
'opclasses': opclasses, 'opclasses': opclasses,
}) })
def test_deconstruction_with_expressions(self):
name = 'unique_fields'
constraint = models.UniqueConstraint(Lower('title'), name=name)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, 'django.db.models.UniqueConstraint')
self.assertEqual(args, (Lower('title'),))
self.assertEqual(kwargs, {'name': name})
def test_database_constraint(self): def test_database_constraint(self):
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):
UniqueConstraintProduct.objects.create(name=self.p1.name, color=self.p1.color) UniqueConstraintProduct.objects.create(name=self.p1.name, color=self.p1.color)
@ -434,6 +479,15 @@ class UniqueConstraintTests(TestCase):
deferrable=models.Deferrable.DEFERRED, deferrable=models.Deferrable.DEFERRED,
) )
def test_deferrable_with_expressions(self):
message = 'UniqueConstraint with expressions cannot be deferred.'
with self.assertRaisesMessage(ValueError, message):
models.UniqueConstraint(
Lower('name'),
name='deferred_expression_unique',
deferrable=models.Deferrable.DEFERRED,
)
def test_invalid_defer_argument(self): def test_invalid_defer_argument(self):
message = 'UniqueConstraint.deferrable must be a Deferrable instance.' message = 'UniqueConstraint.deferrable must be a Deferrable instance.'
with self.assertRaisesMessage(ValueError, message): with self.assertRaisesMessage(ValueError, message):
@ -481,3 +535,33 @@ class UniqueConstraintTests(TestCase):
fields=['field'], fields=['field'],
opclasses=['foo', 'bar'], opclasses=['foo', 'bar'],
) )
def test_requires_field_or_expression(self):
msg = (
'At least one field or expression is required to define a unique '
'constraint.'
)
with self.assertRaisesMessage(ValueError, msg):
models.UniqueConstraint(name='name')
def test_expressions_and_fields_mutually_exclusive(self):
msg = 'UniqueConstraint.fields and expressions are mutually exclusive.'
with self.assertRaisesMessage(ValueError, msg):
models.UniqueConstraint(Lower('field_1'), fields=['field_2'], name='name')
def test_expressions_with_opclasses(self):
msg = (
'UniqueConstraint.opclasses cannot be used with expressions. Use '
'django.contrib.postgres.indexes.OpClass() instead.'
)
with self.assertRaisesMessage(ValueError, msg):
models.UniqueConstraint(
Lower('field'),
name='test_func_opclass',
opclasses=['jsonb_path_ops'],
)
def test_requires_name(self):
msg = 'A unique constraint must be named.'
with self.assertRaisesMessage(ValueError, msg):
models.UniqueConstraint(fields=['field'])

View File

@ -2178,3 +2178,144 @@ class ConstraintsTests(TestCase):
] ]
self.assertEqual(Model.check(databases=self.databases), []) self.assertEqual(Model.check(databases=self.databases), [])
def test_func_unique_constraint(self):
class Model(models.Model):
name = models.CharField(max_length=10)
class Meta:
constraints = [
models.UniqueConstraint(Lower('name'), name='lower_name_uq'),
]
warn = Warning(
'%s does not support unique constraints on expressions.'
% connection.display_name,
hint=(
"A constraint won't be created. Silence this warning if you "
"don't care about it."
),
obj=Model,
id='models.W044',
)
expected = [] if connection.features.supports_expression_indexes else [warn]
self.assertEqual(Model.check(databases=self.databases), expected)
def test_func_unique_constraint_required_db_features(self):
class Model(models.Model):
name = models.CharField(max_length=10)
class Meta:
constraints = [
models.UniqueConstraint(Lower('name'), name='lower_name_unq'),
]
required_db_features = {'supports_expression_indexes'}
self.assertEqual(Model.check(databases=self.databases), [])
@skipUnlessDBFeature('supports_expression_indexes')
def test_func_unique_constraint_expression_custom_lookup(self):
class Model(models.Model):
height = models.IntegerField()
weight = models.IntegerField()
class Meta:
constraints = [
models.UniqueConstraint(
models.F('height') / (models.F('weight__abs') + models.Value(5)),
name='name',
),
]
with register_lookup(models.IntegerField, Abs):
self.assertEqual(Model.check(databases=self.databases), [])
@skipUnlessDBFeature('supports_expression_indexes')
def test_func_unique_constraint_pointing_to_missing_field(self):
class Model(models.Model):
class Meta:
constraints = [
models.UniqueConstraint(Lower('missing_field').desc(), name='name'),
]
self.assertEqual(Model.check(databases=self.databases), [
Error(
"'constraints' refers to the nonexistent field "
"'missing_field'.",
obj=Model,
id='models.E012',
),
])
@skipUnlessDBFeature('supports_expression_indexes')
def test_func_unique_constraint_pointing_to_missing_field_nested(self):
class Model(models.Model):
class Meta:
constraints = [
models.UniqueConstraint(Abs(Round('missing_field')), name='name'),
]
self.assertEqual(Model.check(databases=self.databases), [
Error(
"'constraints' refers to the nonexistent field "
"'missing_field'.",
obj=Model,
id='models.E012',
),
])
@skipUnlessDBFeature('supports_expression_indexes')
def test_func_unique_constraint_pointing_to_m2m_field(self):
class Model(models.Model):
m2m = models.ManyToManyField('self')
class Meta:
constraints = [models.UniqueConstraint(Lower('m2m'), name='name')]
self.assertEqual(Model.check(databases=self.databases), [
Error(
"'constraints' refers to a ManyToManyField 'm2m', but "
"ManyToManyFields are not permitted in 'constraints'.",
obj=Model,
id='models.E013',
),
])
@skipUnlessDBFeature('supports_expression_indexes')
def test_func_unique_constraint_pointing_to_non_local_field(self):
class Foo(models.Model):
field1 = models.CharField(max_length=15)
class Bar(Foo):
class Meta:
constraints = [models.UniqueConstraint(Lower('field1'), name='name')]
self.assertEqual(Bar.check(databases=self.databases), [
Error(
"'constraints' refers to field 'field1' which is not local to "
"model 'Bar'.",
hint='This issue may be caused by multi-table inheritance.',
obj=Bar,
id='models.E016',
),
])
@skipUnlessDBFeature('supports_expression_indexes')
def test_func_unique_constraint_pointing_to_fk(self):
class Foo(models.Model):
id = models.CharField(primary_key=True, max_length=255)
class Bar(models.Model):
foo_1 = models.ForeignKey(Foo, models.CASCADE, related_name='bar_1')
foo_2 = models.ForeignKey(Foo, models.CASCADE, related_name='bar_2')
class Meta:
constraints = [
models.UniqueConstraint(
Lower('foo_1_id'),
Lower('foo_2'),
name='name',
),
]
self.assertEqual(Bar.check(databases=self.databases), [])

View File

@ -2562,6 +2562,99 @@ class OperationTests(OperationTestBase):
'name': 'covering_pink_constraint_rm', 'name': 'covering_pink_constraint_rm',
}) })
def test_add_func_unique_constraint(self):
app_label = 'test_adfuncuc'
constraint_name = f'{app_label}_pony_abs_uq'
table_name = f'{app_label}_pony'
project_state = self.set_up_test_model(app_label)
constraint = models.UniqueConstraint(Abs('weight'), name=constraint_name)
operation = migrations.AddConstraint('Pony', constraint)
self.assertEqual(
operation.describe(),
'Create constraint test_adfuncuc_pony_abs_uq on model Pony',
)
self.assertEqual(
operation.migration_name_fragment,
'pony_test_adfuncuc_pony_abs_uq',
)
new_state = project_state.clone()
operation.state_forwards(app_label, new_state)
self.assertEqual(len(new_state.models[app_label, 'pony'].options['constraints']), 1)
self.assertIndexNameNotExists(table_name, constraint_name)
# Add constraint.
with connection.schema_editor() as editor:
operation.database_forwards(app_label, editor, project_state, new_state)
Pony = new_state.apps.get_model(app_label, 'Pony')
Pony.objects.create(weight=4.0)
if connection.features.supports_expression_indexes:
self.assertIndexNameExists(table_name, constraint_name)
with self.assertRaises(IntegrityError):
Pony.objects.create(weight=-4.0)
else:
self.assertIndexNameNotExists(table_name, constraint_name)
Pony.objects.create(weight=-4.0)
# Reversal.
with connection.schema_editor() as editor:
operation.database_backwards(app_label, editor, new_state, project_state)
self.assertIndexNameNotExists(table_name, constraint_name)
# Constraint doesn't work.
Pony.objects.create(weight=-4.0)
# Deconstruction.
definition = operation.deconstruct()
self.assertEqual(definition[0], 'AddConstraint')
self.assertEqual(definition[1], [])
self.assertEqual(
definition[2],
{'model_name': 'Pony', 'constraint': constraint},
)
def test_remove_func_unique_constraint(self):
app_label = 'test_rmfuncuc'
constraint_name = f'{app_label}_pony_abs_uq'
table_name = f'{app_label}_pony'
project_state = self.set_up_test_model(app_label, constraints=[
models.UniqueConstraint(Abs('weight'), name=constraint_name),
])
self.assertTableExists(table_name)
if connection.features.supports_expression_indexes:
self.assertIndexNameExists(table_name, constraint_name)
operation = migrations.RemoveConstraint('Pony', constraint_name)
self.assertEqual(
operation.describe(),
'Remove constraint test_rmfuncuc_pony_abs_uq from model Pony',
)
self.assertEqual(
operation.migration_name_fragment,
'remove_pony_test_rmfuncuc_pony_abs_uq',
)
new_state = project_state.clone()
operation.state_forwards(app_label, new_state)
self.assertEqual(len(new_state.models[app_label, 'pony'].options['constraints']), 0)
Pony = new_state.apps.get_model(app_label, 'Pony')
self.assertEqual(len(Pony._meta.constraints), 0)
# Remove constraint.
with connection.schema_editor() as editor:
operation.database_forwards(app_label, editor, project_state, new_state)
self.assertIndexNameNotExists(table_name, constraint_name)
# Constraint doesn't work.
Pony.objects.create(pink=1, weight=4.0)
Pony.objects.create(pink=1, weight=-4.0).delete()
# Reversal.
with connection.schema_editor() as editor:
operation.database_backwards(app_label, editor, new_state, project_state)
if connection.features.supports_expression_indexes:
self.assertIndexNameExists(table_name, constraint_name)
with self.assertRaises(IntegrityError):
Pony.objects.create(weight=-4.0)
else:
self.assertIndexNameNotExists(table_name, constraint_name)
Pony.objects.create(weight=-4.0)
# Deconstruction.
definition = operation.deconstruct()
self.assertEqual(definition[0], 'RemoveConstraint')
self.assertEqual(definition[1], [])
self.assertEqual(definition[2], {'model_name': 'Pony', 'name': constraint_name})
def test_alter_model_options(self): def test_alter_model_options(self):
""" """
Tests the AlterModelOptions operation. Tests the AlterModelOptions operation.

View File

@ -1,6 +1,7 @@
import datetime import datetime
from unittest import mock from unittest import mock
from django.contrib.postgres.indexes import OpClass
from django.db import ( from django.db import (
IntegrityError, NotSupportedError, connection, transaction, IntegrityError, NotSupportedError, connection, transaction,
) )
@ -8,8 +9,8 @@ from django.db.models import (
CheckConstraint, Deferrable, F, Func, IntegerField, Q, UniqueConstraint, CheckConstraint, Deferrable, F, Func, IntegerField, Q, UniqueConstraint,
) )
from django.db.models.fields.json import KeyTextTransform from django.db.models.fields.json import KeyTextTransform
from django.db.models.functions import Cast, Left from django.db.models.functions import Cast, Left, Lower
from django.test import skipUnlessDBFeature from django.test import modify_settings, skipUnlessDBFeature
from django.utils import timezone from django.utils import timezone
from . import PostgreSQLTestCase from . import PostgreSQLTestCase
@ -26,6 +27,7 @@ except ImportError:
pass pass
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
class SchemaTests(PostgreSQLTestCase): class SchemaTests(PostgreSQLTestCase):
get_opclass_query = ''' get_opclass_query = '''
SELECT opcname, c.relname FROM pg_opclass AS oc SELECT opcname, c.relname FROM pg_opclass AS oc
@ -166,6 +168,33 @@ class SchemaTests(PostgreSQLTestCase):
[('varchar_pattern_ops', constraint.name)], [('varchar_pattern_ops', constraint.name)],
) )
@skipUnlessDBFeature('supports_expression_indexes')
def test_opclass_func(self):
constraint = UniqueConstraint(
OpClass(Lower('scene'), name='text_pattern_ops'),
name='test_opclass_func',
)
with connection.schema_editor() as editor:
editor.add_constraint(Scene, constraint)
constraints = self.get_constraints(Scene._meta.db_table)
self.assertIs(constraints[constraint.name]['unique'], True)
self.assertIn(constraint.name, constraints)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query, [constraint.name])
self.assertEqual(
cursor.fetchall(),
[('text_pattern_ops', constraint.name)],
)
Scene.objects.create(scene='Scene 10', setting='The dark forest of Ewing')
with self.assertRaises(IntegrityError), transaction.atomic():
Scene.objects.create(scene='ScEnE 10', setting="Sir Bedemir's Castle")
Scene.objects.create(scene='Scene 5', setting="Sir Bedemir's Castle")
# Drop the constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(Scene, constraint)
self.assertNotIn(constraint.name, self.get_constraints(Scene._meta.db_table))
Scene.objects.create(scene='ScEnE 10', setting="Sir Bedemir's Castle")
class ExclusionConstraintTests(PostgreSQLTestCase): class ExclusionConstraintTests(PostgreSQLTestCase):
def get_constraints(self, table): def get_constraints(self, table):

View File

@ -2189,6 +2189,246 @@ class SchemaTests(TransactionTestCase):
AuthorWithUniqueNameAndBirthday._meta.constraints = [] AuthorWithUniqueNameAndBirthday._meta.constraints = []
editor.remove_constraint(AuthorWithUniqueNameAndBirthday, constraint) editor.remove_constraint(AuthorWithUniqueNameAndBirthday, constraint)
@skipUnlessDBFeature('supports_expression_indexes')
def test_func_unique_constraint(self):
with connection.schema_editor() as editor:
editor.create_model(Author)
constraint = UniqueConstraint(Upper('name').desc(), name='func_upper_uq')
# Add constraint.
with connection.schema_editor() as editor:
editor.add_constraint(Author, constraint)
sql = constraint.create_sql(Author, editor)
table = Author._meta.db_table
constraints = self.get_constraints(table)
if connection.features.supports_index_column_ordering:
self.assertIndexOrder(table, constraint.name, ['DESC'])
self.assertIn(constraint.name, constraints)
self.assertIs(constraints[constraint.name]['unique'], True)
# SQL contains a database function.
self.assertIs(sql.references_column(table, 'name'), True)
self.assertIn('UPPER(%s)' % editor.quote_name('name'), str(sql))
# Remove constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(Author, constraint)
self.assertNotIn(constraint.name, self.get_constraints(table))
@skipUnlessDBFeature('supports_expression_indexes')
def test_composite_func_unique_constraint(self):
with connection.schema_editor() as editor:
editor.create_model(Author)
editor.create_model(BookWithSlug)
constraint = UniqueConstraint(
Upper('title'),
Lower('slug'),
name='func_upper_lower_unq',
)
# Add constraint.
with connection.schema_editor() as editor:
editor.add_constraint(BookWithSlug, constraint)
sql = constraint.create_sql(BookWithSlug, editor)
table = BookWithSlug._meta.db_table
constraints = self.get_constraints(table)
self.assertIn(constraint.name, constraints)
self.assertIs(constraints[constraint.name]['unique'], True)
# SQL contains database functions.
self.assertIs(sql.references_column(table, 'title'), True)
self.assertIs(sql.references_column(table, 'slug'), True)
sql = str(sql)
self.assertIn('UPPER(%s)' % editor.quote_name('title'), sql)
self.assertIn('LOWER(%s)' % editor.quote_name('slug'), sql)
self.assertLess(sql.index('UPPER'), sql.index('LOWER'))
# Remove constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(BookWithSlug, constraint)
self.assertNotIn(constraint.name, self.get_constraints(table))
@skipUnlessDBFeature('supports_expression_indexes')
def test_unique_constraint_field_and_expression(self):
with connection.schema_editor() as editor:
editor.create_model(Author)
constraint = UniqueConstraint(
F('height').desc(),
'uuid',
Lower('name').asc(),
name='func_f_lower_field_unq',
)
# Add constraint.
with connection.schema_editor() as editor:
editor.add_constraint(Author, constraint)
sql = constraint.create_sql(Author, editor)
table = Author._meta.db_table
if connection.features.supports_index_column_ordering:
self.assertIndexOrder(table, constraint.name, ['DESC', 'ASC', 'ASC'])
constraints = self.get_constraints(table)
self.assertIs(constraints[constraint.name]['unique'], True)
self.assertEqual(len(constraints[constraint.name]['columns']), 3)
self.assertEqual(constraints[constraint.name]['columns'][1], 'uuid')
# SQL contains database functions and columns.
self.assertIs(sql.references_column(table, 'height'), True)
self.assertIs(sql.references_column(table, 'name'), True)
self.assertIs(sql.references_column(table, 'uuid'), True)
self.assertIn('LOWER(%s)' % editor.quote_name('name'), str(sql))
# Remove constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(Author, constraint)
self.assertNotIn(constraint.name, self.get_constraints(table))
@skipUnlessDBFeature('supports_expression_indexes', 'supports_partial_indexes')
def test_func_unique_constraint_partial(self):
with connection.schema_editor() as editor:
editor.create_model(Author)
constraint = UniqueConstraint(
Upper('name'),
name='func_upper_cond_weight_uq',
condition=Q(weight__isnull=False),
)
# Add constraint.
with connection.schema_editor() as editor:
editor.add_constraint(Author, constraint)
sql = constraint.create_sql(Author, editor)
table = Author._meta.db_table
constraints = self.get_constraints(table)
self.assertIn(constraint.name, constraints)
self.assertIs(constraints[constraint.name]['unique'], True)
self.assertIs(sql.references_column(table, 'name'), True)
self.assertIn('UPPER(%s)' % editor.quote_name('name'), str(sql))
self.assertIn(
'WHERE %s IS NOT NULL' % editor.quote_name('weight'),
str(sql),
)
# Remove constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(Author, constraint)
self.assertNotIn(constraint.name, self.get_constraints(table))
@skipUnlessDBFeature('supports_expression_indexes', 'supports_covering_indexes')
def test_func_unique_constraint_covering(self):
with connection.schema_editor() as editor:
editor.create_model(Author)
constraint = UniqueConstraint(
Upper('name'),
name='func_upper_covering_uq',
include=['weight', 'height'],
)
# Add constraint.
with connection.schema_editor() as editor:
editor.add_constraint(Author, constraint)
sql = constraint.create_sql(Author, editor)
table = Author._meta.db_table
constraints = self.get_constraints(table)
self.assertIn(constraint.name, constraints)
self.assertIs(constraints[constraint.name]['unique'], True)
self.assertEqual(
constraints[constraint.name]['columns'],
[None, 'weight', 'height'],
)
self.assertIs(sql.references_column(table, 'name'), True)
self.assertIs(sql.references_column(table, 'weight'), True)
self.assertIs(sql.references_column(table, 'height'), True)
self.assertIn('UPPER(%s)' % editor.quote_name('name'), str(sql))
self.assertIn(
'INCLUDE (%s, %s)' % (
editor.quote_name('weight'),
editor.quote_name('height'),
),
str(sql),
)
# Remove constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(Author, constraint)
self.assertNotIn(constraint.name, self.get_constraints(table))
@skipUnlessDBFeature('supports_expression_indexes')
def test_func_unique_constraint_lookups(self):
with connection.schema_editor() as editor:
editor.create_model(Author)
with register_lookup(CharField, Lower), register_lookup(IntegerField, Abs):
constraint = UniqueConstraint(
F('name__lower'),
F('weight__abs'),
name='func_lower_abs_lookup_uq',
)
# Add constraint.
with connection.schema_editor() as editor:
editor.add_constraint(Author, constraint)
sql = constraint.create_sql(Author, editor)
table = Author._meta.db_table
constraints = self.get_constraints(table)
self.assertIn(constraint.name, constraints)
self.assertIs(constraints[constraint.name]['unique'], True)
# SQL contains columns.
self.assertIs(sql.references_column(table, 'name'), True)
self.assertIs(sql.references_column(table, 'weight'), True)
# Remove constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(Author, constraint)
self.assertNotIn(constraint.name, self.get_constraints(table))
@skipUnlessDBFeature('supports_expression_indexes')
def test_func_unique_constraint_collate(self):
collation = connection.features.test_collations.get('non_default')
if not collation:
self.skipTest(
'This backend does not support case-insensitive collations.'
)
with connection.schema_editor() as editor:
editor.create_model(Author)
editor.create_model(BookWithSlug)
constraint = UniqueConstraint(
Collate(F('title'), collation=collation).desc(),
Collate('slug', collation=collation),
name='func_collate_uq',
)
# Add constraint.
with connection.schema_editor() as editor:
editor.add_constraint(BookWithSlug, constraint)
sql = constraint.create_sql(BookWithSlug, editor)
table = BookWithSlug._meta.db_table
constraints = self.get_constraints(table)
self.assertIn(constraint.name, constraints)
self.assertIs(constraints[constraint.name]['unique'], True)
if connection.features.supports_index_column_ordering:
self.assertIndexOrder(table, constraint.name, ['DESC', 'ASC'])
# SQL contains columns and a collation.
self.assertIs(sql.references_column(table, 'title'), True)
self.assertIs(sql.references_column(table, 'slug'), True)
self.assertIn('COLLATE %s' % editor.quote_name(collation), str(sql))
# Remove constraint.
with connection.schema_editor() as editor:
editor.remove_constraint(BookWithSlug, constraint)
self.assertNotIn(constraint.name, self.get_constraints(table))
@skipIfDBFeature('supports_expression_indexes')
def test_func_unique_constraint_unsupported(self):
# UniqueConstraint is ignored on databases that don't support indexes on
# expressions.
with connection.schema_editor() as editor:
editor.create_model(Author)
constraint = UniqueConstraint(F('name'), name='func_name_uq')
with connection.schema_editor() as editor, self.assertNumQueries(0):
self.assertIsNone(editor.add_constraint(Author, constraint))
self.assertIsNone(editor.remove_constraint(Author, constraint))
@skipUnlessDBFeature('supports_expression_indexes')
def test_func_unique_constraint_nonexistent_field(self):
constraint = UniqueConstraint(Lower('nonexistent'), name='func_nonexistent_uq')
msg = (
"Cannot resolve keyword 'nonexistent' into field. Choices are: "
"height, id, name, uuid, weight"
)
with self.assertRaisesMessage(FieldError, msg):
with connection.schema_editor() as editor:
editor.add_constraint(Author, constraint)
@skipUnlessDBFeature('supports_expression_indexes')
def test_func_unique_constraint_nondeterministic(self):
with connection.schema_editor() as editor:
editor.create_model(Author)
constraint = UniqueConstraint(Random(), name='func_random_uq')
with connection.schema_editor() as editor:
with self.assertRaises(DatabaseError):
editor.add_constraint(Author, constraint)
def test_index_together(self): def test_index_together(self):
""" """
Tests removing and adding index_together constraints on a model. Tests removing and adding index_together constraints on a model.