Fixed #35575 -- Added support for constraint validation on GeneratedFields.

This commit is contained in:
Mark Gensler 2024-07-18 08:38:06 +01:00 committed by Sarah Boyce
parent f883bef054
commit 228128618b
7 changed files with 273 additions and 44 deletions

View File

@ -183,17 +183,11 @@ class ExclusionConstraint(BaseConstraint):
)
replacements = {F(field): value for field, value in replacement_map.items()}
lookups = []
for idx, (expression, operator) in enumerate(self.expressions):
for expression, operator in self.expressions:
if isinstance(expression, str):
expression = F(expression)
if exclude:
if isinstance(expression, F):
if expression.name in exclude:
return
else:
for expr in expression.flatten():
if isinstance(expr, F) and expr.name in exclude:
return
if exclude and self._expression_refs_exclude(model, expression, exclude):
return
rhs_expression = expression.replace_expressions(replacements)
if hasattr(expression, "get_expression_for_validation"):
expression = expression.get_expression_for_validation()

View File

@ -1337,18 +1337,33 @@ class Model(AltersData, metaclass=ModelBase):
if exclude is None:
exclude = set()
meta = meta or self._meta
field_map = {
field.name: (
value
if (value := getattr(self, field.attname))
and hasattr(value, "resolve_expression")
else Value(value, field)
)
for field in meta.local_concrete_fields
if field.name not in exclude and not field.generated
}
field_map = {}
generated_fields = []
for field in meta.local_concrete_fields:
if field.name in exclude:
continue
if field.generated:
if any(
ref[0] in exclude
for ref in self._get_expr_references(field.expression)
):
continue
generated_fields.append(field)
continue
value = getattr(self, field.attname)
if not value or not hasattr(value, "resolve_expression"):
value = Value(value, field)
field_map[field.name] = value
if "pk" not in exclude:
field_map["pk"] = Value(self.pk, meta.pk)
if generated_fields:
replacements = {F(name): value for name, value in field_map.items()}
for generated_field in generated_fields:
field_map[generated_field.name] = ExpressionWrapper(
generated_field.expression.replace_expressions(replacements),
generated_field.output_field,
)
return field_map
def prepare_database_save(self, field):

View File

