diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index 51c95afd96..debfc760b5 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -14,7 +14,7 @@ class OperationTests(MigrationTestBase): both forwards and backwards. """ - def set_up_test_model(self, app_label, second_model=False, related_model=False): + def set_up_test_model(self, app_label, second_model=False, related_model=False, mti_model=False): """ Creates a test model state and database table. """ @@ -38,7 +38,12 @@ class OperationTests(MigrationTestBase): ], )] if second_model: - operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))])) + operations.append(migrations.CreateModel( + "Stable", + [ + ("id", models.AutoField(primary_key=True)), + ] + )) if related_model: operations.append(migrations.CreateModel( "Rider", @@ -47,6 +52,21 @@ class OperationTests(MigrationTestBase): ("pony", models.ForeignKey("Pony")), ], )) + if mti_model: + operations.append(migrations.CreateModel( + "ShetlandPony", + fields=[ + ('pony_ptr', models.OneToOneField( + auto_created=True, + primary_key=True, + to_field='id', + serialize=False, + to='Pony', + )), + ("cuteness", models.IntegerField(default=1)), + ], + bases=['%s.Pony' % app_label], + )) project_state = ProjectState() for operation in operations: operation.state_forwards(app_label, project_state) @@ -495,7 +515,7 @@ class OperationTests(MigrationTestBase): Tests the RunPython operation """ - project_state = self.set_up_test_model("test_runpython") + project_state = self.set_up_test_model("test_runpython", mti_model=True) # Create the operation def inner_method(models, schema_editor): @@ -533,7 +553,34 @@ class OperationTests(MigrationTestBase): no_reverse_operation.database_forwards("test_runpython", editor, project_state, new_state) with self.assertRaises(NotImplementedError): no_reverse_operation.database_backwards("test_runpython", editor, new_state, project_state) + self.assertEqual(project_state.render().get_model("test_runpython", "Pony").objects.count(), 2) + def create_ponies(models, schema_editor): + Pony = models.get_model("test_runpython", "Pony") + pony1 = Pony.objects.create(pink=1, weight=3.55) + self.assertIsNot(pony1.pk, None) + pony2 = Pony.objects.create(weight=5) + self.assertIsNot(pony2.pk, None) + self.assertNotEqual(pony1.pk, pony2.pk) + + operation = migrations.RunPython(create_ponies) + with connection.schema_editor() as editor: + operation.database_forwards("test_runpython", editor, project_state, new_state) + self.assertEqual(project_state.render().get_model("test_runpython", "Pony").objects.count(), 4) + + def create_shetlandponies(models, schema_editor): + ShetlandPony = models.get_model("test_runpython", "ShetlandPony") + pony1 = ShetlandPony.objects.create(weight=4.0) + self.assertIsNot(pony1.pk, None) + pony2 = ShetlandPony.objects.create(weight=5.0) + self.assertIsNot(pony2.pk, None) + self.assertNotEqual(pony1.pk, pony2.pk) + + operation = migrations.RunPython(create_shetlandponies) + with connection.schema_editor() as editor: + operation.database_forwards("test_runpython", editor, project_state, new_state) + self.assertEqual(project_state.render().get_model("test_runpython", "Pony").objects.count(), 6) + self.assertEqual(project_state.render().get_model("test_runpython", "ShetlandPony").objects.count(), 2) class MigrateNothingRouter(object): """