Fix migration planner to fully understand squashed migrations. And test.

This commit is contained in:
Andrew Godwin 2013-10-23 22:56:54 +01:00
parent 4cfbde71a3
commit 5ab8b5d72c
9 changed files with 207 additions and 37 deletions

View File

@ -11,7 +11,6 @@ class MigrationExecutor(object):
def __init__(self, connection, progress_callback=None): def __init__(self, connection, progress_callback=None):
self.connection = connection self.connection = connection
self.loader = MigrationLoader(self.connection) self.loader = MigrationLoader(self.connection)
self.loader.load_disk()
self.recorder = MigrationRecorder(self.connection) self.recorder = MigrationRecorder(self.connection)
self.progress_callback = progress_callback self.progress_callback = progress_callback
@ -20,7 +19,7 @@ class MigrationExecutor(object):
Given a set of targets, returns a list of (Migration instance, backwards?). Given a set of targets, returns a list of (Migration instance, backwards?).
""" """
plan = [] plan = []
applied = self.recorder.applied_migrations() applied = set(self.loader.applied_migrations)
for target in targets: for target in targets:
# If the target is (appname, None), that means unmigrate everything # If the target is (appname, None), that means unmigrate everything
if target[1] is None: if target[1] is None:
@ -87,7 +86,13 @@ class MigrationExecutor(object):
with self.connection.schema_editor() as schema_editor: with self.connection.schema_editor() as schema_editor:
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
migration.apply(project_state, schema_editor) migration.apply(project_state, schema_editor)
# For replacement migrations, record individual statuses
if migration.replaces:
for app_label, name in migration.replaces:
self.recorder.record_applied(app_label, name)
else:
self.recorder.record_applied(migration.app_label, migration.name) self.recorder.record_applied(migration.app_label, migration.name)
# Report prgress
if self.progress_callback: if self.progress_callback:
self.progress_callback("apply_success", migration) self.progress_callback("apply_success", migration)
@ -101,6 +106,12 @@ class MigrationExecutor(object):
with self.connection.schema_editor() as schema_editor: with self.connection.schema_editor() as schema_editor:
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
migration.unapply(project_state, schema_editor) migration.unapply(project_state, schema_editor)
# For replacement migrations, record individual statuses
if migration.replaces:
for app_label, name in migration.replaces:
self.recorder.record_unapplied(app_label, name)
else:
self.recorder.record_unapplied(migration.app_label, migration.name) self.recorder.record_unapplied(migration.app_label, migration.name)
# Report progress
if self.progress_callback: if self.progress_callback:
self.progress_callback("unapply_success", migration) self.progress_callback("unapply_success", migration)

View File

