Fixed #24155 -- Maintained kwargs and import order in migration writer

Thanks Tomas Dobrovolny for the report and Tim Graham for the review.
This commit is contained in:
Markus Holtermann 2015-01-17 00:55:41 +01:00
parent bd691f4586
commit 7f20041bca
2 changed files with 37 additions and 7 deletions

View File

@ -99,7 +99,7 @@ class OperationWriter(object):
i = len(args) i = len(args)
# Only iterate over remaining arguments # Only iterate over remaining arguments
for arg_name in argspec.args[i + 1:]: for arg_name in argspec.args[i + 1:]:
if arg_name in kwargs: if arg_name in kwargs: # Don't sort to maintain signature order
arg_value = kwargs[arg_name] arg_value = kwargs[arg_name]
_write(arg_name, arg_value) _write(arg_name, arg_value)
@ -138,7 +138,7 @@ class MigrationWriter(object):
"replaces_str": "", "replaces_str": "",
} }
imports = set() imports = {"from django.db import migrations, models"}
# Deconstruct operations # Deconstruct operations
operations = [] operations = []
@ -169,14 +169,17 @@ class MigrationWriter(object):
imports.remove(line) imports.remove(line)
self.needs_manual_porting = True self.needs_manual_porting = True
imports.discard("from django.db import models") imports.discard("from django.db import models")
items["imports"] = "\n".join(imports) + "\n" if imports else "" # Sort imports by the package / module to be imported (the part after
# "from" in "from ... import ..." or after "import" in "import ...").
sorted_imports = sorted(imports, key=lambda i: i.split()[1])
items["imports"] = "\n".join(sorted_imports) + "\n" if imports else ""
if migration_imports: if migration_imports:
items["imports"] += ( items["imports"] += (
"\n\n# Functions from the following migrations need manual " "\n\n# Functions from the following migrations need manual "
"copying.\n# Move them and any dependencies into this file, " "copying.\n# Move them and any dependencies into this file, "
"then update the\n# RunPython operations to refer to the local " "then update the\n# RunPython operations to refer to the local "
"versions:\n# %s" "versions:\n# %s"
) % "\n# ".join(migration_imports) ) % "\n# ".join(sorted(migration_imports))
# If there's a replaces, make a string for it # If there's a replaces, make a string for it
if self.migration.replaces: if self.migration.replaces:
items['replaces_str'] = "\n replaces = %s\n" % self.serialize(self.migration.replaces)[0] items['replaces_str'] = "\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
@ -244,7 +247,7 @@ class MigrationWriter(object):
arg_string, arg_imports = cls.serialize(arg) arg_string, arg_imports = cls.serialize(arg)
strings.append(arg_string) strings.append(arg_string)
imports.update(arg_imports) imports.update(arg_imports)
for kw, arg in kwargs.items(): for kw, arg in sorted(kwargs.items()):
arg_string, arg_imports = cls.serialize(arg) arg_string, arg_imports = cls.serialize(arg)
imports.update(arg_imports) imports.update(arg_imports)
strings.append("%s=%s" % (kw, arg_string)) strings.append("%s=%s" % (kw, arg_string))
@ -297,7 +300,7 @@ class MigrationWriter(object):
elif isinstance(value, dict): elif isinstance(value, dict):
imports = set() imports = set()
strings = [] strings = []
for k, v in value.items(): for k, v in sorted(value.items()):
k_string, k_imports = cls.serialize(k) k_string, k_imports = cls.serialize(k)
v_string, v_imports = cls.serialize(v) v_string, v_imports = cls.serialize(v)
imports.update(k_imports) imports.update(k_imports)
@ -443,7 +446,6 @@ MIGRATION_TEMPLATE = """\
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
from django.db import models, migrations
%(imports)s %(imports)s
class Migration(migrations.Migration): class Migration(migrations.Migration):

View File

@ -216,7 +216,15 @@ class WriterTests(TestCase):
def test_serialize_fields(self): def test_serialize_fields(self):
self.assertSerializedFieldEqual(models.CharField(max_length=255)) self.assertSerializedFieldEqual(models.CharField(max_length=255))
self.assertSerializedResultEqual(
models.CharField(max_length=255),
("models.CharField(max_length=255)", {"from django.db import models"})
)
self.assertSerializedFieldEqual(models.TextField(null=True, blank=True)) self.assertSerializedFieldEqual(models.TextField(null=True, blank=True))
self.assertSerializedResultEqual(
models.TextField(null=True, blank=True),
("models.TextField(blank=True, null=True)", {'from django.db import models'})
)
def test_serialize_settings(self): def test_serialize_settings(self):
self.assertSerializedEqual(SettingsReference(settings.AUTH_USER_MODEL, "AUTH_USER_MODEL")) self.assertSerializedEqual(SettingsReference(settings.AUTH_USER_MODEL, "AUTH_USER_MODEL"))
@ -419,6 +427,26 @@ class WriterTests(TestCase):
result['custom_migration_operations'].more_operations.TestOperation result['custom_migration_operations'].more_operations.TestOperation
) )
def test_sorted_imports(self):
"""
#24155 - Tests ordering of imports.
"""
migration = type(str("Migration"), (migrations.Migration,), {
"operations": [
migrations.AddField("mymodel", "myfield", models.DateTimeField(
default=datetime.datetime(2012, 1, 1, 1, 1, tzinfo=utc),
)),
]
})
writer = MigrationWriter(migration)
output = writer.as_string().decode('utf-8')
self.assertIn(
"import datetime\n"
"from django.db import migrations, models\n"
"from django.utils.timezone import utc\n",
output
)
def test_deconstruct_class_arguments(self): def test_deconstruct_class_arguments(self):
# Yes, it doesn't make sense to use a class as a default for a # Yes, it doesn't make sense to use a class as a default for a
# CharField. It does make sense for custom fields though, for example # CharField. It does make sense for custom fields though, for example