Add unique_together altering operation

This commit is contained in:
Andrew Godwin 2013-07-02 11:19:02 +01:00
parent 310cdf492d
commit 67dcea711e
7 changed files with 84 additions and 17 deletions

View File

@ -1,2 +1,2 @@
from .models import CreateModel, DeleteModel, AlterModelTable from .models import CreateModel, DeleteModel, AlterModelTable, AlterUniqueTogether
from .fields import AddField, RemoveField, AlterField, RenameField from .fields import AddField, RemoveField, AlterField, RenameField

View File

@ -7,7 +7,7 @@ class AddField(Operation):
""" """
def __init__(self, model_name, name, field): def __init__(self, model_name, name, field):
self.model_name = model_name self.model_name = model_name.lower()
self.name = name self.name = name
self.field = field self.field = field
@ -33,7 +33,7 @@ class RemoveField(Operation):
""" """
def __init__(self, model_name, name): def __init__(self, model_name, name):
self.model_name = model_name self.model_name = model_name.lower()
self.name = name self.name = name
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
@ -62,7 +62,7 @@ class AlterField(Operation):
""" """
def __init__(self, model_name, name, field): def __init__(self, model_name, name, field):
self.model_name = model_name self.model_name = model_name.lower()
self.name = name self.name = name
self.field = field self.field = field
@ -93,7 +93,7 @@ class RenameField(Operation):
""" """
def __init__(self, model_name, old_name, new_name): def __init__(self, model_name, old_name, new_name):
self.model_name = model_name self.model_name = model_name.lower()
self.old_name = old_name self.old_name = old_name
self.new_name = new_name self.new_name = new_name

View File

@ -9,7 +9,7 @@ class CreateModel(Operation):
""" """
def __init__(self, name, fields, options=None, bases=None): def __init__(self, name, fields, options=None, bases=None):
self.name = name self.name = name.lower()
self.fields = fields self.fields = fields
self.options = options or {} self.options = options or {}
self.bases = bases or (models.Model,) self.bases = bases or (models.Model,)
@ -35,7 +35,7 @@ class DeleteModel(Operation):
""" """
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name.lower()
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
del state.models[app_label, self.name.lower()] del state.models[app_label, self.name.lower()]
@ -58,7 +58,7 @@ class AlterModelTable(Operation):
""" """
def __init__(self, name, table): def __init__(self, name, table):
self.name = name self.name = name.lower()
self.table = table self.table = table
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
@ -78,3 +78,33 @@ class AlterModelTable(Operation):
def describe(self): def describe(self):
return "Rename table for %s to %s" % (self.name, self.table) return "Rename table for %s to %s" % (self.name, self.table)
class AlterUniqueTogether(Operation):
"""
Changes the value of unique_together to the target one.
Input value of unique_together must be a set of tuples.
"""
def __init__(self, name, unique_together):
self.name = name.lower()
self.unique_together = set(tuple(cons) for cons in unique_together)
def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name.lower()]
model_state.options["unique_together"] = self.unique_together
def database_forwards(self, app_label, schema_editor, from_state, to_state):
old_app_cache = from_state.render()
new_app_cache = to_state.render()
schema_editor.alter_unique_together(
new_app_cache.get_model(app_label, self.name),
getattr(old_app_cache.get_model(app_label, self.name)._meta, "unique_together", set()),
getattr(new_app_cache.get_model(app_label, self.name)._meta, "unique_together", set()),
)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
return self.database_forwards(app_label, schema_editor, from_state, to_state)
def describe(self):
return "Alter unique_together for %s (%s constraints)" % (self.name, len(self.unique_together))

View File

@ -80,8 +80,11 @@ class ModelState(object):
# Ignore some special options # Ignore some special options
if name in ["app_cache", "app_label"]: if name in ["app_cache", "app_label"]:
continue continue
if name in model._meta.original_attrs: elif name in model._meta.original_attrs:
options[name] = model._meta.original_attrs[name] if name == "unique_together":
options[name] = set(model._meta.original_attrs["unique_together"])
else:
options[name] = model._meta.original_attrs[name]
# Make our record # Make our record
bases = tuple(model for model in model.__bases__ if (not hasattr(model, "_meta") or not model._meta.abstract)) bases = tuple(model for model in model.__bases__ if (not hasattr(model, "_meta") or not model._meta.abstract))
if not bases: if not bases:
@ -116,6 +119,8 @@ class ModelState(object):
# First, make a Meta object # First, make a Meta object
meta_contents = {'app_label': self.app_label, "app_cache": app_cache} meta_contents = {'app_label': self.app_label, "app_cache": app_cache}
meta_contents.update(self.options) meta_contents.update(self.options)
if "unique_together" in meta_contents:
meta_contents["unique_together"] = list(meta_contents["unique_together"])
meta = type("Meta", tuple(), meta_contents) meta = type("Meta", tuple(), meta_contents)
# Then, work out our bases # Then, work out our bases
# TODO: Use the actual bases # TODO: Use the actual bases

View File

@ -83,7 +83,7 @@ class AutodetectorTests(TestCase):
# Right action? # Right action?
action = migration.operations[0] action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "CreateModel") self.assertEqual(action.__class__.__name__, "CreateModel")
self.assertEqual(action.name, "Author") self.assertEqual(action.name, "author")
def test_old_model(self): def test_old_model(self):
"Tests deletion of old models" "Tests deletion of old models"
@ -100,7 +100,7 @@ class AutodetectorTests(TestCase):
# Right action? # Right action?
action = migration.operations[0] action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "DeleteModel") self.assertEqual(action.__class__.__name__, "DeleteModel")
self.assertEqual(action.name, "Author") self.assertEqual(action.name, "author")
def test_add_field(self): def test_add_field(self):
"Tests autodetection of new fields" "Tests autodetection of new fields"

