Start adding operations that work and tests for them

This commit is contained in:
Andrew Godwin 2013-05-29 17:47:10 +01:00
parent 264f8650e3
commit d0ecefc2c9
16 changed files with 221 additions and 56 deletions

View File

@ -272,7 +272,7 @@ class BaseDatabaseSchemaEditor(object):
"new_tablespace": self.quote_name(new_db_tablespace), "new_tablespace": self.quote_name(new_db_tablespace),
}) })
def create_field(self, model, field, keep_default=False): def add_field(self, model, field, keep_default=False):
""" """
Creates a field on a model. Creates a field on a model.
Usually involves adding a column, but may involve adding a Usually involves adding a column, but may involve adding a
@ -325,7 +325,7 @@ class BaseDatabaseSchemaEditor(object):
} }
) )
def delete_field(self, model, field): def remove_field(self, model, field):
""" """
Removes a field from a model. Usually involves deleting a column, Removes a field from a model. Usually involves deleting a column,
but for M2Ms may involve deleting a table. but for M2Ms may involve deleting a table.

View File

@ -73,7 +73,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
if restore_pk_field: if restore_pk_field:
restore_pk_field.primary_key = True restore_pk_field.primary_key = True
def create_field(self, model, field): def add_field(self, model, field):
""" """
Creates a field on a model. Creates a field on a model.
Usually involves adding a column, but may involve adding a Usually involves adding a column, but may involve adding a
@ -89,7 +89,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
raise ValueError("You cannot add a null=False column without a default value on SQLite.") raise ValueError("You cannot add a null=False column without a default value on SQLite.")
self._remake_table(model, create_fields=[field]) self._remake_table(model, create_fields=[field])
def delete_field(self, model, field): def remove_field(self, model, field):
""" """
Removes a field from a model. Usually involves deleting a column, Removes a field from a model. Usually involves deleting a column,
but for M2Ms may involve deleting a table. but for M2Ms may involve deleting a table.

View File

@ -1 +1,2 @@
from .migration import Migration from .migration import Migration
from .operations import *

View File

@ -1,4 +1,5 @@
from django.utils.datastructures import SortedSet from django.utils.datastructures import SortedSet
from django.db.migrations.state import ProjectState
class MigrationGraph(object): class MigrationGraph(object):
@ -33,8 +34,10 @@ class MigrationGraph(object):
self.nodes[node] = implementation self.nodes[node] = implementation
def add_dependency(self, child, parent): def add_dependency(self, child, parent):
self.nodes[child] = None if child not in self.nodes:
self.nodes[parent] = None raise KeyError("Dependency references nonexistent child node %r" % (child,))
if parent not in self.nodes:
raise KeyError("Dependency references nonexistent parent node %r" % (parent,))
self.dependencies.setdefault(child, set()).add(parent) self.dependencies.setdefault(child, set()).add(parent)
self.dependents.setdefault(parent, set()).add(child) self.dependents.setdefault(parent, set()).add(child)
@ -117,6 +120,16 @@ class MigrationGraph(object):
def __str__(self): def __str__(self):
return "Graph: %s nodes, %s edges" % (len(self.nodes), sum(len(x) for x in self.dependencies.values())) return "Graph: %s nodes, %s edges" % (len(self.nodes), sum(len(x) for x in self.dependencies.values()))
def project_state(self, node):
"""
Given a migration node, returns a complete ProjectState for it.
"""
plan = self.forwards_plan(node)
project_state = ProjectState()
for node in plan:
project_state = self.nodes[node].mutate_state(project_state)
return project_state
class CircularDependencyError(Exception): class CircularDependencyError(Exception):
""" """

View File

@ -1,5 +1,6 @@
import os import os
from django.utils.importlib import import_module from django.utils.importlib import import_module
from django.utils.functional import cached_property
from django.db.models.loading import cache from django.db.models.loading import cache
from django.db.migrations.recorder import MigrationRecorder from django.db.migrations.recorder import MigrationRecorder
from django.db.migrations.graph import MigrationGraph from django.db.migrations.graph import MigrationGraph
@ -64,9 +65,10 @@ class MigrationLoader(object):
migration_module = import_module("%s.%s" % (module_name, migration_name)) migration_module = import_module("%s.%s" % (module_name, migration_name))
if not hasattr(migration_module, "Migration"): if not hasattr(migration_module, "Migration"):
raise BadMigrationError("Migration %s in app %s has no Migration class" % (migration_name, app_label)) raise BadMigrationError("Migration %s in app %s has no Migration class" % (migration_name, app_label))
self.disk_migrations[app_label, migration_name] = migration_module.Migration self.disk_migrations[app_label, migration_name] = migration_module.Migration(migration_name, app_label)
def build_graph(self): @cached_property
def graph(self):
""" """
Builds a migration dependency graph using both the disk and database. Builds a migration dependency graph using both the disk and database.
""" """
@ -116,6 +118,7 @@ class MigrationLoader(object):
graph = MigrationGraph() graph = MigrationGraph()
for key, migration in normal.items(): for key, migration in normal.items():
graph.add_node(key, migration) graph.add_node(key, migration)
for key, migration in normal.items():
for parent in migration.dependencies: for parent in migration.dependencies:
graph.add_dependency(key, parent) graph.add_dependency(key, parent)
return graph return graph

