Don't make a second migration if there was a force-null-default addcol.

This commit is contained in:
Andrew Godwin 2013-12-04 13:56:22 +00:00
parent df800b1609
commit ce05b8a69e
3 changed files with 63 additions and 11 deletions

View File

@ -188,15 +188,26 @@ class MigrationAutodetector(object):
continue continue
# You can't just add NOT NULL fields with no default # You can't just add NOT NULL fields with no default
if not field.null and not field.has_default(): if not field.null and not field.has_default():
field = field.clone()
field.default = self.questioner.ask_not_null_addition(field_name, model_name) field.default = self.questioner.ask_not_null_addition(field_name, model_name)
self.add_to_migration( self.add_to_migration(
app_label, app_label,
operations.AddField( operations.AddField(
model_name=model_name, model_name=model_name,
name=field_name, name=field_name,
field=field, field=field,
preserve_default=False,
)
)
else:
self.add_to_migration(
app_label,
operations.AddField(
model_name=model_name,
name=field_name,
field=field,
)
) )
)
# Old fields # Old fields
for field_name in old_field_names - new_field_names: for field_name in old_field_names - new_field_names:
self.add_to_migration( self.add_to_migration(
@ -434,7 +445,8 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
"Adding a NOT NULL field to a model" "Adding a NOT NULL field to a model"
choice = self._choice_input( choice = self._choice_input(
"You are trying to add a non-nullable field '%s' to %s without a default;\n" % (field_name, model_name) + "You are trying to add a non-nullable field '%s' to %s without a default;\n" % (field_name, model_name) +
"this is not possible. Please select a fix:", "we can't do that (the database needs something to populate existing rows).\n" +
"Please select a fix:",
[ [
"Provide a one-off default now (will be set on all existing rows)", "Provide a one-off default now (will be set on all existing rows)",
"Quit, and let me add a default in models.py", "Quit, and let me add a default in models.py",

View File

@ -1,4 +1,5 @@
from django.db import router from django.db import router
from django.db.models.fields import NOT_PROVIDED
from .base import Operation from .base import Operation
@ -7,13 +8,20 @@ class AddField(Operation):
Adds a field to a model. Adds a field to a model.
""" """
def __init__(self, model_name, name, field): def __init__(self, model_name, name, field, preserve_default=True):
self.model_name = model_name self.model_name = model_name
self.name = name self.name = name
self.field = field self.field = field
self.preserve_default = preserve_default
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
state.models[app_label, self.model_name.lower()].fields.append((self.name, self.field)) # If preserve default is off, don't use the default for future state
if not self.preserve_default:
field = self.field.clone()
field.default = NOT_PROVIDED
else:
field = self.field
state.models[app_label, self.model_name.lower()].fields.append((self.name, field))
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.render().get_model(app_label, self.model_name) from_model = from_state.render().get_model(app_label, self.model_name)

View File

@ -1,4 +1,5 @@
from django.db import connection, models, migrations, router from django.db import connection, models, migrations, router
from django.db.models.fields import NOT_PROVIDED
from django.db.transaction import atomic from django.db.transaction import atomic
from django.db.utils import IntegrityError from django.db.utils import IntegrityError
from django.db.migrations.state import ProjectState from django.db.migrations.state import ProjectState
@ -130,10 +131,19 @@ class OperationTests(MigrationTestBase):
""" """
project_state = self.set_up_test_model("test_adfl") project_state = self.set_up_test_model("test_adfl")
# Test the state alteration # Test the state alteration
operation = migrations.AddField("Pony", "height", models.FloatField(null=True)) operation = migrations.AddField(
"Pony",
"height",
models.FloatField(null=True, default=5),
)
new_state = project_state.clone() new_state = project_state.clone()
operation.state_forwards("test_adfl", new_state) operation.state_forwards("test_adfl", new_state)
self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 4) self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 4)
field = [
f for n, f in new_state.models["test_adfl", "pony"].fields
if n == "height"
][0]
self.assertEqual(field.default, 5)
# Test the database alteration # Test the database alteration
self.assertColumnNotExists("test_adfl_pony", "height") self.assertColumnNotExists("test_adfl_pony", "height")
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
@ -144,6 +154,28 @@ class OperationTests(MigrationTestBase):
operation.database_backwards("test_adfl", editor, new_state, project_state) operation.database_backwards("test_adfl", editor, new_state, project_state)
self.assertColumnNotExists("test_adfl_pony", "height") self.assertColumnNotExists("test_adfl_pony", "height")
def test_add_field_preserve_default(self):
"""
Tests the AddField operation's state alteration
when preserve_default = False.
"""
project_state = self.set_up_test_model("test_adflpd")
# Test the state alteration
operation = migrations.AddField(
"Pony",
"height",
models.FloatField(null=True, default=4),
preserve_default = False,
)
new_state = project_state.clone()
operation.state_forwards("test_adflpd", new_state)
self.assertEqual(len(new_state.models["test_adflpd", "pony"].fields), 4)
field = [
f for n, f in new_state.models["test_adflpd", "pony"].fields
if n == "height"
][0]
self.assertEqual(field.default, NOT_PROVIDED)
def test_add_field_m2m(self): def test_add_field_m2m(self):
""" """
Tests the AddField operation with a ManyToManyField. Tests the AddField operation with a ManyToManyField.