Fix migration planner to fully understand squashed migrations. And test.

This commit is contained in:
Andrew Godwin 2013-10-23 22:56:54 +01:00
parent 4cfbde71a3
commit 5ab8b5d72c
9 changed files with 207 additions and 37 deletions

View File

@ -11,7 +11,6 @@ class MigrationExecutor(object):
def __init__(self, connection, progress_callback=None):
self.connection = connection
self.loader = MigrationLoader(self.connection)
self.loader.load_disk()
self.recorder = MigrationRecorder(self.connection)
self.progress_callback = progress_callback
@ -20,7 +19,7 @@ class MigrationExecutor(object):
Given a set of targets, returns a list of (Migration instance, backwards?).
"""
plan = []
applied = self.recorder.applied_migrations()
applied = set(self.loader.applied_migrations)
for target in targets:
# If the target is (appname, None), that means unmigrate everything
if target[1] is None:
@ -87,7 +86,13 @@ class MigrationExecutor(object):
with self.connection.schema_editor() as schema_editor:
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
migration.apply(project_state, schema_editor)
self.recorder.record_applied(migration.app_label, migration.name)
# For replacement migrations, record individual statuses
if migration.replaces:
for app_label, name in migration.replaces:
self.recorder.record_applied(app_label, name)
else:
self.recorder.record_applied(migration.app_label, migration.name)
# Report prgress
if self.progress_callback:
self.progress_callback("apply_success", migration)
@ -101,6 +106,12 @@ class MigrationExecutor(object):
with self.connection.schema_editor() as schema_editor:
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
migration.unapply(project_state, schema_editor)
self.recorder.record_unapplied(migration.app_label, migration.name)
# For replacement migrations, record individual statuses
if migration.replaces:
for app_label, name in migration.replaces:
self.recorder.record_unapplied(app_label, name)
else:
self.recorder.record_unapplied(migration.app_label, migration.name)
# Report progress
if self.progress_callback:
self.progress_callback("unapply_success", migration)

View File

@ -1,9 +1,10 @@
import os
import sys
from importlib import import_module
from django.utils.functional import cached_property
from django.db.models.loading import cache
from django.db.migrations.recorder import MigrationRecorder
from django.db.migrations.graph import MigrationGraph
from django.utils import six
from django.conf import settings
@ -32,10 +33,12 @@ class MigrationLoader(object):
in memory.
"""
def __init__(self, connection):
def __init__(self, connection, load=True):
self.connection = connection
self.disk_migrations = None
self.applied_migrations = None
if load:
self.build_graph()
@classmethod
def migrations_module(cls, app_label):
@ -55,6 +58,7 @@ class MigrationLoader(object):
# Get the migrations module directory
app_label = app.__name__.split(".")[-2]
module_name = self.migrations_module(app_label)
was_loaded = module_name in sys.modules
try:
module = import_module(module_name)
except ImportError as e:
@ -71,6 +75,9 @@ class MigrationLoader(object):
# Module is not a package (e.g. migrations.py).
if not hasattr(module, '__path__'):
continue
# Force a reload if it's already loaded (tests need this)
if was_loaded:
six.moves.reload_module(module)
self.migrated_apps.add(app_label)
directory = os.path.dirname(module.__file__)
# Scan for .py[c|o] files
@ -107,9 +114,6 @@ class MigrationLoader(object):
def get_migration_by_prefix(self, app_label, name_prefix):
"Returns the migration(s) which match the given app label and name _prefix_"
# Make sure we have the disk data
if self.disk_migrations is None:
self.load_disk()
# Do the search
results = []
for l, n in self.disk_migrations:
@ -122,18 +126,17 @@ class MigrationLoader(object):
else:
return self.disk_migrations[results[0]]
@cached_property
def graph(self):
def build_graph(self):
"""
Builds a migration dependency graph using both the disk and database.
You'll need to rebuild the graph if you apply migrations. This isn't
usually a problem as generally migration stuff runs in a one-shot process.
"""
# Make sure we have the disk data
if self.disk_migrations is None:
self.load_disk()
# And the database data
if self.applied_migrations is None:
recorder = MigrationRecorder(self.connection)
self.applied_migrations = recorder.applied_migrations()
# Load disk data
self.load_disk()
# Load database data
recorder = MigrationRecorder(self.connection)
self.applied_migrations = recorder.applied_migrations()
# Do a first pass to separate out replacing and non-replacing migrations
normal = {}
replacing = {}
@ -152,12 +155,12 @@ class MigrationLoader(object):
# Carry out replacements if we can - that is, if all replaced migrations
# are either unapplied or missing.
for key, migration in replacing.items():
# Do the check
can_replace = True
for target in migration.replaces:
if target in self.applied_migrations:
can_replace = False
break
# Ensure this replacement migration is not in applied_migrations
self.applied_migrations.discard(key)
# Do the check. We can replace if all our replace targets are
# applied, or if all of them are unapplied.
applied_statuses = [(target in self.applied_migrations) for target in migration.replaces]
can_replace = all(applied_statuses) or (not any(applied_statuses))
if not can_replace:
continue
# Alright, time to replace. Step through the replaced migrations
@ -171,14 +174,16 @@ class MigrationLoader(object):
normal[child_key].dependencies.remove(replaced)
normal[child_key].dependencies.append(key)
normal[key] = migration
# Mark the replacement as applied if all its replaced ones are
if all(applied_statuses):
self.applied_migrations.add(key)
# Finally, make a graph and load everything into it
graph = MigrationGraph()
self.graph = MigrationGraph()
for key, migration in normal.items():
graph.add_node(key, migration)
self.graph.add_node(key, migration)
for key, migration in normal.items():
for parent in migration.dependencies:
graph.add_dependency(key, parent)
return graph
self.graph.add_dependency(key, parent)
class BadMigrationError(Exception):