@ -68,6 +68,19 @@ class BaseConstraint:
def remove_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
@classmethod
def _expression_refs_exclude(cls, model, expression, exclude):
get_field = model._meta.get_field
for field_name, *__ in model._get_expr_references(expression):
if field_name in exclude:
return True
field = get_field(field_name)
if field.generated and cls._expression_refs_exclude(
model, field.expression, exclude
):
return True
return False
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
raise NotImplementedError("This method must be implemented by a subclass.")
@ -606,36 +619,56 @@ class UniqueConstraint(BaseConstraint):
queryset = model._default_manager.using(using)
if self.fields:
lookup_kwargs = {}
generated_field_names = []
for field_name in self.fields:
if exclude and field_name in exclude:
return
field = model._meta.get_field(field_name)
lookup_value = getattr(instance, field.attname)
if (
self.nulls_distinct is not False
and lookup_value is None
or (
lookup_value == ""
and connections[
using
].features.interprets_empty_strings_as_nulls
)
):
# A composite constraint containing NULL value cannot cause
# a violation since NULL != NULL in SQL.
return
lookup_kwargs[field.name] = lookup_value
queryset = queryset.filter(**lookup_kwargs)
if field.generated:
if exclude and self._expression_refs_exclude(
model, field.expression, exclude
):
return
generated_field_names.append(field.name)
else:
lookup_value = getattr(instance, field.attname)
if (
self.nulls_distinct is not False
and lookup_value is None
or (
lookup_value == ""
and connections[
using
].features.interprets_empty_strings_as_nulls
)
):
# A composite constraint containing NULL value cannot cause
# a violation since NULL != NULL in SQL.
return
lookup_kwargs[field.name] = lookup_value
lookup_args = []
if generated_field_names:
field_expression_map = instance._get_field_expression_map(
meta=model._meta, exclude=exclude
)
for field_name in generated_field_names:
expression = field_expression_map[field_name]
if self.nulls_distinct is False:
lhs = F(field_name)
condition = Q(Exact(lhs, expression)) | Q(
IsNull(lhs, True), IsNull(expression, True)
)
lookup_args.append(condition)
else:
lookup_kwargs[field_name] = expression
queryset = queryset.filter(*lookup_args, **lookup_kwargs)
else:
# Ignore constraints with excluded fields.
if exclude:
for expression in self.expressions:
if hasattr(expression, "flatten"):
for expr in expression.flatten():
if isinstance(expr, F) and expr.name in exclude:
return
elif isinstance(expression, F) and expression.name in exclude:
return
if exclude and any(
self._expression_refs_exclude(model, expression, exclude)
for expression in self.expressions
):
return
replacements = {
F(field): value
for field, value in instance._get_field_expression_map(

View File

@ -215,6 +215,9 @@ Models
methods such as
:meth:`QuerySet.union()<django.db.models.query.QuerySet.union>` unpredictable.
* Added support for validation of model constraints which use a
:class:`~django.db.models.GeneratedField`.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1,4 +1,5 @@
from django.db import models
from django.db.models.functions import Coalesce, Lower
class Product(models.Model):
@ -28,6 +29,46 @@ class Product(models.Model):
]
class GeneratedFieldStoredProduct(models.Model):
name = models.CharField(max_length=255, null=True)
price = models.IntegerField(null=True)
discounted_price = models.IntegerField(null=True)
rebate = models.GeneratedField(
expression=Coalesce("price", 0)
- Coalesce("discounted_price", Coalesce("price", 0)),
output_field=models.IntegerField(),
db_persist=True,
)
lower_name = models.GeneratedField(
expression=Lower(models.F("name")),
output_field=models.CharField(max_length=255, null=True),
db_persist=True,
)
class Meta:
required_db_features = {"supports_stored_generated_columns"}
class GeneratedFieldVirtualProduct(models.Model):
name = models.CharField(max_length=255, null=True)
price = models.IntegerField(null=True)
discounted_price = models.IntegerField(null=True)
rebate = models.GeneratedField(
expression=Coalesce("price", 0)
- Coalesce("discounted_price", Coalesce("price", 0)),
output_field=models.IntegerField(),
db_persist=False,
)
lower_name = models.GeneratedField(
expression=Lower(models.F("name")),
output_field=models.CharField(max_length=255, null=True),
db_persist=False,
)
class Meta:
required_db_features = {"supports_virtual_generated_columns"}
class UniqueConstraintProduct(models.Model):
name = models.CharField(max_length=255)
color = models.CharField(max_length=32, null=True)

View File

@ -4,7 +4,7 @@ from django.core.exceptions import ValidationError
from django.db import IntegrityError, connection, models
from django.db.models import F
from django.db.models.constraints import BaseConstraint, UniqueConstraint
from django.db.models.functions import Abs, Lower, Upper
from django.db.models.functions import Abs, Lower, Sqrt, Upper
from django.db.transaction import atomic
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import ignore_warnings
@ -13,6 +13,8 @@ from django.utils.deprecation import RemovedInDjango60Warning
from .models import (
ChildModel,
ChildUniqueConstraintProduct,
GeneratedFieldStoredProduct,
GeneratedFieldVirtualProduct,
JSONFieldModel,
ModelWithDatabaseDefault,
Product,
@ -384,6 +386,29 @@ class CheckConstraintTests(TestCase):
with self.assertRaisesMessage(ValidationError, msg):
json_exact_constraint.validate(JSONFieldModel, JSONFieldModel(data=data))
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_validate_generated_field_stored(self):
self.assertGeneratedFieldIsValidated(model=GeneratedFieldStoredProduct)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_validate_generated_field_virtual(self):
self.assertGeneratedFieldIsValidated(model=GeneratedFieldVirtualProduct)
def assertGeneratedFieldIsValidated(self, model):
constraint = models.CheckConstraint(
condition=models.Q(rebate__range=(0, 100)), name="bounded_rebate"
)
constraint.validate(model, model(price=50, discounted_price=20))
invalid_product = model(price=1200, discounted_price=500)
msg = f"Constraint “{constraint.name}” is violated."
with self.assertRaisesMessage(ValidationError, msg):
constraint.validate(model, invalid_product)
# Excluding referenced or generated fields should skip validation.
constraint.validate(model, invalid_product, exclude={"price"})
constraint.validate(model, invalid_product, exclude={"rebate"})
def test_check_deprecation(self):
msg = "CheckConstraint.check is deprecated in favor of `.condition`."
condition = models.Q(foo="bar")
@ -1062,6 +1087,90 @@ class UniqueConstraintTests(TestCase):
exclude={"name"},
)
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_validate_expression_generated_field_stored(self):
self.assertGeneratedFieldWithExpressionIsValidated(
model=GeneratedFieldStoredProduct
)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_validate_expression_generated_field_virtual(self):
self.assertGeneratedFieldWithExpressionIsValidated(
model=GeneratedFieldVirtualProduct
)
def assertGeneratedFieldWithExpressionIsValidated(self, model):
constraint = UniqueConstraint(Sqrt("rebate"), name="unique_rebate_sqrt")
model.objects.create(price=100, discounted_price=84)
valid_product = model(price=100, discounted_price=75)
constraint.validate(model, valid_product)
invalid_product = model(price=20, discounted_price=4)
with self.assertRaisesMessage(
ValidationError, f"Constraint “{constraint.name}” is violated."
):
constraint.validate(model, invalid_product)
# Excluding referenced or generated fields should skip validation.
constraint.validate(model, invalid_product, exclude={"rebate"})
constraint.validate(model, invalid_product, exclude={"price"})
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_validate_fields_generated_field_stored(self):
self.assertGeneratedFieldWithFieldsIsValidated(
model=GeneratedFieldStoredProduct
)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_validate_fields_generated_field_virtual(self):
self.assertGeneratedFieldWithFieldsIsValidated(
model=GeneratedFieldVirtualProduct
)
def assertGeneratedFieldWithFieldsIsValidated(self, model):
constraint = models.UniqueConstraint(
fields=["lower_name"], name="lower_name_unique"
)
model.objects.create(name="Box")
constraint.validate(model, model(name="Case"))
invalid_product = model(name="BOX")
msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
with self.assertRaisesMessage(ValidationError, msg):
constraint.validate(model, invalid_product)
# Excluding referenced or generated fields should skip validation.
constraint.validate(model, invalid_product, exclude={"lower_name"})
constraint.validate(model, invalid_product, exclude={"name"})
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_validate_fields_generated_field_stored_nulls_distinct(self):
self.assertGeneratedFieldNullsDistinctIsValidated(
model=GeneratedFieldStoredProduct
)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_validate_fields_generated_field_virtual_nulls_distinct(self):
self.assertGeneratedFieldNullsDistinctIsValidated(
model=GeneratedFieldVirtualProduct
)
def assertGeneratedFieldNullsDistinctIsValidated(self, model):
constraint = models.UniqueConstraint(
fields=["lower_name"],
name="lower_name_unique_nulls_distinct",
nulls_distinct=False,
)
model.objects.create(name=None)
valid_product = model(name="Box")
constraint.validate(model, valid_product)
invalid_product = model(name=None)
msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
with self.assertRaisesMessage(ValidationError, msg):
constraint.validate(model, invalid_product)
@skipUnlessDBFeature("supports_table_check_constraints")
def test_validate_nullable_textfield_with_isnull_true(self):
is_null_constraint = models.UniqueConstraint(

View File

@ -14,6 +14,7 @@ from django.db.models import (
F,
ForeignKey,
Func,
GeneratedField,
IntegerField,
Model,
Q,
@ -32,6 +33,7 @@ try:
from django.contrib.postgres.constraints import ExclusionConstraint
from django.contrib.postgres.fields import (
DateTimeRangeField,
IntegerRangeField,
RangeBoundary,
RangeOperators,
)
@ -866,6 +868,38 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
constraint.validate(RangesModel, RangesModel(ints=(51, 60)))
constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"})
@skipUnlessDBFeature("supports_stored_generated_columns")
@isolate_apps("postgres_tests")
def test_validate_generated_field_range_adjacent(self):
class RangesModelGeneratedField(Model):
ints = IntegerRangeField(blank=True, null=True)
ints_generated = GeneratedField(
expression=F("ints"),
output_field=IntegerRangeField(null=True),
db_persist=True,
)
with connection.schema_editor() as editor:
editor.create_model(RangesModelGeneratedField)
constraint = ExclusionConstraint(
name="ints_adjacent",
expressions=[("ints_generated", RangeOperators.ADJACENT_TO)],
violation_error_code="custom_code",
violation_error_message="Custom error message.",
)
RangesModelGeneratedField.objects.create(ints=(20, 50))
range_obj = RangesModelGeneratedField(ints=(3, 20))
with self.assertRaisesMessage(ValidationError, "Custom error message."):
constraint.validate(RangesModelGeneratedField, range_obj)
# Excluding referenced or generated field should skip validation.
constraint.validate(RangesModelGeneratedField, range_obj, exclude={"ints"})
constraint.validate(
RangesModelGeneratedField, range_obj, exclude={"ints_generated"}
)
def test_validate_with_custom_code_and_condition(self):
constraint = ExclusionConstraint(
name="ints_adjacent",