Fixed #34338 -- Allowed customizing code of ValidationError in BaseConstraint and subclasses.

This commit is contained in:
Xavier Fernandez 2023-02-14 21:06:45 +01:00 committed by Mariusz Felisiak
parent 51c9bb7cd1
commit 5b3d3e400a
7 changed files with 268 additions and 23 deletions

View File

@ -32,6 +32,7 @@ class ExclusionConstraint(BaseConstraint):
condition=None,
deferrable=None,
include=None,
violation_error_code=None,
violation_error_message=None,
):
if index_type and index_type.lower() not in {"gist", "spgist"}:
@ -60,7 +61,11 @@ class ExclusionConstraint(BaseConstraint):
self.condition = condition
self.deferrable = deferrable
self.include = tuple(include) if include else ()
super().__init__(name=name, violation_error_message=violation_error_message)
super().__init__(
name=name,
violation_error_code=violation_error_code,
violation_error_message=violation_error_message,
)
def _get_expressions(self, schema_editor, query):
expressions = []
@ -149,12 +154,13 @@ class ExclusionConstraint(BaseConstraint):
and self.condition == other.condition
and self.deferrable == other.deferrable
and self.include == other.include
and self.violation_error_code == other.violation_error_code
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
def __repr__(self):
return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s>" % (
return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s%s>" % (
self.__class__.__qualname__,
repr(self.index_type),
repr(self.expressions),
@ -162,6 +168,11 @@ class ExclusionConstraint(BaseConstraint):
"" if self.condition is None else " condition=%s" % self.condition,
"" if self.deferrable is None else " deferrable=%r" % self.deferrable,
"" if not self.include else " include=%s" % repr(self.include),
(
""
if self.violation_error_code is None
else " violation_error_code=%r" % self.violation_error_code
),
(
""
if self.violation_error_message is None
@ -204,9 +215,13 @@ class ExclusionConstraint(BaseConstraint):
queryset = queryset.exclude(pk=model_class_pk)
if not self.condition:
if queryset.exists():
raise ValidationError(self.get_violation_error_message())
raise ValidationError(
self.get_violation_error_message(), code=self.violation_error_code
)
else:
if (self.condition & Exists(queryset.filter(self.condition))).check(
replacement_map, using=using
):
raise ValidationError(self.get_violation_error_message())
raise ValidationError(
self.get_violation_error_message(), code=self.violation_error_code
)

View File

@ -18,11 +18,16 @@ __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"
class BaseConstraint:
default_violation_error_message = _("Constraint “%(name)s” is violated.")
violation_error_code = None
violation_error_message = None
# RemovedInDjango60Warning: When the deprecation ends, replace with:
# def __init__(self, *, name, violation_error_message=None):
def __init__(self, *args, name=None, violation_error_message=None):
# def __init__(
# self, *, name, violation_error_code=None, violation_error_message=None
# ):
def __init__(
self, *args, name=None, violation_error_code=None, violation_error_message=None
):
# RemovedInDjango60Warning.
if name is None and not args:
raise TypeError(
@ -30,6 +35,8 @@ class BaseConstraint:
f"argument: 'name'"
)
self.name = name
if violation_error_code is not None:
self.violation_error_code = violation_error_code
if violation_error_message is not None:
self.violation_error_message = violation_error_message
else:
@ -74,6 +81,8 @@ class BaseConstraint:
and self.violation_error_message != self.default_violation_error_message
):
kwargs["violation_error_message"] = self.violation_error_message
if self.violation_error_code is not None:
kwargs["violation_error_code"] = self.violation_error_code
return (path, (), kwargs)
def clone(self):
@ -82,13 +91,19 @@ class BaseConstraint:
class CheckConstraint(BaseConstraint):
def __init__(self, *, check, name, violation_error_message=None):
def __init__(
self, *, check, name, violation_error_code=None, violation_error_message=None
):
self.check = check
if not getattr(check, "conditional", False):
raise TypeError(
"CheckConstraint.check must be a Q instance or boolean expression."
)
super().__init__(name=name, violation_error_message=violation_error_message)
super().__init__(
name=name,
violation_error_code=violation_error_code,
violation_error_message=violation_error_message,
)
def _get_check_sql(self, model, schema_editor):
query = Query(model=model, alias_cols=False)
@ -112,15 +127,22 @@ class CheckConstraint(BaseConstraint):
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
try:
if not Q(self.check).check(against, using=using):
raise ValidationError(self.get_violation_error_message())
raise ValidationError(
self.get_violation_error_message(), code=self.violation_error_code
)
except FieldError:
pass
def __repr__(self):
return "<%s: check=%s name=%s%s>" % (
return "<%s: check=%s name=%s%s%s>" % (
self.__class__.__qualname__,
self.check,
repr(self.name),
(
""
if self.violation_error_code is None
else " violation_error_code=%r" % self.violation_error_code
),
(
""
if self.violation_error_message is None
@ -134,6 +156,7 @@ class CheckConstraint(BaseConstraint):
return (
self.name == other.name
and self.check == other.check
and self.violation_error_code == other.violation_error_code
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
@ -163,6 +186,7 @@ class UniqueConstraint(BaseConstraint):
deferrable=None,
include=None,
opclasses=(),
violation_error_code=None,
violation_error_message=None,
):
if not name:
@ -213,7 +237,11 @@ class UniqueConstraint(BaseConstraint):
F(expression) if isinstance(expression, str) else expression
for expression in expressions
)
super().__init__(name=name, violation_error_message=violation_error_message)
super().__init__(
name=name,
violation_error_code=violation_error_code,
violation_error_message=violation_error_message,
)
@property
def contains_expressions(self):
@ -293,7 +321,7 @@ class UniqueConstraint(BaseConstraint):
)
def __repr__(self):
return "<%s:%s%s%s%s%s%s%s%s>" % (
return "<%s:%s%s%s%s%s%s%s%s%s>" % (
self.__class__.__qualname__,
"" if not self.fields else " fields=%s" % repr(self.fields),
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
@ -302,6 +330,11 @@ class UniqueConstraint(BaseConstraint):
"" if self.deferrable is None else " deferrable=%r" % self.deferrable,
"" if not self.include else " include=%s" % repr(self.include),
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
(
""
if self.violation_error_code is None
else " violation_error_code=%r" % self.violation_error_code
),
(
""
if self.violation_error_message is None
@ -320,6 +353,7 @@ class UniqueConstraint(BaseConstraint):
and self.include == other.include
and self.opclasses == other.opclasses
and self.expressions == other.expressions
and self.violation_error_code == other.violation_error_code
and self.violation_error_message == other.violation_error_message
)
return super().__eq__(other)
@ -385,14 +419,17 @@ class UniqueConstraint(BaseConstraint):
if not self.condition:
if queryset.exists():
if self.expressions:
raise ValidationError(self.get_violation_error_message())
raise ValidationError(
self.get_violation_error_message(),
code=self.violation_error_code,
)
# When fields are defined, use the unique_error_message() for
# backward compatibility.
for model, constraints in instance.get_constraints():
for constraint in constraints:
if constraint is self:
raise ValidationError(
instance.unique_error_message(model, self.fields)
instance.unique_error_message(model, self.fields),
)
else:
against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
@ -400,6 +437,9 @@ class UniqueConstraint(BaseConstraint):
if (self.condition & Exists(queryset.filter(self.condition))).check(
against, using=using
):
raise ValidationError(self.get_violation_error_message())
raise ValidationError(
self.get_violation_error_message(),
code=self.violation_error_code,
)
except FieldError:
pass

View File

@ -12,7 +12,7 @@ PostgreSQL supports additional data integrity constraints available from the
``ExclusionConstraint``
=======================
.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None, violation_error_message=None)
.. class:: ExclusionConstraint(*, name, expressions, index_type=None, condition=None, deferrable=None, include=None, violation_error_code=None, violation_error_message=None)
Creates an exclusion constraint in the database. Internally, PostgreSQL
implements exclusion constraints using indexes. The default index type is
@ -133,6 +133,16 @@ used for queries that select only included fields
``include`` is supported for GiST indexes. PostgreSQL 14+ also supports
``include`` for SP-GiST indexes.
``violation_error_code``
------------------------
.. versionadded:: 5.0
.. attribute:: ExclusionConstraint.violation_error_code
The error code used when ``ValidationError`` is raised during
:ref:`model validation <validating-objects>`. Defaults to ``None``.
``violation_error_message``
---------------------------

View File

@ -48,7 +48,7 @@ option.
``BaseConstraint``
==================
.. class:: BaseConstraint(*, name, violation_error_message=None)
.. class:: BaseConstraint(* name, violation_error_code=None, violation_error_message=None)
Base class for all constraints. Subclasses must implement
``constraint_sql()``, ``create_sql()``, ``remove_sql()`` and
@ -68,6 +68,16 @@ All constraints have the following parameters in common:
The name of the constraint. You must always specify a unique name for the
constraint.
``violation_error_code``
------------------------
.. versionadded:: 5.0
.. attribute:: BaseConstraint.violation_error_code
The error code used when ``ValidationError`` is raised during
:ref:`model validation <validating-objects>`. Defaults to ``None``.
``violation_error_message``
---------------------------
@ -94,7 +104,7 @@ This method must be implemented by a subclass.
``CheckConstraint``
===================
.. class:: CheckConstraint(*, check, name, violation_error_message=None)
.. class:: CheckConstraint(*, check, name, violation_error_code=None, violation_error_message=None)
Creates a check constraint in the database.
@ -121,7 +131,7 @@ ensures the age field is never less than 18.
``UniqueConstraint``
====================
.. class:: UniqueConstraint(*expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), violation_error_message=None)
.. class:: UniqueConstraint(*expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), violation_error_code=None, violation_error_message=None)
Creates a unique constraint in the database.
@ -242,6 +252,22 @@ creates a unique index on ``username`` using ``varchar_pattern_ops``.
``opclasses`` are ignored for databases besides PostgreSQL.
``violation_error_code``
------------------------
.. versionadded:: 5.0
.. attribute:: UniqueConstraint.violation_error_code
The error code used when ``ValidationError`` is raised during
:ref:`model validation <validating-objects>`. Defaults to ``None``.
This code is *not used* for :class:`UniqueConstraint`\s with
:attr:`~UniqueConstraint.fields` and without a
:attr:`~UniqueConstraint.condition`. Such :class:`~UniqueConstraint`\s have the
same error code as constraints defined with :attr:`.Field.unique` or in
:attr:`Meta.unique_together <django.db.models.Options.constraints>`.
``violation_error_message``
---------------------------

