First phase of loading migrations from disk

This commit is contained in:
Andrew Godwin 2013-05-10 16:00:55 +01:00
parent cb4b0de49e
commit 9ce8354672
9 changed files with 265 additions and 9 deletions

View File

@ -0,0 +1 @@
from .migration import Migration

View File

@ -1,7 +1,7 @@
from django.utils.datastructures import SortedSet
class MigrationsGraph(object):
class MigrationGraph(object):
"""
Represents the digraph of all migrations in a project.
@ -19,7 +19,7 @@ class MigrationsGraph(object):
replacing migration, and repoint any dependencies that pointed to the
replaced migrations to point to the replacing one.
A node should be a tuple: (applabel, migration_name) - but the code
A node should be a tuple: (app_path, migration_name) - but the code
here doesn't really care.
"""
@ -70,7 +70,7 @@ class MigrationsGraph(object):
return cache[(start, get_children)]
# If we've traversed here before, that's a circular dep
if start in path:
raise CircularDependencyException(path[path.index(start):] + [start])
raise CircularDependencyError(path[path.index(start):] + [start])
# Build our own results list, starting with us
results = []
results.append(start)
@ -88,8 +88,11 @@ class MigrationsGraph(object):
return results
return _dfs(start, get_children, [])
def __str__(self):
return "Graph: %s nodes, %s edges" % (len(self.nodes), sum(len(x) for x in self.dependencies.values()))
class CircularDependencyException(Exception):
class CircularDependencyError(Exception):
"""
Raised when there's an impossible-to-resolve circular dependency.
"""

View File

@ -0,0 +1,128 @@
import os
from django.utils.importlib import import_module
from django.db.models.loading import cache
from django.db.migrations.recorder import MigrationRecorder
from django.db.migrations.graph import MigrationGraph
class MigrationLoader(object):
"""
Loads migration files from disk, and their status from the database.
Migration files are expected to live in the "migrations" directory of
an app. Their names are entirely unimportant from a code perspective,
but will probably follow the 1234_name.py convention.
On initialisation, this class will scan those directories, and open and
read the python files, looking for a class called Migration, which should
inherit from django.db.migrations.Migration. See
django.db.migrations.migration for what that looks like.
Some migrations will be marked as "replacing" another set of migrations.
These are loaded into a separate set of migrations away from the main ones.
If all the migrations they replace are either unapplied or missing from
disk, then they are injected into the main set, replacing the named migrations.
Any dependency pointers to the replaced migrations are re-pointed to the
new migration.
This does mean that this class MUST also talk to the database as well as
to disk, but this is probably fine. We're already not just operating
in memory.
"""
def __init__(self, connection):
self.connection = connection
self.disk_migrations = None
self.applied_migrations = None
def load_disk(self):
"""
Loads the migrations from all INSTALLED_APPS from disk.
"""
self.disk_migrations = {}
for app in cache.get_apps():
# Get the migrations module directory
module_name = ".".join(app.__name__.split(".")[:-1] + ["migrations"])
app_label = module_name.split(".")[-2]
try:
module = import_module(module_name)
except ImportError as e:
# I hate doing this, but I don't want to squash other import errors.
# Might be better to try a directory check directly.
if "No module named migrations" in str(e):
continue
directory = os.path.dirname(module.__file__)
# Scan for .py[c|o] files
migration_names = set()
for name in os.listdir(directory):
if name.endswith(".py") or name.endswith(".pyc") or name.endswith(".pyo"):
import_name = name.rsplit(".", 1)[0]
if import_name[0] not in "_.~":
migration_names.add(import_name)
# Load them
for migration_name in migration_names:
migration_module = import_module("%s.%s" % (module_name, migration_name))
if not hasattr(migration_module, "Migration"):
raise BadMigrationError("Migration %s in app %s has no Migration class" % (migration_name, app_label))
self.disk_migrations[app_label, migration_name] = migration_module.Migration
def build_graph(self):
"""
Builds a migration dependency graph using both the disk and database.
"""
# Make sure we have the disk data
if self.disk_migrations is None:
self.load_disk()
# And the database data
if self.applied_migrations is None:
recorder = MigrationRecorder(self.connection)
self.applied_migrations = recorder.applied_migrations()
# Do a first pass to separate out replacing and non-replacing migrations
normal = {}
replacing = {}
for key, migration in self.disk_migrations.items():
if migration.replaces:
replacing[key] = migration
else:
normal[key] = migration
# Calculate reverse dependencies - i.e., for each migration, what depends on it?
# This is just for dependency re-pointing when applying replacements,
# so we ignore run_before here.
reverse_dependencies = {}
for key, migration in normal.items():
for parent in migration.dependencies:
reverse_dependencies.setdefault(parent, set()).add(key)
# Carry out replacements if we can - that is, if all replaced migrations
# are either unapplied or missing.
for key, migration in replacing.items():
# Do the check
can_replace = True
for target in migration.replaces:
if target in self.applied_migrations:
can_replace = False
break
if not can_replace:
continue
# Alright, time to replace. Step through the replaced migrations
# and remove, repointing dependencies if needs be.
for replaced in migration.replaces:
if replaced in normal:
del normal[replaced]
for child_key in reverse_dependencies.get(replaced, set()):
normal[child_key].dependencies.remove(replaced)
normal[child_key].dependencies.append(key)
normal[key] = migration
# Finally, make a graph and load everything into it
graph = MigrationGraph()
for key, migration in normal.items():
graph.add_node(key, migration)
for parent in migration.dependencies:
graph.add_dependency(key, parent)
return graph
class BadMigrationError(Exception):
"""
Raised when there's a bad migration (unreadable/bad format/etc.)
"""
pass

