Fixed #29641 -- Added support for unique constraints in Meta.constraints.

This constraint is similar to Meta.unique_together but also allows
specifying a name.

Co-authored-by: Ian Foote <python@ian.feete.org>
This commit is contained in:
Simon Charette 2018-08-05 22:30:44 -04:00 committed by Tim Graham
parent 8eae094638
commit db13bca60a
7 changed files with 126 additions and 5 deletions

View File

@ -260,6 +260,15 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol) name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
name = name_token.value[1:-1] name = name_token.value[1:-1]
token = next_ttype(sqlparse.tokens.Keyword) token = next_ttype(sqlparse.tokens.Keyword)
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
constraints[name] = {
'unique': True,
'columns': [],
'primary_key': False,
'foreign_key': False,
'check': False,
'index': False,
}
if token.match(sqlparse.tokens.Keyword, 'CHECK'): if token.match(sqlparse.tokens.Keyword, 'CHECK'):
# Column check constraint # Column check constraint
if name is None: if name is None:

View File

@ -16,7 +16,7 @@ from django.db import (
connections, router, transaction, connections, router, transaction,
) )
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.constraints import CheckConstraint from django.db.models.constraints import CheckConstraint, UniqueConstraint
from django.db.models.deletion import CASCADE, Collector from django.db.models.deletion import CASCADE, Collector
from django.db.models.fields.related import ( from django.db.models.fields.related import (
ForeignObjectRel, OneToOneField, lazy_related_operation, resolve_relation, ForeignObjectRel, OneToOneField, lazy_related_operation, resolve_relation,
@ -982,9 +982,12 @@ class Model(metaclass=ModelBase):
unique_checks = [] unique_checks = []
unique_togethers = [(self.__class__, self._meta.unique_together)] unique_togethers = [(self.__class__, self._meta.unique_together)]
constraints = [(self.__class__, self._meta.constraints)]
for parent_class in self._meta.get_parent_list(): for parent_class in self._meta.get_parent_list():
if parent_class._meta.unique_together: if parent_class._meta.unique_together:
unique_togethers.append((parent_class, parent_class._meta.unique_together)) unique_togethers.append((parent_class, parent_class._meta.unique_together))
if parent_class._meta.constraints:
constraints.append((parent_class, parent_class._meta.constraints))
for model_class, unique_together in unique_togethers: for model_class, unique_together in unique_togethers:
for check in unique_together: for check in unique_together:
@ -992,6 +995,12 @@ class Model(metaclass=ModelBase):
# Add the check if the field isn't excluded. # Add the check if the field isn't excluded.
unique_checks.append((model_class, tuple(check))) unique_checks.append((model_class, tuple(check)))
for model_class, model_constraints in constraints:
for constraint in model_constraints:
if (isinstance(constraint, UniqueConstraint) and
not any(name in exclude for name in constraint.fields)):
unique_checks.append((model_class, constraint.fields))
# These are checks for the unique_for_<date/year/month>. # These are checks for the unique_for_<date/year/month>.
date_checks = [] date_checks = []

View File

@ -1,6 +1,6 @@
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
__all__ = ['CheckConstraint'] __all__ = ['CheckConstraint', 'UniqueConstraint']
class BaseConstraint: class BaseConstraint:
@ -68,3 +68,39 @@ class CheckConstraint(BaseConstraint):
path, args, kwargs = super().deconstruct() path, args, kwargs = super().deconstruct()
kwargs['check'] = self.check kwargs['check'] = self.check
return path, args, kwargs return path, args, kwargs
class UniqueConstraint(BaseConstraint):
def __init__(self, *, fields, name):
if not fields:
raise ValueError('At least one field is required to define a unique constraint.')
self.fields = tuple(fields)
super().__init__(name)
def constraint_sql(self, model, schema_editor):
columns = (
model._meta.get_field(field_name).column
for field_name in self.fields
)
return schema_editor.sql_unique_constraint % {
'columns': ', '.join(map(schema_editor.quote_name, columns)),
}
def create_sql(self, model, schema_editor):
columns = [model._meta.get_field(field_name).column for field_name in self.fields]
return schema_editor._create_unique_sql(model, columns, self.name)
def __repr__(self):
return '<%s: fields=%r name=%r>' % (self.__class__.__name__, self.fields, self.name)
def __eq__(self, other):
return (
isinstance(other, UniqueConstraint) and
self.name == other.name and
self.fields == other.fields
)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
kwargs['fields'] = self.fields
return path, args, kwargs

View File

@ -43,3 +43,28 @@ ensures the age field is never less than 18.
.. attribute:: CheckConstraint.name .. attribute:: CheckConstraint.name
The name of the constraint. The name of the constraint.
``UniqueConstraint``
====================
.. class:: UniqueConstraint(*, fields, name)
Creates a unique constraint in the database.
``fields``
----------
.. attribute:: UniqueConstraint.fields
A list of field names that specifies the unique set of columns you want the
constraint to enforce.
For example ``UniqueConstraint(fields=['room', 'date'], name='unique_location')``
ensures only one location can exist for each ``date``.
``name``
--------
.. attribute:: UniqueConstraint.name
The name of the constraint.

View File

@ -33,7 +33,8 @@ What's new in Django 2.2
Constraints Constraints
----------- -----------
The new :class:`~django.db.models.CheckConstraint` class enables adding custom The new :class:`~django.db.models.CheckConstraint` and
:class:`~django.db.models.UniqueConstraint` classes enable adding custom
database constraints. Constraints are added to models using the database constraints. Constraints are added to models using the
:attr:`Meta.constraints <django.db.models.Options.constraints>` option. :attr:`Meta.constraints <django.db.models.Options.constraints>` option.

View File

@ -3,8 +3,8 @@ from django.db import models
class Product(models.Model): class Product(models.Model):
name = models.CharField(max_length=255) name = models.CharField(max_length=255)
price = models.IntegerField() price = models.IntegerField(null=True)
discounted_price = models.IntegerField() discounted_price = models.IntegerField(null=True)
class Meta: class Meta:
constraints = [ constraints = [
@ -12,4 +12,5 @@ class Product(models.Model):
check=models.Q(price__gt=models.F('discounted_price')), check=models.Q(price__gt=models.F('discounted_price')),
name='price_gt_discounted_price', name='price_gt_discounted_price',
), ),
models.UniqueConstraint(fields=['name'], name='unique_name'),
] ]