View File

@ -78,7 +78,10 @@ Minor features
:mod:`django.contrib.postgres`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* ...
* The new :attr:`~.ExclusionConstraint.violation_error_code` attribute of
:class:`~django.contrib.postgres.constraints.ExclusionConstraint` allows
customizing the ``code`` of ``ValidationError`` raised during
:ref:`model validation <validating-objects>`.
:mod:`django.contrib.redirects`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -182,6 +185,13 @@ Models
and :meth:`.QuerySet.aupdate_or_create` methods allows specifying a different
field values for the create operation.
* The new ``violation_error_code`` attribute of
:class:`~django.db.models.BaseConstraint`,
:class:`~django.db.models.CheckConstraint`, and
:class:`~django.db.models.UniqueConstraint` allows customizing the ``code``
of ``ValidationError`` raised during
:ref:`model validation <validating-objects>`.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -77,17 +77,26 @@ class BaseConstraintTests(SimpleTestCase):
"custom base_name message",
)
def test_custom_violation_code_message(self):
c = BaseConstraint(name="base_name", violation_error_code="custom_code")
self.assertEqual(c.violation_error_code, "custom_code")
def test_deconstruction(self):
constraint = BaseConstraint(
name="base_name",
violation_error_message="custom %(name)s message",
violation_error_code="custom_code",
)
path, args, kwargs = constraint.deconstruct()
self.assertEqual(path, "django.db.models.BaseConstraint")
self.assertEqual(args, ())
self.assertEqual(
kwargs,
{"name": "base_name", "violation_error_message": "custom %(name)s message"},
{
"name": "base_name",
"violation_error_message": "custom %(name)s message",
"violation_error_code": "custom_code",
},
)
def test_deprecation(self):
@ -148,6 +157,20 @@ class CheckConstraintTests(TestCase):
check=check1, name="price", violation_error_message="custom error"
),
)
self.assertNotEqual(
models.CheckConstraint(check=check1, name="price"),
models.CheckConstraint(
check=check1, name="price", violation_error_code="custom_code"
),
)
self.assertEqual(
models.CheckConstraint(
check=check1, name="price", violation_error_code="custom_code"
),
models.CheckConstraint(
check=check1, name="price", violation_error_code="custom_code"
),
)
def test_repr(self):
constraint = models.CheckConstraint(
@ -172,6 +195,18 @@ class CheckConstraintTests(TestCase):
"violation_error_message='More than 1'>",
)
def test_repr_with_violation_error_code(self):
constraint = models.CheckConstraint(
check=models.Q(price__lt=1),
name="price_lt_one",
violation_error_code="more_than_one",
)
self.assertEqual(
repr(constraint),
"<CheckConstraint: check=(AND: ('price__lt', 1)) name='price_lt_one' "
"violation_error_code='more_than_one'>",
)
def test_invalid_check_types(self):
msg = "CheckConstraint.check must be a Q instance or boolean expression."
with self.assertRaisesMessage(TypeError, msg):
@ -237,6 +272,21 @@ class CheckConstraintTests(TestCase):
# Valid product.
constraint.validate(Product, Product(price=10, discounted_price=5))
def test_validate_custom_error(self):
check = models.Q(price__gt=models.F("discounted_price"))
constraint = models.CheckConstraint(
check=check,
name="price",
violation_error_message="discount is fake",
violation_error_code="fake_discount",
)
# Invalid product.
invalid_product = Product(price=10, discounted_price=42)
msg = "discount is fake"
with self.assertRaisesMessage(ValidationError, msg) as cm:
constraint.validate(Product, invalid_product)
self.assertEqual(cm.exception.code, "fake_discount")
def test_validate_boolean_expressions(self):
constraint = models.CheckConstraint(
check=models.expressions.ExpressionWrapper(
@ -341,6 +391,30 @@ class UniqueConstraintTests(TestCase):
violation_error_message="custom error",
),
)
self.assertNotEqual(
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_code="custom_error",
),
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_code="other_custom_error",
),
)
self.assertEqual(
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_code="custom_error",
),
models.UniqueConstraint(
fields=["foo", "bar"],
name="unique",
violation_error_code="custom_error",
),
)
def test_eq_with_condition(self):
self.assertEqual(
@ -512,6 +586,20 @@ class UniqueConstraintTests(TestCase):
),
)
def test_repr_with_violation_error_code(self):
constraint = models.UniqueConstraint(
models.F("baz__lower"),
name="unique_lower_baz",
violation_error_code="baz",
)
self.assertEqual(
repr(constraint),
(
"<UniqueConstraint: expressions=(F(baz__lower),) "
"name='unique_lower_baz' violation_error_code='baz'>"
),
)
def test_deconstruction(self):
fields = ["foo", "bar"]
name = "unique_fields"
@ -656,12 +744,16 @@ class UniqueConstraintTests(TestCase):
def test_validate(self):
constraint = UniqueConstraintProduct._meta.constraints[0]
# Custom message and error code are ignored.
constraint.violation_error_message = "Custom message"
constraint.violation_error_code = "custom_code"
msg = "Unique constraint product with this Name and Color already exists."
non_unique_product = UniqueConstraintProduct(
name=self.p1.name, color=self.p1.color
)
with self.assertRaisesMessage(ValidationError, msg):
with self.assertRaisesMessage(ValidationError, msg) as cm:
constraint.validate(UniqueConstraintProduct, non_unique_product)
self.assertEqual(cm.exception.code, "unique_together")
# Null values are ignored.
constraint.validate(
UniqueConstraintProduct,
@ -716,6 +808,20 @@ class UniqueConstraintTests(TestCase):
exclude={"name"},
)
@skipUnlessDBFeature("supports_partial_indexes")
def test_validate_conditon_custom_error(self):
p1 = UniqueConstraintConditionProduct.objects.create(name="p1")
constraint = UniqueConstraintConditionProduct._meta.constraints[0]
constraint.violation_error_message = "Custom message"
constraint.violation_error_code = "custom_code"
msg = "Custom message"
with self.assertRaisesMessage(ValidationError, msg) as cm:
constraint.validate(
UniqueConstraintConditionProduct,
UniqueConstraintConditionProduct(name=p1.name, color=None),
)
self.assertEqual(cm.exception.code, "custom_code")
def test_validate_expression(self):
constraint = models.UniqueConstraint(Lower("name"), name="name_lower_uniq")
msg = "Constraint “name_lower_uniq” is violated."

