diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index 34f1d2b64d..9faeb81318 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -38,6 +38,33 @@ class FieldOperation(Operation): # a field referencing the specified model. return True + def references_field(self, model_name, name, app_label=None): + model_name_lower = model_name.lower() + # Check if this operation locally references the field. + if model_name_lower == self.model_name_lower: + if name == self.name: + return True + elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields: + return True + # Check if this operation remotely references the field. + if self.field: + model_tuple = ModelTuple(app_label, model_name_lower) + remote_field = self.field.remote_field + if remote_field: + if (ModelTuple.from_model(remote_field.model) == model_tuple and + (not hasattr(self.field, 'to_fields') or + name in self.field.to_fields or None in self.field.to_fields)): + return True + through = getattr(remote_field, 'through', None) + if (through and ModelTuple.from_model(through) == model_tuple and + (getattr(remote_field, 'through_fields', None) is None or + name in remote_field.through_fields)): + return True + return False + # Refuse the temptation to guess. This operation could be performed on + # a field referencing the specified model. + return True + def reduce(self, operation, in_between, app_label=None): return ( super().reduce(operation, in_between, app_label=app_label) or diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index b1581042f7..5dfe4acb4d 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -4,6 +4,7 @@ from django.core.exceptions import FieldDoesNotExist from django.db import connection, migrations, models, transaction from django.db.migrations.migration import Migration from django.db.migrations.operations import CreateModel +from django.db.migrations.operations.fields import FieldOperation from django.db.migrations.state import ModelState, ProjectState from django.db.models.fields import NOT_PROVIDED from django.db.transaction import atomic @@ -2753,3 +2754,60 @@ class TestCreateModel(SimpleTestCase): def test_references_model_mixin(self): CreateModel('name', [], bases=(Mixin, models.Model)).references_model('other_model') + + +class FieldOperationTests(SimpleTestCase): + def test_references_model(self): + operation = FieldOperation('MoDel', 'field') + # When missing a field declaration always assume it's referencing. + self.assertIs(operation.references_model('Whatever'), True) + operation.field = models.ForeignKey('Other', models.CASCADE) + # Model name match. + self.assertIs(operation.references_model('mOdEl'), True) + # Referenced field. + self.assertIs(operation.references_model('oTher'), True) + # Doesn't reference. + self.assertIs(operation.references_model('Whatever'), False) + + def test_references_field_missing_field(self): + operation = FieldOperation('MoDel', 'field') + self.assertIs(operation.references_field('Whatever', 'missing'), True) + + def test_references_field_by_name(self): + operation = FieldOperation('MoDel', 'field', models.BooleanField(default=False)) + self.assertIs(operation.references_field('model', 'field'), True) + + def test_references_field_by_remote_field_model(self): + operation = FieldOperation('Model', 'field', models.ForeignKey('Other', models.CASCADE)) + self.assertIs(operation.references_field('Other', 'whatever'), True) + self.assertIs(operation.references_field('Missing', 'whatever'), False) + + def test_references_field_by_from_fields(self): + operation = FieldOperation( + 'Model', 'field', models.fields.related.ForeignObject('Other', models.CASCADE, ['from'], ['to']) + ) + self.assertIs(operation.references_field('Model', 'from'), True) + self.assertIs(operation.references_field('Model', 'to'), False) + self.assertIs(operation.references_field('Other', 'from'), False) + self.assertIs(operation.references_field('Model', 'to'), False) + + def test_references_field_by_to_fields(self): + operation = FieldOperation('Model', 'field', models.ForeignKey('Other', models.CASCADE, to_field='field')) + self.assertIs(operation.references_field('Other', 'field'), True) + self.assertIs(operation.references_field('Other', 'whatever'), False) + self.assertIs(operation.references_field('Missing', 'whatever'), False) + + def test_references_field_by_through(self): + operation = FieldOperation('Model', 'field', models.ManyToManyField('Other', through='Through')) + self.assertIs(operation.references_field('Other', 'whatever'), True) + self.assertIs(operation.references_field('Through', 'whatever'), True) + self.assertIs(operation.references_field('Missing', 'whatever'), False) + + def test_reference_field_by_through_fields(self): + operation = FieldOperation( + 'Model', 'field', models.ManyToManyField('Other', through='Through', through_fields=('first', 'second')) + ) + self.assertIs(operation.references_field('Other', 'whatever'), True) + self.assertIs(operation.references_field('Through', 'whatever'), False) + self.assertIs(operation.references_field('Through', 'first'), True) + self.assertIs(operation.references_field('Through', 'second'), True)