Fixed #35575 -- Added support for constraint validation on GeneratedFields.

This commit is contained in:
Mark Gensler 2024-07-18 08:38:06 +01:00 committed by Sarah Boyce
parent f883bef054
commit 228128618b
7 changed files with 273 additions and 44 deletions

View File

@ -183,17 +183,11 @@ class ExclusionConstraint(BaseConstraint):
) )
replacements = {F(field): value for field, value in replacement_map.items()} replacements = {F(field): value for field, value in replacement_map.items()}
lookups = [] lookups = []
for idx, (expression, operator) in enumerate(self.expressions): for expression, operator in self.expressions:
if isinstance(expression, str): if isinstance(expression, str):
expression = F(expression) expression = F(expression)
if exclude: if exclude and self._expression_refs_exclude(model, expression, exclude):
if isinstance(expression, F): return
if expression.name in exclude:
return
else:
for expr in expression.flatten():
if isinstance(expr, F) and expr.name in exclude:
return
rhs_expression = expression.replace_expressions(replacements) rhs_expression = expression.replace_expressions(replacements)
if hasattr(expression, "get_expression_for_validation"): if hasattr(expression, "get_expression_for_validation"):
expression = expression.get_expression_for_validation() expression = expression.get_expression_for_validation()

View File

@ -1337,18 +1337,33 @@ class Model(AltersData, metaclass=ModelBase):
if exclude is None: if exclude is None:
exclude = set() exclude = set()
meta = meta or self._meta meta = meta or self._meta
field_map = { field_map = {}
field.name: ( generated_fields = []
value for field in meta.local_concrete_fields:
if (value := getattr(self, field.attname)) if field.name in exclude:
and hasattr(value, "resolve_expression") continue
else Value(value, field) if field.generated:
) if any(
for field in meta.local_concrete_fields ref[0] in exclude
if field.name not in exclude and not field.generated for ref in self._get_expr_references(field.expression)
} ):
continue
generated_fields.append(field)
continue
value = getattr(self, field.attname)
if not value or not hasattr(value, "resolve_expression"):
value = Value(value, field)
field_map[field.name] = value
if "pk" not in exclude: if "pk" not in exclude:
field_map["pk"] = Value(self.pk, meta.pk) field_map["pk"] = Value(self.pk, meta.pk)
if generated_fields:
replacements = {F(name): value for name, value in field_map.items()}
for generated_field in generated_fields:
field_map[generated_field.name] = ExpressionWrapper(
generated_field.expression.replace_expressions(replacements),
generated_field.output_field,
)
return field_map return field_map
def prepare_database_save(self, field): def prepare_database_save(self, field):

View File

