diff --git a/django/db/migrations/operations/__init__.py b/django/db/migrations/operations/__init__.py index 925b05fff3..afa5c85cdc 100644 --- a/django/db/migrations/operations/__init__.py +++ b/django/db/migrations/operations/__init__.py @@ -1,2 +1,2 @@ -from .models import CreateModel, DeleteModel, AlterModelTable +from .models import CreateModel, DeleteModel, AlterModelTable, AlterUniqueTogether from .fields import AddField, RemoveField, AlterField, RenameField diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index cc4f4a43df..37e0c063e1 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -7,7 +7,7 @@ class AddField(Operation): """ def __init__(self, model_name, name, field): - self.model_name = model_name + self.model_name = model_name.lower() self.name = name self.field = field @@ -33,7 +33,7 @@ class RemoveField(Operation): """ def __init__(self, model_name, name): - self.model_name = model_name + self.model_name = model_name.lower() self.name = name def state_forwards(self, app_label, state): @@ -62,7 +62,7 @@ class AlterField(Operation): """ def __init__(self, model_name, name, field): - self.model_name = model_name + self.model_name = model_name.lower() self.name = name self.field = field @@ -93,7 +93,7 @@ class RenameField(Operation): """ def __init__(self, model_name, old_name, new_name): - self.model_name = model_name + self.model_name = model_name.lower() self.old_name = old_name self.new_name = new_name diff --git a/django/db/migrations/operations/models.py b/django/db/migrations/operations/models.py index c73ff179d4..7279a163f0 100644 --- a/django/db/migrations/operations/models.py +++ b/django/db/migrations/operations/models.py @@ -9,7 +9,7 @@ class CreateModel(Operation): """ def __init__(self, name, fields, options=None, bases=None): - self.name = name + self.name = name.lower() self.fields = fields self.options = options or {} self.bases = bases or (models.Model,) @@ -35,7 +35,7 @@ class DeleteModel(Operation): """ def __init__(self, name): - self.name = name + self.name = name.lower() def state_forwards(self, app_label, state): del state.models[app_label, self.name.lower()] @@ -58,7 +58,7 @@ class AlterModelTable(Operation): """ def __init__(self, name, table): - self.name = name + self.name = name.lower() self.table = table def state_forwards(self, app_label, state): @@ -78,3 +78,33 @@ class AlterModelTable(Operation): def describe(self): return "Rename table for %s to %s" % (self.name, self.table) + + +class AlterUniqueTogether(Operation): + """ + Changes the value of unique_together to the target one. + Input value of unique_together must be a set of tuples. + """ + + def __init__(self, name, unique_together): + self.name = name.lower() + self.unique_together = set(tuple(cons) for cons in unique_together) + + def state_forwards(self, app_label, state): + model_state = state.models[app_label, self.name.lower()] + model_state.options["unique_together"] = self.unique_together + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + old_app_cache = from_state.render() + new_app_cache = to_state.render() + schema_editor.alter_unique_together( + new_app_cache.get_model(app_label, self.name), + getattr(old_app_cache.get_model(app_label, self.name)._meta, "unique_together", set()), + getattr(new_app_cache.get_model(app_label, self.name)._meta, "unique_together", set()), + ) + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + return self.database_forwards(app_label, schema_editor, from_state, to_state) + + def describe(self): + return "Alter unique_together for %s (%s constraints)" % (self.name, len(self.unique_together)) diff --git a/django/db/migrations/state.py b/django/db/migrations/state.py index 4ecdb18896..8f0078d731 100644 --- a/django/db/migrations/state.py +++ b/django/db/migrations/state.py @@ -80,8 +80,11 @@ class ModelState(object): # Ignore some special options if name in ["app_cache", "app_label"]: continue - if name in model._meta.original_attrs: - options[name] = model._meta.original_attrs[name] + elif name in model._meta.original_attrs: + if name == "unique_together": + options[name] = set(model._meta.original_attrs["unique_together"]) + else: + options[name] = model._meta.original_attrs[name] # Make our record bases = tuple(model for model in model.__bases__ if (not hasattr(model, "_meta") or not model._meta.abstract)) if not bases: @@ -116,6 +119,8 @@ class ModelState(object): # First, make a Meta object meta_contents = {'app_label': self.app_label, "app_cache": app_cache} meta_contents.update(self.options) + if "unique_together" in meta_contents: + meta_contents["unique_together"] = list(meta_contents["unique_together"]) meta = type("Meta", tuple(), meta_contents) # Then, work out our bases # TODO: Use the actual bases diff --git a/tests/migrations/test_autodetector.py b/tests/migrations/test_autodetector.py index 540e84e8df..659b45dbd2 100644 --- a/tests/migrations/test_autodetector.py +++ b/tests/migrations/test_autodetector.py @@ -83,7 +83,7 @@ class AutodetectorTests(TestCase): # Right action? action = migration.operations[0] self.assertEqual(action.__class__.__name__, "CreateModel") - self.assertEqual(action.name, "Author") + self.assertEqual(action.name, "author") def test_old_model(self): "Tests deletion of old models" @@ -100,7 +100,7 @@ class AutodetectorTests(TestCase): # Right action? action = migration.operations[0] self.assertEqual(action.__class__.__name__, "DeleteModel") - self.assertEqual(action.name, "Author") + self.assertEqual(action.name, "author") def test_add_field(self): "Tests autodetection of new fields" diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index 2e72e11954..b2912de53c 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -1,5 +1,6 @@ from django.test import TestCase from django.db import connection, models, migrations +from django.db.utils import IntegrityError from django.db.migrations.state import ProjectState @@ -38,6 +39,7 @@ class OperationTests(TestCase): [ ("id", models.AutoField(primary_key=True)), ("pink", models.BooleanField(default=True)), + ("weight", models.FloatField()), ], ) project_state = ProjectState() @@ -50,7 +52,7 @@ class OperationTests(TestCase): def test_create_model(self): """ Tests the CreateModel operation. - Most other tests use this as part of setup, so check failures here first. + Most other tests use this operation as part of setup, so check failures here first. """ operation = migrations.CreateModel( "Pony", @@ -63,7 +65,7 @@ class OperationTests(TestCase): project_state = ProjectState() new_state = project_state.clone() operation.state_forwards("test_crmo", new_state) - self.assertEqual(new_state.models["test_crmo", "pony"].name, "Pony") + self.assertEqual(new_state.models["test_crmo", "pony"].name, "pony") self.assertEqual(len(new_state.models["test_crmo", "pony"].fields), 2) # Test the database alteration self.assertTableNotExists("test_crmo_pony") @@ -110,7 +112,7 @@ class OperationTests(TestCase): operation = migrations.AddField("Pony", "height", models.FloatField(null=True)) new_state = project_state.clone() operation.state_forwards("test_adfl", new_state) - self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 3) + self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 4) # Test the database alteration self.assertColumnNotExists("test_adfl_pony", "height") with connection.schema_editor() as editor: @@ -130,7 +132,7 @@ class OperationTests(TestCase): operation = migrations.RemoveField("Pony", "pink") new_state = project_state.clone() operation.state_forwards("test_rmfl", new_state) - self.assertEqual(len(new_state.models["test_rmfl", "pony"].fields), 1) + self.assertEqual(len(new_state.models["test_rmfl", "pony"].fields), 2) # Test the database alteration self.assertColumnExists("test_rmfl_pony", "pink") with connection.schema_editor() as editor: @@ -208,3 +210,33 @@ class OperationTests(TestCase): operation.database_backwards("test_rnfl", editor, new_state, project_state) self.assertColumnExists("test_rnfl_pony", "pink") self.assertColumnNotExists("test_rnfl_pony", "blue") + + def test_alter_unique_together(self): + """ + Tests the AlterUniqueTogether operation. + """ + project_state = self.set_up_test_model("test_alunto") + # Test the state alteration + operation = migrations.AlterUniqueTogether("Pony", [("pink", "weight")]) + new_state = project_state.clone() + operation.state_forwards("test_alunto", new_state) + self.assertEqual(len(project_state.models["test_alunto", "pony"].options.get("unique_together", set())), 0) + self.assertEqual(len(new_state.models["test_alunto", "pony"].options.get("unique_together", set())), 1) + # Make sure we can insert duplicate rows + cursor = connection.cursor() + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") + cursor.execute("DELETE FROM test_alunto_pony") + # Test the database alteration + with connection.schema_editor() as editor: + operation.database_forwards("test_alunto", editor, project_state, new_state) + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") + with self.assertRaises(IntegrityError): + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") + cursor.execute("DELETE FROM test_alunto_pony") + # And test reversal + with connection.schema_editor() as editor: + operation.database_backwards("test_alunto", editor, new_state, project_state) + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (1, 1, 1)") + cursor.execute("INSERT INTO test_alunto_pony (id, pink, weight) VALUES (2, 1, 1)") + cursor.execute("DELETE FROM test_alunto_pony") diff --git a/tests/migrations/test_state.py b/tests/migrations/test_state.py index c6930873ef..e5b3fbfa08 100644 --- a/tests/migrations/test_state.py +++ b/tests/migrations/test_state.py @@ -44,7 +44,7 @@ class StateTests(TestCase): self.assertEqual(author_state.fields[1][1].max_length, 255) self.assertEqual(author_state.fields[2][1].null, False) self.assertEqual(author_state.fields[3][1].null, True) - self.assertEqual(author_state.options, {"unique_together": ["name", "bio"]}) + self.assertEqual(author_state.options, {"unique_together": set(("name", "bio"))}) self.assertEqual(author_state.bases, (models.Model, )) self.assertEqual(book_state.app_label, "migrations")