View File

@ -1,5 +1,6 @@
from django.test import TestCase from django.test import TestCase
from django.db import connection, models, migrations from django.db import connection, models, migrations
from django.db.utils import IntegrityError
from django.db.migrations.state import ProjectState from django.db.migrations.state import ProjectState
@ -38,6 +39,7 @@ class OperationTests(TestCase):
[ [
("id", models.AutoField(primary_key=True)), ("id", models.AutoField(primary_key=True)),
("pink", models.BooleanField(default=True)), ("pink", models.BooleanField(default=True)),
("weight", models.FloatField()),
], ],
) )
project_state = ProjectState() project_state = ProjectState()
@ -50,7 +52,7 @@ class OperationTests(TestCase):
def test_create_model(self): def test_create_model(self):
""" """
Tests the CreateModel operation. Tests the CreateModel operation.
Most other tests use this as part of setup, so check failures here first. Most other tests use this operation as part of setup, so check failures here first.
""" """
operation = migrations.CreateModel( operation = migrations.CreateModel(
"Pony", "Pony",
@ -63,7 +65,7 @@ class OperationTests(TestCase):
project_state = ProjectState() project_state = ProjectState()
new_state = project_state.clone() new_state = project_state.clone()
operation.state_forwards("test_crmo", new_state) operation.state_forwards("test_crmo", new_state)
self.assertEqual(new_state.models["test_crmo", "pony"].name, "Pony") self.assertEqual(new_state.models["test_crmo", "pony"].name, "pony")
self.assertEqual(len(new_state.models["test_crmo", "pony"].fields), 2) self.assertEqual(len(new_state.models["test_crmo", "pony"].fields), 2)
# Test the database alteration # Test the database alteration
self.assertTableNotExists("test_crmo_pony") self.assertTableNotExists("test_crmo_pony")
@ -110,7 +112,7 @@ class OperationTests(TestCase):
operation = migrations.AddField("Pony", "height", models.FloatField(null=True)) operation = migrations.AddField("Pony", "height", models.FloatField(null=True))
new_state = project_state.clone() new_state = project_state.clone()
operation.state_forwards("test_adfl", new_state) operation.state_forwards("test_adfl", new_state)
self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 3) self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 4)
# Test the database alteration # Test the database alteration
self.assertColumnNotExists("test_adfl_pony", "height") self.assertColumnNotExists("test_adfl_pony", "height")
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
@ -130,7 +132,7 @@ class OperationTests(TestCase):
operation = migrations.RemoveField("Pony", "pink") operation = migrations.RemoveField("Pony", "pink")
new_state = project_state.clone() new_state = project_state.clone()
operation.state_forwards("test_rmfl", new_state) operation.state_forwards("test_rmfl", new_state)
self.assertEqual(len(new_state.models["test_rmfl", "pony"].fields), 1) self.assertEqual(len(new_state.models["test_rmfl", "pony"].fields), 2)
# Test the database alteration # Test the database alteration
self.assertColumnExists("test_rmfl_pony", "pink") self.assertColumnExists("test_rmfl_pony", "pink")
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
@ -208,3 +210,33 @@ class OperationTests(TestCase):
operation.database_backwards("test_rnfl", editor, new_state, project_state) operation.database_backwards("test_rnfl", editor, new_state, project_state)
self.assertColumnExists("test_rnfl_pony", "pink") self.assertColumnExists("test_rnfl_pony", "pink")
self.assertColumnNotExists("test_rnfl_pony", "blue") self.assertColumnNotExists("test_rnfl_pony", "blue")
def test_alter_unique_together(self):
"""
Tests the AlterUniqueTogether operation.
"""
project_state = self.set_up_test_model("test_alunto")
# Test the state alteration
operation = migrations.AlterUniqueTogether("Pony", [("pink", "weight")])
new_state = project_state.clone()
operation.state_forwards("test_alunto", new_state)
self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0)
self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1)
# Make sure we can insert duplicate rows
cursor = connection.cursor()
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony")
# Test the database alteration
with connection.schema_editor() as editor:
operation.database_forwards("test_alunto", editor, project_state, new_state)
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
with self.assertRaises(IntegrityError):
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony")
# And test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_alunto", editor, new_state, project_state)
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)")
cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)")
cursor.execute("DELETE FROM test_alunto_pony")

View File

@ -44,7 +44,7 @@ class StateTests(TestCase):
self.assertEqual(author_state.fields[1][1].max_length, 255) self.assertEqual(author_state.fields[1][1].max_length, 255)
self.assertEqual(author_state.fields[2][1].null, False) self.assertEqual(author_state.fields[2][1].null, False)
self.assertEqual(author_state.fields[3][1].null, True) self.assertEqual(author_state.fields[3][1].null, True)
self.assertEqual(author_state.options, {"unique_together": ["name", "bio"]}) self.assertEqual(author_state.options, {"unique_together": set(("name", "bio"))})
self.assertEqual(author_state.bases, (models.Model, )) self.assertEqual(author_state.bases, (models.Model, ))
self.assertEqual(book_state.app_label, "migrations") self.assertEqual(book_state.app_label, "migrations")