@ -68,6 +68,19 @@ class BaseConstraint:
def remove_sql(self, model, schema_editor): def remove_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.") raise NotImplementedError("This method must be implemented by a subclass.")
@classmethod
def _expression_refs_exclude(cls, model, expression, exclude):
get_field = model._meta.get_field
for field_name, *__ in model._get_expr_references(expression):
if field_name in exclude:
return True
field = get_field(field_name)
if field.generated and cls._expression_refs_exclude(
model, field.expression, exclude
):
return True
return False
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
raise NotImplementedError("This method must be implemented by a subclass.") raise NotImplementedError("This method must be implemented by a subclass.")
@ -606,36 +619,56 @@ class UniqueConstraint(BaseConstraint):
queryset = model._default_manager.using(using) queryset = model._default_manager.using(using)
if self.fields: if self.fields:
lookup_kwargs = {} lookup_kwargs = {}
generated_field_names = []
for field_name in self.fields: for field_name in self.fields:
if exclude and field_name in exclude: if exclude and field_name in exclude:
return return
field = model._meta.get_field(field_name) field = model._meta.get_field(field_name)
lookup_value = getattr(instance, field.attname) if field.generated:
if ( if exclude and self._expression_refs_exclude(
self.nulls_distinct is not False model, field.expression, exclude
and lookup_value is None ):
or ( return
lookup_value == "" generated_field_names.append(field.name)
and connections[ else:
using lookup_value = getattr(instance, field.attname)
].features.interprets_empty_strings_as_nulls if (
) self.nulls_distinct is not False
): and lookup_value is None
# A composite constraint containing NULL value cannot cause or (
# a violation since NULL != NULL in SQL. lookup_value == ""
return and connections[
lookup_kwargs[field.name] = lookup_value using
queryset = queryset.filter(**lookup_kwargs) ].features.interprets_empty_strings_as_nulls
)
):
# A composite constraint containing NULL value cannot cause
# a violation since NULL != NULL in SQL.
return
lookup_kwargs[field.name] = lookup_value
lookup_args = []
if generated_field_names:
field_expression_map = instance._get_field_expression_map(
meta=model._meta, exclude=exclude
)
for field_name in generated_field_names:
expression = field_expression_map[field_name]
if self.nulls_distinct is False:
lhs = F(field_name)
condition = Q(Exact(lhs, expression)) | Q(
IsNull(lhs, True), IsNull(expression, True)
)
lookup_args.append(condition)
else:
lookup_kwargs[field_name] = expression
queryset = queryset.filter(*lookup_args, **lookup_kwargs)
else: else:
# Ignore constraints with excluded fields. # Ignore constraints with excluded fields.
if exclude: if exclude and any(
for expression in self.expressions: self._expression_refs_exclude(model, expression, exclude)
if hasattr(expression, "flatten"): for expression in self.expressions
for expr in expression.flatten(): ):
if isinstance(expr, F) and expr.name in exclude: return
return
elif isinstance(expression, F) and expression.name in exclude:
return
replacements = { replacements = {
F(field): value F(field): value
for field, value in instance._get_field_expression_map( for field, value in instance._get_field_expression_map(

View File

@ -215,6 +215,9 @@ Models
methods such as methods such as
:meth:`QuerySet.union()<django.db.models.query.QuerySet.union>` unpredictable. :meth:`QuerySet.union()<django.db.models.query.QuerySet.union>` unpredictable.
* Added support for validation of model constraints which use a
:class:`~django.db.models.GeneratedField`.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1,4 +1,5 @@
from django.db import models from django.db import models
from django.db.models.functions import Coalesce, Lower
class Product(models.Model): class Product(models.Model):
@ -28,6 +29,46 @@ class Product(models.Model):
] ]
class GeneratedFieldStoredProduct(models.Model):
name = models.CharField(max_length=255, null=True)
price = models.IntegerField(null=True)
discounted_price = models.IntegerField(null=True)
rebate = models.GeneratedField(
expression=Coalesce("price", 0)
- Coalesce("discounted_price", Coalesce("price", 0)),
output_field=models.IntegerField(),
db_persist=True,
)
lower_name = models.GeneratedField(
expression=Lower(models.F("name")),
output_field=models.CharField(max_length=255, null=True),
db_persist=True,
)
class Meta:
required_db_features = {"supports_stored_generated_columns"}
class GeneratedFieldVirtualProduct(models.Model):
name = models.CharField(max_length=255, null=True)
price = models.IntegerField(null=True)
discounted_price = models.IntegerField(null=True)
rebate = models.GeneratedField(
expression=Coalesce("price", 0)
- Coalesce("discounted_price", Coalesce("price", 0)),
output_field=models.IntegerField(),
db_persist=False,
)
lower_name = models.GeneratedField(
expression=Lower(models.F("name")),
output_field=models.CharField(max_length=255, null=True),
db_persist=False,
)
class Meta:
required_db_features = {"supports_virtual_generated_columns"}
class UniqueConstraintProduct(models.Model): class UniqueConstraintProduct(models.Model):
name = models.CharField(max_length=255) name = models.CharField(max_length=255)
color = models.CharField(max_length=32, null=True) color = models.CharField(max_length=32, null=True)

View File

@ -4,7 +4,7 @@ from django.core.exceptions import ValidationError
from django.db import IntegrityError, connection, models from django.db import IntegrityError, connection, models
from django.db.models import F from django.db.models import F
from django.db.models.constraints import BaseConstraint, UniqueConstraint from django.db.models.constraints import BaseConstraint, UniqueConstraint
from django.db.models.functions import Abs, Lower, Upper from django.db.models.functions import Abs, Lower, Sqrt, Upper
from django.db.transaction import atomic from django.db.transaction import atomic
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import ignore_warnings from django.test.utils import ignore_warnings
@ -13,6 +13,8 @@ from django.utils.deprecation import RemovedInDjango60Warning
from .models import ( from .models import (
ChildModel, ChildModel,
ChildUniqueConstraintProduct, ChildUniqueConstraintProduct,
GeneratedFieldStoredProduct,
GeneratedFieldVirtualProduct,
JSONFieldModel, JSONFieldModel,
ModelWithDatabaseDefault, ModelWithDatabaseDefault,
Product, Product,
@ -384,6 +386,29 @@ class CheckConstraintTests(TestCase):
with self.assertRaisesMessage(ValidationError, msg): with self.assertRaisesMessage(ValidationError, msg):
json_exact_constraint.validate(JSONFieldModel, JSONFieldModel(data=data)) json_exact_constraint.validate(JSONFieldModel, JSONFieldModel(data=data))
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_validate_generated_field_stored(self):
self.assertGeneratedFieldIsValidated(model=GeneratedFieldStoredProduct)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_validate_generated_field_virtual(self):
self.assertGeneratedFieldIsValidated(model=GeneratedFieldVirtualProduct)
def assertGeneratedFieldIsValidated(self, model):
constraint = models.CheckConstraint(
condition=models.Q(rebate__range=(0, 100)), name="bounded_rebate"
)
constraint.validate(model, model(price=50, discounted_price=20))
invalid_product = model(price=1200, discounted_price=500)
msg = f"Constraint “{constraint.name}” is violated."
with self.assertRaisesMessage(ValidationError, msg):
constraint.validate(model, invalid_product)
# Excluding referenced or generated fields should skip validation.
constraint.validate(model, invalid_product, exclude={"price"})
constraint.validate(model, invalid_product, exclude={"rebate"})
def test_check_deprecation(self): def test_check_deprecation(self):
msg = "CheckConstraint.check is deprecated in favor of `.condition`." msg = "CheckConstraint.check is deprecated in favor of `.condition`."
condition = models.Q(foo="bar") condition = models.Q(foo="bar")
@ -1062,6 +1087,90 @@ class UniqueConstraintTests(TestCase):
exclude={"name"}, exclude={"name"},
) )
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_validate_expression_generated_field_stored(self):
self.assertGeneratedFieldWithExpressionIsValidated(
model=GeneratedFieldStoredProduct
)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_validate_expression_generated_field_virtual(self):
self.assertGeneratedFieldWithExpressionIsValidated(
model=GeneratedFieldVirtualProduct
)
def assertGeneratedFieldWithExpressionIsValidated(self, model):
constraint = UniqueConstraint(Sqrt("rebate"), name="unique_rebate_sqrt")
model.objects.create(price=100, discounted_price=84)
valid_product = model(price=100, discounted_price=75)
constraint.validate(model, valid_product)
invalid_product = model(price=20, discounted_price=4)
with self.assertRaisesMessage(
ValidationError, f"Constraint “{constraint.name}” is violated."
):
constraint.validate(model, invalid_product)
# Excluding referenced or generated fields should skip validation.
constraint.validate(model, invalid_product, exclude={"rebate"})
constraint.validate(model, invalid_product, exclude={"price"})
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_validate_fields_generated_field_stored(self):
self.assertGeneratedFieldWithFieldsIsValidated(
model=GeneratedFieldStoredProduct
)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_validate_fields_generated_field_virtual(self):
self.assertGeneratedFieldWithFieldsIsValidated(
model=GeneratedFieldVirtualProduct
)
def assertGeneratedFieldWithFieldsIsValidated(self, model):
constraint = models.UniqueConstraint(
fields=["lower_name"], name="lower_name_unique"
)
model.objects.create(name="Box")
constraint.validate(model, model(name="Case"))
invalid_product = model(name="BOX")
msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
with self.assertRaisesMessage(ValidationError, msg):
constraint.validate(model, invalid_product)
# Excluding referenced or generated fields should skip validation.
constraint.validate(model, invalid_product, exclude={"lower_name"})
constraint.validate(model, invalid_product, exclude={"name"})
@skipUnlessDBFeature("supports_stored_generated_columns")
def test_validate_fields_generated_field_stored_nulls_distinct(self):
self.assertGeneratedFieldNullsDistinctIsValidated(
model=GeneratedFieldStoredProduct
)
@skipUnlessDBFeature("supports_virtual_generated_columns")
def test_validate_fields_generated_field_virtual_nulls_distinct(self):
self.assertGeneratedFieldNullsDistinctIsValidated(
model=GeneratedFieldVirtualProduct
)
def assertGeneratedFieldNullsDistinctIsValidated(self, model):
constraint = models.UniqueConstraint(
fields=["lower_name"],
name="lower_name_unique_nulls_distinct",
nulls_distinct=False,
)
model.objects.create(name=None)
valid_product = model(name="Box")
constraint.validate(model, valid_product)
invalid_product = model(name=None)
msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
with self.assertRaisesMessage(ValidationError, msg):
constraint.validate(model, invalid_product)
@skipUnlessDBFeature("supports_table_check_constraints") @skipUnlessDBFeature("supports_table_check_constraints")
def test_validate_nullable_textfield_with_isnull_true(self): def test_validate_nullable_textfield_with_isnull_true(self):
is_null_constraint = models.UniqueConstraint( is_null_constraint = models.UniqueConstraint(

View File

@ -14,6 +14,7 @@ from django.db.models import (
F, F,
ForeignKey, ForeignKey,
Func, Func,
GeneratedField,
IntegerField, IntegerField,
Model, Model,
Q, Q,
@ -32,6 +33,7 @@ try:
from django.contrib.postgres.constraints import ExclusionConstraint from django.contrib.postgres.constraints import ExclusionConstraint
from django.contrib.postgres.fields import ( from django.contrib.postgres.fields import (
DateTimeRangeField, DateTimeRangeField,
IntegerRangeField,
RangeBoundary, RangeBoundary,
RangeOperators, RangeOperators,
) )
@ -866,6 +868,38 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
constraint.validate(RangesModel, RangesModel(ints=(51, 60))) constraint.validate(RangesModel, RangesModel(ints=(51, 60)))
constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"}) constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"})
@skipUnlessDBFeature("supports_stored_generated_columns")
@isolate_apps("postgres_tests")
def test_validate_generated_field_range_adjacent(self):
class RangesModelGeneratedField(Model):
ints = IntegerRangeField(blank=True, null=True)
ints_generated = GeneratedField(
expression=F("ints"),
output_field=IntegerRangeField(null=True),
db_persist=True,
)
with connection.schema_editor() as editor:
editor.create_model(RangesModelGeneratedField)
constraint = ExclusionConstraint(
name="ints_adjacent",
expressions=[("ints_generated", RangeOperators.ADJACENT_TO)],
violation_error_code="custom_code",
violation_error_message="Custom error message.",
)
RangesModelGeneratedField.objects.create(ints=(20, 50))
range_obj = RangesModelGeneratedField(ints=(3, 20))
with self.assertRaisesMessage(ValidationError, "Custom error message."):
constraint.validate(RangesModelGeneratedField, range_obj)
# Excluding referenced or generated field should skip validation.
constraint.validate(RangesModelGeneratedField, range_obj, exclude={"ints"})
constraint.validate(
RangesModelGeneratedField, range_obj, exclude={"ints_generated"}
)
def test_validate_with_custom_code_and_condition(self): def test_validate_with_custom_code_and_condition(self):
constraint = ExclusionConstraint( constraint = ExclusionConstraint(
name="ints_adjacent", name="ints_adjacent",