From 1aa3e09c2043c88a760e8b73fb95dc8f1ffef50e Mon Sep 17 00:00:00 2001 From: Claude Paroz Date: Wed, 5 Nov 2014 20:53:39 +0100 Subject: [PATCH] Fixed #23745 -- Reused states as much as possible in migrations Thanks Tim Graham and Markus Holtermann for the reviews. --- django/db/migrations/executor.py | 24 ++++---- django/db/migrations/migration.py | 19 +++--- django/db/migrations/operations/fields.py | 4 ++ django/db/migrations/operations/models.py | 17 ++++-- django/db/migrations/state.py | 72 +++++++++++++++++++++-- tests/migrations/test_autodetector.py | 2 +- tests/migrations/test_executor.py | 2 +- tests/migrations/test_operations.py | 5 +- tests/migrations/test_state.py | 54 ++++++++--------- 9 files changed, 137 insertions(+), 62 deletions(-) diff --git a/django/db/migrations/executor.py b/django/db/migrations/executor.py index 2e42e7c6d58..a7989c4107d 100644 --- a/django/db/migrations/executor.py +++ b/django/db/migrations/executor.py @@ -98,13 +98,13 @@ class MigrationExecutor(object): self.progress_callback("apply_start", migration, fake) if not fake: # Test to see if this is an already-applied initial migration - if self.detect_soft_applied(state, migration): + applied, state = self.detect_soft_applied(state, migration) + if applied: fake = True else: # Alright, do it normally with self.connection.schema_editor() as schema_editor: - project_state = self.loader.project_state((migration.app_label, migration.name), at_end=False) - migration.apply(project_state, schema_editor) + state = migration.apply(state, schema_editor) # For replacement migrations, record individual statuses if migration.replaces: for app_label, name in migration.replaces: @@ -124,8 +124,7 @@ class MigrationExecutor(object): self.progress_callback("unapply_start", migration, fake) if not fake: with self.connection.schema_editor() as schema_editor: - project_state = self.loader.project_state((migration.app_label, migration.name), at_end=False) - migration.unapply(project_state, schema_editor) + state = migration.unapply(state, schema_editor) # For replacement migrations, record individual statuses if migration.replaces: for app_label, name in migration.replaces: @@ -143,12 +142,15 @@ class MigrationExecutor(object): tables it would create exist. This is intended only for use on initial migrations (as it only looks for CreateModel). """ - project_state = self.loader.project_state((migration.app_label, migration.name), at_end=True) - apps = project_state.apps - found_create_migration = False # Bail if the migration isn't the first one in its app if [name for app, name in migration.dependencies if app == migration.app_label]: - return False + return False, project_state + if project_state is None: + after_state = self.loader.project_state((migration.app_label, migration.name), at_end=True) + else: + after_state = migration.mutate_state(project_state) + apps = after_state.apps + found_create_migration = False # Make sure all create model are done for operation in migration.operations: if isinstance(operation, migrations.CreateModel): @@ -158,8 +160,8 @@ class MigrationExecutor(object): # main app cache, as it's not a direct dependency. model = global_apps.get_model(model._meta.swapped) if model._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()): - return False + return False, project_state found_create_migration = True # If we get this far and we found at least one CreateModel migration, # the migration is considered implicitly applied. - return found_create_migration + return found_create_migration, after_state diff --git a/django/db/migrations/migration.py b/django/db/migrations/migration.py index ff9d5264104..168c43c5652 100644 --- a/django/db/migrations/migration.py +++ b/django/db/migrations/migration.py @@ -97,19 +97,17 @@ class Migration(object): schema_editor.collected_sql.append("-- %s" % operation.describe()) schema_editor.collected_sql.append("--") continue - # Get the state after the operation has run - new_state = project_state.clone() - operation.state_forwards(self.app_label, new_state) + # Save the state before the operation has run + old_state = project_state.clone() + operation.state_forwards(self.app_label, project_state) # Run the operation if not schema_editor.connection.features.can_rollback_ddl and operation.atomic: # We're forcing a transaction on a non-transactional-DDL backend with atomic(schema_editor.connection.alias): - operation.database_forwards(self.app_label, schema_editor, project_state, new_state) + operation.database_forwards(self.app_label, schema_editor, old_state, project_state) else: # Normal behaviour - operation.database_forwards(self.app_label, schema_editor, project_state, new_state) - # Switch states - project_state = new_state + operation.database_forwards(self.app_label, schema_editor, old_state, project_state) return project_state def unapply(self, project_state, schema_editor, collect_sql=False): @@ -133,10 +131,9 @@ class Migration(object): # If it's irreversible, error out if not operation.reversible: raise Migration.IrreversibleError("Operation %s in %s is not reversible" % (operation, self)) - new_state = project_state.clone() - operation.state_forwards(self.app_label, new_state) - to_run.append((operation, project_state, new_state)) - project_state = new_state + old_state = project_state.clone() + operation.state_forwards(self.app_label, project_state) + to_run.append((operation, old_state, project_state)) # Now run them in reverse to_run.reverse() for operation, to_state, from_state in to_run: diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index ba123de9c36..54251cf6edc 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -38,6 +38,7 @@ class AddField(Operation): else: field = self.field state.models[app_label, self.model_name.lower()].fields.append((self.name, field)) + state.reload_model(app_label, self.model_name) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) @@ -94,6 +95,7 @@ class RemoveField(Operation): if name != self.name: new_fields.append((name, instance)) state.models[app_label, self.model_name.lower()].fields = new_fields + state.reload_model(app_label, self.model_name) def database_forwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.apps.get_model(app_label, self.model_name) @@ -150,6 +152,7 @@ class AlterField(Operation): state.models[app_label, self.model_name.lower()].fields = [ (n, field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields ] + state.reload_model(app_label, self.model_name) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) @@ -220,6 +223,7 @@ class RenameField(Operation): [self.new_name if n == self.old_name else n for n in together] for together in options[option] ] + state.reload_model(app_label, self.model_name) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index eda30f88e8f..f07f667c51f 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -39,14 +39,14 @@ class CreateModel(Operation): ) def state_forwards(self, app_label, state): - state.models[app_label, self.name.lower()] = ModelState( + state.add_model(ModelState( app_label, self.name, list(self.fields), dict(self.options), tuple(self.bases), list(self.managers), - ) + )) def database_forwards(self, app_label, schema_editor, from_state, to_state): model = to_state.apps.get_model(app_label, self.name) @@ -98,7 +98,7 @@ class DeleteModel(Operation): ) def state_forwards(self, app_label, state): - del state.models[app_label, self.name.lower()] + state.remove_model(app_label, self.name) def database_forwards(self, app_label, schema_editor, from_state, to_state): model = from_state.apps.get_model(app_label, self.name) @@ -141,12 +141,13 @@ class RenameModel(Operation): # Get all of the related objects we need to repoint apps = state.apps model = apps.get_model(app_label, self.old_name) + model._meta.apps = apps related_objects = model._meta.get_all_related_objects() related_m2m_objects = model._meta.get_all_related_many_to_many_objects() # Rename the model state.models[app_label, self.new_name.lower()] = state.models[app_label, self.old_name.lower()] state.models[app_label, self.new_name.lower()].name = self.new_name - del state.models[app_label, self.old_name.lower()] + state.remove_model(app_label, self.old_name) # Repoint the FKs and M2Ms pointing to us for related_object in (related_objects + related_m2m_objects): # Use the new related key for self referential related objects. @@ -164,7 +165,8 @@ class RenameModel(Operation): field.rel.to = "%s.%s" % (app_label, self.new_name) new_fields.append((name, field)) state.models[related_key].fields = new_fields - del state.apps # FIXME: this should be replaced by a logic in state (update_model?) + state.reload_model(*related_key) + state.reload_model(app_label, self.new_name) def database_forwards(self, app_label, schema_editor, from_state, to_state): new_model = to_state.apps.get_model(app_label, self.new_name) @@ -235,6 +237,7 @@ class AlterModelTable(Operation): def state_forwards(self, app_label, state): state.models[app_label, self.name.lower()].options["db_table"] = self.table + state.reload_model(app_label, self.name) def database_forwards(self, app_label, schema_editor, from_state, to_state): new_model = to_state.apps.get_model(app_label, self.name) @@ -290,6 +293,7 @@ class AlterUniqueTogether(Operation): def state_forwards(self, app_label, state): model_state = state.models[app_label, self.name.lower()] model_state.options[self.option_name] = self.unique_together + state.reload_model(app_label, self.name) def database_forwards(self, app_label, schema_editor, from_state, to_state): new_model = to_state.apps.get_model(app_label, self.name) @@ -337,6 +341,7 @@ class AlterIndexTogether(Operation): def state_forwards(self, app_label, state): model_state = state.models[app_label, self.name.lower()] model_state.options[self.option_name] = self.index_together + state.reload_model(app_label, self.name) def database_forwards(self, app_label, schema_editor, from_state, to_state): new_model = to_state.apps.get_model(app_label, self.name) @@ -381,6 +386,7 @@ class AlterOrderWithRespectTo(Operation): def state_forwards(self, app_label, state): model_state = state.models[app_label, self.name.lower()] model_state.options['order_with_respect_to'] = self.order_with_respect_to + state.reload_model(app_label, self.name) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.name) @@ -451,6 +457,7 @@ class AlterModelOptions(Operation): for key in self.ALTER_OPTION_KEYS: if key not in self.options and key in model_state.options: del model_state.options[key] + state.reload_model(app_label, self.name) def database_forwards(self, app_label, schema_editor, from_state, to_state): pass diff --git a/django/db/migrations/state.py b/django/db/migrations/state.py index 8898fc4e9a9..8a45a0d2f29 100644 --- a/django/db/migrations/state.py +++ b/django/db/migrations/state.py @@ -1,4 +1,6 @@ from __future__ import unicode_literals +from collections import OrderedDict +import copy from django.apps import AppConfig from django.apps.registry import Apps, apps as global_apps @@ -30,15 +32,52 @@ class ProjectState(object): # Apps to include from main registry, usually unmigrated ones self.real_apps = real_apps or [] - def add_model_state(self, model_state): - self.models[(model_state.app_label, model_state.name.lower())] = model_state + def add_model(self, model_state): + app_label, model_name = model_state.app_label, model_state.name.lower() + self.models[(app_label, model_name)] = model_state + if 'apps' in self.__dict__: # hasattr would cache the property + self.reload_model(app_label, model_name) + + def remove_model(self, app_label, model_name): + model_name = model_name.lower() + del self.models[app_label, model_name] + if 'apps' in self.__dict__: # hasattr would cache the property + self.apps.unregister_model(app_label, model_name) + + def reload_model(self, app_label, model_name): + if 'apps' in self.__dict__: # hasattr would cache the property + # Get relations before reloading the models, as _meta.apps may change + model_name = model_name.lower() + try: + related_old = { + f.model for f in + self.apps.get_model(app_label, model_name)._meta.get_all_related_objects() + } + except LookupError: + related_old = set() + self._reload_one_model(app_label, model_name) + # Reload models if there are relations + model = self.apps.get_model(app_label, model_name) + related_m2m = {f.rel.to for f, _ in model._meta.get_m2m_with_model()} + for rel_model in related_old.union(related_m2m): + self._reload_one_model(rel_model._meta.app_label, rel_model._meta.model_name) + if related_m2m: + # Re-render this model after related models have been reloaded + self._reload_one_model(app_label, model_name) + + def _reload_one_model(self, app_label, model_name): + self.apps.unregister_model(app_label, model_name) + self.models[app_label, model_name].render(self.apps) def clone(self): "Returns an exact copy of this ProjectState" - return ProjectState( + new_state = ProjectState( models={k: v.clone() for k, v in self.models.items()}, real_apps=self.real_apps, ) + if 'apps' in self.__dict__: + new_state.apps = self.apps.clone() + return new_state @cached_property def apps(self): @@ -147,6 +186,31 @@ class StateApps(Apps): else: do_pending_lookups(model) + def clone(self): + """ + Return a clone of this registry, mainly used by the migration framework. + """ + clone = StateApps([], {}) + clone.all_models = copy.deepcopy(self.all_models) + clone.app_configs = copy.deepcopy(self.app_configs) + return clone + + def register_model(self, app_label, model): + self.all_models[app_label][model._meta.model_name] = model + if app_label not in self.app_configs: + self.app_configs[app_label] = AppConfigStub(app_label) + self.app_configs[app_label].models = OrderedDict() + self.app_configs[app_label].models[model._meta.model_name] = model + self.clear_cache() + + def unregister_model(self, app_label, model_name): + try: + del self.all_models[app_label][model_name] + del self.app_configs[app_label].models[model_name] + except KeyError: + pass + self.clear_cache() + class ModelState(object): """ @@ -368,7 +432,7 @@ class ModelState(object): for mgr_name, manager in self.managers: body[mgr_name] = manager - # Then, make a Model object + # Then, make a Model object (apps.register_model is called in __new__) return type( str(self.name), bases, diff --git a/tests/migrations/test_autodetector.py b/tests/migrations/test_autodetector.py index cfb59d48513..9c878a2beae 100644 --- a/tests/migrations/test_autodetector.py +++ b/tests/migrations/test_autodetector.py @@ -410,7 +410,7 @@ class AutodetectorTests(TestCase): "Shortcut to make ProjectStates from lists of predefined models" project_state = ProjectState() for model_state in model_states: - project_state.add_model_state(model_state.clone()) + project_state.add_model(model_state.clone()) return project_state def test_arrange_for_graph(self): diff --git a/tests/migrations/test_executor.py b/tests/migrations/test_executor.py index 3a7a779f8cb..78ee131d18c 100644 --- a/tests/migrations/test_executor.py +++ b/tests/migrations/test_executor.py @@ -223,7 +223,7 @@ class ExecutorTests(MigrationTestBase): global_apps.get_app_config("migrations").models["author"] = migrations_apps.get_model("migrations", "author") try: migration = executor.loader.get_migration("auth", "0001_initial") - self.assertEqual(executor.detect_soft_applied(None, migration), True) + self.assertEqual(executor.detect_soft_applied(None, migration)[0], True) finally: connection.introspection.table_names = old_table_names del global_apps.get_app_config("migrations").models["author"] diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index b52a4df5859..90ec456f84b 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -828,12 +828,13 @@ class OperationTests(OperationTestBase): ]) self.assertTableExists("test_rmflmm_pony_stables") + with_field_state = project_state.clone() operations = [migrations.RemoveField("Pony", "stables")] - self.apply_operations("test_rmflmm", project_state, operations=operations) + project_state = self.apply_operations("test_rmflmm", project_state, operations=operations) self.assertTableNotExists("test_rmflmm_pony_stables") # And test reversal - self.unapply_operations("test_rmflmm", project_state, operations=operations) + self.unapply_operations("test_rmflmm", with_field_state, operations=operations) self.assertTableExists("test_rmflmm_pony_stables") def test_remove_field_m2m_with_through(self): diff --git a/tests/migrations/test_state.py b/tests/migrations/test_state.py index 143ce49f8da..6f510813afb 100644 --- a/tests/migrations/test_state.py +++ b/tests/migrations/test_state.py @@ -159,7 +159,7 @@ class StateTests(TestCase): Tests rendering a ProjectState into an Apps. """ project_state = ProjectState() - project_state.add_model_state(ModelState( + project_state.add_model(ModelState( app_label="migrations", name="Tag", fields=[ @@ -168,7 +168,7 @@ class StateTests(TestCase): ("hidden", models.BooleanField()), ], )) - project_state.add_model_state(ModelState( + project_state.add_model(ModelState( app_label="migrations", name="SubTag", fields=[ @@ -187,7 +187,7 @@ class StateTests(TestCase): base_mgr = models.Manager() mgr1 = FoodManager('a', 'b') mgr2 = FoodManager('x', 'y', c=3, d=4) - project_state.add_model_state(ModelState( + project_state.add_model(ModelState( app_label="migrations", name="Food", fields=[ @@ -324,21 +324,21 @@ class StateTests(TestCase): # Make a ProjectState and render it project_state = ProjectState() - project_state.add_model_state(ModelState.from_model(A)) - project_state.add_model_state(ModelState.from_model(B)) - project_state.add_model_state(ModelState.from_model(C)) - project_state.add_model_state(ModelState.from_model(D)) - project_state.add_model_state(ModelState.from_model(E)) - project_state.add_model_state(ModelState.from_model(F)) + project_state.add_model(ModelState.from_model(A)) + project_state.add_model(ModelState.from_model(B)) + project_state.add_model(ModelState.from_model(C)) + project_state.add_model(ModelState.from_model(D)) + project_state.add_model(ModelState.from_model(E)) + project_state.add_model(ModelState.from_model(F)) final_apps = project_state.apps self.assertEqual(len(final_apps.get_models()), 6) # Now make an invalid ProjectState and make sure it fails project_state = ProjectState() - project_state.add_model_state(ModelState.from_model(A)) - project_state.add_model_state(ModelState.from_model(B)) - project_state.add_model_state(ModelState.from_model(C)) - project_state.add_model_state(ModelState.from_model(F)) + project_state.add_model(ModelState.from_model(A)) + project_state.add_model(ModelState.from_model(B)) + project_state.add_model(ModelState.from_model(C)) + project_state.add_model(ModelState.from_model(F)) with self.assertRaises(InvalidBasesError): project_state.apps @@ -358,8 +358,8 @@ class StateTests(TestCase): # Make a ProjectState and render it project_state = ProjectState() - project_state.add_model_state(ModelState.from_model(A)) - project_state.add_model_state(ModelState.from_model(B)) + project_state.add_model(ModelState.from_model(A)) + project_state.add_model(ModelState.from_model(B)) self.assertEqual(len(project_state.apps.get_models()), 2) def test_equality(self): @@ -369,7 +369,7 @@ class StateTests(TestCase): # Test two things that should be equal project_state = ProjectState() - project_state.add_model_state(ModelState( + project_state.add_model(ModelState( "migrations", "Tag", [ @@ -388,7 +388,7 @@ class StateTests(TestCase): # Make a very small change (max_len 99) and see if that affects it project_state = ProjectState() - project_state.add_model_state(ModelState( + project_state.add_model(ModelState( "migrations", "Tag", [ @@ -428,20 +428,20 @@ class StateTests(TestCase): # Make a valid ProjectState and render it project_state = ProjectState() - project_state.add_model_state(ModelState.from_model(Author)) - project_state.add_model_state(ModelState.from_model(Book)) - project_state.add_model_state(ModelState.from_model(Magazine)) + project_state.add_model(ModelState.from_model(Author)) + project_state.add_model(ModelState.from_model(Book)) + project_state.add_model(ModelState.from_model(Magazine)) self.assertEqual(len(project_state.apps.get_models()), 3) # now make an invalid one with a ForeignKey project_state = ProjectState() - project_state.add_model_state(ModelState.from_model(Book)) + project_state.add_model(ModelState.from_model(Book)) with self.assertRaises(ValueError): project_state.apps # and another with ManyToManyField project_state = ProjectState() - project_state.add_model_state(ModelState.from_model(Magazine)) + project_state.add_model(ModelState.from_model(Magazine)) with self.assertRaises(ValueError): project_state.apps @@ -461,13 +461,13 @@ class StateTests(TestCase): # If we just stick it into an empty state it should fail project_state = ProjectState() - project_state.add_model_state(ModelState.from_model(TestModel)) + project_state.add_model(ModelState.from_model(TestModel)) with self.assertRaises(ValueError): project_state.apps # If we include the real app it should succeed project_state = ProjectState(real_apps=["contenttypes"]) - project_state.add_model_state(ModelState.from_model(TestModel)) + project_state.add_model(ModelState.from_model(TestModel)) rendered_state = project_state.apps self.assertEqual( len([x for x in rendered_state.get_models() if x._meta.app_label == "migrations"]), @@ -498,8 +498,8 @@ class StateTests(TestCase): # Make a valid ProjectState and render it project_state = ProjectState() - project_state.add_model_state(ModelState.from_model(Author)) - project_state.add_model_state(ModelState.from_model(Book)) + project_state.add_model(ModelState.from_model(Author)) + project_state.add_model(ModelState.from_model(Book)) self.assertEqual( [name for name, field in project_state.models["migrations", "book"].fields], ["id", "author"], @@ -534,6 +534,6 @@ class ModelStateTests(TestCase): self.assertEqual(repr(state), "") project_state = ProjectState() - project_state.add_model_state(state) + project_state.add_model(state) with self.assertRaisesMessage(InvalidBasesError, "Cannot resolve bases for []"): project_state.apps