From bb39037fcbe07a4c4060764533b5c03a4018bf81 Mon Sep 17 00:00:00 2001 From: Matthew Schinckel Date: Wed, 11 Jun 2014 22:30:52 +0930 Subject: [PATCH] Fixed #22788 -- Ensured custom migration operations can be written. This inspects the migration operation, and if it is not in the django.db.migrations module, it adds the relevant imports to the migration writer and uses the correct class name. --- django/db/migrations/writer.py | 12 +++++++-- tests/custom_migration_operations/__init__.py | 0 .../more_operations.py | 22 ++++++++++++++++ .../custom_migration_operations/operations.py | 26 +++++++++++++++++++ tests/migrations/test_writer.py | 22 ++++++++++++++++ 5 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 tests/custom_migration_operations/__init__.py create mode 100644 tests/custom_migration_operations/more_operations.py create mode 100644 tests/custom_migration_operations/operations.py diff --git a/django/db/migrations/writer.py b/django/db/migrations/writer.py index 88145733ec..dee867929d 100644 --- a/django/db/migrations/writer.py +++ b/django/db/migrations/writer.py @@ -10,7 +10,7 @@ import sys import types from django.apps import apps -from django.db import models +from django.db import models, migrations from django.db.migrations.loader import MigrationLoader from django.utils import datetime_safe, six from django.utils.encoding import force_text @@ -44,7 +44,15 @@ class OperationWriter(object): argspec = inspect.getargspec(self.operation.__init__) normalized_kwargs = inspect.getcallargs(self.operation.__init__, *args, **kwargs) - self.feed('migrations.%s(' % name) + # See if this operation is in django.db.migrations. If it is, + # We can just use the fact we already have that imported, + # otherwise, we need to add an import for the operation class. + if getattr(migrations, name, None) == self.operation.__class__: + self.feed('migrations.%s(' % name) + else: + imports.add('import %s' % (self.operation.__class__.__module__)) + self.feed('%s.%s(' % (self.operation.__class__.__module__, name)) + self.indent() for arg_name in argspec.args[1:]: arg_value = normalized_kwargs[arg_name] diff --git a/tests/custom_migration_operations/__init__.py b/tests/custom_migration_operations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/custom_migration_operations/more_operations.py b/tests/custom_migration_operations/more_operations.py new file mode 100644 index 0000000000..6fe3d1cf93 --- /dev/null +++ b/tests/custom_migration_operations/more_operations.py @@ -0,0 +1,22 @@ +from django.db.migrations.operations.base import Operation + + +class TestOperation(Operation): + def __init__(self): + pass + + @property + def reversible(self): + return True + + def state_forwards(self, app_label, state): + pass + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + pass + + def state_backwards(self, app_label, state): + pass + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + pass diff --git a/tests/custom_migration_operations/operations.py b/tests/custom_migration_operations/operations.py new file mode 100644 index 0000000000..fc084f8412 --- /dev/null +++ b/tests/custom_migration_operations/operations.py @@ -0,0 +1,26 @@ +from django.db.migrations.operations.base import Operation + + +class TestOperation(Operation): + def __init__(self): + pass + + @property + def reversible(self): + return True + + def state_forwards(self, app_label, state): + pass + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + pass + + def state_backwards(self, app_label, state): + pass + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + pass + + +class CreateModel(TestOperation): + pass diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index 925a9d6538..cd59938314 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -16,6 +16,9 @@ from django.utils.deconstruct import deconstructible from django.utils.translation import ugettext_lazy as _ from django.utils.timezone import get_default_timezone +import custom_migration_operations.operations +import custom_migration_operations.more_operations + class TestModel1(object): def upload_to(self): @@ -222,3 +225,22 @@ class WriterTests(TestCase): expected_path = os.path.join(base_dir, *(app.split('.') + ['migrations', '0001_initial.py'])) writer = MigrationWriter(migration) self.assertEqual(writer.path, expected_path) + + def test_custom_operation(self): + migration = type(str("Migration"), (migrations.Migration,), { + "operations": [ + custom_migration_operations.operations.TestOperation(), + custom_migration_operations.operations.CreateModel(), + migrations.CreateModel("MyModel", (), {}, (models.Model,)), + custom_migration_operations.more_operations.TestOperation() + ], + "dependencies": [] + }) + writer = MigrationWriter(migration) + output = writer.as_string() + result = self.safe_exec(output) + self.assertIn("custom_migration_operations", result) + self.assertNotEqual( + result['custom_migration_operations'].operations.TestOperation, + result['custom_migration_operations'].more_operations.TestOperation + )