From 91c470def50c4de420b0c6ee7debddc5bbd53ec8 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Fri, 7 Jun 2013 17:56:43 +0100 Subject: [PATCH] Auto-naming for migrations and some writer fixes --- django/db/migrations/autodetector.py | 91 +++++++++++++++++++++++---- django/db/migrations/graph.py | 16 +++-- django/db/migrations/writer.py | 12 ++-- tests/migrations/test_autodetector.py | 33 ++++++++-- tests/migrations/test_writer.py | 41 ++++++++---- 5 files changed, 156 insertions(+), 37 deletions(-) diff --git a/django/db/migrations/autodetector.py b/django/db/migrations/autodetector.py index ddb14520d3..be3e1c561f 100644 --- a/django/db/migrations/autodetector.py +++ b/django/db/migrations/autodetector.py @@ -1,3 +1,4 @@ +import re from django.db.migrations import operations from django.db.migrations.migration import Migration @@ -11,7 +12,7 @@ class MigrationAutodetector(object): Note that this naturally operates on entire projects at a time, as it's likely that changes interact (for example, you can't add a ForeignKey without having a migration to add the table it - depends on first). A user interface may offer single-app detection + depends on first). A user interface may offer single-app usage if it wishes, with the caveat that it may not always be possible. """ @@ -21,8 +22,12 @@ class MigrationAutodetector(object): def changes(self): """ - Returns a set of migration plans which will achieve the - change from from_state to to_state. + Returns a dict of migration plans which will achieve the + change from from_state to to_state. The dict has app labels + as kays and a list of migrations as values. + + The resulting migrations aren't specially named, but the names + do matter for dependencies inside the set. """ # We'll store migrations as lists by app names for now self.migrations = {} @@ -53,17 +58,77 @@ class MigrationAutodetector(object): for app_label, migrations in self.migrations.items(): for m1, m2 in zip(migrations, migrations[1:]): m2.dependencies.append((app_label, m1.name)) - # Flatten and return - result = set() - for app_label, migrations in self.migrations.items(): - for migration in migrations: - subclass = type("Migration", (Migration,), migration) - instance = subclass(migration['name'], app_label) - result.add(instance) - return result + return self.migrations def add_to_migration(self, app_label, operation): migrations = self.migrations.setdefault(app_label, []) if not migrations: - migrations.append({"name": "auto_%i" % (len(migrations) + 1), "operations": [], "dependencies": []}) - migrations[-1]['operations'].append(operation) + subclass = type("Migration", (Migration,), {"operations": [], "dependencies": []}) + instance = subclass("auto_%i" % (len(migrations) + 1), app_label) + migrations.append(instance) + migrations[-1].operations.append(operation) + + @classmethod + def suggest_name(cls, ops): + """ + Given a set of operations, suggests a name for the migration + they might represent. Names not guaranteed to be unique; they + must be prefixed by a number or date. + """ + if len(ops) == 1: + if isinstance(ops[0], operations.CreateModel): + return ops[0].name.lower() + elif isinstance(ops[0], operations.DeleteModel): + return "delete_%s" % ops[0].name.lower() + elif all(isinstance(o, operations.CreateModel) for o in ops): + return "_".join(sorted(o.name.lower() for o in ops)) + return "auto" + + @classmethod + def parse_number(cls, name): + """ + Given a migration name, tries to extract a number from the + beginning of it. If no number found, returns None. + """ + if re.match(r"^\d+_", name): + return int(name.split("_")[0]) + return None + + @classmethod + def arrange_for_graph(cls, changes, graph): + """ + Takes in a result from changes() and a MigrationGraph, + and fixes the names and dependencies of the changes so they + extend the graph from the leaf nodes for each app. + """ + leaves = graph.leaf_nodes() + name_map = {} + for app_label, migrations in changes.items(): + if not migrations: + continue + # Find the app label's current leaf node + app_leaf = None + for leaf in leaves: + if leaf[0] == app_label: + app_leaf = leaf + break + # Work out the next number in the sequence + if app_leaf is None: + next_number = 1 + else: + next_number = (cls.parse_number(app_leaf[1]) or 0) + 1 + # Name each migration + for i, migration in enumerate(migrations): + if i == 0 and app_leaf: + migration.dependencies.append(app_leaf) + if i == 0 and not app_leaf: + new_name = "0001_initial" + else: + new_name = "%04i_%s" % (next_number, cls.suggest_name(migration.operations)) + name_map[(app_label, migration.name)] = (app_label, new_name) + migration.name = new_name + # Now fix dependencies + for app_label, migrations in changes.items(): + for migration in migrations: + migration.dependencies = [name_map.get(d, d) for d in migration.dependencies] + return changes diff --git a/django/db/migrations/graph.py b/django/db/migrations/graph.py index 620534bc22..c1c3ba75bb 100644 --- a/django/db/migrations/graph.py +++ b/django/db/migrations/graph.py @@ -120,14 +120,20 @@ class MigrationGraph(object): def __str__(self): return "Graph: %s nodes, %s edges" % (len(self.nodes), sum(len(x) for x in self.dependencies.values())) - def project_state(self, node, at_end=True): + def project_state(self, nodes, at_end=True): """ - Given a migration node, returns a complete ProjectState for it. + Given a migration node or nodes, returns a complete ProjectState for it. If at_end is False, returns the state before the migration has run. """ - plan = self.forwards_plan(node) - if not at_end: - plan = plan[:-1] + if not isinstance(nodes[0], tuple): + nodes = [nodes] + plan = [] + for node in nodes: + for migration in self.forwards_plan(node): + if migration not in plan: + if not at_end and migration in nodes: + continue + plan.append(migration) project_state = ProjectState() for node in plan: project_state = self.nodes[node].mutate_state(project_state) diff --git a/django/db/migrations/writer.py b/django/db/migrations/writer.py index b21c6c9648..f386cd847c 100644 --- a/django/db/migrations/writer.py +++ b/django/db/migrations/writer.py @@ -1,5 +1,7 @@ +from __future__ import unicode_literals import datetime import types +from django.utils import six from django.db import models @@ -36,11 +38,12 @@ class MigrationWriter(object): operation_strings.append("migrations.%s(%s\n )" % (name, "".join("\n %s," % arg for arg in arg_strings))) items["operations"] = "[%s\n ]" % "".join("\n %s," % s for s in operation_strings) # Format imports nicely + imports.discard("from django.db import models") if not imports: items["imports"] = "" else: items["imports"] = "\n".join(imports) + "\n" - return MIGRATION_TEMPLATE % items + return (MIGRATION_TEMPLATE % items).encode("utf8") @property def filename(self): @@ -84,16 +87,17 @@ class MigrationWriter(object): elif isinstance(value, (datetime.datetime, datetime.date)): return repr(value), set(["import datetime"]) # Simple types - elif isinstance(value, (int, long, float, str, unicode, bool, types.NoneType)): + elif isinstance(value, (int, long, float, six.binary_type, six.text_type, bool, types.NoneType)): return repr(value), set() # Django fields elif isinstance(value, models.Field): attr_name, path, args, kwargs = value.deconstruct() module, name = path.rsplit(".", 1) if module == "django.db.models": - imports = set() + imports = set(["from django.db import models"]) + name = "models.%s" % name else: - imports = set("import %s" % module) + imports = set(["import %s" % module]) name = path arg_strings = [] for arg in args: diff --git a/tests/migrations/test_autodetector.py b/tests/migrations/test_autodetector.py index 8e6a1e4160..1fc8f7aefb 100644 --- a/tests/migrations/test_autodetector.py +++ b/tests/migrations/test_autodetector.py @@ -2,6 +2,7 @@ from django.test import TransactionTestCase from django.db.migrations.autodetector import MigrationAutodetector from django.db.migrations.state import ProjectState, ModelState +from django.db.migrations.graph import MigrationGraph from django.db import models @@ -11,6 +12,8 @@ class AutodetectorTests(TransactionTestCase): """ author_empty = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True))]) + other_pony = ModelState("otherapp", "Pony", [("id", models.AutoField(primary_key=True))]) + other_stable = ModelState("otherapp", "Stable", [("id", models.AutoField(primary_key=True))]) def make_project_state(self, model_states): "Shortcut to make ProjectStates from lists of predefined models" @@ -19,6 +22,28 @@ class AutodetectorTests(TransactionTestCase): project_state.add_model_state(model_state) return project_state + def test_arrange_for_graph(self): + "Tests auto-naming of migrations for graph matching." + # Make a fake graph + graph = MigrationGraph() + graph.add_node(("testapp", "0001_initial"), None) + graph.add_node(("testapp", "0002_foobar"), None) + graph.add_node(("otherapp", "0001_initial"), None) + graph.add_dependency(("testapp", "0002_foobar"), ("testapp", "0001_initial")) + graph.add_dependency(("testapp", "0002_foobar"), ("otherapp", "0001_initial")) + # Use project state to make a new migration change set + before = self.make_project_state([]) + after = self.make_project_state([self.author_empty, self.other_pony, self.other_stable]) + autodetector = MigrationAutodetector(before, after) + changes = autodetector.changes() + # Run through arrange_for_graph + changes = autodetector.arrange_for_graph(changes, graph) + # Make sure there's a new name, deps match, etc. + self.assertEqual(changes["testapp"][0].name, "0003_author") + self.assertEqual(changes["testapp"][0].dependencies, [("testapp", "0002_foobar")]) + self.assertEqual(changes["otherapp"][0].name, "0002_pony_stable") + self.assertEqual(changes["otherapp"][0].dependencies, [("otherapp", "0001_initial")]) + def test_new_model(self): "Tests autodetection of new models" # Make state @@ -27,9 +52,9 @@ class AutodetectorTests(TransactionTestCase): autodetector = MigrationAutodetector(before, after) changes = autodetector.changes() # Right number of migrations? - self.assertEqual(len(changes), 1) + self.assertEqual(len(changes['testapp']), 1) # Right number of actions? - migration = changes.pop() + migration = changes['testapp'][0] self.assertEqual(len(migration.operations), 1) # Right action? action = migration.operations[0] @@ -44,9 +69,9 @@ class AutodetectorTests(TransactionTestCase): autodetector = MigrationAutodetector(before, after) changes = autodetector.changes() # Right number of migrations? - self.assertEqual(len(changes), 1) + self.assertEqual(len(changes['testapp']), 1) # Right number of actions? - migration = changes.pop() + migration = changes['testapp'][0] self.assertEqual(len(migration.operations), 1) # Right action? action = migration.operations[0] diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index 0581d6a4bd..c6ca100c1a 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -1,5 +1,6 @@ # encoding: utf8 import datetime +from django.utils import six from django.test import TransactionTestCase from django.db.migrations.writer import MigrationWriter from django.db import models, migrations @@ -10,23 +11,33 @@ class WriterTests(TransactionTestCase): Tests the migration writer (makes migration files from Migration instances) """ - def safe_exec(self, value, string): + def safe_exec(self, string, value=None): l = {} try: - exec(string, {}, l) - except: - self.fail("Could not serialize %r: failed to exec %r" % (value, string.strip())) + exec(string, globals(), l) + except Exception as e: + if value: + self.fail("Could not exec %r (from value %r): %s" % (string.strip(), value, e)) + else: + self.fail("Could not exec %r: %s" % (string.strip(), e)) return l - def assertSerializedEqual(self, value): + def serialize_round_trip(self, value): string, imports = MigrationWriter.serialize(value) - new_value = self.safe_exec(value, "%s\ntest_value_result = %s" % ("\n".join(imports), string))['test_value_result'] - self.assertEqual(new_value, value) + return self.safe_exec("%s\ntest_value_result = %s" % ("\n".join(imports), string), value)['test_value_result'] + + def assertSerializedEqual(self, value): + self.assertEqual(self.serialize_round_trip(value), value) def assertSerializedIs(self, value): - string, imports = MigrationWriter.serialize(value) - new_value = self.safe_exec(value, "%s\ntest_value_result = %s" % ("\n".join(imports), string))['test_value_result'] - self.assertIs(new_value, value) + self.assertIs(self.serialize_round_trip(value), value) + + def assertSerializedFieldEqual(self, value): + new_value = self.serialize_round_trip(value) + self.assertEqual(value.__class__, new_value.__class__) + self.assertEqual(value.max_length, new_value.max_length) + self.assertEqual(value.null, new_value.null) + self.assertEqual(value.unique, new_value.unique) def test_serialize(self): """ @@ -48,6 +59,9 @@ class WriterTests(TransactionTestCase): self.assertSerializedEqual(datetime.datetime.utcnow) self.assertSerializedEqual(datetime.date.today()) self.assertSerializedEqual(datetime.date.today) + # Django fields + self.assertSerializedFieldEqual(models.CharField(max_length=255)) + self.assertSerializedFieldEqual(models.TextField(null=True, blank=True)) def test_simple_migration(self): """ @@ -62,4 +76,9 @@ class WriterTests(TransactionTestCase): }) writer = MigrationWriter(migration) output = writer.as_string() - print output + # It should NOT be unicode. + self.assertIsInstance(output, six.binary_type, "Migration as_string returned unicode") + # We don't test the output formatting - that's too fragile. + # Just make sure it runs for now, and that things look alright. + result = self.safe_exec(output) + self.assertIn("Migration", result)