From ccbf714ebeff51d1370789e5e487a978d0e2dbfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20=22Twidi=22=20Angel?= Date: Thu, 7 Jul 2022 04:26:49 +0200 Subject: [PATCH] Fixed #33829 -- Made BaseConstraint.deconstruct() and equality handle violation_error_message. Regression in 667105877e6723c6985399803a364848891513cc. --- django/contrib/postgres/constraints.py | 1 + django/db/models/constraints.py | 20 +++++- tests/constraints/tests.py | 77 ++++++++++++++++++++++++ tests/postgres_tests/test_constraints.py | 22 +++++++ 4 files changed, 117 insertions(+), 3 deletions(-) diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py index c19602b26e..2e6f7f7998 100644 --- a/django/contrib/postgres/constraints.py +++ b/django/contrib/postgres/constraints.py @@ -177,6 +177,7 @@ class ExclusionConstraint(BaseConstraint): and self.deferrable == other.deferrable and self.include == other.include and self.opclasses == other.opclasses + and self.violation_error_message == other.violation_error_message ) return super().__eq__(other) diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index 9949b50b1e..86f015465a 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -14,12 +14,15 @@ __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint" class BaseConstraint: - violation_error_message = _("Constraint ā€œ%(name)sā€ is violated.") + default_violation_error_message = _("Constraint ā€œ%(name)sā€ is violated.") + violation_error_message = None def __init__(self, name, violation_error_message=None): self.name = name if violation_error_message is not None: self.violation_error_message = violation_error_message + else: + self.violation_error_message = self.default_violation_error_message @property def contains_expressions(self): @@ -43,7 +46,13 @@ class BaseConstraint: def deconstruct(self): path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) path = path.replace("django.db.models.constraints", "django.db.models") - return (path, (), {"name": self.name}) + kwargs = {"name": self.name} + if ( + self.violation_error_message is not None + and self.violation_error_message != self.default_violation_error_message + ): + kwargs["violation_error_message"] = self.violation_error_message + return (path, (), kwargs) def clone(self): _, args, kwargs = self.deconstruct() @@ -94,7 +103,11 @@ class CheckConstraint(BaseConstraint): def __eq__(self, other): if isinstance(other, CheckConstraint): - return self.name == other.name and self.check == other.check + return ( + self.name == other.name + and self.check == other.check + and self.violation_error_message == other.violation_error_message + ) return super().__eq__(other) def deconstruct(self): @@ -273,6 +286,7 @@ class UniqueConstraint(BaseConstraint): and self.include == other.include and self.opclasses == other.opclasses and self.expressions == other.expressions + and self.violation_error_message == other.violation_error_message ) return super().__eq__(other) diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py index d9e377438e..4032b418b4 100644 --- a/tests/constraints/tests.py +++ b/tests/constraints/tests.py @@ -65,6 +65,29 @@ class BaseConstraintTests(SimpleTestCase): ) self.assertEqual(c.get_violation_error_message(), "custom base_name message") + def test_custom_violation_error_message_clone(self): + constraint = BaseConstraint( + "base_name", + violation_error_message="custom %(name)s message", + ).clone() + self.assertEqual( + constraint.get_violation_error_message(), + "custom base_name message", + ) + + def test_deconstruction(self): + constraint = BaseConstraint( + "base_name", + violation_error_message="custom %(name)s message", + ) + path, args, kwargs = constraint.deconstruct() + self.assertEqual(path, "django.db.models.BaseConstraint") + self.assertEqual(args, ()) + self.assertEqual( + kwargs, + {"name": "base_name", "violation_error_message": "custom %(name)s message"}, + ) + class CheckConstraintTests(TestCase): def test_eq(self): @@ -84,6 +107,28 @@ class CheckConstraintTests(TestCase): models.CheckConstraint(check=check2, name="price"), ) self.assertNotEqual(models.CheckConstraint(check=check1, name="price"), 1) + self.assertNotEqual( + models.CheckConstraint(check=check1, name="price"), + models.CheckConstraint( + check=check1, name="price", violation_error_message="custom error" + ), + ) + self.assertNotEqual( + models.CheckConstraint( + check=check1, name="price", violation_error_message="custom error" + ), + models.CheckConstraint( + check=check1, name="price", violation_error_message="other custom error" + ), + ) + self.assertEqual( + models.CheckConstraint( + check=check1, name="price", violation_error_message="custom error" + ), + models.CheckConstraint( + check=check1, name="price", violation_error_message="custom error" + ), + ) def test_repr(self): constraint = models.CheckConstraint( @@ -216,6 +261,38 @@ class UniqueConstraintTests(TestCase): self.assertNotEqual( models.UniqueConstraint(fields=["foo", "bar"], name="unique"), 1 ) + self.assertNotEqual( + models.UniqueConstraint(fields=["foo", "bar"], name="unique"), + models.UniqueConstraint( + fields=["foo", "bar"], + name="unique", + violation_error_message="custom error", + ), + ) + self.assertNotEqual( + models.UniqueConstraint( + fields=["foo", "bar"], + name="unique", + violation_error_message="custom error", + ), + models.UniqueConstraint( + fields=["foo", "bar"], + name="unique", + violation_error_message="other custom error", + ), + ) + self.assertEqual( + models.UniqueConstraint( + fields=["foo", "bar"], + name="unique", + violation_error_message="custom error", + ), + models.UniqueConstraint( + fields=["foo", "bar"], + name="unique", + violation_error_message="custom error", + ), + ) def test_eq_with_condition(self): self.assertEqual( diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py index d36c6fd9ed..a33c485a36 100644 --- a/tests/postgres_tests/test_constraints.py +++ b/tests/postgres_tests/test_constraints.py @@ -444,17 +444,39 @@ class ExclusionConstraintTests(PostgreSQLTestCase): ) self.assertNotEqual(constraint_2, constraint_9) self.assertNotEqual(constraint_7, constraint_8) + + constraint_10 = ExclusionConstraint( + name="exclude_overlapping", + expressions=[ + (F("datespan"), RangeOperators.OVERLAPS), + (F("room"), RangeOperators.EQUAL), + ], + condition=Q(cancelled=False), + violation_error_message="custom error", + ) + constraint_11 = ExclusionConstraint( + name="exclude_overlapping", + expressions=[ + (F("datespan"), RangeOperators.OVERLAPS), + (F("room"), RangeOperators.EQUAL), + ], + condition=Q(cancelled=False), + violation_error_message="other custom error", + ) self.assertEqual(constraint_1, constraint_1) self.assertEqual(constraint_1, mock.ANY) self.assertNotEqual(constraint_1, constraint_2) self.assertNotEqual(constraint_1, constraint_3) self.assertNotEqual(constraint_1, constraint_4) + self.assertNotEqual(constraint_1, constraint_10) self.assertNotEqual(constraint_2, constraint_3) self.assertNotEqual(constraint_2, constraint_4) self.assertNotEqual(constraint_2, constraint_7) self.assertNotEqual(constraint_4, constraint_5) self.assertNotEqual(constraint_5, constraint_6) self.assertNotEqual(constraint_1, object()) + self.assertNotEqual(constraint_10, constraint_11) + self.assertEqual(constraint_10, constraint_10) def test_deconstruct(self): constraint = ExclusionConstraint(