From b7b7df5fbcf44e6598396905136cab5a19e9faff Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Fri, 22 May 2020 15:27:43 +0200 Subject: [PATCH] Fixed #31530 -- Added system checks for invalid model field names in CheckConstraint.check and UniqueConstraint.condition. --- django/db/models/base.py | 66 +++++- docs/ref/checks.txt | 1 + tests/invalid_models_tests/test_models.py | 232 ++++++++++++++++++++++ 3 files changed, 296 insertions(+), 3 deletions(-) diff --git a/django/db/models/base.py b/django/db/models/base.py index 422b9e7e2d..df5b0ca137 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -28,7 +28,7 @@ from django.db.models.fields.related import ( from django.db.models.functions import Coalesce from django.db.models.manager import Manager from django.db.models.options import Options -from django.db.models.query import Q +from django.db.models.query import F, Q from django.db.models.signals import ( class_prepared, post_init, post_save, pre_init, pre_save, ) @@ -1878,6 +1878,22 @@ class Model(metaclass=ModelBase): return errors + @classmethod + def _get_expr_references(cls, expr): + if isinstance(expr, Q): + for child in expr.children: + if isinstance(child, tuple): + lookup, value = child + yield tuple(lookup.split(LOOKUP_SEP)) + yield from cls._get_expr_references(value) + else: + yield from cls._get_expr_references(child) + elif isinstance(expr, F): + yield tuple(expr.name.split(LOOKUP_SEP)) + elif hasattr(expr, 'get_source_expressions'): + for src_expr in expr.get_source_expressions(): + yield from cls._get_expr_references(src_expr) + @classmethod def _check_constraints(cls, databases): errors = [] @@ -1960,10 +1976,54 @@ class Model(metaclass=ModelBase): id='models.W039', ) ) - fields = chain.from_iterable( + fields = set(chain.from_iterable( (*constraint.fields, *constraint.include) for constraint in cls._meta.constraints if isinstance(constraint, UniqueConstraint) - ) + )) + references = set() + for constraint in cls._meta.constraints: + if isinstance(constraint, UniqueConstraint): + if ( + connection.features.supports_partial_indexes or + 'supports_partial_indexes' not in cls._meta.required_db_features + ) and isinstance(constraint.condition, Q): + references.update(cls._get_expr_references(constraint.condition)) + elif isinstance(constraint, CheckConstraint): + if ( + connection.features.supports_table_check_constraints or + 'supports_table_check_constraints' not in cls._meta.required_db_features + ) and isinstance(constraint.check, Q): + references.update(cls._get_expr_references(constraint.check)) + for field_name, *lookups in references: + # pk is an alias that won't be found by opts.get_field. + if field_name != 'pk': + fields.add(field_name) + if not lookups: + # If it has no lookups it cannot result in a JOIN. + continue + try: + if field_name == 'pk': + field = cls._meta.pk + else: + field = cls._meta.get_field(field_name) + if not field.is_relation or field.many_to_many or field.one_to_many: + continue + except FieldDoesNotExist: + continue + # JOIN must happen at the first lookup. + first_lookup = lookups[0] + if ( + field.get_transform(first_lookup) is None and + field.get_lookup(first_lookup) is None + ): + errors.append( + checks.Error( + "'constraints' refers to the joined field '%s'." + % LOOKUP_SEP.join([field_name] + lookups), + obj=cls, + id='models.E041', + ) + ) errors.extend(cls._check_local_fields(fields, 'constraints')) return errors diff --git a/docs/ref/checks.txt b/docs/ref/checks.txt index 792c3936f6..97d2c1d097 100644 --- a/docs/ref/checks.txt +++ b/docs/ref/checks.txt @@ -364,6 +364,7 @@ Models non-key columns. * **models.W040**: ```` does not support indexes with non-key columns. +* **models.E041**: ``constraints`` refers to the joined field ````. Security -------- diff --git a/tests/invalid_models_tests/test_models.py b/tests/invalid_models_tests/test_models.py index 7875705860..d9993c00cd 100644 --- a/tests/invalid_models_tests/test_models.py +++ b/tests/invalid_models_tests/test_models.py @@ -1534,6 +1534,192 @@ class ConstraintsTests(TestCase): constraints = [models.CheckConstraint(check=models.Q(age__gte=18), name='is_adult')] self.assertEqual(Model.check(databases=self.databases), []) + def test_check_constraint_pointing_to_missing_field(self): + class Model(models.Model): + class Meta: + required_db_features = {'supports_table_check_constraints'} + constraints = [ + models.CheckConstraint( + name='name', check=models.Q(missing_field=2), + ), + ] + + self.assertEqual(Model.check(databases=self.databases), [ + Error( + "'constraints' refers to the nonexistent field " + "'missing_field'.", + obj=Model, + id='models.E012', + ), + ] if connection.features.supports_table_check_constraints else []) + + @skipUnlessDBFeature('supports_table_check_constraints') + def test_check_constraint_pointing_to_reverse_fk(self): + class Model(models.Model): + parent = models.ForeignKey('self', models.CASCADE, related_name='parents') + + class Meta: + constraints = [ + models.CheckConstraint(name='name', check=models.Q(parents=3)), + ] + + self.assertEqual(Model.check(databases=self.databases), [ + Error( + "'constraints' refers to the nonexistent field 'parents'.", + obj=Model, + id='models.E012', + ), + ]) + + @skipUnlessDBFeature('supports_table_check_constraints') + def test_check_constraint_pointing_to_m2m_field(self): + class Model(models.Model): + m2m = models.ManyToManyField('self') + + class Meta: + constraints = [ + models.CheckConstraint(name='name', check=models.Q(m2m=2)), + ] + + self.assertEqual(Model.check(databases=self.databases), [ + Error( + "'constraints' refers to a ManyToManyField 'm2m', but " + "ManyToManyFields are not permitted in 'constraints'.", + obj=Model, + id='models.E013', + ), + ]) + + @skipUnlessDBFeature('supports_table_check_constraints') + def test_check_constraint_pointing_to_fk(self): + class Target(models.Model): + pass + + class Model(models.Model): + fk_1 = models.ForeignKey(Target, models.CASCADE, related_name='target_1') + fk_2 = models.ForeignKey(Target, models.CASCADE, related_name='target_2') + + class Meta: + constraints = [ + models.CheckConstraint( + name='name', + check=models.Q(fk_1_id=2) | models.Q(fk_2=2), + ), + ] + + self.assertEqual(Model.check(databases=self.databases), []) + + @skipUnlessDBFeature('supports_table_check_constraints') + def test_check_constraint_pointing_to_pk(self): + class Model(models.Model): + age = models.SmallIntegerField() + + class Meta: + constraints = [ + models.CheckConstraint( + name='name', + check=models.Q(pk__gt=5) & models.Q(age__gt=models.F('pk')), + ), + ] + + self.assertEqual(Model.check(databases=self.databases), []) + + @skipUnlessDBFeature('supports_table_check_constraints') + def test_check_constraint_pointing_to_non_local_field(self): + class Parent(models.Model): + field1 = models.IntegerField() + + class Child(Parent): + pass + + class Meta: + constraints = [ + models.CheckConstraint(name='name', check=models.Q(field1=1)), + ] + + self.assertEqual(Child.check(databases=self.databases), [ + Error( + "'constraints' refers to field 'field1' which is not local to " + "model 'Child'.", + hint='This issue may be caused by multi-table inheritance.', + obj=Child, + id='models.E016', + ), + ]) + + @skipUnlessDBFeature('supports_table_check_constraints') + def test_check_constraint_pointing_to_joined_fields(self): + class Model(models.Model): + name = models.CharField(max_length=10) + field1 = models.PositiveSmallIntegerField() + field2 = models.PositiveSmallIntegerField() + field3 = models.PositiveSmallIntegerField() + parent = models.ForeignKey('self', models.CASCADE) + + class Meta: + constraints = [ + models.CheckConstraint( + name='name1', check=models.Q( + field1__lt=models.F('parent__field1') + models.F('parent__field2') + ) + ), + models.CheckConstraint( + name='name2', check=models.Q(name=Lower('parent__name')) + ), + models.CheckConstraint( + name='name3', check=models.Q(parent__field3=models.F('field1')) + ), + ] + + joined_fields = ['parent__field1', 'parent__field2', 'parent__field3', 'parent__name'] + errors = Model.check(databases=self.databases) + expected_errors = [ + Error( + "'constraints' refers to the joined field '%s'." % field_name, + obj=Model, + id='models.E041', + ) for field_name in joined_fields + ] + self.assertCountEqual(errors, expected_errors) + + @skipUnlessDBFeature('supports_table_check_constraints') + def test_check_constraint_pointing_to_joined_fields_complex_check(self): + class Model(models.Model): + name = models.PositiveSmallIntegerField() + field1 = models.PositiveSmallIntegerField() + field2 = models.PositiveSmallIntegerField() + parent = models.ForeignKey('self', models.CASCADE) + + class Meta: + constraints = [ + models.CheckConstraint( + name='name', + check=models.Q( + ( + models.Q(name='test') & + models.Q(field1__lt=models.F('parent__field1')) + ) | + ( + models.Q(name__startswith=Lower('parent__name')) & + models.Q(field1__gte=( + models.F('parent__field1') + models.F('parent__field2') + )) + ) + ) | (models.Q(name='test1')) + ), + ] + + joined_fields = ['parent__field1', 'parent__field2', 'parent__name'] + errors = Model.check(databases=self.databases) + expected_errors = [ + Error( + "'constraints' refers to the joined field '%s'." % field_name, + obj=Model, + id='models.E041', + ) for field_name in joined_fields + ] + self.assertCountEqual(errors, expected_errors) + def test_unique_constraint_with_condition(self): class Model(models.Model): age = models.IntegerField() @@ -1578,6 +1764,52 @@ class ConstraintsTests(TestCase): self.assertEqual(Model.check(databases=self.databases), []) + def test_unique_constraint_condition_pointing_to_missing_field(self): + class Model(models.Model): + age = models.SmallIntegerField() + + class Meta: + required_db_features = {'supports_partial_indexes'} + constraints = [ + models.UniqueConstraint( + name='name', + fields=['age'], + condition=models.Q(missing_field=2), + ), + ] + + self.assertEqual(Model.check(databases=self.databases), [ + Error( + "'constraints' refers to the nonexistent field " + "'missing_field'.", + obj=Model, + id='models.E012', + ), + ] if connection.features.supports_partial_indexes else []) + + def test_unique_constraint_condition_pointing_to_joined_fields(self): + class Model(models.Model): + age = models.SmallIntegerField() + parent = models.ForeignKey('self', models.CASCADE) + + class Meta: + required_db_features = {'supports_partial_indexes'} + constraints = [ + models.UniqueConstraint( + name='name', + fields=['age'], + condition=models.Q(parent__age__lt=2), + ), + ] + + self.assertEqual(Model.check(databases=self.databases), [ + Error( + "'constraints' refers to the joined field 'parent__age__lt'.", + obj=Model, + id='models.E041', + ) + ] if connection.features.supports_partial_indexes else []) + def test_deferrable_unique_constraint(self): class Model(models.Model): age = models.IntegerField()