View File

@ -10,6 +10,9 @@ class Migration(object):
- dependencies: A list of tuples of (app_path, migration_name) - dependencies: A list of tuples of (app_path, migration_name)
- run_before: A list of tuples of (app_path, migration_name) - run_before: A list of tuples of (app_path, migration_name)
- replaces: A list of migration_names - replaces: A list of migration_names
Note that all migrations come out of migrations and into the Loader or
Graph as instances, having been initialised with their app label and name.
""" """
# Operations to apply during this migration, in order. # Operations to apply during this migration, in order.
@ -28,3 +31,17 @@ class Migration(object):
# non-empty, this migration will only be applied if all these migrations # non-empty, this migration will only be applied if all these migrations
# are not applied. # are not applied.
replaces = [] replaces = []
def __init__(self, name, app_label):
self.name = name
self.app_label = app_label
def mutate_state(self, project_state):
"""
Takes a ProjectState and returns a new one with the migration's
operations applied to it.
"""
new_state = project_state.clone()
for operation in self.operations:
operation.state_forwards(self.app_label, new_state)
return new_state

View File

@ -1 +1,2 @@
from .models import CreateModel, DeleteModel from .models import CreateModel, DeleteModel
from .fields import AddField, RemoveField

View File

@ -15,21 +15,21 @@ class Operation(object):
# Some operations are impossible to reverse, like deleting data. # Some operations are impossible to reverse, like deleting data.
reversible = True reversible = True
def state_forwards(self, app, state): def state_forwards(self, app_label, state):
""" """
Takes the state from the previous migration, and mutates it Takes the state from the previous migration, and mutates it
so that it matches what this migration would perform. so that it matches what this migration would perform.
""" """
raise NotImplementedError() raise NotImplementedError()
def database_forwards(self, app, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
""" """
Performs the mutation on the database schema in the normal Performs the mutation on the database schema in the normal
(forwards) direction. (forwards) direction.
""" """
raise NotImplementedError() raise NotImplementedError()
def database_backwards(self, app, schema_editor, from_state, to_state): def database_backwards(self, app_label, schema_editor, from_state, to_state):
""" """
Performs the mutation on the database schema in the reverse Performs the mutation on the database schema in the reverse
direction - e.g. if this were CreateModel, it would in fact direction - e.g. if this were CreateModel, it would in fact

View File

@ -0,0 +1,52 @@
from .base import Operation
class AddField(Operation):
"""
Adds a field to a model.
"""
def __init__(self, model_name, name, instance):
self.model_name = model_name
self.name = name
self.instance = instance
def state_forwards(self, app_label, state):
state.models[app_label, self.model_name.lower()].fields.append((self.name, self.instance))
def database_forwards(self, app_label, schema_editor, from_state, to_state):
app_cache = to_state.render()
model = app_cache.get_model(app_label, self.name)
schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
def database_backwards(self, app_label, schema_editor, from_state, to_state):
app_cache = from_state.render()
model = app_cache.get_model(app_label, self.name)
schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
class RemoveField(Operation):
"""
Removes a field from a model.
"""
def __init__(self, model_name, name):
self.model_name = model_name
self.name = name
def state_forwards(self, app_label, state):
new_fields = []
for name, instance in state.models[app_label, self.model_name.lower()].fields:
if name != self.name:
new_fields.append((name, instance))
state.models[app_label, self.model_name.lower()].fields = new_fields
def database_forwards(self, app_label, schema_editor, from_state, to_state):
app_cache = from_state.render()
model = app_cache.get_model(app_label, self.name)
schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
def database_backwards(self, app_label, schema_editor, from_state, to_state):
app_cache = to_state.render()
model = app_cache.get_model(app_label, self.name)
schema_editor.add_field(model, model._meta.get_field_by_name(self.name))

View File

@ -1,4 +1,5 @@
from .base import Operation from .base import Operation
from django.db import models
from django.db.migrations.state import ModelState from django.db.migrations.state import ModelState
@ -7,20 +8,39 @@ class CreateModel(Operation):
Create a model's table. Create a model's table.
""" """
def __init__(self, name): def __init__(self, name, fields, options=None, bases=None):
self.name = name self.name = name
self.fields = fields
self.options = options or {}
self.bases = bases or (models.Model,)
def state_forwards(self, app, state): def state_forwards(self, app_label, state):
state.models[app, self.name.lower()] = ModelState(state, app, self.name) state.models[app_label, self.name.lower()] = ModelState(app_label, self.name, self.fields, self.options, self.bases)
def database_forwards(self, app, schema_editor, from_state, to_state): def database_forwards(self, app, schema_editor, from_state, to_state):
app_cache = to_state.render() app_cache = to_state.render()
schema_editor.create_model(app_cache.get_model(app, self.name)) schema_editor.create_model(app_cache.get_model(app, self.name))
def database_backwards(self, app, schema_editor, from_state, to_state): def database_backwards(self, app, schema_editor, from_state, to_state):
""" app_cache = from_state.render()
Performs the mutation on the database schema in the reverse schema_editor.delete_model(app_cache.get_model(app, self.name))
direction - e.g. if this were CreateModel, it would in fact
drop the model's table.
""" class DeleteModel(Operation):
raise NotImplementedError() """
Drops a model's table.
"""
def __init__(self, name):
self.name = name
def state_forwards(self, app_label, state):
del state.models[app_label, self.name.lower()]
def database_forwards(self, app_label, schema_editor, from_state, to_state):
app_cache = from_state.render()
schema_editor.delete_model(app_cache.get_model(app_label, self.name))
def database_backwards(self, app_label, schema_editor, from_state, to_state):
app_cache = to_state.render()
schema_editor.create_model(app_cache.get_model(app_label, self.name))

View File

@ -21,7 +21,7 @@ class ProjectState(object):
def clone(self): def clone(self):
"Returns an exact copy of this ProjectState" "Returns an exact copy of this ProjectState"
return ProjectState( return ProjectState(
models = dict((k, v.copy()) for k, v in self.models.items()) models = dict((k, v.clone()) for k, v in self.models.items())
) )
def render(self): def render(self):
@ -49,12 +49,15 @@ class ModelState(object):
mutate this one and then render it into a Model as required. mutate this one and then render it into a Model as required.
""" """
def __init__(self, app_label, name, fields=None, options=None, bases=None): def __init__(self, app_label, name, fields, options=None, bases=None):
self.app_label = app_label self.app_label = app_label
self.name = name self.name = name
self.fields = fields or [] self.fields = fields
self.options = options or {} self.options = options or {}
self.bases = bases or (models.Model, ) self.bases = bases or (models.Model, )
# Sanity-check that fields is NOT a dict. It must be ordered.
if isinstance(self.fields, dict):
raise ValueError("ModelState.fields cannot be a dict - it must be a list of 2-tuples.")
@classmethod @classmethod
def from_model(cls, model): def from_model(cls, model):

View File

@ -1,5 +1,27 @@
from django.db import migrations from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
pass
operations = [
migrations.CreateModel(
"Author",
[
("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=255)),
("slug", models.SlugField(null=True)),
("age", models.IntegerField(default=0)),
("silly_field", models.BooleanField()),
],
),
migrations.CreateModel(
"Tribble",
[
("id", models.AutoField(primary_key=True)),
("fluffy", models.BooleanField(default=True)),
],
)
]

View File

@ -1,6 +1,24 @@
from django.db import migrations from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [("migrations", "0001_initial")] dependencies = [("migrations", "0001_initial")]
operations = [
migrations.DeleteModel("Tribble"),
migrations.RemoveField("Author", "silly_field"),
migrations.AddField("Author", "important", models.BooleanField()),
migrations.CreateModel(
"Book",
[
("id", models.AutoField(primary_key=True)),
("author", models.ForeignKey("migrations.Author", null=True)),
],
)
]

View File

@ -1,11 +1,8 @@
from django.test import TransactionTestCase, TestCase from django.test import TestCase
from django.db import connection
from django.db.migrations.graph import MigrationGraph, CircularDependencyError from django.db.migrations.graph import MigrationGraph, CircularDependencyError
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.recorder import MigrationRecorder
class GraphTests(TransactionTestCase): class GraphTests(TestCase):
""" """
Tests the digraph structure. Tests the digraph structure.
""" """
@ -117,20 +114,3 @@ class GraphTests(TransactionTestCase):
CircularDependencyError, CircularDependencyError,
graph.forwards_plan, ("app_a", "0003"), graph.forwards_plan, ("app_a", "0003"),
) )
class LoaderTests(TransactionTestCase):
"""
Tests the disk and database loader.
"""
def test_load(self):
"""
Makes sure the loader can load the migrations for the test apps.
"""
migration_loader = MigrationLoader(connection)
graph = migration_loader.build_graph()
self.assertEqual(
graph.forwards_plan(("migrations", "0002_second")),
[("migrations", "0001_initial"), ("migrations", "0002_second")],
)