View File

@ -0,0 +1,30 @@
class Migration(object):
"""
The base class for all migrations.
Migration files will import this from django.db.migrations.Migration
and subclass it as a class called Migration. It will have one or more
of the following attributes:
- operations: A list of Operation instances, probably from django.db.migrations.operations
- dependencies: A list of tuples of (app_path, migration_name)
- run_before: A list of tuples of (app_path, migration_name)
- replaces: A list of migration_names
"""
# Operations to apply during this migration, in order.
operations = []
# Other migrations that should be run before this migration.
# Should be a list of (app, migration_name).
dependencies = []
# Other migrations that should be run after this one (i.e. have
# this migration added to their dependencies). Useful to make third-party
# apps' migrations run after your AUTH_USER replacement, for example.
run_before = []
# Migration names in this app that this migration replaces. If this is
# non-empty, this migration will only be applied if all these migrations
# are not applied.
replaces = []

View File

@ -0,0 +1,64 @@
import datetime
from django.db import models
from django.db.models.loading import BaseAppCache
class MigrationRecorder(object):
"""
Deals with storing migration records in the database.
Because this table is actually itself used for dealing with model
creation, it's the one thing we can't do normally via syncdb or migrations.
We manually handle table creation/schema updating (using schema backend)
and then have a floating model to do queries with.
If a migration is unapplied its row is removed from the table. Having
a row in the table always means a migration is applied.
"""
class Migration(models.Model):
app = models.CharField(max_length=255)
name = models.CharField(max_length=255)
applied = models.DateTimeField(default=datetime.datetime.utcnow)
class Meta:
app_cache = BaseAppCache()
app_label = "migrations"
db_table = "django_migrations"
def __init__(self, connection):
self.connection = connection
def ensure_schema(self):
"""
Ensures the table exists and has the correct schema.
"""
# If the table's there, that's fine - we've never changed its schema
# in the codebase.
if self.Migration._meta.db_table in self.connection.introspection.get_table_list(self.connection.cursor()):
return
# Make the table
editor = self.connection.schema_editor()
editor.start()
editor.create_model(self.Migration)
editor.commit()
def applied_migrations(self):
"""
Returns a set of (app, name) of applied migrations.
"""
self.ensure_schema()
return set(tuple(x) for x in self.Migration.objects.values_list("app", "name"))
def record_applied(self, app, name):
"""
Records that a migration was applied.
"""
self.ensure_schema()
self.Migration.objects.create(app=app, name=name)
def record_unapplied(self, app, name):
"""
Records that a migration was unapplied.
"""
self.ensure_schema()
self.Migration.objects.filter(app=app, name=name).delete()

View File

@ -0,0 +1,5 @@
from django.db import migrations
class Migration(migrations.Migration):
pass

View File

@ -0,0 +1,6 @@
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [("migrations", "0001_initial")]

View File

View File

@ -1,5 +1,7 @@
from django.test import TransactionTestCase
from django.db.migrations.graph import MigrationsGraph, CircularDependencyException
from django.db import connection
from django.db.migrations.graph import MigrationGraph, CircularDependencyError
from django.db.migrations.loader import MigrationLoader
class GraphTests(TransactionTestCase):
@ -16,7 +18,7 @@ class GraphTests(TransactionTestCase):
app_b: 0001 <-- 0002 <-/
"""
# Build graph
graph = MigrationsGraph()
graph = MigrationGraph()
graph.add_dependency(("app_a", "0004"), ("app_a", "0003"))
graph.add_dependency(("app_a", "0003"), ("app_a", "0002"))
graph.add_dependency(("app_a", "0002"), ("app_a", "0001"))
@ -54,7 +56,7 @@ class GraphTests(TransactionTestCase):
app_c: \ 0001 <-- 0002 <-
"""
# Build graph
graph = MigrationsGraph()
graph = MigrationGraph()
graph.add_dependency(("app_a", "0004"), ("app_a", "0003"))
graph.add_dependency(("app_a", "0003"), ("app_a", "0002"))
graph.add_dependency(("app_a", "0002"), ("app_a", "0001"))
@ -85,7 +87,7 @@ class GraphTests(TransactionTestCase):
Tests a circular dependency graph.
"""
# Build graph
graph = MigrationsGraph()
graph = MigrationGraph()
graph.add_dependency(("app_a", "0003"), ("app_a", "0002"))
graph.add_dependency(("app_a", "0002"), ("app_a", "0001"))
graph.add_dependency(("app_a", "0001"), ("app_b", "0002"))
@ -93,6 +95,23 @@ class GraphTests(TransactionTestCase):
graph.add_dependency(("app_b", "0001"), ("app_a", "0003"))
# Test whole graph
self.assertRaises(
CircularDependencyException,
CircularDependencyError,
graph.forwards_plan, ("app_a", "0003"),
)
class LoaderTests(TransactionTestCase):
"""
Tests the disk and database loader.
"""
def test_load(self):
"""
Makes sure the loader can load the migrations for the test apps.
"""
migration_loader = MigrationLoader(connection)
graph = migration_loader.build_graph()
self.assertEqual(
graph.forwards_plan(("migrations", "0002_second")),
[("migrations", "0001_initial"), ("migrations", "0002_second")],
)