diff --git a/django/db/migrations/operations/base.py b/django/db/migrations/operations/base.py index 3fb1002c445..2448284a2bf 100644 --- a/django/db/migrations/operations/base.py +++ b/django/db/migrations/operations/base.py @@ -80,6 +80,16 @@ class Operation: """ return "%s: %s" % (self.__class__.__name__, self._constructor_args) + def model_to_key(self, model): + """ + Take either a model class or an 'app_label.ModelName' string and return + (app_label, model_name). + """ + if isinstance(model, str): + return tuple(model.lower().split('.', 1)) + else: + return model._meta.app_label, model._meta.model_name + def references_model(self, name, app_label=None): """ Return True if there is a chance this operation references the given diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index a41b3444e50..05e57da0004 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -26,10 +26,25 @@ class FieldOperation(Operation): return self.is_same_model_operation(operation) and self.name_lower == operation.name_lower def references_model(self, name, app_label=None): - return name.lower() == self.model_name_lower + name_lower = name.lower() + if name_lower == self.model_name_lower: + return True + field = getattr(self, 'field', None) + if field and field.remote_field: + remote_app_label, remote_model_name = self.model_to_key(field.remote_field.model) + if (remote_model_name == name_lower and app_label is None or + not remote_app_label or remote_app_label == app_label): + return True + through = getattr(field.remote_field, 'through', None) + if through and self.model_to_key(through) == (app_label, name_lower): + through_app_label, through_model_name = self.model_to_key(through) + if (through_model_name == name_lower and app_label is None or + not through_app_label or through_app_label == app_label): + return True + return False def references_field(self, model_name, name, app_label=None): - return self.references_model(model_name) and name.lower() == self.name_lower + return self.references_model(model_name, app_label) and name.lower() == self.name_lower def reduce(self, operation, in_between, app_label=None): return ( diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index bd1c66a01d4..24147751a4e 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -115,21 +115,11 @@ class CreateModel(ModelOperation): # Now go over all the models and check against them for model in models_to_check: model_app_label, model_name = self.model_to_key(model) - if model_name.lower() == name_lower: - if app_label is None or not model_app_label or model_app_label == app_label: - return True + if (model_name == name_lower and app_label is None or + not model_app_label or model_app_label == app_label): + return True return False - def model_to_key(self, model): - """ - Take either a model class or an "app_label.ModelName" string - and return (app_label, object_name). - """ - if isinstance(model, str): - return model.split(".", 1) - else: - return model._meta.app_label, model._meta.object_name - def reduce(self, operation, in_between, app_label=None): if (isinstance(operation, DeleteModel) and self.name_lower == operation.name_lower and @@ -157,18 +147,6 @@ class CreateModel(ModelOperation): ] elif isinstance(operation, FieldOperation) and self.name_lower == operation.model_name_lower: if isinstance(operation, AddField): - # Don't allow optimizations of FKs through models they reference - if hasattr(operation.field, "remote_field") and operation.field.remote_field: - for between in in_between: - # Check that it doesn't point to the model - app_label, object_name = self.model_to_key(operation.field.remote_field.model) - if between.references_model(object_name, app_label): - return False - # Check that it's not through the model - if getattr(operation.field.remote_field, "through", None): - app_label, object_name = self.model_to_key(operation.field.remote_field.through) - if between.references_model(object_name, app_label): - return False return [ CreateModel( self.name, diff --git a/tests/migrations/test_autodetector.py b/tests/migrations/test_autodetector.py index 5ad32e610cd..34c9dac9892 100644 --- a/tests/migrations/test_autodetector.py +++ b/tests/migrations/test_autodetector.py @@ -1150,10 +1150,9 @@ class AutodetectorTests(TestCase): changes = self.get_changes([], [self.author_with_publisher, self.publisher]) # Right number/type of migrations? self.assertNumberMigrations(changes, 'testapp', 1) - self.assertOperationTypes(changes, 'testapp', 0, ["CreateModel", "CreateModel", "AddField"]) - self.assertOperationAttributes(changes, "testapp", 0, 0, name="Author") - self.assertOperationAttributes(changes, "testapp", 0, 1, name="Publisher") - self.assertOperationAttributes(changes, "testapp", 0, 2, name="publisher") + self.assertOperationTypes(changes, 'testapp', 0, ["CreateModel", "CreateModel"]) + self.assertOperationAttributes(changes, "testapp", 0, 0, name="Publisher") + self.assertOperationAttributes(changes, "testapp", 0, 1, name="Author") self.assertMigrationDependencies(changes, 'testapp', 0, []) def test_circular_fk_dependency(self): @@ -1907,13 +1906,12 @@ class AutodetectorTests(TestCase): # Right number/type of migrations? self.assertNumberMigrations(changes, "testapp", 1) self.assertOperationTypes(changes, "testapp", 0, [ - "CreateModel", "CreateModel", "CreateModel", "AddField", "AddField" + 'CreateModel', 'CreateModel', 'CreateModel', 'AddField', ]) - self.assertOperationAttributes(changes, 'testapp', 0, 0, name="Author") - self.assertOperationAttributes(changes, 'testapp', 0, 1, name="Contract") - self.assertOperationAttributes(changes, 'testapp', 0, 2, name="Publisher") - self.assertOperationAttributes(changes, 'testapp', 0, 3, model_name='contract', name='publisher') - self.assertOperationAttributes(changes, 'testapp', 0, 4, model_name='author', name='publishers') + self.assertOperationAttributes(changes, 'testapp', 0, 0, name='Author') + self.assertOperationAttributes(changes, 'testapp', 0, 1, name='Publisher') + self.assertOperationAttributes(changes, 'testapp', 0, 2, name='Contract') + self.assertOperationAttributes(changes, 'testapp', 0, 3, model_name='author', name='publishers') def test_many_to_many_removed_before_through_model(self): """ diff --git a/tests/migrations/test_optimizer.py b/tests/migrations/test_optimizer.py index 7a1876a5084..6c1c5f5c6b6 100644 --- a/tests/migrations/test_optimizer.py +++ b/tests/migrations/test_optimizer.py @@ -300,16 +300,91 @@ class OptimizerTests(SimpleTestCase): ], ) - def test_create_model_add_field_not_through_fk(self): + def test_create_model_reordering(self): """ - AddField should NOT optimize into CreateModel if it's an FK to a model - that's between them. + AddField optimizes into CreateModel if it's a FK to a model that's + between them (and there's no FK in the other direction), by changing + the order of the CreateModel operations. + """ + self.assertOptimizesTo( + [ + migrations.CreateModel('Foo', [('name', models.CharField(max_length=255))]), + migrations.CreateModel('Link', [('url', models.TextField())]), + migrations.AddField('Foo', 'link', models.ForeignKey('migrations.Link', models.CASCADE)), + ], + [ + migrations.CreateModel('Link', [('url', models.TextField())]), + migrations.CreateModel('Foo', [ + ('name', models.CharField(max_length=255)), + ('link', models.ForeignKey('migrations.Link', models.CASCADE)) + ]), + ], + ) + + def test_create_model_reordering_circular_fk(self): + """ + CreateModel reordering behavior doesn't result in an infinite loop if + there are FKs in both directions. + """ + self.assertOptimizesTo( + [ + migrations.CreateModel('Bar', [('url', models.TextField())]), + migrations.CreateModel('Foo', [('name', models.CharField(max_length=255))]), + migrations.AddField('Bar', 'foo_fk', models.ForeignKey('migrations.Foo', models.CASCADE)), + migrations.AddField('Foo', 'bar_fk', models.ForeignKey('migrations.Bar', models.CASCADE)), + ], + [ + migrations.CreateModel('Foo', [('name', models.CharField(max_length=255))]), + migrations.CreateModel('Bar', [ + ('url', models.TextField()), + ('foo_fk', models.ForeignKey('migrations.Foo', models.CASCADE)), + ]), + migrations.AddField('Foo', 'bar_fk', models.ForeignKey('migrations.Foo', models.CASCADE)), + ], + ) + + def test_create_model_no_reordering_for_unrelated_fk(self): + """ + CreateModel order remains unchanged if the later AddField operation + isn't a FK between them. """ self.assertDoesNotOptimize( [ - migrations.CreateModel("Foo", [("name", models.CharField(max_length=255))]), - migrations.CreateModel("Link", [("url", models.TextField())]), - migrations.AddField("Foo", "link", models.ForeignKey("migrations.Link", models.CASCADE)), + migrations.CreateModel('Foo', [('name', models.CharField(max_length=255))]), + migrations.CreateModel('Link', [('url', models.TextField())]), + migrations.AddField('Other', 'link', models.ForeignKey('migrations.Link', models.CASCADE)), + ], + ) + + def test_create_model_no_reordering_of_inherited_model(self): + """ + A CreateModel that inherits from another isn't reordered to avoid + moving it earlier than its parent CreateModel operation. + """ + self.assertOptimizesTo( + [ + migrations.CreateModel('Other', [('foo', models.CharField(max_length=255))]), + migrations.CreateModel('ParentModel', [('bar', models.CharField(max_length=255))]), + migrations.CreateModel( + 'ChildModel', + [('baz', models.CharField(max_length=255))], + bases=('migrations.parentmodel',), + ), + migrations.AddField('Other', 'fk', models.ForeignKey('migrations.ChildModel', models.CASCADE)), + ], + [ + migrations.CreateModel('ParentModel', [('bar', models.CharField(max_length=255))]), + migrations.CreateModel( + 'ChildModel', + [('baz', models.CharField(max_length=255))], + bases=('migrations.parentmodel',), + ), + migrations.CreateModel( + 'Other', [ + ('foo', models.CharField(max_length=255)), + ('fk', models.ForeignKey('migrations.ChildModel', models.CASCADE)), + ] + ), ], )