Add an Executor for end-to-end running

This commit is contained in:
Andrew Godwin 2013-05-30 18:08:58 +01:00
parent 7f9a0b7061
commit e6f7f4533c
7 changed files with 188 additions and 11 deletions

View File

@ -0,0 +1,68 @@
from .loader import MigrationLoader
from .recorder import MigrationRecorder
class MigrationExecutor(object):
"""
End-to-end migration execution - loads migrations, and runs them
up or down to a specified set of targets.
"""
def __init__(self, connection):
self.connection = connection
self.loader = MigrationLoader(self.connection)
self.recorder = MigrationRecorder(self.connection)
def migration_plan(self, targets):
"""
Given a set of targets, returns a list of (Migration instance, backwards?).
"""
plan = []
applied = self.recorder.applied_migrations()
for target in targets:
# If the migration is already applied, do backwards mode,
# otherwise do forwards mode.
if target in applied:
for migration in self.loader.graph.backwards_plan(target)[:-1]:
if migration in applied:
plan.append((self.loader.graph.nodes[migration], True))
applied.remove(migration)
else:
for migration in self.loader.graph.forwards_plan(target):
if migration not in applied:
plan.append((self.loader.graph.nodes[migration], False))
applied.add(migration)
return plan
def migrate(self, targets):
"""
Migrates the database up to the given targets.
"""
plan = self.migration_plan(targets)
for migration, backwards in plan:
if not backwards:
self.apply_migration(migration)
else:
self.unapply_migration(migration)
def apply_migration(self, migration):
"""
Runs a migration forwards.
"""
print "Applying %s" % migration
with self.connection.schema_editor() as schema_editor:
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
migration.apply(project_state, schema_editor)
self.recorder.record_applied(migration.app_label, migration.name)
print "Finished %s" % migration
def unapply_migration(self, migration):
"""
Runs a migration backwards.
"""
print "Unapplying %s" % migration
with self.connection.schema_editor() as schema_editor:
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
migration.unapply(project_state, schema_editor)
self.recorder.record_unapplied(migration.app_label, migration.name)
print "Finished %s" % migration

View File

@ -36,6 +36,17 @@ class Migration(object):
self.name = name self.name = name
self.app_label = app_label self.app_label = app_label
def __eq__(self, other):
if not isinstance(other, Migration):
return False
return (self.name == other.name) and (self.app_label == other.app_label)
def __ne__(self, other):
return not (self == other)
def __repr__(self):
return "<Migration %s.%s>" % (self.app_label, self.name)
def mutate_state(self, project_state): def mutate_state(self, project_state):
""" """
Takes a ProjectState and returns a new one with the migration's Takes a ProjectState and returns a new one with the migration's
@ -45,3 +56,40 @@ class Migration(object):
for operation in self.operations: for operation in self.operations:
operation.state_forwards(self.app_label, new_state) operation.state_forwards(self.app_label, new_state)
return new_state return new_state
def apply(self, project_state, schema_editor):
"""
Takes a project_state representing all migrations prior to this one
and a schema_editor for a live database and applies the migration
in a forwards order.
Returns the resulting project state for efficient re-use by following
Migrations.
"""
for operation in self.operations:
# Get the state after the operation has run
new_state = project_state.clone()
operation.state_forwards(self.app_label, new_state)
# Run the operation
operation.database_forwards(self.app_label, schema_editor, project_state, new_state)
# Switch states
project_state = new_state
return project_state
def unapply(self, project_state, schema_editor):
"""
Takes a project_state representing all migrations prior to this one
and a schema_editor for a live database and applies the migration
in a reverse order.
"""
# We need to pre-calculate the stack of project states
to_run = []
for operation in self.operations:
new_state = project_state.clone()
operation.state_forwards(self.app_label, new_state)
to_run.append((operation, project_state, new_state))
project_state = new_state
# Now run them in reverse
to_run.reverse()
for operation, to_state, from_state in to_run:
operation.database_backwards(self.app_label, schema_editor, from_state, to_state)

View File

