Add an Executor for end-to-end running
This commit is contained in:
parent
7f9a0b7061
commit
e6f7f4533c
|
@ -0,0 +1,68 @@
|
||||||
|
from .loader import MigrationLoader
|
||||||
|
from .recorder import MigrationRecorder
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationExecutor(object):
|
||||||
|
"""
|
||||||
|
End-to-end migration execution - loads migrations, and runs them
|
||||||
|
up or down to a specified set of targets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, connection):
|
||||||
|
self.connection = connection
|
||||||
|
self.loader = MigrationLoader(self.connection)
|
||||||
|
self.recorder = MigrationRecorder(self.connection)
|
||||||
|
|
||||||
|
def migration_plan(self, targets):
|
||||||
|
"""
|
||||||
|
Given a set of targets, returns a list of (Migration instance, backwards?).
|
||||||
|
"""
|
||||||
|
plan = []
|
||||||
|
applied = self.recorder.applied_migrations()
|
||||||
|
for target in targets:
|
||||||
|
# If the migration is already applied, do backwards mode,
|
||||||
|
# otherwise do forwards mode.
|
||||||
|
if target in applied:
|
||||||
|
for migration in self.loader.graph.backwards_plan(target)[:-1]:
|
||||||
|
if migration in applied:
|
||||||
|
plan.append((self.loader.graph.nodes[migration], True))
|
||||||
|
applied.remove(migration)
|
||||||
|
else:
|
||||||
|
for migration in self.loader.graph.forwards_plan(target):
|
||||||
|
if migration not in applied:
|
||||||
|
plan.append((self.loader.graph.nodes[migration], False))
|
||||||
|
applied.add(migration)
|
||||||
|
return plan
|
||||||
|
|
||||||
|
def migrate(self, targets):
|
||||||
|
"""
|
||||||
|
Migrates the database up to the given targets.
|
||||||
|
"""
|
||||||
|
plan = self.migration_plan(targets)
|
||||||
|
for migration, backwards in plan:
|
||||||
|
if not backwards:
|
||||||
|
self.apply_migration(migration)
|
||||||
|
else:
|
||||||
|
self.unapply_migration(migration)
|
||||||
|
|
||||||
|
def apply_migration(self, migration):
|
||||||
|
"""
|
||||||
|
Runs a migration forwards.
|
||||||
|
"""
|
||||||
|
print "Applying %s" % migration
|
||||||
|
with self.connection.schema_editor() as schema_editor:
|
||||||
|
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
|
||||||
|
migration.apply(project_state, schema_editor)
|
||||||
|
self.recorder.record_applied(migration.app_label, migration.name)
|
||||||
|
print "Finished %s" % migration
|
||||||
|
|
||||||
|
def unapply_migration(self, migration):
|
||||||
|
"""
|
||||||
|
Runs a migration backwards.
|
||||||
|
"""
|
||||||
|
print "Unapplying %s" % migration
|
||||||
|
with self.connection.schema_editor() as schema_editor:
|
||||||
|
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
|
||||||
|
migration.unapply(project_state, schema_editor)
|
||||||
|
self.recorder.record_unapplied(migration.app_label, migration.name)
|
||||||
|
print "Finished %s" % migration
|
|
@ -36,6 +36,17 @@ class Migration(object):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.app_label = app_label
|
self.app_label = app_label
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, Migration):
|
||||||
|
return False
|
||||||
|
return (self.name == other.name) and (self.app_label == other.app_label)
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not (self == other)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<Migration %s.%s>" % (self.app_label, self.name)
|
||||||
|
|
||||||
def mutate_state(self, project_state):
|
def mutate_state(self, project_state):
|
||||||
"""
|
"""
|
||||||
Takes a ProjectState and returns a new one with the migration's
|
Takes a ProjectState and returns a new one with the migration's
|
||||||
|
@ -45,3 +56,40 @@ class Migration(object):
|
||||||
for operation in self.operations:
|
for operation in self.operations:
|
||||||
operation.state_forwards(self.app_label, new_state)
|
operation.state_forwards(self.app_label, new_state)
|
||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
|
def apply(self, project_state, schema_editor):
|
||||||
|
"""
|
||||||
|
Takes a project_state representing all migrations prior to this one
|
||||||
|
and a schema_editor for a live database and applies the migration
|
||||||
|
in a forwards order.
|
||||||
|
|
||||||
|
Returns the resulting project state for efficient re-use by following
|
||||||
|
Migrations.
|
||||||
|
"""
|
||||||
|
for operation in self.operations:
|
||||||
|
# Get the state after the operation has run
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards(self.app_label, new_state)
|
||||||
|
# Run the operation
|
||||||
|
operation.database_forwards(self.app_label, schema_editor, project_state, new_state)
|
||||||
|
# Switch states
|
||||||
|
project_state = new_state
|
||||||
|
return project_state
|
||||||
|
|
||||||
|
def unapply(self, project_state, schema_editor):
|
||||||
|
"""
|
||||||
|
Takes a project_state representing all migrations prior to this one
|
||||||
|
and a schema_editor for a live database and applies the migration
|
||||||
|
in a reverse order.
|
||||||
|
"""
|
||||||
|
# We need to pre-calculate the stack of project states
|
||||||
|
to_run = []
|
||||||
|
for operation in self.operations:
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards(self.app_label, new_state)
|
||||||
|
to_run.append((operation, project_state, new_state))
|
||||||
|
project_state = new_state
|
||||||
|
# Now run them in reverse
|
||||||
|
to_run.reverse()
|
||||||
|
for operation, to_state, from_state in to_run:
|
||||||
|
operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
|
||||||
|
|
|
@ -16,13 +16,13 @@ class AddField(Operation):
|
||||||
|
|
||||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||||
app_cache = to_state.render()
|
app_cache = to_state.render()
|
||||||
model = app_cache.get_model(app_label, self.name)
|
model = app_cache.get_model(app_label, self.model_name)
|
||||||
schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
|
schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])
|
||||||
|
|
||||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||||
app_cache = from_state.render()
|
app_cache = from_state.render()
|
||||||
model = app_cache.get_model(app_label, self.name)
|
model = app_cache.get_model(app_label, self.model_name)
|
||||||
schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
|
schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])
|
||||||
|
|
||||||
|
|
||||||
class RemoveField(Operation):
|
class RemoveField(Operation):
|
||||||
|
@ -43,10 +43,10 @@ class RemoveField(Operation):
|
||||||
|
|
||||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||||
app_cache = from_state.render()
|
app_cache = from_state.render()
|
||||||
model = app_cache.get_model(app_label, self.name)
|
model = app_cache.get_model(app_label, self.model_name)
|
||||||
schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
|
schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])
|
||||||
|
|
||||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||||
app_cache = to_state.render()
|
app_cache = to_state.render()
|
||||||
model = app_cache.get_model(app_label, self.name)
|
model = app_cache.get_model(app_label, self.model_name)
|
||||||
schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
|
schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])
|
||||||
|
|
|
@ -11,7 +11,7 @@ class Migration(migrations.Migration):
|
||||||
|
|
||||||
migrations.RemoveField("Author", "silly_field"),
|
migrations.RemoveField("Author", "silly_field"),
|
||||||
|
|
||||||
migrations.AddField("Author", "important", models.BooleanField()),
|
migrations.AddField("Author", "rating", models.IntegerField(default=0)),
|
||||||
|
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
"Book",
|
"Book",
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
from django.test import TransactionTestCase
|
||||||
|
from django.db import connection
|
||||||
|
from django.db.migrations.executor import MigrationExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutorTests(TransactionTestCase):
|
||||||
|
"""
|
||||||
|
Tests the migration executor (full end-to-end running).
|
||||||
|
|
||||||
|
Bear in mind that if these are failing you should fix the other
|
||||||
|
test failures first, as they may be propagating into here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_run(self):
|
||||||
|
"""
|
||||||
|
Tests running a simple set of migrations.
|
||||||
|
"""
|
||||||
|
executor = MigrationExecutor(connection)
|
||||||
|
# Let's look at the plan first and make sure it's up to scratch
|
||||||
|
plan = executor.migration_plan([("migrations", "0002_second")])
|
||||||
|
self.assertEqual(
|
||||||
|
plan,
|
||||||
|
[
|
||||||
|
(executor.loader.graph.nodes["migrations", "0001_initial"], False),
|
||||||
|
(executor.loader.graph.nodes["migrations", "0002_second"], False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Were the tables there before?
|
||||||
|
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
|
||||||
|
self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
|
||||||
|
# Alright, let's try running it
|
||||||
|
executor.migrate([("migrations", "0002_second")])
|
||||||
|
# Are the tables there now?
|
||||||
|
self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
|
||||||
|
self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
|
|
@ -54,7 +54,7 @@ class LoaderTests(TransactionTestCase):
|
||||||
author_state = project_state.models["migrations", "author"]
|
author_state = project_state.models["migrations", "author"]
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[x for x, y in author_state.fields],
|
[x for x, y in author_state.fields],
|
||||||
["id", "name", "slug", "age", "important"]
|
["id", "name", "slug", "age", "rating"]
|
||||||
)
|
)
|
||||||
|
|
||||||
book_state = project_state.models["migrations", "book"]
|
book_state = project_state.models["migrations", "book"]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from django.test import TransactionTestCase
|
from django.test import TransactionTestCase
|
||||||
from django.db import connection, models, migrations
|
from django.db import connection, models, migrations
|
||||||
from django.db.migrations.state import ProjectState, ModelState
|
from django.db.migrations.state import ProjectState
|
||||||
|
|
||||||
|
|
||||||
class OperationTests(TransactionTestCase):
|
class OperationTests(TransactionTestCase):
|
||||||
|
@ -16,6 +16,12 @@ class OperationTests(TransactionTestCase):
|
||||||
def assertTableNotExists(self, table):
|
def assertTableNotExists(self, table):
|
||||||
self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor()))
|
self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor()))
|
||||||
|
|
||||||
|
def assertColumnExists(self, table, column):
|
||||||
|
self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
|
||||||
|
|
||||||
|
def assertColumnNotExists(self, table, column):
|
||||||
|
self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
|
||||||
|
|
||||||
def set_up_test_model(self, app_label):
|
def set_up_test_model(self, app_label):
|
||||||
"""
|
"""
|
||||||
Creates a test model state and database table.
|
Creates a test model state and database table.
|
||||||
|
@ -82,3 +88,23 @@ class OperationTests(TransactionTestCase):
|
||||||
with connection.schema_editor() as editor:
|
with connection.schema_editor() as editor:
|
||||||
operation.database_backwards("test_dlmo", editor, new_state, project_state)
|
operation.database_backwards("test_dlmo", editor, new_state, project_state)
|
||||||
self.assertTableExists("test_dlmo_pony")
|
self.assertTableExists("test_dlmo_pony")
|
||||||
|
|
||||||
|
def test_add_field(self):
|
||||||
|
"""
|
||||||
|
Tests the AddField operation.
|
||||||
|
"""
|
||||||
|
project_state = self.set_up_test_model("test_adfl")
|
||||||
|
# Test the state alteration
|
||||||
|
operation = migrations.AddField("Pony", "height", models.FloatField(null=True))
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards("test_adfl", new_state)
|
||||||
|
self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 3)
|
||||||
|
# Test the database alteration
|
||||||
|
self.assertColumnNotExists("test_adfl_pony", "height")
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards("test_adfl", editor, project_state, new_state)
|
||||||
|
self.assertColumnExists("test_adfl_pony", "height")
|
||||||
|
# And test reversal
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_backwards("test_adfl", editor, new_state, project_state)
|
||||||
|
self.assertColumnNotExists("test_adfl_pony", "height")
|
||||||
|
|
Loading…
Reference in New Issue