mirror of https://github.com/django/django.git
Fixed #35575 -- Added support for constraint validation on GeneratedFields.
This commit is contained in:
parent
f883bef054
commit
228128618b
|
@ -183,17 +183,11 @@ class ExclusionConstraint(BaseConstraint):
|
||||||
)
|
)
|
||||||
replacements = {F(field): value for field, value in replacement_map.items()}
|
replacements = {F(field): value for field, value in replacement_map.items()}
|
||||||
lookups = []
|
lookups = []
|
||||||
for idx, (expression, operator) in enumerate(self.expressions):
|
for expression, operator in self.expressions:
|
||||||
if isinstance(expression, str):
|
if isinstance(expression, str):
|
||||||
expression = F(expression)
|
expression = F(expression)
|
||||||
if exclude:
|
if exclude and self._expression_refs_exclude(model, expression, exclude):
|
||||||
if isinstance(expression, F):
|
return
|
||||||
if expression.name in exclude:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
for expr in expression.flatten():
|
|
||||||
if isinstance(expr, F) and expr.name in exclude:
|
|
||||||
return
|
|
||||||
rhs_expression = expression.replace_expressions(replacements)
|
rhs_expression = expression.replace_expressions(replacements)
|
||||||
if hasattr(expression, "get_expression_for_validation"):
|
if hasattr(expression, "get_expression_for_validation"):
|
||||||
expression = expression.get_expression_for_validation()
|
expression = expression.get_expression_for_validation()
|
||||||
|
|
|
@ -1337,18 +1337,33 @@ class Model(AltersData, metaclass=ModelBase):
|
||||||
if exclude is None:
|
if exclude is None:
|
||||||
exclude = set()
|
exclude = set()
|
||||||
meta = meta or self._meta
|
meta = meta or self._meta
|
||||||
field_map = {
|
field_map = {}
|
||||||
field.name: (
|
generated_fields = []
|
||||||
value
|
for field in meta.local_concrete_fields:
|
||||||
if (value := getattr(self, field.attname))
|
if field.name in exclude:
|
||||||
and hasattr(value, "resolve_expression")
|
continue
|
||||||
else Value(value, field)
|
if field.generated:
|
||||||
)
|
if any(
|
||||||
for field in meta.local_concrete_fields
|
ref[0] in exclude
|
||||||
if field.name not in exclude and not field.generated
|
for ref in self._get_expr_references(field.expression)
|
||||||
}
|
):
|
||||||
|
continue
|
||||||
|
generated_fields.append(field)
|
||||||
|
continue
|
||||||
|
value = getattr(self, field.attname)
|
||||||
|
if not value or not hasattr(value, "resolve_expression"):
|
||||||
|
value = Value(value, field)
|
||||||
|
field_map[field.name] = value
|
||||||
if "pk" not in exclude:
|
if "pk" not in exclude:
|
||||||
field_map["pk"] = Value(self.pk, meta.pk)
|
field_map["pk"] = Value(self.pk, meta.pk)
|
||||||
|
if generated_fields:
|
||||||
|
replacements = {F(name): value for name, value in field_map.items()}
|
||||||
|
for generated_field in generated_fields:
|
||||||
|
field_map[generated_field.name] = ExpressionWrapper(
|
||||||
|
generated_field.expression.replace_expressions(replacements),
|
||||||
|
generated_field.output_field,
|
||||||
|
)
|
||||||
|
|
||||||
return field_map
|
return field_map
|
||||||
|
|
||||||
def prepare_database_save(self, field):
|
def prepare_database_save(self, field):
|
||||||
|
|
|
@ -68,6 +68,19 @@ class BaseConstraint:
|
||||||
def remove_sql(self, model, schema_editor):
|
def remove_sql(self, model, schema_editor):
|
||||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _expression_refs_exclude(cls, model, expression, exclude):
|
||||||
|
get_field = model._meta.get_field
|
||||||
|
for field_name, *__ in model._get_expr_references(expression):
|
||||||
|
if field_name in exclude:
|
||||||
|
return True
|
||||||
|
field = get_field(field_name)
|
||||||
|
if field.generated and cls._expression_refs_exclude(
|
||||||
|
model, field.expression, exclude
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
||||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||||
|
|
||||||
|
@ -606,36 +619,56 @@ class UniqueConstraint(BaseConstraint):
|
||||||
queryset = model._default_manager.using(using)
|
queryset = model._default_manager.using(using)
|
||||||
if self.fields:
|
if self.fields:
|
||||||
lookup_kwargs = {}
|
lookup_kwargs = {}
|
||||||
|
generated_field_names = []
|
||||||
for field_name in self.fields:
|
for field_name in self.fields:
|
||||||
if exclude and field_name in exclude:
|
if exclude and field_name in exclude:
|
||||||
return
|
return
|
||||||
field = model._meta.get_field(field_name)
|
field = model._meta.get_field(field_name)
|
||||||
lookup_value = getattr(instance, field.attname)
|
if field.generated:
|
||||||
if (
|
if exclude and self._expression_refs_exclude(
|
||||||
self.nulls_distinct is not False
|
model, field.expression, exclude
|
||||||
and lookup_value is None
|
):
|
||||||
or (
|
return
|
||||||
lookup_value == ""
|
generated_field_names.append(field.name)
|
||||||
and connections[
|
else:
|
||||||
using
|
lookup_value = getattr(instance, field.attname)
|
||||||
].features.interprets_empty_strings_as_nulls
|
if (
|
||||||
)
|
self.nulls_distinct is not False
|
||||||
):
|
and lookup_value is None
|
||||||
# A composite constraint containing NULL value cannot cause
|
or (
|
||||||
# a violation since NULL != NULL in SQL.
|
lookup_value == ""
|
||||||
return
|
and connections[
|
||||||
lookup_kwargs[field.name] = lookup_value
|
using
|
||||||
queryset = queryset.filter(**lookup_kwargs)
|
].features.interprets_empty_strings_as_nulls
|
||||||
|
)
|
||||||
|
):
|
||||||
|
# A composite constraint containing NULL value cannot cause
|
||||||
|
# a violation since NULL != NULL in SQL.
|
||||||
|
return
|
||||||
|
lookup_kwargs[field.name] = lookup_value
|
||||||
|
lookup_args = []
|
||||||
|
if generated_field_names:
|
||||||
|
field_expression_map = instance._get_field_expression_map(
|
||||||
|
meta=model._meta, exclude=exclude
|
||||||
|
)
|
||||||
|
for field_name in generated_field_names:
|
||||||
|
expression = field_expression_map[field_name]
|
||||||
|
if self.nulls_distinct is False:
|
||||||
|
lhs = F(field_name)
|
||||||
|
condition = Q(Exact(lhs, expression)) | Q(
|
||||||
|
IsNull(lhs, True), IsNull(expression, True)
|
||||||
|
)
|
||||||
|
lookup_args.append(condition)
|
||||||
|
else:
|
||||||
|
lookup_kwargs[field_name] = expression
|
||||||
|
queryset = queryset.filter(*lookup_args, **lookup_kwargs)
|
||||||
else:
|
else:
|
||||||
# Ignore constraints with excluded fields.
|
# Ignore constraints with excluded fields.
|
||||||
if exclude:
|
if exclude and any(
|
||||||
for expression in self.expressions:
|
self._expression_refs_exclude(model, expression, exclude)
|
||||||
if hasattr(expression, "flatten"):
|
for expression in self.expressions
|
||||||
for expr in expression.flatten():
|
):
|
||||||
if isinstance(expr, F) and expr.name in exclude:
|
return
|
||||||
return
|
|
||||||
elif isinstance(expression, F) and expression.name in exclude:
|
|
||||||
return
|
|
||||||
replacements = {
|
replacements = {
|
||||||
F(field): value
|
F(field): value
|
||||||
for field, value in instance._get_field_expression_map(
|
for field, value in instance._get_field_expression_map(
|
||||||
|
|
|
@ -215,6 +215,9 @@ Models
|
||||||
methods such as
|
methods such as
|
||||||
:meth:`QuerySet.union()<django.db.models.query.QuerySet.union>` unpredictable.
|
:meth:`QuerySet.union()<django.db.models.query.QuerySet.union>` unpredictable.
|
||||||
|
|
||||||
|
* Added support for validation of model constraints which use a
|
||||||
|
:class:`~django.db.models.GeneratedField`.
|
||||||
|
|
||||||
Requests and Responses
|
Requests and Responses
|
||||||
~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.db.models.functions import Coalesce, Lower
|
||||||
|
|
||||||
|
|
||||||
class Product(models.Model):
|
class Product(models.Model):
|
||||||
|
@ -28,6 +29,46 @@ class Product(models.Model):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratedFieldStoredProduct(models.Model):
|
||||||
|
name = models.CharField(max_length=255, null=True)
|
||||||
|
price = models.IntegerField(null=True)
|
||||||
|
discounted_price = models.IntegerField(null=True)
|
||||||
|
rebate = models.GeneratedField(
|
||||||
|
expression=Coalesce("price", 0)
|
||||||
|
- Coalesce("discounted_price", Coalesce("price", 0)),
|
||||||
|
output_field=models.IntegerField(),
|
||||||
|
db_persist=True,
|
||||||
|
)
|
||||||
|
lower_name = models.GeneratedField(
|
||||||
|
expression=Lower(models.F("name")),
|
||||||
|
output_field=models.CharField(max_length=255, null=True),
|
||||||
|
db_persist=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
required_db_features = {"supports_stored_generated_columns"}
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratedFieldVirtualProduct(models.Model):
|
||||||
|
name = models.CharField(max_length=255, null=True)
|
||||||
|
price = models.IntegerField(null=True)
|
||||||
|
discounted_price = models.IntegerField(null=True)
|
||||||
|
rebate = models.GeneratedField(
|
||||||
|
expression=Coalesce("price", 0)
|
||||||
|
- Coalesce("discounted_price", Coalesce("price", 0)),
|
||||||
|
output_field=models.IntegerField(),
|
||||||
|
db_persist=False,
|
||||||
|
)
|
||||||
|
lower_name = models.GeneratedField(
|
||||||
|
expression=Lower(models.F("name")),
|
||||||
|
output_field=models.CharField(max_length=255, null=True),
|
||||||
|
db_persist=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
required_db_features = {"supports_virtual_generated_columns"}
|
||||||
|
|
||||||
|
|
||||||
class UniqueConstraintProduct(models.Model):
|
class UniqueConstraintProduct(models.Model):
|
||||||
name = models.CharField(max_length=255)
|
name = models.CharField(max_length=255)
|
||||||
color = models.CharField(max_length=32, null=True)
|
color = models.CharField(max_length=32, null=True)
|
||||||
|
|
|
@ -4,7 +4,7 @@ from django.core.exceptions import ValidationError
|
||||||
from django.db import IntegrityError, connection, models
|
from django.db import IntegrityError, connection, models
|
||||||
from django.db.models import F
|
from django.db.models import F
|
||||||
from django.db.models.constraints import BaseConstraint, UniqueConstraint
|
from django.db.models.constraints import BaseConstraint, UniqueConstraint
|
||||||
from django.db.models.functions import Abs, Lower, Upper
|
from django.db.models.functions import Abs, Lower, Sqrt, Upper
|
||||||
from django.db.transaction import atomic
|
from django.db.transaction import atomic
|
||||||
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
|
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||||
from django.test.utils import ignore_warnings
|
from django.test.utils import ignore_warnings
|
||||||
|
@ -13,6 +13,8 @@ from django.utils.deprecation import RemovedInDjango60Warning
|
||||||
from .models import (
|
from .models import (
|
||||||
ChildModel,
|
ChildModel,
|
||||||
ChildUniqueConstraintProduct,
|
ChildUniqueConstraintProduct,
|
||||||
|
GeneratedFieldStoredProduct,
|
||||||
|
GeneratedFieldVirtualProduct,
|
||||||
JSONFieldModel,
|
JSONFieldModel,
|
||||||
ModelWithDatabaseDefault,
|
ModelWithDatabaseDefault,
|
||||||
Product,
|
Product,
|
||||||
|
@ -384,6 +386,29 @@ class CheckConstraintTests(TestCase):
|
||||||
with self.assertRaisesMessage(ValidationError, msg):
|
with self.assertRaisesMessage(ValidationError, msg):
|
||||||
json_exact_constraint.validate(JSONFieldModel, JSONFieldModel(data=data))
|
json_exact_constraint.validate(JSONFieldModel, JSONFieldModel(data=data))
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_stored_generated_columns")
|
||||||
|
def test_validate_generated_field_stored(self):
|
||||||
|
self.assertGeneratedFieldIsValidated(model=GeneratedFieldStoredProduct)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_virtual_generated_columns")
|
||||||
|
def test_validate_generated_field_virtual(self):
|
||||||
|
self.assertGeneratedFieldIsValidated(model=GeneratedFieldVirtualProduct)
|
||||||
|
|
||||||
|
def assertGeneratedFieldIsValidated(self, model):
|
||||||
|
constraint = models.CheckConstraint(
|
||||||
|
condition=models.Q(rebate__range=(0, 100)), name="bounded_rebate"
|
||||||
|
)
|
||||||
|
constraint.validate(model, model(price=50, discounted_price=20))
|
||||||
|
|
||||||
|
invalid_product = model(price=1200, discounted_price=500)
|
||||||
|
msg = f"Constraint “{constraint.name}” is violated."
|
||||||
|
with self.assertRaisesMessage(ValidationError, msg):
|
||||||
|
constraint.validate(model, invalid_product)
|
||||||
|
|
||||||
|
# Excluding referenced or generated fields should skip validation.
|
||||||
|
constraint.validate(model, invalid_product, exclude={"price"})
|
||||||
|
constraint.validate(model, invalid_product, exclude={"rebate"})
|
||||||
|
|
||||||
def test_check_deprecation(self):
|
def test_check_deprecation(self):
|
||||||
msg = "CheckConstraint.check is deprecated in favor of `.condition`."
|
msg = "CheckConstraint.check is deprecated in favor of `.condition`."
|
||||||
condition = models.Q(foo="bar")
|
condition = models.Q(foo="bar")
|
||||||
|
@ -1062,6 +1087,90 @@ class UniqueConstraintTests(TestCase):
|
||||||
exclude={"name"},
|
exclude={"name"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_stored_generated_columns")
|
||||||
|
def test_validate_expression_generated_field_stored(self):
|
||||||
|
self.assertGeneratedFieldWithExpressionIsValidated(
|
||||||
|
model=GeneratedFieldStoredProduct
|
||||||
|
)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_virtual_generated_columns")
|
||||||
|
def test_validate_expression_generated_field_virtual(self):
|
||||||
|
self.assertGeneratedFieldWithExpressionIsValidated(
|
||||||
|
model=GeneratedFieldVirtualProduct
|
||||||
|
)
|
||||||
|
|
||||||
|
def assertGeneratedFieldWithExpressionIsValidated(self, model):
|
||||||
|
constraint = UniqueConstraint(Sqrt("rebate"), name="unique_rebate_sqrt")
|
||||||
|
model.objects.create(price=100, discounted_price=84)
|
||||||
|
|
||||||
|
valid_product = model(price=100, discounted_price=75)
|
||||||
|
constraint.validate(model, valid_product)
|
||||||
|
|
||||||
|
invalid_product = model(price=20, discounted_price=4)
|
||||||
|
with self.assertRaisesMessage(
|
||||||
|
ValidationError, f"Constraint “{constraint.name}” is violated."
|
||||||
|
):
|
||||||
|
constraint.validate(model, invalid_product)
|
||||||
|
|
||||||
|
# Excluding referenced or generated fields should skip validation.
|
||||||
|
constraint.validate(model, invalid_product, exclude={"rebate"})
|
||||||
|
constraint.validate(model, invalid_product, exclude={"price"})
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_stored_generated_columns")
|
||||||
|
def test_validate_fields_generated_field_stored(self):
|
||||||
|
self.assertGeneratedFieldWithFieldsIsValidated(
|
||||||
|
model=GeneratedFieldStoredProduct
|
||||||
|
)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_virtual_generated_columns")
|
||||||
|
def test_validate_fields_generated_field_virtual(self):
|
||||||
|
self.assertGeneratedFieldWithFieldsIsValidated(
|
||||||
|
model=GeneratedFieldVirtualProduct
|
||||||
|
)
|
||||||
|
|
||||||
|
def assertGeneratedFieldWithFieldsIsValidated(self, model):
|
||||||
|
constraint = models.UniqueConstraint(
|
||||||
|
fields=["lower_name"], name="lower_name_unique"
|
||||||
|
)
|
||||||
|
model.objects.create(name="Box")
|
||||||
|
constraint.validate(model, model(name="Case"))
|
||||||
|
|
||||||
|
invalid_product = model(name="BOX")
|
||||||
|
msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
|
||||||
|
with self.assertRaisesMessage(ValidationError, msg):
|
||||||
|
constraint.validate(model, invalid_product)
|
||||||
|
|
||||||
|
# Excluding referenced or generated fields should skip validation.
|
||||||
|
constraint.validate(model, invalid_product, exclude={"lower_name"})
|
||||||
|
constraint.validate(model, invalid_product, exclude={"name"})
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_stored_generated_columns")
|
||||||
|
def test_validate_fields_generated_field_stored_nulls_distinct(self):
|
||||||
|
self.assertGeneratedFieldNullsDistinctIsValidated(
|
||||||
|
model=GeneratedFieldStoredProduct
|
||||||
|
)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_virtual_generated_columns")
|
||||||
|
def test_validate_fields_generated_field_virtual_nulls_distinct(self):
|
||||||
|
self.assertGeneratedFieldNullsDistinctIsValidated(
|
||||||
|
model=GeneratedFieldVirtualProduct
|
||||||
|
)
|
||||||
|
|
||||||
|
def assertGeneratedFieldNullsDistinctIsValidated(self, model):
|
||||||
|
constraint = models.UniqueConstraint(
|
||||||
|
fields=["lower_name"],
|
||||||
|
name="lower_name_unique_nulls_distinct",
|
||||||
|
nulls_distinct=False,
|
||||||
|
)
|
||||||
|
model.objects.create(name=None)
|
||||||
|
valid_product = model(name="Box")
|
||||||
|
constraint.validate(model, valid_product)
|
||||||
|
|
||||||
|
invalid_product = model(name=None)
|
||||||
|
msg = str(invalid_product.unique_error_message(model, ["lower_name"]))
|
||||||
|
with self.assertRaisesMessage(ValidationError, msg):
|
||||||
|
constraint.validate(model, invalid_product)
|
||||||
|
|
||||||
@skipUnlessDBFeature("supports_table_check_constraints")
|
@skipUnlessDBFeature("supports_table_check_constraints")
|
||||||
def test_validate_nullable_textfield_with_isnull_true(self):
|
def test_validate_nullable_textfield_with_isnull_true(self):
|
||||||
is_null_constraint = models.UniqueConstraint(
|
is_null_constraint = models.UniqueConstraint(
|
||||||
|
|
|
@ -14,6 +14,7 @@ from django.db.models import (
|
||||||
F,
|
F,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Func,
|
Func,
|
||||||
|
GeneratedField,
|
||||||
IntegerField,
|
IntegerField,
|
||||||
Model,
|
Model,
|
||||||
Q,
|
Q,
|
||||||
|
@ -32,6 +33,7 @@ try:
|
||||||
from django.contrib.postgres.constraints import ExclusionConstraint
|
from django.contrib.postgres.constraints import ExclusionConstraint
|
||||||
from django.contrib.postgres.fields import (
|
from django.contrib.postgres.fields import (
|
||||||
DateTimeRangeField,
|
DateTimeRangeField,
|
||||||
|
IntegerRangeField,
|
||||||
RangeBoundary,
|
RangeBoundary,
|
||||||
RangeOperators,
|
RangeOperators,
|
||||||
)
|
)
|
||||||
|
@ -866,6 +868,38 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
|
||||||
constraint.validate(RangesModel, RangesModel(ints=(51, 60)))
|
constraint.validate(RangesModel, RangesModel(ints=(51, 60)))
|
||||||
constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"})
|
constraint.validate(RangesModel, RangesModel(ints=(10, 20)), exclude={"ints"})
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("supports_stored_generated_columns")
|
||||||
|
@isolate_apps("postgres_tests")
|
||||||
|
def test_validate_generated_field_range_adjacent(self):
|
||||||
|
class RangesModelGeneratedField(Model):
|
||||||
|
ints = IntegerRangeField(blank=True, null=True)
|
||||||
|
ints_generated = GeneratedField(
|
||||||
|
expression=F("ints"),
|
||||||
|
output_field=IntegerRangeField(null=True),
|
||||||
|
db_persist=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
editor.create_model(RangesModelGeneratedField)
|
||||||
|
|
||||||
|
constraint = ExclusionConstraint(
|
||||||
|
name="ints_adjacent",
|
||||||
|
expressions=[("ints_generated", RangeOperators.ADJACENT_TO)],
|
||||||
|
violation_error_code="custom_code",
|
||||||
|
violation_error_message="Custom error message.",
|
||||||
|
)
|
||||||
|
RangesModelGeneratedField.objects.create(ints=(20, 50))
|
||||||
|
|
||||||
|
range_obj = RangesModelGeneratedField(ints=(3, 20))
|
||||||
|
with self.assertRaisesMessage(ValidationError, "Custom error message."):
|
||||||
|
constraint.validate(RangesModelGeneratedField, range_obj)
|
||||||
|
|
||||||
|
# Excluding referenced or generated field should skip validation.
|
||||||
|
constraint.validate(RangesModelGeneratedField, range_obj, exclude={"ints"})
|
||||||
|
constraint.validate(
|
||||||
|
RangesModelGeneratedField, range_obj, exclude={"ints_generated"}
|
||||||
|
)
|
||||||
|
|
||||||
def test_validate_with_custom_code_and_condition(self):
|
def test_validate_with_custom_code_and_condition(self):
|
||||||
constraint = ExclusionConstraint(
|
constraint = ExclusionConstraint(
|
||||||
name="ints_adjacent",
|
name="ints_adjacent",
|
||||||
|
|
Loading…
Reference in New Issue