diff --git a/django/db/migrations/writer.py b/django/db/migrations/writer.py index 3c4bbb04271..6d77c6121a2 100644 --- a/django/db/migrations/writer.py +++ b/django/db/migrations/writer.py @@ -46,10 +46,38 @@ class OperationWriter(object): self.buff = [] def serialize(self): + + def _write(_arg_name, _arg_value): + if (_arg_name in self.operation.serialization_expand_args and + isinstance(_arg_value, (list, tuple, dict))): + if isinstance(_arg_value, dict): + self.feed('%s={' % _arg_name) + self.indent() + for key, value in _arg_value.items(): + key_string, key_imports = MigrationWriter.serialize(key) + arg_string, arg_imports = MigrationWriter.serialize(value) + self.feed('%s: %s,' % (key_string, arg_string)) + imports.update(key_imports) + imports.update(arg_imports) + self.unindent() + self.feed('},') + else: + self.feed('%s=[' % _arg_name) + self.indent() + for item in _arg_value: + arg_string, arg_imports = MigrationWriter.serialize(item) + self.feed('%s,' % arg_string) + imports.update(arg_imports) + self.unindent() + self.feed('],') + else: + arg_string, arg_imports = MigrationWriter.serialize(_arg_value) + self.feed('%s=%s,' % (_arg_name, arg_string)) + imports.update(arg_imports) + imports = set() name, args, kwargs = self.operation.deconstruct() argspec = inspect.getargspec(self.operation.__init__) - normalized_kwargs = inspect.getcallargs(self.operation.__init__, *args, **kwargs) # See if this operation is in django.db.migrations. If it is, # We can just use the fact we already have that imported, @@ -61,34 +89,20 @@ class OperationWriter(object): self.feed('%s.%s(' % (self.operation.__class__.__module__, name)) self.indent() - for arg_name in argspec.args[1:]: - arg_value = normalized_kwargs[arg_name] - if (arg_name in self.operation.serialization_expand_args and - isinstance(arg_value, (list, tuple, dict))): - if isinstance(arg_value, dict): - self.feed('%s={' % arg_name) - self.indent() - for key, value in arg_value.items(): - key_string, key_imports = MigrationWriter.serialize(key) - arg_string, arg_imports = MigrationWriter.serialize(value) - self.feed('%s: %s,' % (key_string, arg_string)) - imports.update(key_imports) - imports.update(arg_imports) - self.unindent() - self.feed('},') - else: - self.feed('%s=[' % arg_name) - self.indent() - for item in arg_value: - arg_string, arg_imports = MigrationWriter.serialize(item) - self.feed('%s,' % arg_string) - imports.update(arg_imports) - self.unindent() - self.feed('],') - else: - arg_string, arg_imports = MigrationWriter.serialize(arg_value) - self.feed('%s=%s,' % (arg_name, arg_string)) - imports.update(arg_imports) + + # Start at one because argspec includes "self" + for i, arg in enumerate(args, 1): + arg_value = arg + arg_name = argspec.args[i] + _write(arg_name, arg_value) + + i = len(args) + # Only iterate over remaining arguments + for arg_name in argspec.args[i + 1:]: + if arg_name in kwargs: + arg_value = kwargs[arg_name] + _write(arg_name, arg_value) + self.unindent() self.feed('),') return self.render(), imports diff --git a/tests/custom_migration_operations/operations.py b/tests/custom_migration_operations/operations.py index 3a4127d7533..bd62280f81a 100644 --- a/tests/custom_migration_operations/operations.py +++ b/tests/custom_migration_operations/operations.py @@ -31,3 +31,64 @@ class TestOperation(Operation): class CreateModel(TestOperation): pass + + +class ArgsOperation(TestOperation): + def __init__(self, arg1, arg2): + self.arg1, self.arg2 = arg1, arg2 + + def deconstruct(self): + return ( + self.__class__.__name__, + [self.arg1, self.arg2], + {} + ) + + +class KwargsOperation(TestOperation): + def __init__(self, kwarg1=None, kwarg2=None): + self.kwarg1, self.kwarg2 = kwarg1, kwarg2 + + def deconstruct(self): + kwargs = {} + if self.kwarg1 is not None: + kwargs['kwarg1'] = self.kwarg1 + if self.kwarg2 is not None: + kwargs['kwarg2'] = self.kwarg2 + return ( + self.__class__.__name__, + [], + kwargs + ) + + +class ArgsKwargsOperation(TestOperation): + def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None): + self.arg1, self.arg2 = arg1, arg2 + self.kwarg1, self.kwarg2 = kwarg1, kwarg2 + + def deconstruct(self): + kwargs = {} + if self.kwarg1 is not None: + kwargs['kwarg1'] = self.kwarg1 + if self.kwarg2 is not None: + kwargs['kwarg2'] = self.kwarg2 + return ( + self.__class__.__name__, + [self.arg1, self.arg2], + kwargs, + ) + + +class ExpandArgsOperation(TestOperation): + serialization_expand_args = ['arg'] + + def __init__(self, arg): + self.arg = arg + + def deconstruct(self): + return ( + self.__class__.__name__, + [self.arg], + {} + ) diff --git a/tests/migrations/test_writer.py b/tests/migrations/test_writer.py index edcc5b285d1..2519e43f427 100644 --- a/tests/migrations/test_writer.py +++ b/tests/migrations/test_writer.py @@ -10,8 +10,8 @@ import unittest from django.core.validators import RegexValidator, EmailValidator from django.db import models, migrations -from django.db.migrations.writer import MigrationWriter, SettingsReference -from django.test import TestCase, ignore_warnings +from django.db.migrations.writer import MigrationWriter, OperationWriter, SettingsReference +from django.test import SimpleTestCase, TestCase, ignore_warnings from django.conf import settings from django.utils import datetime_safe, six from django.utils.deconstruct import deconstructible @@ -30,6 +30,79 @@ class TestModel1(object): thing = models.FileField(upload_to=upload_to) +class OperationWriterTests(SimpleTestCase): + + def test_empty_signature(self): + operation = custom_migration_operations.operations.TestOperation() + writer = OperationWriter(operation) + writer.indentation = 0 + buff, imports = writer.serialize() + self.assertEqual(imports, {'import custom_migration_operations.operations'}) + self.assertEqual( + buff, + 'custom_migration_operations.operations.TestOperation(\n' + '),' + ) + + def test_args_signature(self): + operation = custom_migration_operations.operations.ArgsOperation(1, 2) + writer = OperationWriter(operation) + writer.indentation = 0 + buff, imports = writer.serialize() + self.assertEqual(imports, {'import custom_migration_operations.operations'}) + self.assertEqual( + buff, + 'custom_migration_operations.operations.ArgsOperation(\n' + ' arg1=1,\n' + ' arg2=2,\n' + '),' + ) + + def test_kwargs_signature(self): + operation = custom_migration_operations.operations.KwargsOperation(kwarg1=1) + writer = OperationWriter(operation) + writer.indentation = 0 + buff, imports = writer.serialize() + self.assertEqual(imports, {'import custom_migration_operations.operations'}) + self.assertEqual( + buff, + 'custom_migration_operations.operations.KwargsOperation(\n' + ' kwarg1=1,\n' + '),' + ) + + def test_args_kwargs_signature(self): + operation = custom_migration_operations.operations.ArgsKwargsOperation(1, 2, kwarg2=4) + writer = OperationWriter(operation) + writer.indentation = 0 + buff, imports = writer.serialize() + self.assertEqual(imports, {'import custom_migration_operations.operations'}) + self.assertEqual( + buff, + 'custom_migration_operations.operations.ArgsKwargsOperation(\n' + ' arg1=1,\n' + ' arg2=2,\n' + ' kwarg2=4,\n' + '),' + ) + + def test_expand_args_signature(self): + operation = custom_migration_operations.operations.ExpandArgsOperation([1, 2]) + writer = OperationWriter(operation) + writer.indentation = 0 + buff, imports = writer.serialize() + self.assertEqual(imports, {'import custom_migration_operations.operations'}) + self.assertEqual( + buff, + 'custom_migration_operations.operations.ExpandArgsOperation(\n' + ' arg=[\n' + ' 1,\n' + ' 2,\n' + ' ],\n' + '),' + ) + + class WriterTests(TestCase): """ Tests the migration writer (makes migration files from Migration instances)