diff --git a/django/db/migrations/writer.py b/django/db/migrations/writer.py index 88145733ec..dee867929d 100644 --- a/django/db/migrations/writer.py +++ b/django/db/migrations/writer.py @@ -10,7 +10,7 @@ import sys import types from django.apps import apps -from django.db import models +from django.db import models, migrations from django.db.migrations.loader import MigrationLoader from django.utils import datetime_safe, six from django.utils.encoding import force_text @@ -44,7 +44,15 @@ class OperationWriter(object): argspec = inspect.getargspec(self.operation.__init__) normalized_kwargs = inspect.getcallargs(self.operation.__init__, *args, **kwargs) - self.feed('migrations.%s(' % name) + # See if this operation is in django.db.migrations. If it is, + # We can just use the fact we already have that imported, + # otherwise, we need to add an import for the operation class. + if getattr(migrations, name, None) == self.operation.__class__: + self.feed('migrations.%s(' % name) + else: + imports.add('import %s' % (self.operation.__class__.__module__)) + self.feed('%s.%s(' % (self.operation.__class__.__module__, name)) + self.indent() for arg_name in argspec.args[1:]: arg_value = normalized_kwargs[arg_name] diff --git a/tests/custom_migration_operations/__init__.py b/tests/custom_migration_operations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/custom_migration_operations/more_operations.py b/tests/custom_migration_operations/more_operations.py new file mode 100644 index 0000000000..6fe3d1cf93 --- /dev/null +++ b/tests/custom_migration_operations/more_operations.py @@ -0,0 +1,22 @@ +from django.db.migrations.operations.base import Operation + + +class TestOperation(Operation): + def __init__(self): + pass + + @property + def reversible(self): + return True + + def state_forwards(self, app_label, state): + pass + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + pass + + def state_backwards(self, app_label, state): + pass + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + pass diff --git a/tests/custom_migration_operations/operations.py b/tests/custom_migration_operations/operations.py new file mode 100644 index 0000000000..fc084f8412 --- /dev/null +++ b/tests/custom_migration_operations/operations.py @@ -0,0 +1,26 @@ +from django.db.migrations.operations.base import Operation + + +class TestOperation(Operation): + def __init__(self): + pass + + @property + def reversible(self): + return True + + def state_forwards(self, app_label, state): + pass + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + pass + + def state_backwards(self, app_label, state): + pass + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + pass + + +class CreateModel(TestOperation): + pass diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index 925a9d6538..cd59938314 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -16,6 +16,9 @@ from django.utils.deconstruct import deconstructible from django.utils.translation import ugettext_lazy as _ from django.utils.timezone import get_default_timezone +import custom_migration_operations.operations +import custom_migration_operations.more_operations + class TestModel1(object): def upload_to(self): @@ -222,3 +225,22 @@ class WriterTests(TestCase): expected_path = os.path.join(base_dir, *(app.split('.') + ['migrations', '0001_initial.py'])) writer = MigrationWriter(migration) self.assertEqual(writer.path, expected_path) + + def test_custom_operation(self): + migration = type(str("Migration"), (migrations.Migration,), { + "operations": [ + custom_migration_operations.operations.TestOperation(), + custom_migration_operations.operations.CreateModel(), + migrations.CreateModel("MyModel", (), {}, (models.Model,)), + custom_migration_operations.more_operations.TestOperation() + ], + "dependencies": [] + }) + writer = MigrationWriter(migration) + output = writer.as_string() + result = self.safe_exec(output) + self.assertIn("custom_migration_operations", result) + self.assertNotEqual( + result['custom_migration_operations'].operations.TestOperation, + result['custom_migration_operations'].more_operations.TestOperation + )