Start adding operations that work and tests for them

This commit is contained in:
Andrew Godwin 2013-05-29 17:47:10 +01:00
parent 264f8650e3
commit d0ecefc2c9
16 changed files with 221 additions and 56 deletions

View File

@ -272,7 +272,7 @@ class BaseDatabaseSchemaEditor(object):
"new_tablespace": self.quote_name(new_db_tablespace),
})
def create_field(self, model, field, keep_default=False):
def add_field(self, model, field, keep_default=False):
"""
Creates a field on a model.
Usually involves adding a column, but may involve adding a
@ -325,7 +325,7 @@ class BaseDatabaseSchemaEditor(object):
}
)
def delete_field(self, model, field):
def remove_field(self, model, field):
"""
Removes a field from a model. Usually involves deleting a column,
but for M2Ms may involve deleting a table.

View File

@ -73,7 +73,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
if restore_pk_field:
restore_pk_field.primary_key = True
def create_field(self, model, field):
def add_field(self, model, field):
"""
Creates a field on a model.
Usually involves adding a column, but may involve adding a
@ -89,7 +89,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
raise ValueError("You cannot add a null=False column without a default value on SQLite.")
self._remake_table(model, create_fields=[field])
def delete_field(self, model, field):
def remove_field(self, model, field):
"""
Removes a field from a model. Usually involves deleting a column,
but for M2Ms may involve deleting a table.

View File

@ -1 +1,2 @@
from .migration import Migration
from .operations import *

View File

@ -1,4 +1,5 @@
from django.utils.datastructures import SortedSet
from django.db.migrations.state import ProjectState
class MigrationGraph(object):
@ -33,8 +34,10 @@ class MigrationGraph(object):
self.nodes[node] = implementation
def add_dependency(self, child, parent):
self.nodes[child] = None
self.nodes[parent] = None
if child not in self.nodes:
raise KeyError("Dependency references nonexistent child node %r" % (child,))
if parent not in self.nodes:
raise KeyError("Dependency references nonexistent parent node %r" % (parent,))
self.dependencies.setdefault(child, set()).add(parent)
self.dependents.setdefault(parent, set()).add(child)
@ -117,6 +120,16 @@ class MigrationGraph(object):
def __str__(self):
return "Graph: %s nodes, %s edges" % (len(self.nodes), sum(len(x) for x in self.dependencies.values()))
def project_state(self, node):
"""
Given a migration node, returns a complete ProjectState for it.
"""
plan = self.forwards_plan(node)
project_state = ProjectState()
for node in plan:
project_state = self.nodes[node].mutate_state(project_state)
return project_state
class CircularDependencyError(Exception):
"""

View File

@ -1,5 +1,6 @@
import os
from django.utils.importlib import import_module
from django.utils.functional import cached_property
from django.db.models.loading import cache
from django.db.migrations.recorder import MigrationRecorder
from django.db.migrations.graph import MigrationGraph
@ -64,9 +65,10 @@ class MigrationLoader(object):
migration_module = import_module("%s.%s" % (module_name, migration_name))
if not hasattr(migration_module, "Migration"):
raise BadMigrationError("Migration %s in app %s has no Migration class" % (migration_name, app_label))
self.disk_migrations[app_label, migration_name] = migration_module.Migration
self.disk_migrations[app_label, migration_name] = migration_module.Migration(migration_name, app_label)
def build_graph(self):
@cached_property
def graph(self):
"""
Builds a migration dependency graph using both the disk and database.
"""
@ -116,6 +118,7 @@ class MigrationLoader(object):
graph = MigrationGraph()
for key, migration in normal.items():
graph.add_node(key, migration)
for key, migration in normal.items():
for parent in migration.dependencies:
graph.add_dependency(key, parent)
return graph

View File

@ -10,6 +10,9 @@ class Migration(object):
- dependencies: A list of tuples of (app_path, migration_name)
- run_before: A list of tuples of (app_path, migration_name)
- replaces: A list of migration_names
Note that all migrations come out of migrations and into the Loader or
Graph as instances, having been initialised with their app label and name.
"""
# Operations to apply during this migration, in order.
@ -28,3 +31,17 @@ class Migration(object):
# non-empty, this migration will only be applied if all these migrations
# are not applied.
replaces = []
def __init__(self, name, app_label):
self.name = name
self.app_label = app_label
def mutate_state(self, project_state):
"""
Takes a ProjectState and returns a new one with the migration's
operations applied to it.
"""
new_state = project_state.clone()
for operation in self.operations:
operation.state_forwards(self.app_label, new_state)
return new_state

View File

@ -1 +1,2 @@
from .models import CreateModel, DeleteModel
from .fields import AddField, RemoveField

View File

