Fixed #31530 -- Added system checks for invalid model field names in CheckConstraint.check and UniqueConstraint.condition.
This commit is contained in:
parent
659a73bc0a
commit
b7b7df5fbc
|
@ -28,7 +28,7 @@ from django.db.models.fields.related import (
|
||||||
from django.db.models.functions import Coalesce
|
from django.db.models.functions import Coalesce
|
||||||
from django.db.models.manager import Manager
|
from django.db.models.manager import Manager
|
||||||
from django.db.models.options import Options
|
from django.db.models.options import Options
|
||||||
from django.db.models.query import Q
|
from django.db.models.query import F, Q
|
||||||
from django.db.models.signals import (
|
from django.db.models.signals import (
|
||||||
class_prepared, post_init, post_save, pre_init, pre_save,
|
class_prepared, post_init, post_save, pre_init, pre_save,
|
||||||
)
|
)
|
||||||
|
@ -1878,6 +1878,22 @@ class Model(metaclass=ModelBase):
|
||||||
|
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_expr_references(cls, expr):
|
||||||
|
if isinstance(expr, Q):
|
||||||
|
for child in expr.children:
|
||||||
|
if isinstance(child, tuple):
|
||||||
|
lookup, value = child
|
||||||
|
yield tuple(lookup.split(LOOKUP_SEP))
|
||||||
|
yield from cls._get_expr_references(value)
|
||||||
|
else:
|
||||||
|
yield from cls._get_expr_references(child)
|
||||||
|
elif isinstance(expr, F):
|
||||||
|
yield tuple(expr.name.split(LOOKUP_SEP))
|
||||||
|
elif hasattr(expr, 'get_source_expressions'):
|
||||||
|
for src_expr in expr.get_source_expressions():
|
||||||
|
yield from cls._get_expr_references(src_expr)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_constraints(cls, databases):
|
def _check_constraints(cls, databases):
|
||||||
errors = []
|
errors = []
|
||||||
|
@ -1960,10 +1976,54 @@ class Model(metaclass=ModelBase):
|
||||||
id='models.W039',
|
id='models.W039',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
fields = chain.from_iterable(
|
fields = set(chain.from_iterable(
|
||||||
(*constraint.fields, *constraint.include)
|
(*constraint.fields, *constraint.include)
|
||||||
for constraint in cls._meta.constraints if isinstance(constraint, UniqueConstraint)
|
for constraint in cls._meta.constraints if isinstance(constraint, UniqueConstraint)
|
||||||
)
|
))
|
||||||
|
references = set()
|
||||||
|
for constraint in cls._meta.constraints:
|
||||||
|
if isinstance(constraint, UniqueConstraint):
|
||||||
|
if (
|
||||||
|
connection.features.supports_partial_indexes or
|
||||||
|
'supports_partial_indexes' not in cls._meta.required_db_features
|
||||||
|
) and isinstance(constraint.condition, Q):
|
||||||
|
references.update(cls._get_expr_references(constraint.condition))
|
||||||
|
elif isinstance(constraint, CheckConstraint):
|
||||||
|
if (
|
||||||
|
connection.features.supports_table_check_constraints or
|
||||||
|
'supports_table_check_constraints' not in cls._meta.required_db_features
|
||||||
|
) and isinstance(constraint.check, Q):
|
||||||
|
references.update(cls._get_expr_references(constraint.check))
|
||||||
|
for field_name, *lookups in references:
|
||||||
|
# pk is an alias that won't be found by opts.get_field.
|
||||||
|
if field_name != 'pk':
|
||||||
|
fields.add(field_name)
|
||||||
|
if not lookups:
|
||||||
|
# If it has no lookups it cannot result in a JOIN.
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if field_name == 'pk':
|
||||||
|
field = cls._meta.pk
|
||||||
|
else:
|
||||||
|
field = cls._meta.get_field(field_name)
|
||||||
|
if not field.is_relation or field.many_to_many or field.one_to_many:
|
||||||
|
continue
|
||||||
|
except FieldDoesNotExist:
|
||||||
|
continue
|
||||||
|
# JOIN must happen at the first lookup.
|
||||||
|
first_lookup = lookups[0]
|
||||||
|
if (
|
||||||
|
field.get_transform(first_lookup) is None and
|
||||||
|
field.get_lookup(first_lookup) is None
|
||||||
|
):
|
||||||
|
errors.append(
|
||||||
|
checks.Error(
|
||||||
|
"'constraints' refers to the joined field '%s'."
|
||||||
|
% LOOKUP_SEP.join([field_name] + lookups),
|
||||||
|
obj=cls,
|
||||||
|
id='models.E041',
|
||||||
|
)
|
||||||
|
)
|
||||||
errors.extend(cls._check_local_fields(fields, 'constraints'))
|
errors.extend(cls._check_local_fields(fields, 'constraints'))
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
|
|
@ -364,6 +364,7 @@ Models
|
||||||
non-key columns.
|
non-key columns.
|
||||||
* **models.W040**: ``<database>`` does not support indexes with non-key
|
* **models.W040**: ``<database>`` does not support indexes with non-key
|
||||||
columns.
|
columns.
|
||||||
|
* **models.E041**: ``constraints`` refers to the joined field ``<field name>``.
|
||||||
|
|
||||||
Security
|
Security
|
||||||
--------
|
--------
|
||||||
|
|
|
@ -1534,6 +1534,192 @@ class ConstraintsTests(TestCase):
|
||||||
constraints = [models.CheckConstraint(check=models.Q(age__gte=18), name='is_adult')]
|
constraints = [models.CheckConstraint(check=models.Q(age__gte=18), name='is_adult')]
|
||||||
self.assertEqual(Model.check(databases=self.databases), [])
|
self.assertEqual(Model.check(databases=self.databases), [])
|
||||||
|
|
||||||
|
def test_check_constraint_pointing_to_missing_field(self):
|
||||||
|
class Model(models.Model):
|
||||||
|
class Meta:
|
||||||
|
required_db_features = {'supports_table_check_constraints'}
|
||||||
|
constraints = [
|
||||||
|
models.CheckConstraint(
|
||||||
|
name='name', check=models.Q(missing_field=2),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(Model.check(databases=self.databases), [
|
||||||
|
Error(
|
||||||
|
"'constraints' refers to the nonexistent field "
|
||||||
|
"'missing_field'.",
|
||||||
|
obj=Model,
|
||||||
|
id='models.E012',
|
||||||
|
),
|
||||||
|
] if connection.features.supports_table_check_constraints else [])
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||||
|
def test_check_constraint_pointing_to_reverse_fk(self):
|
||||||
|
class Model(models.Model):
|
||||||
|
parent = models.ForeignKey('self', models.CASCADE, related_name='parents')
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [
|
||||||
|
models.CheckConstraint(name='name', check=models.Q(parents=3)),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(Model.check(databases=self.databases), [
|
||||||
|
Error(
|
||||||
|
"'constraints' refers to the nonexistent field 'parents'.",
|
||||||
|
obj=Model,
|
||||||
|
id='models.E012',
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||||
|
def test_check_constraint_pointing_to_m2m_field(self):
|
||||||
|
class Model(models.Model):
|
||||||
|
m2m = models.ManyToManyField('self')
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [
|
||||||
|
models.CheckConstraint(name='name', check=models.Q(m2m=2)),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(Model.check(databases=self.databases), [
|
||||||
|
Error(
|
||||||
|
"'constraints' refers to a ManyToManyField 'm2m', but "
|
||||||
|
"ManyToManyFields are not permitted in 'constraints'.",
|
||||||
|
obj=Model,
|
||||||
|
id='models.E013',
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||||
|
def test_check_constraint_pointing_to_fk(self):
|
||||||
|
class Target(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Model(models.Model):
|
||||||
|
fk_1 = models.ForeignKey(Target, models.CASCADE, related_name='target_1')
|
||||||
|
fk_2 = models.ForeignKey(Target, models.CASCADE, related_name='target_2')
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [
|
||||||
|
models.CheckConstraint(
|
||||||
|
name='name',
|
||||||
|
check=models.Q(fk_1_id=2) | models.Q(fk_2=2),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(Model.check(databases=self.databases), [])
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||||
|
def test_check_constraint_pointing_to_pk(self):
|
||||||
|
class Model(models.Model):
|
||||||
|
age = models.SmallIntegerField()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [
|
||||||
|
models.CheckConstraint(
|
||||||
|
name='name',
|
||||||
|
check=models.Q(pk__gt=5) & models.Q(age__gt=models.F('pk')),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(Model.check(databases=self.databases), [])
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||||
|
def test_check_constraint_pointing_to_non_local_field(self):
|
||||||
|
class Parent(models.Model):
|
||||||
|
field1 = models.IntegerField()
|
||||||
|
|
||||||
|
class Child(Parent):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [
|
||||||
|
models.CheckConstraint(name='name', check=models.Q(field1=1)),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(Child.check(databases=self.databases), [
|
||||||
|
Error(
|
||||||
|
"'constraints' refers to field 'field1' which is not local to "
|
||||||
|
"model 'Child'.",
|
||||||
|
hint='This issue may be caused by multi-table inheritance.',
|
||||||
|
obj=Child,
|
||||||
|
id='models.E016',
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||||
|
def test_check_constraint_pointing_to_joined_fields(self):
|
||||||
|
class Model(models.Model):
|
||||||
|
name = models.CharField(max_length=10)
|
||||||
|
field1 = models.PositiveSmallIntegerField()
|
||||||
|
field2 = models.PositiveSmallIntegerField()
|
||||||
|
field3 = models.PositiveSmallIntegerField()
|
||||||
|
parent = models.ForeignKey('self', models.CASCADE)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [
|
||||||
|
models.CheckConstraint(
|
||||||
|
name='name1', check=models.Q(
|
||||||
|
field1__lt=models.F('parent__field1') + models.F('parent__field2')
|
||||||
|
)
|
||||||
|
),
|
||||||
|
models.CheckConstraint(
|
||||||
|
name='name2', check=models.Q(name=Lower('parent__name'))
|
||||||
|
),
|
||||||
|
models.CheckConstraint(
|
||||||
|
name='name3', check=models.Q(parent__field3=models.F('field1'))
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
joined_fields = ['parent__field1', 'parent__field2', 'parent__field3', 'parent__name']
|
||||||
|
errors = Model.check(databases=self.databases)
|
||||||
|
expected_errors = [
|
||||||
|
Error(
|
||||||
|
"'constraints' refers to the joined field '%s'." % field_name,
|
||||||
|
obj=Model,
|
||||||
|
id='models.E041',
|
||||||
|
) for field_name in joined_fields
|
||||||
|
]
|
||||||
|
self.assertCountEqual(errors, expected_errors)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('supports_table_check_constraints')
|
||||||
|
def test_check_constraint_pointing_to_joined_fields_complex_check(self):
|
||||||
|
class Model(models.Model):
|
||||||
|
name = models.PositiveSmallIntegerField()
|
||||||
|
field1 = models.PositiveSmallIntegerField()
|
||||||
|
field2 = models.PositiveSmallIntegerField()
|
||||||
|
parent = models.ForeignKey('self', models.CASCADE)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [
|
||||||
|
models.CheckConstraint(
|
||||||
|
name='name',
|
||||||
|
check=models.Q(
|
||||||
|
(
|
||||||
|
models.Q(name='test') &
|
||||||
|
models.Q(field1__lt=models.F('parent__field1'))
|
||||||
|
) |
|
||||||
|
(
|
||||||
|
models.Q(name__startswith=Lower('parent__name')) &
|
||||||
|
models.Q(field1__gte=(
|
||||||
|
models.F('parent__field1') + models.F('parent__field2')
|
||||||
|
))
|
||||||
|
)
|
||||||
|
) | (models.Q(name='test1'))
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
joined_fields = ['parent__field1', 'parent__field2', 'parent__name']
|
||||||
|
errors = Model.check(databases=self.databases)
|
||||||
|
expected_errors = [
|
||||||
|
Error(
|
||||||
|
"'constraints' refers to the joined field '%s'." % field_name,
|
||||||
|
obj=Model,
|
||||||
|
id='models.E041',
|
||||||
|
) for field_name in joined_fields
|
||||||
|
]
|
||||||
|
self.assertCountEqual(errors, expected_errors)
|
||||||
|
|
||||||
def test_unique_constraint_with_condition(self):
|
def test_unique_constraint_with_condition(self):
|
||||||
class Model(models.Model):
|
class Model(models.Model):
|
||||||
age = models.IntegerField()
|
age = models.IntegerField()
|
||||||
|
@ -1578,6 +1764,52 @@ class ConstraintsTests(TestCase):
|
||||||
|
|
||||||
self.assertEqual(Model.check(databases=self.databases), [])
|
self.assertEqual(Model.check(databases=self.databases), [])
|
||||||
|
|
||||||
|
def test_unique_constraint_condition_pointing_to_missing_field(self):
|
||||||
|
class Model(models.Model):
|
||||||
|
age = models.SmallIntegerField()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
required_db_features = {'supports_partial_indexes'}
|
||||||
|
constraints = [
|
||||||
|
models.UniqueConstraint(
|
||||||
|
name='name',
|
||||||
|
fields=['age'],
|
||||||
|
condition=models.Q(missing_field=2),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(Model.check(databases=self.databases), [
|
||||||
|
Error(
|
||||||
|
"'constraints' refers to the nonexistent field "
|
||||||
|
"'missing_field'.",
|
||||||
|
obj=Model,
|
||||||
|
id='models.E012',
|
||||||
|
),
|
||||||
|
] if connection.features.supports_partial_indexes else [])
|
||||||
|
|
||||||
|
def test_unique_constraint_condition_pointing_to_joined_fields(self):
|
||||||
|
class Model(models.Model):
|
||||||
|
age = models.SmallIntegerField()
|
||||||
|
parent = models.ForeignKey('self', models.CASCADE)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
required_db_features = {'supports_partial_indexes'}
|
||||||
|
constraints = [
|
||||||
|
models.UniqueConstraint(
|
||||||
|
name='name',
|
||||||
|
fields=['age'],
|
||||||
|
condition=models.Q(parent__age__lt=2),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(Model.check(databases=self.databases), [
|
||||||
|
Error(
|
||||||
|
"'constraints' refers to the joined field 'parent__age__lt'.",
|
||||||
|
obj=Model,
|
||||||
|
id='models.E041',
|
||||||
|
)
|
||||||
|
] if connection.features.supports_partial_indexes else [])
|
||||||
|
|
||||||
def test_deferrable_unique_constraint(self):
|
def test_deferrable_unique_constraint(self):
|
||||||
class Model(models.Model):
|
class Model(models.Model):
|
||||||
age = models.IntegerField()
|
age = models.IntegerField()
|
||||||
|
|
Loading…
Reference in New Issue