From db13bca60a6758d5fe63eeb01c00c3f54f650715 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Sun, 5 Aug 2018 22:30:44 -0400 Subject: [PATCH] Fixed #29641 -- Added support for unique constraints in Meta.constraints. This constraint is similar to Meta.unique_together but also allows specifying a name. Co-authored-by: Ian Foote --- django/db/backends/sqlite3/introspection.py | 9 +++++ django/db/models/base.py | 11 +++++- django/db/models/constraints.py | 38 +++++++++++++++++++- docs/ref/models/constraints.txt | 25 +++++++++++++ docs/releases/2.2.txt | 3 +- tests/constraints/models.py | 5 +-- tests/constraints/tests.py | 40 +++++++++++++++++++++ 7 files changed, 126 insertions(+), 5 deletions(-) diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 47ca25a78a..4f4a54eacf 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -260,6 +260,15 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol) name = name_token.value[1:-1] token = next_ttype(sqlparse.tokens.Keyword) + if token.match(sqlparse.tokens.Keyword, 'UNIQUE'): + constraints[name] = { + 'unique': True, + 'columns': [], + 'primary_key': False, + 'foreign_key': False, + 'check': False, + 'index': False, + } if token.match(sqlparse.tokens.Keyword, 'CHECK'): # Column check constraint if name is None: diff --git a/django/db/models/base.py b/django/db/models/base.py index 89faf9d1e1..b57726fbcf 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -16,7 +16,7 @@ from django.db import ( connections, router, transaction, ) from django.db.models.constants import LOOKUP_SEP -from django.db.models.constraints import CheckConstraint +from django.db.models.constraints import CheckConstraint, UniqueConstraint from django.db.models.deletion import CASCADE, Collector from django.db.models.fields.related import ( ForeignObjectRel, OneToOneField, lazy_related_operation, resolve_relation, @@ -982,9 +982,12 @@ class Model(metaclass=ModelBase): unique_checks = [] unique_togethers = [(self.__class__, self._meta.unique_together)] + constraints = [(self.__class__, self._meta.constraints)] for parent_class in self._meta.get_parent_list(): if parent_class._meta.unique_together: unique_togethers.append((parent_class, parent_class._meta.unique_together)) + if parent_class._meta.constraints: + constraints.append((parent_class, parent_class._meta.constraints)) for model_class, unique_together in unique_togethers: for check in unique_together: @@ -992,6 +995,12 @@ class Model(metaclass=ModelBase): # Add the check if the field isn't excluded. unique_checks.append((model_class, tuple(check))) + for model_class, model_constraints in constraints: + for constraint in model_constraints: + if (isinstance(constraint, UniqueConstraint) and + not any(name in exclude for name in constraint.fields)): + unique_checks.append((model_class, constraint.fields)) + # These are checks for the unique_for_. date_checks = [] diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index 698b278fe8..f6ea0fa12a 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -1,6 +1,6 @@ from django.db.models.sql.query import Query -__all__ = ['CheckConstraint'] +__all__ = ['CheckConstraint', 'UniqueConstraint'] class BaseConstraint: @@ -68,3 +68,39 @@ class CheckConstraint(BaseConstraint): path, args, kwargs = super().deconstruct() kwargs['check'] = self.check return path, args, kwargs + + +class UniqueConstraint(BaseConstraint): + def __init__(self, *, fields, name): + if not fields: + raise ValueError('At least one field is required to define a unique constraint.') + self.fields = tuple(fields) + super().__init__(name) + + def constraint_sql(self, model, schema_editor): + columns = ( + model._meta.get_field(field_name).column + for field_name in self.fields + ) + return schema_editor.sql_unique_constraint % { + 'columns': ', '.join(map(schema_editor.quote_name, columns)), + } + + def create_sql(self, model, schema_editor): + columns = [model._meta.get_field(field_name).column for field_name in self.fields] + return schema_editor._create_unique_sql(model, columns, self.name) + + def __repr__(self): + return '<%s: fields=%r name=%r>' % (self.__class__.__name__, self.fields, self.name) + + def __eq__(self, other): + return ( + isinstance(other, UniqueConstraint) and + self.name == other.name and + self.fields == other.fields + ) + + def deconstruct(self): + path, args, kwargs = super().deconstruct() + kwargs['fields'] = self.fields + return path, args, kwargs diff --git a/docs/ref/models/constraints.txt b/docs/ref/models/constraints.txt index 9e24a1cad8..5c94ce094a 100644 --- a/docs/ref/models/constraints.txt +++ b/docs/ref/models/constraints.txt @@ -43,3 +43,28 @@ ensures the age field is never less than 18. .. attribute:: CheckConstraint.name The name of the constraint. + +``UniqueConstraint`` +==================== + +.. class:: UniqueConstraint(*, fields, name) + + Creates a unique constraint in the database. + +``fields`` +---------- + +.. attribute:: UniqueConstraint.fields + +A list of field names that specifies the unique set of columns you want the +constraint to enforce. + +For example ``UniqueConstraint(fields=['room', 'date'], name='unique_location')`` +ensures only one location can exist for each ``date``. + +``name`` +-------- + +.. attribute:: UniqueConstraint.name + +The name of the constraint. diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt index ef68c31eab..8a7a75a0bc 100644 --- a/docs/releases/2.2.txt +++ b/docs/releases/2.2.txt @@ -33,7 +33,8 @@ What's new in Django 2.2 Constraints ----------- -The new :class:`~django.db.models.CheckConstraint` class enables adding custom +The new :class:`~django.db.models.CheckConstraint` and +:class:`~django.db.models.UniqueConstraint` classes enable adding custom database constraints. Constraints are added to models using the :attr:`Meta.constraints ` option. diff --git a/tests/constraints/models.py b/tests/constraints/models.py index 08fbe9e1df..b1645ecedb 100644 --- a/tests/constraints/models.py +++ b/tests/constraints/models.py @@ -3,8 +3,8 @@ from django.db import models class Product(models.Model): name = models.CharField(max_length=255) - price = models.IntegerField() - discounted_price = models.IntegerField() + price = models.IntegerField(null=True) + discounted_price = models.IntegerField(null=True) class Meta: constraints = [ @@ -12,4 +12,5 @@ class Product(models.Model): check=models.Q(price__gt=models.F('discounted_price')), name='price_gt_discounted_price', ), + models.UniqueConstraint(fields=['name'], name='unique_name'), ] diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py index b144c24a28..ddcaf7cfe0 100644 --- a/tests/constraints/tests.py +++ b/tests/constraints/tests.py @@ -1,3 +1,4 @@ +from django.core.exceptions import ValidationError from django.db import IntegrityError, connection, models from django.db.models.constraints import BaseConstraint from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature @@ -50,3 +51,42 @@ class CheckConstraintTests(TestCase): if connection.features.uppercases_column_names: expected_name = expected_name.upper() self.assertIn(expected_name, constraints) + + +class UniqueConstraintTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.p1 = Product.objects.create(name='p1') + + def test_repr(self): + fields = ['foo', 'bar'] + name = 'unique_fields' + constraint = models.UniqueConstraint(fields=fields, name=name) + self.assertEqual( + repr(constraint), + "", + ) + + def test_deconstruction(self): + fields = ['foo', 'bar'] + name = 'unique_fields' + check = models.UniqueConstraint(fields=fields, name=name) + path, args, kwargs = check.deconstruct() + self.assertEqual(path, 'django.db.models.UniqueConstraint') + self.assertEqual(args, ()) + self.assertEqual(kwargs, {'fields': tuple(fields), 'name': name}) + + def test_database_constraint(self): + with self.assertRaises(IntegrityError): + Product.objects.create(name=self.p1.name) + + def test_model_validation(self): + with self.assertRaisesMessage(ValidationError, 'Product with this Name already exists.'): + Product(name=self.p1.name).validate_unique() + + def test_name(self): + constraints = get_constraints(Product._meta.db_table) + expected_name = 'unique_name' + if connection.features.uppercases_column_names: + expected_name = expected_name.upper() + self.assertIn(expected_name, constraints)