View File

@ -1,3 +1,4 @@
from django.core.exceptions import ValidationError
from django.db import IntegrityError, connection, models from django.db import IntegrityError, connection, models
from django.db.models.constraints import BaseConstraint from django.db.models.constraints import BaseConstraint
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
@ -50,3 +51,42 @@ class CheckConstraintTests(TestCase):
if connection.features.uppercases_column_names: if connection.features.uppercases_column_names:
expected_name = expected_name.upper() expected_name = expected_name.upper()
self.assertIn(expected_name, constraints) self.assertIn(expected_name, constraints)
class UniqueConstraintTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.p1 = Product.objects.create(name='p1')
def test_repr(self):
fields = ['foo', 'bar']
name = 'unique_fields'
constraint = models.UniqueConstraint(fields=fields, name=name)
self.assertEqual(
repr(constraint),
"<UniqueConstraint: fields=('foo', 'bar') name='unique_fields'>",
)
def test_deconstruction(self):
fields = ['foo', 'bar']
name = 'unique_fields'
check = models.UniqueConstraint(fields=fields, name=name)
path, args, kwargs = check.deconstruct()
self.assertEqual(path, 'django.db.models.UniqueConstraint')
self.assertEqual(args, ())
self.assertEqual(kwargs, {'fields': tuple(fields), 'name': name})
def test_database_constraint(self):
with self.assertRaises(IntegrityError):
Product.objects.create(name=self.p1.name)
def test_model_validation(self):
with self.assertRaisesMessage(ValidationError, 'Product with this Name already exists.'):
Product(name=self.p1.name).validate_unique()
def test_name(self):
constraints = get_constraints(Product._meta.db_table)
expected_name = 'unique_name'
if connection.features.uppercases_column_names:
expected_name = expected_name.upper()
self.assertIn(expected_name, constraints)