diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index d282e0898b..21eeefab82 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -272,7 +272,7 @@ class BaseDatabaseSchemaEditor(object): "new_tablespace": self.quote_name(new_db_tablespace), }) - def create_field(self, model, field, keep_default=False): + def add_field(self, model, field, keep_default=False): """ Creates a field on a model. Usually involves adding a column, but may involve adding a @@ -325,7 +325,7 @@ class BaseDatabaseSchemaEditor(object): } ) - def delete_field(self, model, field): + def remove_field(self, model, field): """ Removes a field from a model. Usually involves deleting a column, but for M2Ms may involve deleting a table. diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index de32dfd893..19bffc7520 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -73,7 +73,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): if restore_pk_field: restore_pk_field.primary_key = True - def create_field(self, model, field): + def add_field(self, model, field): """ Creates a field on a model. Usually involves adding a column, but may involve adding a @@ -89,7 +89,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): raise ValueError("You cannot add a null=False column without a default value on SQLite.") self._remake_table(model, create_fields=[field]) - def delete_field(self, model, field): + def remove_field(self, model, field): """ Removes a field from a model. Usually involves deleting a column, but for M2Ms may involve deleting a table. diff --git a/django/db/migrations/__init__.py b/django/db/migrations/__init__.py index 154e728341..e072786473 100644 --- a/django/db/migrations/__init__.py +++ b/django/db/migrations/__init__.py @@ -1 +1,2 @@ from .migration import Migration +from .operations import * diff --git a/django/db/migrations/graph.py b/django/db/migrations/graph.py index 8d23b36cb7..8e2446ca99 100644 --- a/django/db/migrations/graph.py +++ b/django/db/migrations/graph.py @@ -1,4 +1,5 @@ from django.utils.datastructures import SortedSet +from django.db.migrations.state import ProjectState class MigrationGraph(object): @@ -33,8 +34,10 @@ class MigrationGraph(object): self.nodes[node] = implementation def add_dependency(self, child, parent): - self.nodes[child] = None - self.nodes[parent] = None + if child not in self.nodes: + raise KeyError("Dependency references nonexistent child node %r" % (child,)) + if parent not in self.nodes: + raise KeyError("Dependency references nonexistent parent node %r" % (parent,)) self.dependencies.setdefault(child, set()).add(parent) self.dependents.setdefault(parent, set()).add(child) @@ -117,6 +120,16 @@ class MigrationGraph(object): def __str__(self): return "Graph: %s nodes, %s edges" % (len(self.nodes), sum(len(x) for x in self.dependencies.values())) + def project_state(self, node): + """ + Given a migration node, returns a complete ProjectState for it. + """ + plan = self.forwards_plan(node) + project_state = ProjectState() + for node in plan: + project_state = self.nodes[node].mutate_state(project_state) + return project_state + class CircularDependencyError(Exception): """ diff --git a/django/db/migrations/loader.py b/django/db/migrations/loader.py index 4d191714cb..ce9fb7c8de 100644 --- a/django/db/migrations/loader.py +++ b/django/db/migrations/loader.py @@ -1,5 +1,6 @@ import os from django.utils.importlib import import_module +from django.utils.functional import cached_property from django.db.models.loading import cache from django.db.migrations.recorder import MigrationRecorder from django.db.migrations.graph import MigrationGraph @@ -64,9 +65,10 @@ class MigrationLoader(object): migration_module = import_module("%s.%s" % (module_name, migration_name)) if not hasattr(migration_module, "Migration"): raise BadMigrationError("Migration %s in app %s has no Migration class" % (migration_name, app_label)) - self.disk_migrations[app_label, migration_name] = migration_module.Migration + self.disk_migrations[app_label, migration_name] = migration_module.Migration(migration_name, app_label) - def build_graph(self): + @cached_property + def graph(self): """ Builds a migration dependency graph using both the disk and database. """ @@ -116,6 +118,7 @@ class MigrationLoader(object): graph = MigrationGraph() for key, migration in normal.items(): graph.add_node(key, migration) + for key, migration in normal.items(): for parent in migration.dependencies: graph.add_dependency(key, parent) return graph diff --git a/django/db/migrations/migration.py b/django/db/migrations/migration.py index afbcf65297..a8b744a9b4 100644 --- a/django/db/migrations/migration.py +++ b/django/db/migrations/migration.py @@ -10,6 +10,9 @@ class Migration(object): - dependencies: A list of tuples of (app_path, migration_name) - run_before: A list of tuples of (app_path, migration_name) - replaces: A list of migration_names + + Note that all migrations come out of migrations and into the Loader or + Graph as instances, having been initialised with their app label and name. """ # Operations to apply during this migration, in order. @@ -28,3 +31,17 @@ class Migration(object): # non-empty, this migration will only be applied if all these migrations # are not applied. replaces = [] + + def __init__(self, name, app_label): + self.name = name + self.app_label = app_label + + def mutate_state(self, project_state): + """ + Takes a ProjectState and returns a new one with the migration's + operations applied to it. + """ + new_state = project_state.clone() + for operation in self.operations: + operation.state_forwards(self.app_label, new_state) + return new_state diff --git a/django/db/migrations/operations/__init__.py b/django/db/migrations/operations/__init__.py index 4fb70b0418..0aa7e2d119 100644 --- a/django/db/migrations/operations/__init__.py +++ b/django/db/migrations/operations/__init__.py @@ -1 +1,2 @@ from .models import CreateModel, DeleteModel +from .fields import AddField, RemoveField diff --git a/django/db/migrations/operations/base.py b/django/db/migrations/operations/base.py index b24b45a09a..f1b30d79f5 100644 --- a/django/db/migrations/operations/base.py +++ b/django/db/migrations/operations/base.py @@ -15,21 +15,21 @@ class Operation(object): # Some operations are impossible to reverse, like deleting data. reversible = True - def state_forwards(self, app, state): + def state_forwards(self, app_label, state): """ Takes the state from the previous migration, and mutates it so that it matches what this migration would perform. """ raise NotImplementedError() - def database_forwards(self, app, schema_editor, from_state, to_state): + def database_forwards(self, app_label, schema_editor, from_state, to_state): """ Performs the mutation on the database schema in the normal (forwards) direction. """ raise NotImplementedError() - def database_backwards(self, app, schema_editor, from_state, to_state): + def database_backwards(self, app_label, schema_editor, from_state, to_state): """ Performs the mutation on the database schema in the reverse direction - e.g. if this were CreateModel, it would in fact diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py new file mode 100644 index 0000000000..2ecf77f7ef --- /dev/null +++ b/django/db/migrations/operations/fields.py @@ -0,0 +1,52 @@ +from .base import Operation + + +class AddField(Operation): + """ + Adds a field to a model. + """ + + def __init__(self, model_name, name, instance): + self.model_name = model_name + self.name = name + self.instance = instance + + def state_forwards(self, app_label, state): + state.models[app_label, self.model_name.lower()].fields.append((self.name, self.instance)) + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + app_cache = to_state.render() + model = app_cache.get_model(app_label, self.name) + schema_editor.add_field(model, model._meta.get_field_by_name(self.name)) + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + app_cache = from_state.render() + model = app_cache.get_model(app_label, self.name) + schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)) + + +class RemoveField(Operation): + """ + Removes a field from a model. + """ + + def __init__(self, model_name, name): + self.model_name = model_name + self.name = name + + def state_forwards(self, app_label, state): + new_fields = [] + for name, instance in state.models[app_label, self.model_name.lower()].fields: + if name != self.name: + new_fields.append((name, instance)) + state.models[app_label, self.model_name.lower()].fields = new_fields + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + app_cache = from_state.render() + model = app_cache.get_model(app_label, self.name) + schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)) + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + app_cache = to_state.render() + model = app_cache.get_model(app_label, self.name) + schema_editor.add_field(model, model._meta.get_field_by_name(self.name)) diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index fd709e26fa..22d24f1eed 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -1,4 +1,5 @@ from .base import Operation +from django.db import models from django.db.migrations.state import ModelState @@ -7,20 +8,39 @@ class CreateModel(Operation): Create a model's table. """ - def __init__(self, name): + def __init__(self, name, fields, options=None, bases=None): self.name = name + self.fields = fields + self.options = options or {} + self.bases = bases or (models.Model,) - def state_forwards(self, app, state): - state.models[app, self.name.lower()] = ModelState(state, app, self.name) + def state_forwards(self, app_label, state): + state.models[app_label, self.name.lower()] = ModelState(app_label, self.name, self.fields, self.options, self.bases) def database_forwards(self, app, schema_editor, from_state, to_state): app_cache = to_state.render() schema_editor.create_model(app_cache.get_model(app, self.name)) def database_backwards(self, app, schema_editor, from_state, to_state): - """ - Performs the mutation on the database schema in the reverse - direction - e.g. if this were CreateModel, it would in fact - drop the model's table. - """ - raise NotImplementedError() + app_cache = from_state.render() + schema_editor.delete_model(app_cache.get_model(app, self.name)) + + +class DeleteModel(Operation): + """ + Drops a model's table. + """ + + def __init__(self, name): + self.name = name + + def state_forwards(self, app_label, state): + del state.models[app_label, self.name.lower()] + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + app_cache = from_state.render() + schema_editor.delete_model(app_cache.get_model(app_label, self.name)) + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + app_cache = to_state.render() + schema_editor.create_model(app_cache.get_model(app_label, self.name)) diff --git a/django/db/migrations/state.py b/django/db/migrations/state.py index 44ee166121..d189e8709e 100644 --- a/django/db/migrations/state.py +++ b/django/db/migrations/state.py @@ -21,7 +21,7 @@ class ProjectState(object): def clone(self): "Returns an exact copy of this ProjectState" return ProjectState( - models = dict((k, v.copy()) for k, v in self.models.items()) + models = dict((k, v.clone()) for k, v in self.models.items()) ) def render(self): @@ -49,12 +49,15 @@ class ModelState(object): mutate this one and then render it into a Model as required. """ - def __init__(self, app_label, name, fields=None, options=None, bases=None): + def __init__(self, app_label, name, fields, options=None, bases=None): self.app_label = app_label self.name = name - self.fields = fields or [] + self.fields = fields self.options = options or {} self.bases = bases or (models.Model, ) + # Sanity-check that fields is NOT a dict. It must be ordered. + if isinstance(self.fields, dict): + raise ValueError("ModelState.fields cannot be a dict - it must be a list of 2-tuples.") @classmethod def from_model(cls, model): diff --git a/tests/migrations/migrations/0001_initial.py b/tests/migrations/migrations/0001_initial.py index bd613aa95e..e2ed8559a6 100644 --- a/tests/migrations/migrations/0001_initial.py +++ b/tests/migrations/migrations/0001_initial.py @@ -1,5 +1,27 @@ -from django.db import migrations +from django.db import migrations, models class Migration(migrations.Migration): - pass + + operations = [ + + migrations.CreateModel( + "Author", + [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=255)), + ("slug", models.SlugField(null=True)), + ("age", models.IntegerField(default=0)), + ("silly_field", models.BooleanField()), + ], + ), + + migrations.CreateModel( + "Tribble", + [ + ("id", models.AutoField(primary_key=True)), + ("fluffy", models.BooleanField(default=True)), + ], + ) + + ] diff --git a/tests/migrations/migrations/0002_second.py b/tests/migrations/migrations/0002_second.py index f4d3ba9902..fbaef11f71 100644 --- a/tests/migrations/migrations/0002_second.py +++ b/tests/migrations/migrations/0002_second.py @@ -1,6 +1,24 @@ -from django.db import migrations +from django.db import migrations, models class Migration(migrations.Migration): dependencies = [("migrations", "0001_initial")] + + operations = [ + + migrations.DeleteModel("Tribble"), + + migrations.RemoveField("Author", "silly_field"), + + migrations.AddField("Author", "important", models.BooleanField()), + + migrations.CreateModel( + "Book", + [ + ("id", models.AutoField(primary_key=True)), + ("author", models.ForeignKey("migrations.Author", null=True)), + ], + ) + + ] diff --git a/tests/migrations/test_graph.py b/tests/migrations/test_graph.py index b35d04fb8a..207cc45741 100644 --- a/tests/migrations/test_graph.py +++ b/tests/migrations/test_graph.py @@ -1,11 +1,8 @@ -from django.test import TransactionTestCase, TestCase -from django.db import connection +from django.test import TestCase from django.db.migrations.graph import MigrationGraph, CircularDependencyError -from django.db.migrations.loader import MigrationLoader -from django.db.migrations.recorder import MigrationRecorder -class GraphTests(TransactionTestCase): +class GraphTests(TestCase): """ Tests the digraph structure. """ @@ -117,20 +114,3 @@ class GraphTests(TransactionTestCase): CircularDependencyError, graph.forwards_plan, ("app_a", "0003"), ) - - -class LoaderTests(TransactionTestCase): - """ - Tests the disk and database loader. - """ - - def test_load(self): - """ - Makes sure the loader can load the migrations for the test apps. - """ - migration_loader = MigrationLoader(connection) - graph = migration_loader.build_graph() - self.assertEqual( - graph.forwards_plan(("migrations", "0002_second")), - [("migrations", "0001_initial"), ("migrations", "0002_second")], - ) diff --git a/tests/migrations/test_loader.py b/tests/migrations/test_loader.py index f8f31734f1..badace57cc 100644 --- a/tests/migrations/test_loader.py +++ b/tests/migrations/test_loader.py @@ -1,11 +1,12 @@ -from django.test import TestCase +from django.test import TestCase, TransactionTestCase from django.db import connection +from django.db.migrations.loader import MigrationLoader from django.db.migrations.recorder import MigrationRecorder class RecorderTests(TestCase): """ - Tests the disk and database loader. + Tests recording migrations as applied or not. """ def test_apply(self): @@ -27,3 +28,37 @@ class RecorderTests(TestCase): recorder.applied_migrations(), set(), ) + + +class LoaderTests(TransactionTestCase): + """ + Tests the disk and database loader, and running through migrations + in memory. + """ + + def test_load(self): + """ + Makes sure the loader can load the migrations for the test apps, + and then render them out to a new AppCache. + """ + # Load and test the plan + migration_loader = MigrationLoader(connection) + self.assertEqual( + migration_loader.graph.forwards_plan(("migrations", "0002_second")), + [("migrations", "0001_initial"), ("migrations", "0002_second")], + ) + # Now render it out! + project_state = migration_loader.graph.project_state(("migrations", "0002_second")) + self.assertEqual(len(project_state.models), 2) + + author_state = project_state.models["migrations", "author"] + self.assertEqual( + [x for x, y in author_state.fields], + ["id", "name", "slug", "age", "important"] + ) + + book_state = project_state.models["migrations", "book"] + self.assertEqual( + [x for x, y in book_state.fields], + ["id", "author"] + ) diff --git a/tests/schema/tests.py b/tests/schema/tests.py index 752f9a5d0b..f643f3ed68 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -132,7 +132,7 @@ class SchemaTests(TransactionTestCase): else: self.fail("No FK constraint for author_id found") - def test_create_field(self): + def test_add_field(self): """ Tests adding fields to models """ @@ -146,7 +146,7 @@ class SchemaTests(TransactionTestCase): new_field = IntegerField(null=True) new_field.set_attributes_from_name("age") with connection.schema_editor() as editor: - editor.create_field( + editor.add_field( Author, new_field, ) @@ -251,7 +251,7 @@ class SchemaTests(TransactionTestCase): connection.rollback() # Add the field with connection.schema_editor() as editor: - editor.create_field( + editor.add_field( Author, new_field, ) @@ -260,7 +260,7 @@ class SchemaTests(TransactionTestCase): self.assertEqual(columns['tag_id'][0], "IntegerField") # Remove the M2M table again with connection.schema_editor() as editor: - editor.delete_field( + editor.remove_field( Author, new_field, ) @@ -530,7 +530,7 @@ class SchemaTests(TransactionTestCase): ) # Add a unique column, verify that creates an implicit index with connection.schema_editor() as editor: - editor.create_field( + editor.add_field( Book, BookWithSlug._meta.get_field_by_name("slug")[0], ) @@ -568,7 +568,7 @@ class SchemaTests(TransactionTestCase): new_field = SlugField(primary_key=True) new_field.set_attributes_from_name("slug") with connection.schema_editor() as editor: - editor.delete_field(Tag, Tag._meta.get_field_by_name("id")[0]) + editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0]) editor.alter_field( Tag, Tag._meta.get_field_by_name("slug")[0],