View File

@ -1,11 +1,12 @@
from django.test import TestCase from django.test import TestCase, TransactionTestCase
from django.db import connection from django.db import connection
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.recorder import MigrationRecorder from django.db.migrations.recorder import MigrationRecorder
class RecorderTests(TestCase): class RecorderTests(TestCase):
""" """
Tests the disk and database loader. Tests recording migrations as applied or not.
""" """
def test_apply(self): def test_apply(self):
@ -27,3 +28,37 @@ class RecorderTests(TestCase):
recorder.applied_migrations(), recorder.applied_migrations(),
set(), set(),
) )
class LoaderTests(TransactionTestCase):
"""
Tests the disk and database loader, and running through migrations
in memory.
"""
def test_load(self):
"""
Makes sure the loader can load the migrations for the test apps,
and then render them out to a new AppCache.
"""
# Load and test the plan
migration_loader = MigrationLoader(connection)
self.assertEqual(
migration_loader.graph.forwards_plan(("migrations", "0002_second")),
[("migrations", "0001_initial"), ("migrations", "0002_second")],
)
# Now render it out!
project_state = migration_loader.graph.project_state(("migrations", "0002_second"))
self.assertEqual(len(project_state.models), 2)
author_state = project_state.models["migrations", "author"]
self.assertEqual(
[x for x, y in author_state.fields],
["id", "name", "slug", "age", "important"]
)
book_state = project_state.models["migrations", "book"]
self.assertEqual(
[x for x, y in book_state.fields],
["id", "author"]
)

