A bit of an autodetector and a bit of a writer

This commit is contained in:
Andrew Godwin 2013-06-07 15:28:38 +01:00
parent 3c296382b8
commit 4492f06408
5 changed files with 280 additions and 0 deletions

View File

@ -0,0 +1,69 @@
from django.db.migrations import operations
from django.db.migrations.migration import Migration
class AutoDetector(object):
"""
Takes a pair of ProjectStates, and compares them to see what the
first would need doing to make it match the second (the second
usually being the project's current state).
Note that this naturally operates on entire projects at a time,
as it's likely that changes interact (for example, you can't
add a ForeignKey without having a migration to add the table it
depends on first). A user interface may offer single-app detection
if it wishes, with the caveat that it may not always be possible.
"""
def __init__(self, from_state, to_state):
self.from_state = from_state
self.to_state = to_state
def changes(self):
"""
Returns a set of migration plans which will achieve the
change from from_state to to_state.
"""
# We'll store migrations as lists by app names for now
self.migrations = {}
# Stage one: Adding models.
added_models = set(self.to_state.keys()) - set(self.from_state.keys())
for app_label, model_name in added_models:
model_state = self.to_state[app_label, model_name]
self.add_to_migration(
app_label,
operations.CreateModel(
model_state.name,
model_state.fields,
model_state.options,
model_state.bases,
)
)
# Removing models
removed_models = set(self.from_state.keys()) - set(self.to_state.keys())
for app_label, model_name in removed_models:
model_state = self.from_state[app_label, model_name]
self.add_to_migration(
app_label,
operations.DeleteModel(
model_state.name,
)
)
# Alright, now sort out and return the migrations
for app_label, migrations in self.migrations.items():
for m1, m2 in zip(migrations, migrations[1:]):
m2.dependencies.append((app_label, m1.name))
# Flatten and return
result = set()
for app_label, migrations in self.migrations.items():
for migration in migrations:
subclass = type("Migration", (Migration,), migration)
instance = subclass(migration['name'], app_label)
result.append(instance)
return result
def add_to_migration(self, app_label, operation):
migrations = self.migrations.setdefault(app_label, [])
if not migrations:
migrations.append({"name": "temp-%i" % len(migrations) + 1, "operations": [], "dependencies": []})
migrations[-1].operations.append(operation)

View File

