Fixed #22788 -- Ensured custom migration operations can be written.

This inspects the migration operation, and if it is not in the
django.db.migrations module, it adds the relevant imports to the
migration writer and uses the correct class name.
This commit is contained in:
Matthew Schinckel 2014-06-11 22:30:52 +09:30 committed by Tim Graham
parent 37a8f5aeed
commit bb39037fcb
5 changed files with 80 additions and 2 deletions

View File

@ -10,7 +10,7 @@ import sys
import types import types
from django.apps import apps from django.apps import apps
from django.db import models from django.db import models, migrations
from django.db.migrations.loader import MigrationLoader from django.db.migrations.loader import MigrationLoader
from django.utils import datetime_safe, six from django.utils import datetime_safe, six
from django.utils.encoding import force_text from django.utils.encoding import force_text
@ -44,7 +44,15 @@ class OperationWriter(object):
argspec = inspect.getargspec(self.operation.__init__) argspec = inspect.getargspec(self.operation.__init__)
normalized_kwargs = inspect.getcallargs(self.operation.__init__, *args, **kwargs) normalized_kwargs = inspect.getcallargs(self.operation.__init__, *args, **kwargs)
self.feed('migrations.%s(' % name) # See if this operation is in django.db.migrations. If it is,
# We can just use the fact we already have that imported,
# otherwise, we need to add an import for the operation class.
if getattr(migrations, name, None) == self.operation.__class__:
self.feed('migrations.%s(' % name)
else:
imports.add('import %s' % (self.operation.__class__.__module__))
self.feed('%s.%s(' % (self.operation.__class__.__module__, name))
self.indent() self.indent()
for arg_name in argspec.args[1:]: for arg_name in argspec.args[1:]:
arg_value = normalized_kwargs[arg_name] arg_value = normalized_kwargs[arg_name]

View File

@ -0,0 +1,22 @@
from django.db.migrations.operations.base import Operation
class TestOperation(Operation):
def __init__(self):
pass
@property
def reversible(self):
return True
def state_forwards(self, app_label, state):
pass
def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass
def state_backwards(self, app_label, state):
pass
def database_backwards(self, app_label, schema_editor, from_state, to_state):
pass

View File

@ -0,0 +1,26 @@
from django.db.migrations.operations.base import Operation
class TestOperation(Operation):
def __init__(self):
pass
@property
def reversible(self):
return True
def state_forwards(self, app_label, state):
pass
def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass
def state_backwards(self, app_label, state):
pass
def database_backwards(self, app_label, schema_editor, from_state, to_state):
pass
class CreateModel(TestOperation):
pass

View File

@ -16,6 +16,9 @@ from django.utils.deconstruct import deconstructible
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils.timezone import get_default_timezone from django.utils.timezone import get_default_timezone
import custom_migration_operations.operations
import custom_migration_operations.more_operations
class TestModel1(object): class TestModel1(object):
def upload_to(self): def upload_to(self):
@ -222,3 +225,22 @@ class WriterTests(TestCase):
expected_path = os.path.join(base_dir, *(app.split('.') + ['migrations', '0001_initial.py'])) expected_path = os.path.join(base_dir, *(app.split('.') + ['migrations', '0001_initial.py']))
writer = MigrationWriter(migration) writer = MigrationWriter(migration)
self.assertEqual(writer.path, expected_path) self.assertEqual(writer.path, expected_path)
def test_custom_operation(self):
migration = type(str("Migration"), (migrations.Migration,), {
"operations": [
custom_migration_operations.operations.TestOperation(),
custom_migration_operations.operations.CreateModel(),
migrations.CreateModel("MyModel", (), {}, (models.Model,)),
custom_migration_operations.more_operations.TestOperation()
],
"dependencies": []
})
writer = MigrationWriter(migration)
output = writer.as_string()
result = self.safe_exec(output)
self.assertIn("custom_migration_operations", result)
self.assertNotEqual(
result['custom_migration_operations'].operations.TestOperation,
result['custom_migration_operations'].more_operations.TestOperation
)