mirror of https://github.com/django/django.git
Fixed #35575 -- Added support for constraint validation on GeneratedFields.
This commit is contained in:
parent
f883bef054
commit
228128618b
|
@ -183,16 +183,10 @@ class ExclusionConstraint(BaseConstraint):
|
|||
)
|
||||
replacements = {F(field): value for field, value in replacement_map.items()}
|
||||
lookups = []
|
||||
for idx, (expression, operator) in enumerate(self.expressions):
|
||||
for expression, operator in self.expressions:
|
||||
if isinstance(expression, str):
|
||||
expression = F(expression)
|
||||
if exclude:
|
||||
if isinstance(expression, F):
|
||||
if expression.name in exclude:
|
||||
return
|
||||
else:
|
||||
for expr in expression.flatten():
|
||||
if isinstance(expr, F) and expr.name in exclude:
|
||||
if exclude and self._expression_refs_exclude(model, expression, exclude):
|
||||
return
|
||||
rhs_expression = expression.replace_expressions(replacements)
|
||||
if hasattr(expression, "get_expression_for_validation"):
|
||||
|
|
|
@ -1337,18 +1337,33 @@ class Model(AltersData, metaclass=ModelBase):
|
|||
if exclude is None:
|
||||
exclude = set()
|
||||
meta = meta or self._meta
|
||||
field_map = {
|
||||
field.name: (
|
||||
value
|
||||
if (value := getattr(self, field.attname))
|
||||
and hasattr(value, "resolve_expression")
|
||||
else Value(value, field)
|
||||
)
|
||||
for field in meta.local_concrete_fields
|
||||
if field.name not in exclude and not field.generated
|
||||
}
|
||||
field_map = {}
|
||||
generated_fields = []
|
||||
for field in meta.local_concrete_fields:
|
||||
if field.name in exclude:
|
||||
continue
|
||||
if field.generated:
|
||||
if any(
|
||||
ref[0] in exclude
|
||||
for ref in self._get_expr_references(field.expression)
|
||||
):
|
||||
continue
|
||||
generated_fields.append(field)
|
||||
continue
|
||||
value = getattr(self, field.attname)
|
||||
if not value or not hasattr(value, "resolve_expression"):
|
||||
value = Value(value, field)
|
||||
field_map[field.name] = value
|
||||
if "pk" not in exclude:
|
||||
field_map["pk"] = Value(self.pk, meta.pk)
|
||||
if generated_fields:
|
||||
replacements = {F(name): value for name, value in field_map.items()}
|
||||
for generated_field in generated_fields:
|
||||
field_map[generated_field.name] = ExpressionWrapper(
|
||||
generated_field.expression.replace_expressions(replacements),
|
||||
generated_field.output_field,
|
||||
)
|
||||
|
||||
return field_map
|
||||
|
||||
def prepare_database_save(self, field):
|
||||
|
|
|
@ -68,6 +68,19 @@ class BaseConstraint:
|
|||
def remove_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
@classmethod
|
||||
def _expression_refs_exclude(cls, model, expression, exclude):
|
||||
get_field = model._meta.get_field
|
||||
for field_name, *__ in model._get_expr_references(expression):
|
||||
if field_name in exclude:
|
||||
return True
|
||||
field = get_field(field_name)
|
||||
if field.generated and cls._expression_refs_exclude(
|
||||
model, field.expression, exclude
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
|
@ -606,10 +619,18 @@ class UniqueConstraint(BaseConstraint):
|
|||
queryset = model._default_manager.using(using)
|
||||
if self.fields:
|
||||
lookup_kwargs = {}
|
||||
generated_field_names = []
|
||||
for field_name in self.fields:
|
||||
if exclude and field_name in exclude:
|
||||
return
|
||||
field = model._meta.get_field(field_name)
|
||||
if field.generated:
|
||||
if exclude and self._expression_refs_exclude(
|
||||
model, field.expression, exclude
|
||||
):
|
||||
return
|
||||
generated_field_names.append(field.name)
|
||||
else:
|
||||
lookup_value = getattr(instance, field.attname)
|
||||
if (
|
||||
self.nulls_distinct is not False
|
||||
|
@ -625,16 +646,28 @@ class UniqueConstraint(BaseConstraint):
|
|||
# a violation since NULL != NULL in SQL.
|
||||
return
|
||||
lookup_kwargs[field.name] = lookup_value
|
||||
queryset = queryset.filter(**lookup_kwargs)
|
||||
lookup_args = []
|
||||
if generated_field_names:
|
||||
field_expression_map = instance._get_field_expression_map(
|
||||
meta=model._meta, exclude=exclude
|
||||
)
|
||||
for field_name in generated_field_names:
|
||||
expression = field_expression_map[field_name]
|
||||
if self.nulls_distinct is False:
|
||||
lhs = F(field_name)
|
||||
condition = Q(Exact(lhs, expression)) | Q(
|
||||
IsNull(lhs, True), IsNull(expression, True)
|
||||
)
|
||||
lookup_args.append(condition)
|
||||
else:
|
||||
lookup_kwargs[field_name] = expression
|
||||
queryset = queryset.filter(*lookup_args, **lookup_kwargs)
|
||||
else:
|
||||
# Ignore constraints with excluded fields.
|
||||
if exclude:
|
||||
for expression in self.expressions:
|
||||
if hasattr(expression, "flatten"):
|
||||
for expr in expression.flatten():
|
||||
if isinstance(expr, F) and expr.name in exclude:
|
||||
return
|
||||
elif isinstance(expression, F) and expression.name in exclude:
|
||||
if exclude and any(
|
||||
self._expression_refs_exclude(model, expression, exclude)
|
||||
for expression in self.expressions
|
||||
):
|
||||
return
|
||||
replacements = {
|
||||
F(field): value
|
||||
|
|
|
@ -215,6 +215,9 @@ Models
|
|||
methods such as
|
||||
:meth:`QuerySet.union()<django.db.models.query.QuerySet.union>` unpredictable.
|
||||
|
||||
* Added support for validation of model constraints which use a
|
||||
:class:`~django.db.models.GeneratedField`.
|
||||
|
||||
Requests and Responses
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from django.db import models
|
||||
from django.db.models.functions import Coalesce, Lower
|
||||
|
||||
|
||||
class Product(models.Model):
|
||||
|
@ -28,6 +29,46 @@ class Product(models.Model):
|
|||
]
|
||||
|
||||
|
||||
class GeneratedFieldStoredProduct(models.Model):
|
||||
name = models.CharField(max_length=255, null=True)
|
||||
price = models.IntegerField(null=True)
|
||||
discounted_price = models.IntegerField(null=True)
|
||||
rebate = models.GeneratedField(
|
||||
expression=Coalesce("price", 0)
|
||||
- Coalesce("discounted_price", Coalesce("price", 0)),
|
||||
output_field=models.IntegerField(),
|
||||
db_persist=True,
|
||||
)
|
||||
lower_name = models.GeneratedField(
|
||||
expression=Lower(models.F("name")),
|
||||
output_field=models.CharField(max_length=255, null=True),
|
||||
db_persist=True,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
required_db_features = {"supports_stored_generated_columns"}
|
||||
|
||||
|
||||
class GeneratedFieldVirtualProduct(models.Model):
|
||||
name = models.CharField(max_length=255, null=True)
|
||||
price = models.IntegerField(null=True)
|
||||
discounted_price = models.IntegerField(null=True)
|
||||
rebate = models.GeneratedField(
|
||||
expression=Coalesce("price", 0)
|
||||
- Coalesce("discounted_price", Coalesce("price", 0)),
|
||||
output_field=models.IntegerField(),
|
||||
db_persist=False,
|
||||
)
|
||||
lower_name = models.GeneratedField(
|
||||
expression=Lower(models.F("name")),
|
||||
output_field=models.CharField(max_length=255, null=True),
|
||||
db_persist=False,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
required_db_features = {"supports_virtual_generated_columns"}
|
||||
|
||||
|
||||
class UniqueConstraintProduct(models.Model):
|
||||
name = models.CharField(max_length=255)
|
||||
color = models.CharField(max_length=32, null=True)
|
||||
|
|
|
@ -4,7 +4,7 @@ from django.core.exceptions import ValidationError
|
|||
from django.db import IntegrityError, connection, models
|
||||
from django.db.models import F
|
||||
from django.db.models.constraints import BaseConstraint, UniqueConstraint
|
||||
from django.db.models.functions import Abs, Lower, Upper
|
||||
from django.db.models.functions import Abs, Lower, Sqrt, Upper
|
||||
from django.db.transaction import atomic
|
||||
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||
from django.test.utils import ignore_warnings
|
||||
|
@ -13,6 +13,8 @@ from django.utils.deprecation import RemovedInDjango60Warning
|
|||
from .models import (
|
||||
ChildModel,
|
||||
ChildUniqueConstraintProduct,
|
||||
GeneratedFieldStoredProduct,
|
||||
GeneratedFieldVirtualProduct,
|
||||
JSONFieldModel,
|
||||
ModelWithDatabaseDefault,
|
||||
Product,
|
||||
|
@ -384,6 +386,29 @@ class CheckConstraintTests(TestCase):
|
|||
with self.assertRaisesMessage(ValidationError, msg):
|
||||
json_exact_constraint.validate(JSONFieldModel, JSONFieldModel(data=data))
|
||||
|
||||
@skipUnlessDBFeature("supports_stored_generated_columns")
|
||||
def test_validate_generated_field_stored(self):
|
||||
self.assertGeneratedFieldIsValidated(model=GeneratedFieldStoredProduct)
|
||||
|
||||
@skipUnlessDBFeature("supports_virtual_generated_columns")
|
||||
def test_validate_generated_field_virtual(self):
|
||||
self.assertGeneratedFieldIsValidated(model=GeneratedFieldVirtualProduct)
|
||||
|
||||
def assertGeneratedFieldIsValidated(self, model):
|
||||
constraint = models.CheckConstraint(
|
||||
condition=models.Q(rebate__range=(0, 100)), name="bounded_rebate"
|
||||
)
|
||||
constraint.validate(model, model(price=50, discounted_price=20))
|
||||
|
||||
invalid_product = model(price=1200, discounted_price=500)
|
||||
msg = f"Constraint “{constraint.name}” is violated."
|
||||
with self.assertRaisesMessage(ValidationError, msg):
|
||||
constraint.validate(model, invalid_product)
|
||||
|
||||
# Excluding referenced or generated fields should skip validation.
|
||||
constraint.validate(model, invalid_product, exclude={"price"})
|
||||
constraint.validate(model, invalid_product, exclude={"rebate"})
|
||||
|
||||
def test_check_deprecation(self):
|
||||
msg = "CheckConstraint.check is deprecated in favor of `.condition`."
|
||||
condition = models.Q(foo="bar")
|
||||
|
@ -1062,6 +1087,90 @@ class UniqueConstraintTests(TestCase):
|
|||
exclude={"name"},
|
||||
)
|
||||
|
||||
@skipUnlessDBFeature("supports_stored_generated_columns")
|
||||
def test_validate_expression_generated_field_stored(self):
|
||||
self.assertGeneratedFieldWithExpressionIsValidated(
|
||||
model=GeneratedFieldStoredProduct
|
||||
)
|
||||
|
||||
@skipUnlessDBFeature("supports_virtual_generated_columns")
|
||||
def test_validate_expression_generated_field_virtual(self):
|
||||
self.assertGeneratedFieldWithExpressionIsValidated(
|
||||
model=GeneratedFieldVirtualProduct
|
||||
)
|
||||
|
||||
def assertGeneratedFieldWithExpressionIsValidated(self, model):
|
||||
constraint = UniqueConstraint(Sqrt("rebate"), name="unique_rebate_sqrt")
|
||||
model.objects.create(price=100, discounted_price=84)
|
||||
|
||||
valid_product = model(price=100, discounted_price=75)
|
||||
constraint.validate(model, valid_product)
|
||||
|
||||
invalid_product = model(price=20, discounted_price=4)
|
||||
with self.assertRaisesMessage(
|
||||
ValidationError, f"Constraint “{constraint.name}” is violated."
|
||||
):
|
||||
constraint.validate(model, invalid_product)
|
||||
|
||||
# Excluding referenced or generated fields should skip validation.
|
||||
constraint.validate(model, invalid_product, exclude={"rebate"})
|
||||
constraint.validate(model, invalid_product, exclude={"price"})
|
||||
|
||||
@skipUnlessDBFeature("supports_stored_generated_columns")
|
||||
def test_validate_fields_generated_field_stored(self):
|
||||
self.assertGeneratedFieldWithFieldsIsValidated(
|
||||
model=GeneratedFieldStoredProduct
|
||||
)
|
||||
|
||||
@skipUnlessDBFeature("supports_virtual_generated_columns")
|
||||
def test_validate_fields_generated_field_virtual(self):
|
||||
self.assertGeneratedFieldWithFieldsIsValidated(
|
||||
model=GeneratedFieldVirtualProduct
|
||||
)
|
||||
|
||||
def assertGeneratedFieldWithFieldsIsValidated(self, model):
|
||||
constraint = models.UniqueConstraint(
|
||||
fields=["lower_name"], name="lower_name_unique"
|
||||
)
|
||||
model.objects.create(name="Box")
|
||||
constraint.validate(model, model(name="Case"))
|
||||
|
||||
invalid_product = model(name="BOX")
|
||||
msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
|
||||
with self.assertRaisesMessage(ValidationError, msg):
|
||||
constraint.validate(model, invalid_product)
|
||||
|
||||
# Excluding referenced or generated fields should skip validation.
|
||||
constraint.validate(model, invalid_product, exclude={"lower_name"})
|
||||
constraint.validate(model, invalid_product, exclude={"name"})
|
||||
|
||||
@skipUnlessDBFeature("supports_stored_generated_columns")
|
||||
def test_validate_fields_generated_field_stored_nulls_distinct(self):
|
||||
self.assertGeneratedFieldNullsDistinctIsValidated(
|
||||
model=GeneratedFieldStoredProduct
|
||||
)
|
||||
|
||||
@skipUnlessDBFeature("supports_virtual_generated_columns")
|
||||
def test_validate_fields_generated_field_virtual_nulls_distinct(self):
|
||||
self.assertGeneratedFieldNullsDistinctIsValidated(
|
||||
model=GeneratedFieldVirtualProduct
|
||||
)
|
||||
|
||||
def assertGeneratedFieldNullsDistinctIsValidated(self, model):
|
||||
constraint = models.UniqueConstraint(
|
||||
fields=["lower_name"],
|
||||
name="lower_name_unique_nulls_distinct",
|
||||
nulls_distinct=False,
|
||||
)
|
||||
model.objects.create(name=None)
|
||||
valid_product = model(name="Box")
|
||||
constraint.validate(model, valid_product)
|
||||
|
||||
invalid_product = model(name=None)
|
||||
msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
|
||||
with self.assertRaisesMessage(ValidationError, msg):
|
||||
constraint.validate(model, invalid_product)
|
||||
|
||||
@skipUnlessDBFeature("supports_table_check_constraints")
|
||||
def test_validate_nullable_textfield_with_isnull_true(self):
|
||||
is_null_constraint = models.UniqueConstraint(
|
||||
|
|
|
@ -14,6 +14,7 @@ from django.db.models import (
|
|||
F,
|
||||
ForeignKey,
|
||||
Func,
|
||||
GeneratedField,
|
||||
IntegerField,
|
||||
Model,
|
||||
Q,
|
||||
|
@ -32,6 +33,7 @@ try:
|
|||
from django.contrib.postgres.constraints import ExclusionConstraint
|
||||
from django.contrib.postgres.fields import (
|
||||
DateTimeRangeField,
|
||||
IntegerRangeField,
|
||||
RangeBoundary,
|
||||
RangeOperators,
|
||||
)
|
||||
|
@ -866,6 +868,38 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
|
|||
constraint.validate(RangesModel, RangesModel(ints=(51, 60)))
|
||||
constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"})
|
||||
|
||||
@skipUnlessDBFeature("supports_stored_generated_columns")
|
||||
@isolate_apps("postgres_tests")
|
||||
def test_validate_generated_field_range_adjacent(self):
|
||||
class RangesModelGeneratedField(Model):
|
||||
ints = IntegerRangeField(blank=True, null=True)
|
||||
ints_generated = GeneratedField(
|
||||
expression=F("ints"),
|
||||
output_field=IntegerRangeField(null=True),
|
||||
db_persist=True,
|
||||
)
|
||||
|
||||
with connection.schema_editor() as editor:
|
||||
editor.create_model(RangesModelGeneratedField)
|
||||
|
||||
constraint = ExclusionConstraint(
|
||||
name="ints_adjacent",
|
||||
expressions=[("ints_generated", RangeOperators.ADJACENT_TO)],
|
||||
violation_error_code="custom_code",
|
||||
violation_error_message="Custom error message.",
|
||||
)
|
||||
RangesModelGeneratedField.objects.create(ints=(20, 50))
|
||||
|
||||
range_obj = RangesModelGeneratedField(ints=(3, 20))
|
||||
with self.assertRaisesMessage(ValidationError, "Custom error message."):
|
||||
constraint.validate(RangesModelGeneratedField, range_obj)
|
||||
|
||||
# Excluding referenced or generated field should skip validation.
|
||||
constraint.validate(RangesModelGeneratedField, range_obj, exclude={"ints"})
|
||||
constraint.validate(
|
||||
RangesModelGeneratedField, range_obj, exclude={"ints_generated"}
|
||||
)
|
||||
|
||||
def test_validate_with_custom_code_and_condition(self):
|
||||
constraint = ExclusionConstraint(
|
||||
name="ints_adjacent",
|
||||
|
|
Loading…
Reference in New Issue