@ -15,21 +15,21 @@ class Operation(object):
# Some operations are impossible to reverse, like deleting data.
reversible = True
def state_forwards(self, app, state):
def state_forwards(self, app_label, state):
"""
Takes the state from the previous migration, and mutates it
so that it matches what this migration would perform.
"""
raise NotImplementedError()
def database_forwards(self, app, schema_editor, from_state, to_state):
def database_forwards(self, app_label, schema_editor, from_state, to_state):
"""
Performs the mutation on the database schema in the normal
(forwards) direction.
"""
raise NotImplementedError()
def database_backwards(self, app, schema_editor, from_state, to_state):
def database_backwards(self, app_label, schema_editor, from_state, to_state):
"""
Performs the mutation on the database schema in the reverse
direction - e.g. if this were CreateModel, it would in fact

View File

@ -0,0 +1,52 @@
from .base import Operation
class AddField(Operation):
"""
Adds a field to a model.
"""
def __init__(self, model_name, name, instance):
self.model_name = model_name
self.name = name
self.instance = instance
def state_forwards(self, app_label, state):
state.models[app_label, self.model_name.lower()].fields.append((self.name, self.instance))
def database_forwards(self, app_label, schema_editor, from_state, to_state):
app_cache = to_state.render()
model = app_cache.get_model(app_label, self.name)
schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
def database_backwards(self, app_label, schema_editor, from_state, to_state):
app_cache = from_state.render()
model = app_cache.get_model(app_label, self.name)
schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
class RemoveField(Operation):
"""
Removes a field from a model.
"""
def __init__(self, model_name, name):
self.model_name = model_name
self.name = name
def state_forwards(self, app_label, state):
new_fields = []
for name, instance in state.models[app_label, self.model_name.lower()].fields:
if name != self.name:
new_fields.append((name, instance))
state.models[app_label, self.model_name.lower()].fields = new_fields
def database_forwards(self, app_label, schema_editor, from_state, to_state):
app_cache = from_state.render()
model = app_cache.get_model(app_label, self.name)
schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
def database_backwards(self, app_label, schema_editor, from_state, to_state):
app_cache = to_state.render()
model = app_cache.get_model(app_label, self.name)
schema_editor.add_field(model, model._meta.get_field_by_name(self.name))

View File

@ -1,4 +1,5 @@
from .base import Operation
from django.db import models
from django.db.migrations.state import ModelState
@ -7,20 +8,39 @@ class CreateModel(Operation):
Create a model's table.
"""
def __init__(self, name):
def __init__(self, name, fields, options=None, bases=None):
self.name = name
self.fields = fields
self.options = options or {}
self.bases = bases or (models.Model,)
def state_forwards(self, app, state):
state.models[app, self.name.lower()] = ModelState(state, app, self.name)
def state_forwards(self, app_label, state):
state.models[app_label, self.name.lower()] = ModelState(app_label, self.name, self.fields, self.options, self.bases)
def database_forwards(self, app, schema_editor, from_state, to_state):
app_cache = to_state.render()
schema_editor.create_model(app_cache.get_model(app, self.name))
def database_backwards(self, app, schema_editor, from_state, to_state):
"""
Performs the mutation on the database schema in the reverse
direction - e.g. if this were CreateModel, it would in fact
drop the model's table.
"""
raise NotImplementedError()
app_cache = from_state.render()
schema_editor.delete_model(app_cache.get_model(app, self.name))
class DeleteModel(Operation):
"""
Drops a model's table.
"""
def __init__(self, name):
self.name = name
def state_forwards(self, app_label, state):
del state.models[app_label, self.name.lower()]
def database_forwards(self, app_label, schema_editor, from_state, to_state):
app_cache = from_state.render()
schema_editor.delete_model(app_cache.get_model(app_label, self.name))
def database_backwards(self, app_label, schema_editor, from_state, to_state):
app_cache = to_state.render()
schema_editor.create_model(app_cache.get_model(app_label, self.name))

View File