View File

@ -132,7 +132,7 @@ class SchemaTests(TransactionTestCase):
else: else:
self.fail("No FK constraint for author_id found") self.fail("No FK constraint for author_id found")
def test_create_field(self): def test_add_field(self):
""" """
Tests adding fields to models Tests adding fields to models
""" """
@ -146,7 +146,7 @@ class SchemaTests(TransactionTestCase):
new_field = IntegerField(null=True) new_field = IntegerField(null=True)
new_field.set_attributes_from_name("age") new_field.set_attributes_from_name("age")
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.create_field( editor.add_field(
Author, Author,
new_field, new_field,
) )
@ -251,7 +251,7 @@ class SchemaTests(TransactionTestCase):
connection.rollback() connection.rollback()
# Add the field # Add the field
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.create_field( editor.add_field(
Author, Author,
new_field, new_field,
) )
@ -260,7 +260,7 @@ class SchemaTests(TransactionTestCase):
self.assertEqual(columns['tag_id'][0], "IntegerField") self.assertEqual(columns['tag_id'][0], "IntegerField")
# Remove the M2M table again # Remove the M2M table again
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.delete_field( editor.remove_field(
Author, Author,
new_field, new_field,
) )
@ -530,7 +530,7 @@ class SchemaTests(TransactionTestCase):
) )
# Add a unique column, verify that creates an implicit index # Add a unique column, verify that creates an implicit index
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.create_field( editor.add_field(
Book, Book,
BookWithSlug._meta.get_field_by_name("slug")[0], BookWithSlug._meta.get_field_by_name("slug")[0],
) )
@ -568,7 +568,7 @@ class SchemaTests(TransactionTestCase):
new_field = SlugField(primary_key=True) new_field = SlugField(primary_key=True)
new_field.set_attributes_from_name("slug") new_field.set_attributes_from_name("slug")
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.delete_field(Tag, Tag._meta.get_field_by_name("id")[0]) editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0])
editor.alter_field( editor.alter_field(
Tag, Tag,
Tag._meta.get_field_by_name("slug")[0], Tag._meta.get_field_by_name("slug")[0],