@ -15,6 +15,24 @@ class Operation(object):
# Some operations are impossible to reverse, like deleting data.
reversible = True
def __new__(cls, *args, **kwargs):
# We capture the arguments to make returning them trivial
self = object.__new__(cls)
self._constructor_args = (args, kwargs)
return self
def deconstruct(self):
"""
Returns a 3-tuple of class import path (or just name if it lives
under django.db.migrations), positional arguments, and keyword
arguments.
"""
return (
self.__class__.__name__,
self._constructor_args[0],
self._constructor_args[1],
)
def state_forwards(self, app_label, state):
"""
Takes the state from the previous migration, and mutates it

View File

@ -0,0 +1,123 @@
import datetime
import types
class MigrationWriter(object):
"""
Takes a Migration instance and is able to produce the contents
of the migration file from it.
"""
def __init__(self, migration):
self.migration = migration
def as_string(self):
"""
Returns a string of the file contents.
"""
items = {
"dependencies": repr(self.migration.dependencies),
}
imports = set()
# Deconstruct operations
operation_strings = []
for operation in self.migration.operations:
name, args, kwargs = operation.deconstruct()
arg_strings = []
for arg in args:
arg_string, arg_imports = self.serialize(arg)
arg_strings.append(arg_string)
imports.update(arg_imports)
for kw, arg in kwargs.items():
arg_string, arg_imports = self.serialize(arg)
imports.update(arg_imports)
arg_strings.append("%s = %s" % (kw, arg_string))
operation_strings.append("migrations.%s(%s\n )" % (name, "".join("\n %s," % arg for arg in arg_strings)))
items["operations"] = "[%s\n ]" % "".join("\n %s," % s for s in operation_strings)
# Format imports nicely
if not imports:
items["imports"] = ""
else:
items["imports"] = "\n".join(imports) + "\n"
return MIGRATION_TEMPLATE % items
@property
def filename(self):
return "%s.py" % self.migration.name
@classmethod
def serialize(cls, value):
"""
Serializes the value to a string that's parsable by Python, along
with any needed imports to make that string work.
More advanced than repr() as it can encode things
like datetime.datetime.now.
"""
# Sequences
if isinstance(value, (list, set, tuple)):
imports = set()
strings = []
for item in value:
item_string, item_imports = cls.serialize(item)
imports.update(item_imports)
strings.append(item_string)
if isinstance(value, set):
format = "set([%s])"
elif isinstance(value, tuple):
format = "(%s,)"
else:
format = "[%s]"
return format % (", ".join(strings)), imports
# Dictionaries
elif isinstance(value, dict):
imports = set()
strings = []
for k, v in value.items():
k_string, k_imports = cls.serialize(k)
v_string, v_imports = cls.serialize(v)
imports.update(k_imports)
imports.update(v_imports)
strings.append((k_string, v_string))
return "{%s}" % (", ".join(["%s: %s" % (k, v) for k, v in strings])), imports
# Datetimes
elif isinstance(value, (datetime.datetime, datetime.date)):
return repr(value), set(["import datetime"])
# Simple types
elif isinstance(value, (int, long, float, str, unicode, bool, types.NoneType)):
return repr(value), set()
# Functions
elif isinstance(value, (types.FunctionType, types.BuiltinFunctionType)):
# Special-cases, as these don't have im_class
special_cases = [
(datetime.datetime.now, "datetime.datetime.now", ["import datetime"]),
(datetime.datetime.utcnow, "datetime.datetime.utcnow", ["import datetime"]),
(datetime.date.today, "datetime.date.today", ["import datetime"]),
]
for func, string, imports in special_cases:
if func == value: # For some reason "utcnow is not utcnow"
return string, set(imports)
# Method?
if hasattr(value, "im_class"):
klass = value.im_class
module = klass.__module__
return "%s.%s.%s" % (module, klass.__name__, value.__name__), set(["import %s" % module])
else:
module = value.__module__
if module is None:
raise ValueError("Cannot serialize function %r: No module" % value)
return "%s.%s" % (module, value.__name__), set(["import %s" % module])
# Uh oh.
else:
raise ValueError("Cannot serialize: %r" % value)
MIGRATION_TEMPLATE = """# encoding: utf8
from django.db import models, migrations
%(imports)s
class Migration(migrations.Migration):
dependencies = %(dependencies)s
operations = %(operations)s
"""

View File

@ -68,6 +68,12 @@ class OperationTests(TransactionTestCase):
with connection.schema_editor() as editor:
operation.database_backwards("test_crmo", editor, new_state, project_state)
self.assertTableNotExists("test_crmo_pony")
# And deconstruction
definition = operation.deconstruct()
self.assertEqual(definition[0], "CreateModel")
self.assertEqual(len(definition[1]), 2)
self.assertEqual(len(definition[2]), 0)
self.assertEqual(definition[1][0], "Pony")
def test_delete_model(self):
"""

View File

@ -0,0 +1,64 @@
# encoding: utf8
import datetime
from django.test import TransactionTestCase
from django.db.migrations.writer import MigrationWriter
from django.db import migrations
class WriterTests(TransactionTestCase):
"""
Tests the migration writer (makes migration files from Migration instances)
"""
def safe_exec(self, value, string):
l = {}
try:
exec(string, {}, l)
except:
self.fail("Could not serialize %r: failed to exec %r" % (value, string.strip()))
return l
def assertSerializedEqual(self, value):
string, imports = MigrationWriter.serialize(value)
new_value = self.safe_exec(value, "%s\ntest_value_result = %s" % ("\n".join(imports), string))['test_value_result']
self.assertEqual(new_value, value)
def assertSerializedIs(self, value):
string, imports = MigrationWriter.serialize(value)
new_value = self.safe_exec(value, "%s\ntest_value_result = %s" % ("\n".join(imports), string))['test_value_result']
self.assertIs(new_value, value)
def test_serialize(self):
"""
Tests various different forms of the serializer.
This does not care about formatting, just that the parsed result is
correct, so we always exec() the result and check that.
"""
# Basic values
self.assertSerializedEqual(1)
self.assertSerializedEqual(None)
self.assertSerializedEqual("foobar")
self.assertSerializedEqual(u"föobár")
self.assertSerializedEqual({1: 2})
self.assertSerializedEqual(["a", 2, True, None])
self.assertSerializedEqual(set([2, 3, "eighty"]))
self.assertSerializedEqual({"lalalala": ["yeah", "no", "maybe"]})
# Datetime stuff
self.assertSerializedEqual(datetime.datetime.utcnow())
self.assertSerializedEqual(datetime.datetime.utcnow)
self.assertSerializedEqual(datetime.date.today())
self.assertSerializedEqual(datetime.date.today)
def test_simple_migration(self):
"""
Tests serializing a simple migration.
"""
migration = type("Migration", (migrations.Migration,), {
"operations": [
migrations.DeleteModel("MyModel"),
],
"dependencies": [("testapp", "some_other_one")],
})
writer = MigrationWriter(migration)
output = writer.as_string()
print output