Add test for creating M2Ms

This commit is contained in:
Andrew Godwin 2013-07-25 16:31:34 +01:00
parent f8297f6323
commit a758c9c186
3 changed files with 38 additions and 13 deletions

View File

@ -9,7 +9,7 @@ class CreateModel(Operation):
""" """
def __init__(self, name, fields, options=None, bases=None): def __init__(self, name, fields, options=None, bases=None):
self.name = name.lower() self.name = name
self.fields = fields self.fields = fields
self.options = options or {} self.options = options or {}
self.bases = bases or (models.Model,) self.bases = bases or (models.Model,)
@ -35,7 +35,7 @@ class DeleteModel(Operation):
""" """
def __init__(self, name): def __init__(self, name):
self.name = name.lower() self.name = name
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
del state.models[app_label, self.name.lower()] del state.models[app_label, self.name.lower()]
@ -58,7 +58,7 @@ class AlterModelTable(Operation):
""" """
def __init__(self, name, table): def __init__(self, name, table):
self.name = name.lower() self.name = name
self.table = table self.table = table
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
@ -87,7 +87,7 @@ class AlterUniqueTogether(Operation):
""" """
def __init__(self, name, unique_together): def __init__(self, name, unique_together):
self.name = name.lower() self.name = name
self.unique_together = set(tuple(cons) for cons in unique_together) self.unique_together = set(tuple(cons) for cons in unique_together)
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
@ -117,7 +117,7 @@ class AlterIndexTogether(Operation):
""" """
def __init__(self, name, index_together): def __init__(self, name, index_together):
self.name = name.lower() self.name = name
self.index_together = set(tuple(cons) for cons in index_together) self.index_together = set(tuple(cons) for cons in index_together)
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):

View File

@ -85,7 +85,7 @@ class AutodetectorTests(TestCase):
# Right action? # Right action?
action = migration.operations[0] action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "CreateModel") self.assertEqual(action.__class__.__name__, "CreateModel")
self.assertEqual(action.name, "author") self.assertEqual(action.name, "Author")
def test_old_model(self): def test_old_model(self):
"Tests deletion of old models" "Tests deletion of old models"
@ -102,7 +102,7 @@ class AutodetectorTests(TestCase):
# Right action? # Right action?
action = migration.operations[0] action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "DeleteModel") self.assertEqual(action.__class__.__name__, "DeleteModel")
self.assertEqual(action.name, "author") self.assertEqual(action.name, "Author")
def test_add_field(self): def test_add_field(self):
"Tests autodetection of new fields" "Tests autodetection of new fields"

View File

@ -12,24 +12,28 @@ class OperationTests(MigrationTestBase):
both forwards and backwards. both forwards and backwards.
""" """
def set_up_test_model(self, app_label): def set_up_test_model(self, app_label, second_model=False):
""" """
Creates a test model state and database table. Creates a test model state and database table.
""" """
# Make the "current" state # Make the "current" state
creation = migrations.CreateModel( operations = [migrations.CreateModel(
"Pony", "Pony",
[ [
("id", models.AutoField(primary_key=True)), ("id", models.AutoField(primary_key=True)),
("pink", models.IntegerField(default=3)), ("pink", models.IntegerField(default=3)),
("weight", models.FloatField()), ("weight", models.FloatField()),
], ],
) )]
if second_model:
operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))]))
project_state = ProjectState() project_state = ProjectState()
creation.state_forwards(app_label, project_state) for operation in operations:
operation.state_forwards(app_label, project_state)
# Set up the database # Set up the database
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
creation.database_forwards(app_label, editor, ProjectState(), project_state) for operation in operations:
operation.database_forwards(app_label, editor, ProjectState(), project_state)
return project_state return project_state
def test_create_model(self): def test_create_model(self):
@ -48,7 +52,7 @@ class OperationTests(MigrationTestBase):
project_state = ProjectState() project_state = ProjectState()
new_state = project_state.clone() new_state = project_state.clone()
operation.state_forwards("test_crmo", new_state) operation.state_forwards("test_crmo", new_state)
self.assertEqual(new_state.models["test_crmo", "pony"].name, "pony") self.assertEqual(new_state.models["test_crmo", "pony"].name, "Pony")
self.assertEqual(len(new_state.models["test_crmo", "pony"].fields), 2) self.assertEqual(len(new_state.models["test_crmo", "pony"].fields), 2)
# Test the database alteration # Test the database alteration
self.assertTableNotExists("test_crmo_pony") self.assertTableNotExists("test_crmo_pony")
@ -106,6 +110,27 @@ class OperationTests(MigrationTestBase):
operation.database_backwards("test_adfl", editor, new_state, project_state) operation.database_backwards("test_adfl", editor, new_state, project_state)
self.assertColumnNotExists("test_adfl_pony", "height") self.assertColumnNotExists("test_adfl_pony", "height")
def test_add_field_m2m(self):
"""
Tests the AddField operation with a ManyToManyField.
"""
project_state = self.set_up_test_model("test_adflmm", second_model=True)
# Test the state alteration
operation = migrations.AddField("Pony", "stables", models.ManyToManyField("Stable"))
new_state = project_state.clone()
operation.state_forwards("test_adflmm", new_state)
self.assertEqual(len(new_state.models["test_adflmm", "pony"].fields), 4)
# Test the database alteration
self.assertTableNotExists("test_adflmm_pony_stables")
with connection.schema_editor() as editor:
operation.database_forwards("test_adflmm", editor, project_state, new_state)
self.assertTableExists("test_adflmm_pony_stables")
self.assertColumnNotExists("test_adflmm_pony", "stables")
# And test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_adflmm", editor, new_state, project_state)
self.assertTableNotExists("test_adflmm_pony_stables")
def test_remove_field(self): def test_remove_field(self):
""" """
Tests the RemoveField operation. Tests the RemoveField operation.