@ -1,9 +1,10 @@
import os import os
import sys
from importlib import import_module from importlib import import_module
from django.utils.functional import cached_property
from django.db.models.loading import cache from django.db.models.loading import cache
from django.db.migrations.recorder import MigrationRecorder from django.db.migrations.recorder import MigrationRecorder
from django.db.migrations.graph import MigrationGraph from django.db.migrations.graph import MigrationGraph
from django.utils import six
from django.conf import settings from django.conf import settings
@ -32,10 +33,12 @@ class MigrationLoader(object):
in memory. in memory.
""" """
def __init__(self, connection): def __init__(self, connection, load=True):
self.connection = connection self.connection = connection
self.disk_migrations = None self.disk_migrations = None
self.applied_migrations = None self.applied_migrations = None
if load:
self.build_graph()
@classmethod @classmethod
def migrations_module(cls, app_label): def migrations_module(cls, app_label):
@ -55,6 +58,7 @@ class MigrationLoader(object):
# Get the migrations module directory # Get the migrations module directory
app_label = app.__name__.split(".")[-2] app_label = app.__name__.split(".")[-2]
module_name = self.migrations_module(app_label) module_name = self.migrations_module(app_label)
was_loaded = module_name in sys.modules
try: try:
module = import_module(module_name) module = import_module(module_name)
except ImportError as e: except ImportError as e:
@ -71,6 +75,9 @@ class MigrationLoader(object):
# Module is not a package (e.g. migrations.py). # Module is not a package (e.g. migrations.py).
if not hasattr(module, '__path__'): if not hasattr(module, '__path__'):
continue continue
# Force a reload if it's already loaded (tests need this)
if was_loaded:
six.moves.reload_module(module)
self.migrated_apps.add(app_label) self.migrated_apps.add(app_label)
directory = os.path.dirname(module.__file__) directory = os.path.dirname(module.__file__)
# Scan for .py[c|o] files # Scan for .py[c|o] files
@ -107,9 +114,6 @@ class MigrationLoader(object):
def get_migration_by_prefix(self, app_label, name_prefix): def get_migration_by_prefix(self, app_label, name_prefix):
"Returns the migration(s) which match the given app label and name _prefix_" "Returns the migration(s) which match the given app label and name _prefix_"
# Make sure we have the disk data
if self.disk_migrations is None:
self.load_disk()
# Do the search # Do the search
results = [] results = []
for l, n in self.disk_migrations: for l, n in self.disk_migrations:
@ -122,16 +126,15 @@ class MigrationLoader(object):
else: else:
return self.disk_migrations[results[0]] return self.disk_migrations[results[0]]
@cached_property def build_graph(self):
def graph(self):
""" """
Builds a migration dependency graph using both the disk and database. Builds a migration dependency graph using both the disk and database.
You'll need to rebuild the graph if you apply migrations. This isn't
usually a problem as generally migration stuff runs in a one-shot process.
""" """
# Make sure we have the disk data # Load disk data
if self.disk_migrations is None:
self.load_disk() self.load_disk()
# And the database data # Load database data
if self.applied_migrations is None:
recorder = MigrationRecorder(self.connection) recorder = MigrationRecorder(self.connection)
self.applied_migrations = recorder.applied_migrations() self.applied_migrations = recorder.applied_migrations()
# Do a first pass to separate out replacing and non-replacing migrations # Do a first pass to separate out replacing and non-replacing migrations
@ -152,12 +155,12 @@ class MigrationLoader(object):
# Carry out replacements if we can - that is, if all replaced migrations # Carry out replacements if we can - that is, if all replaced migrations
# are either unapplied or missing. # are either unapplied or missing.
for key, migration in replacing.items(): for key, migration in replacing.items():
# Do the check # Ensure this replacement migration is not in applied_migrations
can_replace = True self.applied_migrations.discard(key)
for target in migration.replaces: # Do the check. We can replace if all our replace targets are
if target in self.applied_migrations: # applied, or if all of them are unapplied.
can_replace = False applied_statuses = [(target in self.applied_migrations) for target in migration.replaces]
break can_replace = all(applied_statuses) or (not any(applied_statuses))
if not can_replace: if not can_replace:
continue continue
# Alright, time to replace. Step through the replaced migrations # Alright, time to replace. Step through the replaced migrations
@ -171,14 +174,16 @@ class MigrationLoader(object):
normal[child_key].dependencies.remove(replaced) normal[child_key].dependencies.remove(replaced)
normal[child_key].dependencies.append(key) normal[child_key].dependencies.append(key)
normal[key] = migration normal[key] = migration
# Mark the replacement as applied if all its replaced ones are
if all(applied_statuses):
self.applied_migrations.add(key)
# Finally, make a graph and load everything into it # Finally, make a graph and load everything into it
graph = MigrationGraph() self.graph = MigrationGraph()
for key, migration in normal.items(): for key, migration in normal.items():
graph.add_node(key, migration) self.graph.add_node(key, migration)
for key, migration in normal.items(): for key, migration in normal.items():
for parent in migration.dependencies: for parent in migration.dependencies:
graph.add_dependency(key, parent) self.graph.add_dependency(key, parent)
return graph
class BadMigrationError(Exception): class BadMigrationError(Exception):

View File

@ -39,6 +39,11 @@ class Migration(object):
def __init__(self, name, app_label): def __init__(self, name, app_label):
self.name = name self.name = name
self.app_label = app_label self.app_label = app_label
# Copy dependencies & other attrs as we might mutate them at runtime
self.operations = list(self.__class__.operations)
self.dependencies = list(self.__class__.dependencies)
self.run_before = list(self.__class__.run_before)
self.replaces = list(self.__class__.replaces)
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Migration): if not isinstance(other, Migration):

View File