@ -21,7 +21,7 @@ class ProjectState(object):
def clone(self):
"Returns an exact copy of this ProjectState"
return ProjectState(
models = dict((k, v.copy()) for k, v in self.models.items())
models = dict((k, v.clone()) for k, v in self.models.items())
)
def render(self):
@ -49,12 +49,15 @@ class ModelState(object):
mutate this one and then render it into a Model as required.
"""
def __init__(self, app_label, name, fields=None, options=None, bases=None):
def __init__(self, app_label, name, fields, options=None, bases=None):
self.app_label = app_label
self.name = name
self.fields = fields or []
self.fields = fields
self.options = options or {}
self.bases = bases or (models.Model, )
# Sanity-check that fields is NOT a dict. It must be ordered.
if isinstance(self.fields, dict):
raise ValueError("ModelState.fields cannot be a dict - it must be a list of 2-tuples.")
@classmethod
def from_model(cls, model):

View File

@ -1,5 +1,27 @@
from django.db import migrations
from django.db import migrations, models
class Migration(migrations.Migration):
pass
operations = [
migrations.CreateModel(
"Author",
[
("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=255)),
("slug", models.SlugField(null=True)),
("age", models.IntegerField(default=0)),
("silly_field", models.BooleanField()),
],
),
migrations.CreateModel(
"Tribble",
[
("id", models.AutoField(primary_key=True)),
("fluffy", models.BooleanField(default=True)),
],
)
]

View File

@ -1,6 +1,24 @@
from django.db import migrations
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [("migrations", "0001_initial")]
operations = [
migrations.DeleteModel("Tribble"),
migrations.RemoveField("Author", "silly_field"),
migrations.AddField("Author", "important", models.BooleanField()),
migrations.CreateModel(
"Book",
[
("id", models.AutoField(primary_key=True)),
("author", models.ForeignKey("migrations.Author", null=True)),
],
)
]

View File

@ -1,11 +1,8 @@
from django.test import TransactionTestCase, TestCase
from django.db import connection
from django.test import TestCase
from django.db.migrations.graph import MigrationGraph, CircularDependencyError
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.recorder import MigrationRecorder
class GraphTests(TransactionTestCase):
class GraphTests(TestCase):
"""
Tests the digraph structure.
"""
@ -117,20 +114,3 @@ class GraphTests(TransactionTestCase):
CircularDependencyError,
graph.forwards_plan, ("app_a", "0003"),
)
class LoaderTests(TransactionTestCase):
"""
Tests the disk and database loader.
"""
def test_load(self):
"""
Makes sure the loader can load the migrations for the test apps.
"""
migration_loader = MigrationLoader(connection)
graph = migration_loader.build_graph()
self.assertEqual(
graph.forwards_plan(("migrations", "0002_second")),
[("migrations", "0001_initial"), ("migrations", "0002_second")],
)

View File

@ -1,11 +1,12 @@
from django.test import TestCase
from django.test import TestCase, TransactionTestCase
from django.db import connection
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.recorder import MigrationRecorder
class RecorderTests(TestCase):
"""
Tests the disk and database loader.
Tests recording migrations as applied or not.
"""
def test_apply(self):
@ -27,3 +28,37 @@ class RecorderTests(TestCase):
recorder.applied_migrations(),
set(),
)
class LoaderTests(TransactionTestCase):
"""
Tests the disk and database loader, and running through migrations
in memory.
"""
def test_load(self):
"""
Makes sure the loader can load the migrations for the test apps,
and then render them out to a new AppCache.
"""
# Load and test the plan
migration_loader = MigrationLoader(connection)
self.assertEqual(
migration_loader.graph.forwards_plan(("migrations", "0002_second")),
[("migrations", "0001_initial"), ("migrations", "0002_second")],
)
# Now render it out!
project_state = migration_loader.graph.project_state(("migrations", "0002_second"))
self.assertEqual(len(project_state.models), 2)
author_state = project_state.models["migrations", "author"]
self.assertEqual(
[x for x, y in author_state.fields],
["id", "name", "slug", "age", "important"]
)
book_state = project_state.models["migrations", "book"]
self.assertEqual(
[x for x, y in book_state.fields],
["id", "author"]
)

View File

@ -132,7 +132,7 @@ class SchemaTests(TransactionTestCase):
else:
self.fail("No FK constraint for author_id found")
def test_create_field(self):
def test_add_field(self):
"""
Tests adding fields to models
"""
@ -146,7 +146,7 @@ class SchemaTests(TransactionTestCase):
new_field = IntegerField(null=True)
new_field.set_attributes_from_name("age")
with connection.schema_editor() as editor:
editor.create_field(
editor.add_field(
Author,
new_field,
)
@ -251,7 +251,7 @@ class SchemaTests(TransactionTestCase):
connection.rollback()
# Add the field
with connection.schema_editor() as editor:
editor.create_field(
editor.add_field(
Author,
new_field,
)
@ -260,7 +260,7 @@ class SchemaTests(TransactionTestCase):
self.assertEqual(columns['tag_id'][0], "IntegerField")
# Remove the M2M table again
with connection.schema_editor() as editor:
editor.delete_field(
editor.remove_field(
Author,
new_field,
)
@ -530,7 +530,7 @@ class SchemaTests(TransactionTestCase):
)
# Add a unique column, verify that creates an implicit index
with connection.schema_editor() as editor:
editor.create_field(
editor.add_field(
Book,
BookWithSlug._meta.get_field_by_name("slug")[0],
)
@ -568,7 +568,7 @@ class SchemaTests(TransactionTestCase):
new_field = SlugField(primary_key=True)
new_field.set_attributes_from_name("slug")
with connection.schema_editor() as editor:
editor.delete_field(Tag, Tag._meta.get_field_by_name("id")[0])
editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0])
editor.alter_field(
Tag,
Tag._meta.get_field_by_name("slug")[0],