View File

@ -39,6 +39,11 @@ class Migration(object):
def __init__(self, name, app_label):
self.name = name
self.app_label = app_label
# Copy dependencies & other attrs as we might mutate them at runtime
self.operations = list(self.__class__.operations)
self.dependencies = list(self.__class__.dependencies)
self.run_before = list(self.__class__.run_before)
self.replaces = list(self.__class__.replaces)
def __eq__(self, other):
if not isinstance(other, Migration):

View File

@ -38,7 +38,58 @@ class ExecutorTests(TransactionTestCase):
# Are the tables there now?
self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
# Rebuild the graph to reflect the new DB state
executor.loader.build_graph()
# Alright, let's undo what we did
plan = executor.migration_plan([("migrations", None)])
self.assertEqual(
plan,
[
(executor.loader.graph.nodes["migrations", "0002_second"], True),
(executor.loader.graph.nodes["migrations", "0001_initial"], True),
],
)
executor.migrate([("migrations", None)])
# Are the tables gone?
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
@override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"})
def test_run_with_squashed(self):
"""
Tests running a squashed migration from zero (should ignore what it replaces)
"""
executor = MigrationExecutor(connection)
executor.recorder.flush()
# Check our leaf node is the squashed one
leaves = [key for key in executor.loader.graph.leaf_nodes() if key[0] == "migrations"]
self.assertEqual(leaves, [("migrations", "0001_squashed_0002")])
# Check the plan
plan = executor.migration_plan([("migrations", "0001_squashed_0002")])
self.assertEqual(
plan,
[
(executor.loader.graph.nodes["migrations", "0001_squashed_0002"], False),
],
)
# Were the tables there before?
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
# Alright, let's try running it
executor.migrate([("migrations", "0001_squashed_0002")])
# Are the tables there now?
self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
# Rebuild the graph to reflect the new DB state
executor.loader.build_graph()
# Alright, let's undo what we did. Should also just use squashed.
plan = executor.migration_plan([("migrations", None)])
self.assertEqual(
plan,
[
(executor.loader.graph.nodes["migrations", "0001_squashed_0002"], True),
],
)
executor.migrate([("migrations", None)])
# Are the tables gone?
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
@ -70,6 +121,8 @@ class ExecutorTests(TransactionTestCase):
)
# Fake-apply all migrations
executor.migrate([("migrations", "0002_second"), ("sessions", "0001_initial")], fake=True)
# Rebuild the graph to reflect the new DB state
executor.loader.build_graph()
# Now plan a second time and make sure it's empty
plan = executor.migration_plan([("migrations", "0002_second"), ("sessions", "0001_initial")])
self.assertEqual(plan, [])

View File

@ -82,21 +82,34 @@ class LoaderTests(TestCase):
migration_loader.get_migration_by_prefix("migrations", "blarg")
def test_load_import_error(self):
migration_loader = MigrationLoader(connection)
with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.import_error"}):
with self.assertRaises(ImportError):
migration_loader.load_disk()
MigrationLoader(connection)
def test_load_module_file(self):
migration_loader = MigrationLoader(connection)
with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.file"}):
migration_loader.load_disk()
MigrationLoader(connection)
@skipIf(six.PY2, "PY2 doesn't load empty dirs.")
def test_load_empty_dir(self):
migration_loader = MigrationLoader(connection)
with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.namespace"}):
migration_loader.load_disk()
MigrationLoader(connection)
@override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"})
def test_loading_squashed(self):
"Tests loading a squashed migration"
migration_loader = MigrationLoader(connection)
recorder = MigrationRecorder(connection)
# Loading with nothing applied should just give us the one node
self.assertEqual(
len(migration_loader.graph.nodes),
1,
)
# However, fake-apply one migration and it should now use the old two
recorder.record_applied("migrations", "0001_initial")
migration_loader.build_graph()
self.assertEqual(
len(migration_loader.graph.nodes),
2,
)
recorder.flush()

View File

@ -0,0 +1,27 @@
from django.db import migrations, models
class Migration(migrations.Migration):
operations = [
migrations.CreateModel(
"Author",
[
("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=255)),
("slug", models.SlugField(null=True)),
("age", models.IntegerField(default=0)),
("silly_field", models.BooleanField(default=False)),
],
),
migrations.CreateModel(
"Tribble",
[
("id", models.AutoField(primary_key=True)),
("fluffy", models.BooleanField(default=True)),
],
)
]

View File

@ -0,0 +1,32 @@
from django.db import migrations, models
class Migration(migrations.Migration):
replaces = [
("migrations", "0001_initial"),
("migrations", "0002_second"),
]
operations = [
migrations.CreateModel(
"Author",
[
("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=255)),
("slug", models.SlugField(null=True)),
("age", models.IntegerField(default=0)),
("rating", models.IntegerField(default=0)),
],
),
migrations.CreateModel(
"Book",
[
("id", models.AutoField(primary_key=True)),
("author", models.ForeignKey("migrations.Author", null=True)),
],
),
]

View File

@ -0,0 +1,24 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [("migrations", "0001_initial")]
operations = [
migrations.DeleteModel("Tribble"),
migrations.RemoveField("Author", "silly_field"),
migrations.AddField("Author", "rating", models.IntegerField(default=0)),
migrations.CreateModel(
"Book",
[
("id", models.AutoField(primary_key=True)),
("author", models.ForeignKey("migrations.Author", null=True)),
],
)
]