From 5ab8b5d72c5277833908bc57b0765682e5aadc0b Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Wed, 23 Oct 2013 22:56:54 +0100 Subject: [PATCH] Fix migration planner to fully understand squashed migrations. And test. --- django/db/migrations/executor.py | 19 +++++-- django/db/migrations/loader.py | 53 ++++++++++--------- django/db/migrations/migration.py | 5 ++ tests/migrations/test_executor.py | 53 +++++++++++++++++++ tests/migrations/test_loader.py | 31 +++++++---- .../test_migrations_squashed/0001_initial.py | 27 ++++++++++ .../0001_squashed_0002.py | 32 +++++++++++ .../test_migrations_squashed/0002_second.py | 24 +++++++++ .../test_migrations_squashed/__init__.py | 0 9 files changed, 207 insertions(+), 37 deletions(-) create mode 100644 tests/migrations/test_migrations_squashed/0001_initial.py create mode 100644 tests/migrations/test_migrations_squashed/0001_squashed_0002.py create mode 100644 tests/migrations/test_migrations_squashed/0002_second.py create mode 100644 tests/migrations/test_migrations_squashed/__init__.py diff --git a/django/db/migrations/executor.py b/django/db/migrations/executor.py index 420fbca8b8..a9b0c0f755 100644 --- a/django/db/migrations/executor.py +++ b/django/db/migrations/executor.py @@ -11,7 +11,6 @@ class MigrationExecutor(object): def __init__(self, connection, progress_callback=None): self.connection = connection self.loader = MigrationLoader(self.connection) - self.loader.load_disk() self.recorder = MigrationRecorder(self.connection) self.progress_callback = progress_callback @@ -20,7 +19,7 @@ class MigrationExecutor(object): Given a set of targets, returns a list of (Migration instance, backwards?). """ plan = [] - applied = self.recorder.applied_migrations() + applied = set(self.loader.applied_migrations) for target in targets: # If the target is (appname, None), that means unmigrate everything if target[1] is None: @@ -87,7 +86,13 @@ class MigrationExecutor(object): with self.connection.schema_editor() as schema_editor: project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) migration.apply(project_state, schema_editor) - self.recorder.record_applied(migration.app_label, migration.name) + # For replacement migrations, record individual statuses + if migration.replaces: + for app_label, name in migration.replaces: + self.recorder.record_applied(app_label, name) + else: + self.recorder.record_applied(migration.app_label, migration.name) + # Report prgress if self.progress_callback: self.progress_callback("apply_success", migration) @@ -101,6 +106,12 @@ class MigrationExecutor(object): with self.connection.schema_editor() as schema_editor: project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) migration.unapply(project_state, schema_editor) - self.recorder.record_unapplied(migration.app_label, migration.name) + # For replacement migrations, record individual statuses + if migration.replaces: + for app_label, name in migration.replaces: + self.recorder.record_unapplied(app_label, name) + else: + self.recorder.record_unapplied(migration.app_label, migration.name) + # Report progress if self.progress_callback: self.progress_callback("unapply_success", migration) diff --git a/django/db/migrations/loader.py b/django/db/migrations/loader.py index c9be3841b9..16557074a1 100644 --- a/django/db/migrations/loader.py +++ b/django/db/migrations/loader.py @@ -1,9 +1,10 @@ import os +import sys from importlib import import_module -from django.utils.functional import cached_property from django.db.models.loading import cache from django.db.migrations.recorder import MigrationRecorder from django.db.migrations.graph import MigrationGraph +from django.utils import six from django.conf import settings @@ -32,10 +33,12 @@ class MigrationLoader(object): in memory. """ - def __init__(self, connection): + def __init__(self, connection, load=True): self.connection = connection self.disk_migrations = None self.applied_migrations = None + if load: + self.build_graph() @classmethod def migrations_module(cls, app_label): @@ -55,6 +58,7 @@ class MigrationLoader(object): # Get the migrations module directory app_label = app.__name__.split(".")[-2] module_name = self.migrations_module(app_label) + was_loaded = module_name in sys.modules try: module = import_module(module_name) except ImportError as e: @@ -71,6 +75,9 @@ class MigrationLoader(object): # Module is not a package (e.g. migrations.py). if not hasattr(module, '__path__'): continue + # Force a reload if it's already loaded (tests need this) + if was_loaded: + six.moves.reload_module(module) self.migrated_apps.add(app_label) directory = os.path.dirname(module.__file__) # Scan for .py[c|o] files @@ -107,9 +114,6 @@ class MigrationLoader(object): def get_migration_by_prefix(self, app_label, name_prefix): "Returns the migration(s) which match the given app label and name _prefix_" - # Make sure we have the disk data - if self.disk_migrations is None: - self.load_disk() # Do the search results = [] for l, n in self.disk_migrations: @@ -122,18 +126,17 @@ class MigrationLoader(object): else: return self.disk_migrations[results[0]] - @cached_property - def graph(self): + def build_graph(self): """ Builds a migration dependency graph using both the disk and database. + You'll need to rebuild the graph if you apply migrations. This isn't + usually a problem as generally migration stuff runs in a one-shot process. """ - # Make sure we have the disk data - if self.disk_migrations is None: - self.load_disk() - # And the database data - if self.applied_migrations is None: - recorder = MigrationRecorder(self.connection) - self.applied_migrations = recorder.applied_migrations() + # Load disk data + self.load_disk() + # Load database data + recorder = MigrationRecorder(self.connection) + self.applied_migrations = recorder.applied_migrations() # Do a first pass to separate out replacing and non-replacing migrations normal = {} replacing = {} @@ -152,12 +155,12 @@ class MigrationLoader(object): # Carry out replacements if we can - that is, if all replaced migrations # are either unapplied or missing. for key, migration in replacing.items(): - # Do the check - can_replace = True - for target in migration.replaces: - if target in self.applied_migrations: - can_replace = False - break + # Ensure this replacement migration is not in applied_migrations + self.applied_migrations.discard(key) + # Do the check. We can replace if all our replace targets are + # applied, or if all of them are unapplied. + applied_statuses = [(target in self.applied_migrations) for target in migration.replaces] + can_replace = all(applied_statuses) or (not any(applied_statuses)) if not can_replace: continue # Alright, time to replace. Step through the replaced migrations @@ -171,14 +174,16 @@ class MigrationLoader(object): normal[child_key].dependencies.remove(replaced) normal[child_key].dependencies.append(key) normal[key] = migration + # Mark the replacement as applied if all its replaced ones are + if all(applied_statuses): + self.applied_migrations.add(key) # Finally, make a graph and load everything into it - graph = MigrationGraph() + self.graph = MigrationGraph() for key, migration in normal.items(): - graph.add_node(key, migration) + self.graph.add_node(key, migration) for key, migration in normal.items(): for parent in migration.dependencies: - graph.add_dependency(key, parent) - return graph + self.graph.add_dependency(key, parent) class BadMigrationError(Exception): diff --git a/django/db/migrations/migration.py b/django/db/migrations/migration.py index 06da84179b..c0ed1a9564 100644 --- a/django/db/migrations/migration.py +++ b/django/db/migrations/migration.py @@ -39,6 +39,11 @@ class Migration(object): def __init__(self, name, app_label): self.name = name self.app_label = app_label + # Copy dependencies & other attrs as we might mutate them at runtime + self.operations = list(self.__class__.operations) + self.dependencies = list(self.__class__.dependencies) + self.run_before = list(self.__class__.run_before) + self.replaces = list(self.__class__.replaces) def __eq__(self, other): if not isinstance(other, Migration): diff --git a/tests/migrations/test_executor.py b/tests/migrations/test_executor.py index dbdea900a5..c81a8bfa99 100644 --- a/tests/migrations/test_executor.py +++ b/tests/migrations/test_executor.py @@ -38,7 +38,58 @@ class ExecutorTests(TransactionTestCase): # Are the tables there now? self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) + # Rebuild the graph to reflect the new DB state + executor.loader.build_graph() # Alright, let's undo what we did + plan = executor.migration_plan([("migrations", None)]) + self.assertEqual( + plan, + [ + (executor.loader.graph.nodes["migrations", "0002_second"], True), + (executor.loader.graph.nodes["migrations", "0001_initial"], True), + ], + ) + executor.migrate([("migrations", None)]) + # Are the tables gone? + self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) + self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) + + @override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"}) + def test_run_with_squashed(self): + """ + Tests running a squashed migration from zero (should ignore what it replaces) + """ + executor = MigrationExecutor(connection) + executor.recorder.flush() + # Check our leaf node is the squashed one + leaves = [key for key in executor.loader.graph.leaf_nodes() if key[0] == "migrations"] + self.assertEqual(leaves, [("migrations", "0001_squashed_0002")]) + # Check the plan + plan = executor.migration_plan([("migrations", "0001_squashed_0002")]) + self.assertEqual( + plan, + [ + (executor.loader.graph.nodes["migrations", "0001_squashed_0002"], False), + ], + ) + # Were the tables there before? + self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) + self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) + # Alright, let's try running it + executor.migrate([("migrations", "0001_squashed_0002")]) + # Are the tables there now? + self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) + self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) + # Rebuild the graph to reflect the new DB state + executor.loader.build_graph() + # Alright, let's undo what we did. Should also just use squashed. + plan = executor.migration_plan([("migrations", None)]) + self.assertEqual( + plan, + [ + (executor.loader.graph.nodes["migrations", "0001_squashed_0002"], True), + ], + ) executor.migrate([("migrations", None)]) # Are the tables gone? self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) @@ -70,6 +121,8 @@ class ExecutorTests(TransactionTestCase): ) # Fake-apply all migrations executor.migrate([("migrations", "0002_second"), ("sessions", "0001_initial")], fake=True) + # Rebuild the graph to reflect the new DB state + executor.loader.build_graph() # Now plan a second time and make sure it's empty plan = executor.migration_plan([("migrations", "0002_second"), ("sessions", "0001_initial")]) self.assertEqual(plan, []) diff --git a/tests/migrations/test_loader.py b/tests/migrations/test_loader.py index c7fe16cf9a..80ac9ffd73 100644 --- a/tests/migrations/test_loader.py +++ b/tests/migrations/test_loader.py @@ -82,21 +82,34 @@ class LoaderTests(TestCase): migration_loader.get_migration_by_prefix("migrations", "blarg") def test_load_import_error(self): - migration_loader = MigrationLoader(connection) - with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.import_error"}): with self.assertRaises(ImportError): - migration_loader.load_disk() + MigrationLoader(connection) def test_load_module_file(self): - migration_loader = MigrationLoader(connection) - with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.file"}): - migration_loader.load_disk() + MigrationLoader(connection) @skipIf(six.PY2, "PY2 doesn't load empty dirs.") def test_load_empty_dir(self): - migration_loader = MigrationLoader(connection) - with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.namespace"}): - migration_loader.load_disk() + MigrationLoader(connection) + + @override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"}) + def test_loading_squashed(self): + "Tests loading a squashed migration" + migration_loader = MigrationLoader(connection) + recorder = MigrationRecorder(connection) + # Loading with nothing applied should just give us the one node + self.assertEqual( + len(migration_loader.graph.nodes), + 1, + ) + # However, fake-apply one migration and it should now use the old two + recorder.record_applied("migrations", "0001_initial") + migration_loader.build_graph() + self.assertEqual( + len(migration_loader.graph.nodes), + 2, + ) + recorder.flush() diff --git a/tests/migrations/test_migrations_squashed/0001_initial.py b/tests/migrations/test_migrations_squashed/0001_initial.py new file mode 100644 index 0000000000..344bebdfe3 --- /dev/null +++ b/tests/migrations/test_migrations_squashed/0001_initial.py @@ -0,0 +1,27 @@ +from django.db import migrations, models + + +class Migration(migrations.Migration): + + operations = [ + + migrations.CreateModel( + "Author", + [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=255)), + ("slug", models.SlugField(null=True)), + ("age", models.IntegerField(default=0)), + ("silly_field", models.BooleanField(default=False)), + ], + ), + + migrations.CreateModel( + "Tribble", + [ + ("id", models.AutoField(primary_key=True)), + ("fluffy", models.BooleanField(default=True)), + ], + ) + + ] diff --git a/tests/migrations/test_migrations_squashed/0001_squashed_0002.py b/tests/migrations/test_migrations_squashed/0001_squashed_0002.py new file mode 100644 index 0000000000..742be641aa --- /dev/null +++ b/tests/migrations/test_migrations_squashed/0001_squashed_0002.py @@ -0,0 +1,32 @@ +from django.db import migrations, models + + +class Migration(migrations.Migration): + + replaces = [ + ("migrations", "0001_initial"), + ("migrations", "0002_second"), + ] + + operations = [ + + migrations.CreateModel( + "Author", + [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=255)), + ("slug", models.SlugField(null=True)), + ("age", models.IntegerField(default=0)), + ("rating", models.IntegerField(default=0)), + ], + ), + + migrations.CreateModel( + "Book", + [ + ("id", models.AutoField(primary_key=True)), + ("author", models.ForeignKey("migrations.Author", null=True)), + ], + ), + + ] diff --git a/tests/migrations/test_migrations_squashed/0002_second.py b/tests/migrations/test_migrations_squashed/0002_second.py new file mode 100644 index 0000000000..ace9a83347 --- /dev/null +++ b/tests/migrations/test_migrations_squashed/0002_second.py @@ -0,0 +1,24 @@ +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [("migrations", "0001_initial")] + + operations = [ + + migrations.DeleteModel("Tribble"), + + migrations.RemoveField("Author", "silly_field"), + + migrations.AddField("Author", "rating", models.IntegerField(default=0)), + + migrations.CreateModel( + "Book", + [ + ("id", models.AutoField(primary_key=True)), + ("author", models.ForeignKey("migrations.Author", null=True)), + ], + ) + + ] diff --git a/tests/migrations/test_migrations_squashed/__init__.py b/tests/migrations/test_migrations_squashed/__init__.py new file mode 100644 index 0000000000..e69de29bb2