From 5b3d3e400ab9334ba429ca360c9818c6dfc3a51b Mon Sep 17 00:00:00 2001 From: Xavier Fernandez Date: Tue, 14 Feb 2023 21:06:45 +0100 Subject: [PATCH] Fixed #34338 -- Allowed customizing code of ValidationError in BaseConstraint and subclasses. --- django/contrib/postgres/constraints.py | 23 ++++- django/db/models/constraints.py | 62 +++++++++--- docs/ref/contrib/postgres/constraints.txt | 12 ++- docs/ref/models/constraints.txt | 32 ++++++- docs/releases/5.0.txt | 12 ++- tests/constraints/tests.py | 110 +++++++++++++++++++++- tests/postgres_tests/test_constraints.py | 40 +++++++- 7 files changed, 268 insertions(+), 23 deletions(-) diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py index ad3a5f61f55..c61072b5a54 100644 --- a/django/contrib/postgres/constraints.py +++ b/django/contrib/postgres/constraints.py @@ -32,6 +32,7 @@ class ExclusionConstraint(BaseConstraint): condition=None, deferrable=None, include=None, + violation_error_code=None, violation_error_message=None, ): if index_type and index_type.lower() not in {"gist", "spgist"}: @@ -60,7 +61,11 @@ class ExclusionConstraint(BaseConstraint): self.condition = condition self.deferrable = deferrable self.include = tuple(include) if include else () - super().__init__(name=name, violation_error_message=violation_error_message) + super().__init__( + name=name, + violation_error_code=violation_error_code, + violation_error_message=violation_error_message, + ) def _get_expressions(self, schema_editor, query): expressions = [] @@ -149,12 +154,13 @@ class ExclusionConstraint(BaseConstraint): and self.condition == other.condition and self.deferrable == other.deferrable and self.include == other.include + and self.violation_error_code == other.violation_error_code and self.violation_error_message == other.violation_error_message ) return super().__eq__(other) def __repr__(self): - return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s>" % ( + return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s%s>" % ( self.__class__.__qualname__, repr(self.index_type), repr(self.expressions), @@ -162,6 +168,11 @@ class ExclusionConstraint(BaseConstraint): "" if self.condition is None else " condition=%s" % self.condition, "" if self.deferrable is None else " deferrable=%r" % self.deferrable, "" if not self.include else " include=%s" % repr(self.include), + ( + "" + if self.violation_error_code is None + else " violation_error_code=%r" % self.violation_error_code + ), ( "" if self.violation_error_message is None @@ -204,9 +215,13 @@ class ExclusionConstraint(BaseConstraint): queryset = queryset.exclude(pk=model_class_pk) if not self.condition: if queryset.exists(): - raise ValidationError(self.get_violation_error_message()) + raise ValidationError( + self.get_violation_error_message(), code=self.violation_error_code + ) else: if (self.condition & Exists(queryset.filter(self.condition))).check( replacement_map, using=using ): - raise ValidationError(self.get_violation_error_message()) + raise ValidationError( + self.get_violation_error_message(), code=self.violation_error_code + ) diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index 075ecee1be7..0df0782b6f1 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -18,11 +18,16 @@ __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint" class BaseConstraint: default_violation_error_message = _("Constraint “%(name)s” is violated.") + violation_error_code = None violation_error_message = None # RemovedInDjango60Warning: When the deprecation ends, replace with: - # def __init__(self, *, name, violation_error_message=None): - def __init__(self, *args, name=None, violation_error_message=None): + # def __init__( + # self, *, name, violation_error_code=None, violation_error_message=None + # ): + def __init__( + self, *args, name=None, violation_error_code=None, violation_error_message=None + ): # RemovedInDjango60Warning. if name is None and not args: raise TypeError( @@ -30,6 +35,8 @@ class BaseConstraint: f"argument: 'name'" ) self.name = name + if violation_error_code is not None: + self.violation_error_code = violation_error_code if violation_error_message is not None: self.violation_error_message = violation_error_message else: @@ -74,6 +81,8 @@ class BaseConstraint: and self.violation_error_message != self.default_violation_error_message ): kwargs["violation_error_message"] = self.violation_error_message + if self.violation_error_code is not None: + kwargs["violation_error_code"] = self.violation_error_code return (path, (), kwargs) def clone(self): @@ -82,13 +91,19 @@ class BaseConstraint: class CheckConstraint(BaseConstraint): - def __init__(self, *, check, name, violation_error_message=None): + def __init__( + self, *, check, name, violation_error_code=None, violation_error_message=None + ): self.check = check if not getattr(check, "conditional", False): raise TypeError( "CheckConstraint.check must be a Q instance or boolean expression." ) - super().__init__(name=name, violation_error_message=violation_error_message) + super().__init__( + name=name, + violation_error_code=violation_error_code, + violation_error_message=violation_error_message, + ) def _get_check_sql(self, model, schema_editor): query = Query(model=model, alias_cols=False) @@ -112,15 +127,22 @@ class CheckConstraint(BaseConstraint): against = instance._get_field_value_map(meta=model._meta, exclude=exclude) try: if not Q(self.check).check(against, using=using): - raise ValidationError(self.get_violation_error_message()) + raise ValidationError( + self.get_violation_error_message(), code=self.violation_error_code + ) except FieldError: pass def __repr__(self): - return "<%s: check=%s name=%s%s>" % ( + return "<%s: check=%s name=%s%s%s>" % ( self.__class__.__qualname__, self.check, repr(self.name), + ( + "" + if self.violation_error_code is None + else " violation_error_code=%r" % self.violation_error_code + ), ( "" if self.violation_error_message is None @@ -134,6 +156,7 @@ class CheckConstraint(BaseConstraint): return ( self.name == other.name and self.check == other.check + and self.violation_error_code == other.violation_error_code and self.violation_error_message == other.violation_error_message ) return super().__eq__(other) @@ -163,6 +186,7 @@ class UniqueConstraint(BaseConstraint): deferrable=None, include=None, opclasses=(), + violation_error_code=None, violation_error_message=None, ): if not name: @@ -213,7 +237,11 @@ class UniqueConstraint(BaseConstraint): F(expression) if isinstance(expression, str) else expression for expression in expressions ) - super().__init__(name=name, violation_error_message=violation_error_message) + super().__init__( + name=name, + violation_error_code=violation_error_code, + violation_error_message=violation_error_message, + ) @property def contains_expressions(self): @@ -293,7 +321,7 @@ class UniqueConstraint(BaseConstraint): ) def __repr__(self): - return "<%s:%s%s%s%s%s%s%s%s>" % ( + return "<%s:%s%s%s%s%s%s%s%s%s>" % ( self.__class__.__qualname__, "" if not self.fields else " fields=%s" % repr(self.fields), "" if not self.expressions else " expressions=%s" % repr(self.expressions), @@ -302,6 +330,11 @@ class UniqueConstraint(BaseConstraint): "" if self.deferrable is None else " deferrable=%r" % self.deferrable, "" if not self.include else " include=%s" % repr(self.include), "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), + ( + "" + if self.violation_error_code is None + else " violation_error_code=%r" % self.violation_error_code + ), ( "" if self.violation_error_message is None @@ -320,6 +353,7 @@ class UniqueConstraint(BaseConstraint): and self.include == other.include and self.opclasses == other.opclasses and self.expressions == other.expressions + and self.violation_error_code == other.violation_error_code and self.violation_error_message == other.violation_error_message ) return super().__eq__(other) @@ -385,14 +419,17 @@ class UniqueConstraint(BaseConstraint): if not self.condition: if queryset.exists(): if self.expressions: - raise ValidationError(self.get_violation_error_message()) + raise ValidationError( + self.get_violation_error_message(), + code=self.violation_error_code, + ) # When fields are defined, use the unique_error_message() for # backward compatibility. for model, constraints in instance.get_constraints(): for constraint in constraints: if constraint is self: raise ValidationError( - instance.unique_error_message(model, self.fields) + instance.unique_error_message(model, self.fields), ) else: against = instance._get_field_value_map(meta=model._meta, exclude=exclude) @@ -400,6 +437,9 @@ class UniqueConstraint(BaseConstraint): if (self.condition & Exists(queryset.filter(self.condition))).check( against, using=using ): - raise ValidationError(self.get_violation_error_message()) + raise ValidationError( + self.get_violation_error_message(), + code=self.violation_error_code, + ) except FieldError: pass diff --git a/docs/ref/contrib/postgres/constraints.txt b/docs/ref/contrib/postgres/constraints.txt index fcf50b8b5f9..abc5e4e4e46 100644 --- a/docs/ref/contrib/postgres/constraints.txt +++ b/docs/ref/contrib/postgres/constraints.txt @@ -12,7 +12,7 @@ PostgreSQL supports additional data integrity constraints available from the ``ExclusionConstraint`` ======================= -.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None, violation_error_message=None) +.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None, violation_error_code=None, violation_error_message=None) Creates an exclusion constraint in the database. Internally, PostgreSQL implements exclusion constraints using indexes. The default index type is @@ -133,6 +133,16 @@ used for queries that select only included fields ``include`` is supported for GiST indexes. PostgreSQL 14+ also supports ``include`` for SP-GiST indexes. +``violation_error_code`` +------------------------ + +.. versionadded:: 5.0 + +.. attribute:: ExclusionConstraint.violation_error_code + +The error code used when ``ValidationError`` is raised during +:ref:`model validation `. Defaults to ``None``. + ``violation_error_message`` --------------------------- diff --git a/docs/ref/models/constraints.txt b/docs/ref/models/constraints.txt index baaacd8754d..f248de03154 100644 --- a/docs/ref/models/constraints.txt +++ b/docs/ref/models/constraints.txt @@ -48,7 +48,7 @@ option. ``BaseConstraint`` ================== -.. class:: BaseConstraint(*, name, violation_error_message=None) +.. class:: BaseConstraint(* name, violation_error_code=None, violation_error_message=None) Base class for all constraints. Subclasses must implement ``constraint_sql()``, ``create_sql()``, ``remove_sql()`` and @@ -68,6 +68,16 @@ All constraints have the following parameters in common: The name of the constraint. You must always specify a unique name for the constraint. +``violation_error_code`` +------------------------ + +.. versionadded:: 5.0 + +.. attribute:: BaseConstraint.violation_error_code + +The error code used when ``ValidationError`` is raised during +:ref:`model validation `. Defaults to ``None``. + ``violation_error_message`` --------------------------- @@ -94,7 +104,7 @@ This method must be implemented by a subclass. ``CheckConstraint`` =================== -.. class:: CheckConstraint(*, check, name, violation_error_message=None) +.. class:: CheckConstraint(*, check, name, violation_error_code=None, violation_error_message=None) Creates a check constraint in the database. @@ -121,7 +131,7 @@ ensures the age field is never less than 18. ``UniqueConstraint`` ==================== -.. class:: UniqueConstraint(*expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), violation_error_message=None) +.. class:: UniqueConstraint(*expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), violation_error_code=None, violation_error_message=None) Creates a unique constraint in the database. @@ -242,6 +252,22 @@ creates a unique index on ``username`` using ``varchar_pattern_ops``. ``opclasses`` are ignored for databases besides PostgreSQL. +``violation_error_code`` +------------------------ + +.. versionadded:: 5.0 + +.. attribute:: UniqueConstraint.violation_error_code + +The error code used when ``ValidationError`` is raised during +:ref:`model validation `. Defaults to ``None``. + +This code is *not used* for :class:`UniqueConstraint`\s with +:attr:`~UniqueConstraint.fields` and without a +:attr:`~UniqueConstraint.condition`. Such :class:`~UniqueConstraint`\s have the +same error code as constraints defined with :attr:`.Field.unique` or in +:attr:`Meta.unique_together `. + ``violation_error_message`` --------------------------- diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index 14e767cfd80..84689f223db 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -78,7 +78,10 @@ Minor features :mod:`django.contrib.postgres` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -* ... +* The new :attr:`~.ExclusionConstraint.violation_error_code` attribute of + :class:`~django.contrib.postgres.constraints.ExclusionConstraint` allows + customizing the ``code`` of ``ValidationError`` raised during + :ref:`model validation `. :mod:`django.contrib.redirects` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -182,6 +185,13 @@ Models and :meth:`.QuerySet.aupdate_or_create` methods allows specifying a different field values for the create operation. +* The new ``violation_error_code`` attribute of + :class:`~django.db.models.BaseConstraint`, + :class:`~django.db.models.CheckConstraint`, and + :class:`~django.db.models.UniqueConstraint` allows customizing the ``code`` + of ``ValidationError`` raised during + :ref:`model validation `. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py index b45dc6499a3..ec2b030ac5d 100644 --- a/tests/constraints/tests.py +++ b/tests/constraints/tests.py @@ -77,17 +77,26 @@ class BaseConstraintTests(SimpleTestCase): "custom base_name message", ) + def test_custom_violation_code_message(self): + c = BaseConstraint(name="base_name", violation_error_code="custom_code") + self.assertEqual(c.violation_error_code, "custom_code") + def test_deconstruction(self): constraint = BaseConstraint( name="base_name", violation_error_message="custom %(name)s message", + violation_error_code="custom_code", ) path, args, kwargs = constraint.deconstruct() self.assertEqual(path, "django.db.models.BaseConstraint") self.assertEqual(args, ()) self.assertEqual( kwargs, - {"name": "base_name", "violation_error_message": "custom %(name)s message"}, + { + "name": "base_name", + "violation_error_message": "custom %(name)s message", + "violation_error_code": "custom_code", + }, ) def test_deprecation(self): @@ -148,6 +157,20 @@ class CheckConstraintTests(TestCase): check=check1, name="price", violation_error_message="custom error" ), ) + self.assertNotEqual( + models.CheckConstraint(check=check1, name="price"), + models.CheckConstraint( + check=check1, name="price", violation_error_code="custom_code" + ), + ) + self.assertEqual( + models.CheckConstraint( + check=check1, name="price", violation_error_code="custom_code" + ), + models.CheckConstraint( + check=check1, name="price", violation_error_code="custom_code" + ), + ) def test_repr(self): constraint = models.CheckConstraint( @@ -172,6 +195,18 @@ class CheckConstraintTests(TestCase): "violation_error_message='More than 1'>", ) + def test_repr_with_violation_error_code(self): + constraint = models.CheckConstraint( + check=models.Q(price__lt=1), + name="price_lt_one", + violation_error_code="more_than_one", + ) + self.assertEqual( + repr(constraint), + "", + ) + def test_invalid_check_types(self): msg = "CheckConstraint.check must be a Q instance or boolean expression." with self.assertRaisesMessage(TypeError, msg): @@ -237,6 +272,21 @@ class CheckConstraintTests(TestCase): # Valid product. constraint.validate(Product, Product(price=10, discounted_price=5)) + def test_validate_custom_error(self): + check = models.Q(price__gt=models.F("discounted_price")) + constraint = models.CheckConstraint( + check=check, + name="price", + violation_error_message="discount is fake", + violation_error_code="fake_discount", + ) + # Invalid product. + invalid_product = Product(price=10, discounted_price=42) + msg = "discount is fake" + with self.assertRaisesMessage(ValidationError, msg) as cm: + constraint.validate(Product, invalid_product) + self.assertEqual(cm.exception.code, "fake_discount") + def test_validate_boolean_expressions(self): constraint = models.CheckConstraint( check=models.expressions.ExpressionWrapper( @@ -341,6 +391,30 @@ class UniqueConstraintTests(TestCase): violation_error_message="custom error", ), ) + self.assertNotEqual( + models.UniqueConstraint( + fields=["foo", "bar"], + name="unique", + violation_error_code="custom_error", + ), + models.UniqueConstraint( + fields=["foo", "bar"], + name="unique", + violation_error_code="other_custom_error", + ), + ) + self.assertEqual( + models.UniqueConstraint( + fields=["foo", "bar"], + name="unique", + violation_error_code="custom_error", + ), + models.UniqueConstraint( + fields=["foo", "bar"], + name="unique", + violation_error_code="custom_error", + ), + ) def test_eq_with_condition(self): self.assertEqual( @@ -512,6 +586,20 @@ class UniqueConstraintTests(TestCase): ), ) + def test_repr_with_violation_error_code(self): + constraint = models.UniqueConstraint( + models.F("baz__lower"), + name="unique_lower_baz", + violation_error_code="baz", + ) + self.assertEqual( + repr(constraint), + ( + "" + ), + ) + def test_deconstruction(self): fields = ["foo", "bar"] name = "unique_fields" @@ -656,12 +744,16 @@ class UniqueConstraintTests(TestCase): def test_validate(self): constraint = UniqueConstraintProduct._meta.constraints[0] + # Custom message and error code are ignored. + constraint.violation_error_message = "Custom message" + constraint.violation_error_code = "custom_code" msg = "Unique constraint product with this Name and Color already exists." non_unique_product = UniqueConstraintProduct( name=self.p1.name, color=self.p1.color ) - with self.assertRaisesMessage(ValidationError, msg): + with self.assertRaisesMessage(ValidationError, msg) as cm: constraint.validate(UniqueConstraintProduct, non_unique_product) + self.assertEqual(cm.exception.code, "unique_together") # Null values are ignored. constraint.validate( UniqueConstraintProduct, @@ -716,6 +808,20 @@ class UniqueConstraintTests(TestCase): exclude={"name"}, ) + @skipUnlessDBFeature("supports_partial_indexes") + def test_validate_conditon_custom_error(self): + p1 = UniqueConstraintConditionProduct.objects.create(name="p1") + constraint = UniqueConstraintConditionProduct._meta.constraints[0] + constraint.violation_error_message = "Custom message" + constraint.violation_error_code = "custom_code" + msg = "Custom message" + with self.assertRaisesMessage(ValidationError, msg) as cm: + constraint.validate( + UniqueConstraintConditionProduct, + UniqueConstraintConditionProduct(name=p1.name, color=None), + ) + self.assertEqual(cm.exception.code, "custom_code") + def test_validate_expression(self): constraint = models.UniqueConstraint(Lower("name"), name="name_lower_uniq") msg = "Constraint “name_lower_uniq” is violated." diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py index a5248e14916..bf478337430 100644 --- a/tests/postgres_tests/test_constraints.py +++ b/tests/postgres_tests/test_constraints.py @@ -397,6 +397,17 @@ class ExclusionConstraintTests(PostgreSQLTestCase): "(F(datespan), '-|-')] name='exclude_overlapping' " "violation_error_message='Overlapping must be excluded'>", ) + constraint = ExclusionConstraint( + name="exclude_overlapping", + expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)], + violation_error_code="overlapping_must_be_excluded", + ) + self.assertEqual( + repr(constraint), + "", + ) def test_eq(self): constraint_1 = ExclusionConstraint( @@ -470,6 +481,16 @@ class ExclusionConstraintTests(PostgreSQLTestCase): condition=Q(cancelled=False), violation_error_message="other custom error", ) + constraint_12 = ExclusionConstraint( + name="exclude_overlapping", + expressions=[ + (F("datespan"), RangeOperators.OVERLAPS), + (F("room"), RangeOperators.EQUAL), + ], + condition=Q(cancelled=False), + violation_error_code="custom_code", + violation_error_message="other custom error", + ) self.assertEqual(constraint_1, constraint_1) self.assertEqual(constraint_1, mock.ANY) self.assertNotEqual(constraint_1, constraint_2) @@ -483,7 +504,9 @@ class ExclusionConstraintTests(PostgreSQLTestCase): self.assertNotEqual(constraint_5, constraint_6) self.assertNotEqual(constraint_1, object()) self.assertNotEqual(constraint_10, constraint_11) + self.assertNotEqual(constraint_11, constraint_12) self.assertEqual(constraint_10, constraint_10) + self.assertEqual(constraint_12, constraint_12) def test_deconstruct(self): constraint = ExclusionConstraint( @@ -760,17 +783,32 @@ class ExclusionConstraintTests(PostgreSQLTestCase): constraint = ExclusionConstraint( name="ints_adjacent", expressions=[("ints", RangeOperators.ADJACENT_TO)], + violation_error_code="custom_code", violation_error_message="Custom error message.", ) range_obj = RangesModel.objects.create(ints=(20, 50)) constraint.validate(RangesModel, range_obj) msg = "Custom error message." - with self.assertRaisesMessage(ValidationError, msg): + with self.assertRaisesMessage(ValidationError, msg) as cm: constraint.validate(RangesModel, RangesModel(ints=(10, 20))) + self.assertEqual(cm.exception.code, "custom_code") constraint.validate(RangesModel, RangesModel(ints=(10, 19))) constraint.validate(RangesModel, RangesModel(ints=(51, 60))) constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"}) + def test_validate_with_custom_code_and_condition(self): + constraint = ExclusionConstraint( + name="ints_adjacent", + expressions=[("ints", RangeOperators.ADJACENT_TO)], + violation_error_code="custom_code", + condition=Q(ints__lt=(100, 200)), + ) + range_obj = RangesModel.objects.create(ints=(20, 50)) + constraint.validate(RangesModel, range_obj) + with self.assertRaises(ValidationError) as cm: + constraint.validate(RangesModel, RangesModel(ints=(10, 20))) + self.assertEqual(cm.exception.code, "custom_code") + def test_expressions_with_params(self): constraint_name = "scene_left_equal" self.assertNotIn(constraint_name, self.get_constraints(Scene._meta.db_table))