@ -16,13 +16,13 @@ class AddField(Operation):
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
app_cache = to_state.render() app_cache = to_state.render()
model = app_cache.get_model(app_label, self.name) model = app_cache.get_model(app_label, self.model_name)
schema_editor.add_field(model, model._meta.get_field_by_name(self.name)) schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])
def database_backwards(self, app_label, schema_editor, from_state, to_state): def database_backwards(self, app_label, schema_editor, from_state, to_state):
app_cache = from_state.render() app_cache = from_state.render()
model = app_cache.get_model(app_label, self.name) model = app_cache.get_model(app_label, self.model_name)
schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)) schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])
class RemoveField(Operation): class RemoveField(Operation):
@ -43,10 +43,10 @@ class RemoveField(Operation):
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
app_cache = from_state.render() app_cache = from_state.render()
model = app_cache.get_model(app_label, self.name) model = app_cache.get_model(app_label, self.model_name)
schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)) schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])
def database_backwards(self, app_label, schema_editor, from_state, to_state): def database_backwards(self, app_label, schema_editor, from_state, to_state):
app_cache = to_state.render() app_cache = to_state.render()
model = app_cache.get_model(app_label, self.name) model = app_cache.get_model(app_label, self.model_name)
schema_editor.add_field(model, model._meta.get_field_by_name(self.name)) schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])

View File

@ -11,7 +11,7 @@ class Migration(migrations.Migration):
migrations.RemoveField("Author", "silly_field"), migrations.RemoveField("Author", "silly_field"),
migrations.AddField("Author", "important", models.BooleanField()), migrations.AddField("Author", "rating", models.IntegerField(default=0)),
migrations.CreateModel( migrations.CreateModel(
"Book", "Book",

View File

@ -0,0 +1,35 @@
from django.test import TransactionTestCase
from django.db import connection
from django.db.migrations.executor import MigrationExecutor
class ExecutorTests(TransactionTestCase):
"""
Tests the migration executor (full end-to-end running).
Bear in mind that if these are failing you should fix the other
test failures first, as they may be propagating into here.
"""
def test_run(self):
"""
Tests running a simple set of migrations.
"""
executor = MigrationExecutor(connection)
# Let's look at the plan first and make sure it's up to scratch
plan = executor.migration_plan([("migrations", "0002_second")])
self.assertEqual(
plan,
[
(executor.loader.graph.nodes["migrations", "0001_initial"], False),
(executor.loader.graph.nodes["migrations", "0002_second"], False),
],
)
# Were the tables there before?
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
# Alright, let's try running it
executor.migrate([("migrations", "0002_second")])
# Are the tables there now?
self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))

View File

@ -54,7 +54,7 @@ class LoaderTests(TransactionTestCase):
author_state = project_state.models["migrations", "author"] author_state = project_state.models["migrations", "author"]
self.assertEqual( self.assertEqual(
[x for x, y in author_state.fields], [x for x, y in author_state.fields],
["id", "name", "slug", "age", "important"] ["id", "name", "slug", "age", "rating"]
) )
book_state = project_state.models["migrations", "book"] book_state = project_state.models["migrations", "book"]

View File

@ -1,6 +1,6 @@
from django.test import TransactionTestCase from django.test import TransactionTestCase
from django.db import connection, models, migrations from django.db import connection, models, migrations
from django.db.migrations.state import ProjectState, ModelState from django.db.migrations.state import ProjectState
class OperationTests(TransactionTestCase): class OperationTests(TransactionTestCase):
@ -16,6 +16,12 @@ class OperationTests(TransactionTestCase):
def assertTableNotExists(self, table): def assertTableNotExists(self, table):
self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor())) self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor()))
def assertColumnExists(self, table, column):
self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
def assertColumnNotExists(self, table, column):
self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
def set_up_test_model(self, app_label): def set_up_test_model(self, app_label):
""" """
Creates a test model state and database table. Creates a test model state and database table.
@ -82,3 +88,23 @@ class OperationTests(TransactionTestCase):
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
operation.database_backwards("test_dlmo", editor, new_state, project_state) operation.database_backwards("test_dlmo", editor, new_state, project_state)
self.assertTableExists("test_dlmo_pony") self.assertTableExists("test_dlmo_pony")
def test_add_field(self):
"""
Tests the AddField operation.
"""
project_state = self.set_up_test_model("test_adfl")
# Test the state alteration
operation = migrations.AddField("Pony", "height", models.FloatField(null=True))
new_state = project_state.clone()
operation.state_forwards("test_adfl", new_state)
self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 3)
# Test the database alteration
self.assertColumnNotExists("test_adfl_pony", "height")
with connection.schema_editor() as editor:
operation.database_forwards("test_adfl", editor, project_state, new_state)
self.assertColumnExists("test_adfl_pony", "height")
# And test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_adfl", editor, new_state, project_state)
self.assertColumnNotExists("test_adfl_pony", "height")