@ -38,7 +38,58 @@ class ExecutorTests(TransactionTestCase):
# Are the tables there now? # Are the tables there now?
self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
# Rebuild the graph to reflect the new DB state
executor.loader.build_graph()
# Alright, let's undo what we did # Alright, let's undo what we did
plan = executor.migration_plan([("migrations", None)])
self.assertEqual(
plan,
[
(executor.loader.graph.nodes["migrations", "0002_second"], True),
(executor.loader.graph.nodes["migrations", "0001_initial"], True),
],
)
executor.migrate([("migrations", None)])
# Are the tables gone?
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
@override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"})
def test_run_with_squashed(self):
"""
Tests running a squashed migration from zero (should ignore what it replaces)
"""
executor = MigrationExecutor(connection)
executor.recorder.flush()
# Check our leaf node is the squashed one
leaves = [key for key in executor.loader.graph.leaf_nodes() if key[0] == "migrations"]
self.assertEqual(leaves, [("migrations", "0001_squashed_0002")])
# Check the plan
plan = executor.migration_plan([("migrations", "0001_squashed_0002")])
self.assertEqual(
plan,
[
(executor.loader.graph.nodes["migrations", "0001_squashed_0002"], False),
],
)
# Were the tables there before?
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
# Alright, let's try running it
executor.migrate([("migrations", "0001_squashed_0002")])
# Are the tables there now?
self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
# Rebuild the graph to reflect the new DB state
executor.loader.build_graph()
# Alright, let's undo what we did. Should also just use squashed.
plan = executor.migration_plan([("migrations", None)])
self.assertEqual(
plan,
[
(executor.loader.graph.nodes["migrations", "0001_squashed_0002"], True),
],
)
executor.migrate([("migrations", None)]) executor.migrate([("migrations", None)])
# Are the tables gone? # Are the tables gone?
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
@ -70,6 +121,8 @@ class ExecutorTests(TransactionTestCase):
) )
# Fake-apply all migrations # Fake-apply all migrations
executor.migrate([("migrations", "0002_second"), ("sessions", "0001_initial")], fake=True) executor.migrate([("migrations", "0002_second"), ("sessions", "0001_initial")], fake=True)
# Rebuild the graph to reflect the new DB state
executor.loader.build_graph()
# Now plan a second time and make sure it's empty # Now plan a second time and make sure it's empty
plan = executor.migration_plan([("migrations", "0002_second"), ("sessions", "0001_initial")]) plan = executor.migration_plan([("migrations", "0002_second"), ("sessions", "0001_initial")])
self.assertEqual(plan, []) self.assertEqual(plan, [])

View File

@ -82,21 +82,34 @@ class LoaderTests(TestCase):
migration_loader.get_migration_by_prefix("migrations", "blarg") migration_loader.get_migration_by_prefix("migrations", "blarg")
def test_load_import_error(self): def test_load_import_error(self):
migration_loader = MigrationLoader(connection)
with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.import_error"}): with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.import_error"}):
with self.assertRaises(ImportError): with self.assertRaises(ImportError):
migration_loader.load_disk() MigrationLoader(connection)
def test_load_module_file(self): def test_load_module_file(self):
migration_loader = MigrationLoader(connection)
with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.file"}): with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.file"}):
migration_loader.load_disk() MigrationLoader(connection)
@skipIf(six.PY2, "PY2 doesn't load empty dirs.") @skipIf(six.PY2, "PY2 doesn't load empty dirs.")
def test_load_empty_dir(self): def test_load_empty_dir(self):
migration_loader = MigrationLoader(connection)
with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.namespace"}): with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.namespace"}):
migration_loader.load_disk() MigrationLoader(connection)
@override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"})
def test_loading_squashed(self):
"Tests loading a squashed migration"
migration_loader = MigrationLoader(connection)
recorder = MigrationRecorder(connection)
# Loading with nothing applied should just give us the one node
self.assertEqual(
len(migration_loader.graph.nodes),
1,
)
# However, fake-apply one migration and it should now use the old two
recorder.record_applied("migrations", "0001_initial")
migration_loader.build_graph()
self.assertEqual(
len(migration_loader.graph.nodes),
2,
)
recorder.flush()

View File

@ -0,0 +1,27 @@
from django.db import migrations, models
class Migration(migrations.Migration):
operations = [
migrations.CreateModel(
"Author",
[
("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=255)),
("slug", models.SlugField(null=True)),
("age", models.IntegerField(default=0)),
("silly_field", models.BooleanField(default=False)),
],
),
migrations.CreateModel(
"Tribble",
[
("id", models.AutoField(primary_key=True)),
("fluffy", models.BooleanField(default=True)),
],
)
]

View File

@ -0,0 +1,32 @@
from django.db import migrations, models
class Migration(migrations.Migration):
replaces = [
("migrations", "0001_initial"),
("migrations", "0002_second"),
]
operations = [
migrations.CreateModel(
"Author",
[
("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=255)),
("slug", models.SlugField(null=True)),
("age", models.IntegerField(default=0)),
("rating", models.IntegerField(default=0)),
],
),
migrations.CreateModel(
"Book",
[
("id", models.AutoField(primary_key=True)),
("author", models.ForeignKey("migrations.Author", null=True)),
],
),
]

View File

@ -0,0 +1,24 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [("migrations", "0001_initial")]
operations = [
migrations.DeleteModel("Tribble"),
migrations.RemoveField("Author", "silly_field"),
migrations.AddField("Author", "rating", models.IntegerField(default=0)),
migrations.CreateModel(
"Book",
[
("id", models.AutoField(primary_key=True)),
("author", models.ForeignKey("migrations.Author", null=True)),
],
)
]