View File

@ -397,6 +397,17 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
"(F(datespan), '-|-')] name='exclude_overlapping' "
"violation_error_message='Overlapping must be excluded'>",
)
constraint = ExclusionConstraint(
name="exclude_overlapping",
expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
violation_error_code="overlapping_must_be_excluded",
)
self.assertEqual(
repr(constraint),
"<ExclusionConstraint: index_type='GIST' expressions=["
"(F(datespan), '-|-')] name='exclude_overlapping' "
"violation_error_code='overlapping_must_be_excluded'>",
)
def test_eq(self):
constraint_1 = ExclusionConstraint(
@ -470,6 +481,16 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
condition=Q(cancelled=False),
violation_error_message="other custom error",
)
constraint_12 = ExclusionConstraint(
name="exclude_overlapping",
expressions=[
(F("datespan"), RangeOperators.OVERLAPS),
(F("room"), RangeOperators.EQUAL),
],
condition=Q(cancelled=False),
violation_error_code="custom_code",
violation_error_message="other custom error",
)
self.assertEqual(constraint_1, constraint_1)
self.assertEqual(constraint_1, mock.ANY)
self.assertNotEqual(constraint_1, constraint_2)
@ -483,7 +504,9 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
self.assertNotEqual(constraint_5, constraint_6)
self.assertNotEqual(constraint_1, object())
self.assertNotEqual(constraint_10, constraint_11)
self.assertNotEqual(constraint_11, constraint_12)
self.assertEqual(constraint_10, constraint_10)
self.assertEqual(constraint_12, constraint_12)
def test_deconstruct(self):
constraint = ExclusionConstraint(
@ -760,17 +783,32 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
constraint = ExclusionConstraint(
name="ints_adjacent",
expressions=[("ints", RangeOperators.ADJACENT_TO)],
violation_error_code="custom_code",
violation_error_message="Custom error message.",
)
range_obj = RangesModel.objects.create(ints=(20, 50))
constraint.validate(RangesModel, range_obj)
msg = "Custom error message."
with self.assertRaisesMessage(ValidationError, msg):
with self.assertRaisesMessage(ValidationError, msg) as cm:
constraint.validate(RangesModel, RangesModel(ints=(10, 20)))
self.assertEqual(cm.exception.code, "custom_code")
constraint.validate(RangesModel, RangesModel(ints=(10, 19)))
constraint.validate(RangesModel, RangesModel(ints=(51, 60)))
constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"})
def test_validate_with_custom_code_and_condition(self):
constraint = ExclusionConstraint(
name="ints_adjacent",
expressions=[("ints", RangeOperators.ADJACENT_TO)],
violation_error_code="custom_code",
condition=Q(ints__lt=(100, 200)),
)
range_obj = RangesModel.objects.create(ints=(20, 50))
constraint.validate(RangesModel, range_obj)
with self.assertRaises(ValidationError) as cm:
constraint.validate(RangesModel, RangesModel(ints=(10, 20)))
self.assertEqual(cm.exception.code, "custom_code")
def test_expressions_with_params(self):
constraint_name = "scene_left_equal"
self.assertNotIn(constraint_name, self.get_constraints(Scene._meta.db_table))