mirror of https://github.com/django/django.git
Don't make a second migration if there was a force-null-default addcol.
This commit is contained in:
parent
df800b1609
commit
ce05b8a69e
|
@ -188,15 +188,26 @@ class MigrationAutodetector(object):
|
||||||
continue
|
continue
|
||||||
# You can't just add NOT NULL fields with no default
|
# You can't just add NOT NULL fields with no default
|
||||||
if not field.null and not field.has_default():
|
if not field.null and not field.has_default():
|
||||||
|
field = field.clone()
|
||||||
field.default = self.questioner.ask_not_null_addition(field_name, model_name)
|
field.default = self.questioner.ask_not_null_addition(field_name, model_name)
|
||||||
self.add_to_migration(
|
self.add_to_migration(
|
||||||
app_label,
|
app_label,
|
||||||
operations.AddField(
|
operations.AddField(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
name=field_name,
|
name=field_name,
|
||||||
field=field,
|
field=field,
|
||||||
|
preserve_default=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.add_to_migration(
|
||||||
|
app_label,
|
||||||
|
operations.AddField(
|
||||||
|
model_name=model_name,
|
||||||
|
name=field_name,
|
||||||
|
field=field,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
# Old fields
|
# Old fields
|
||||||
for field_name in old_field_names - new_field_names:
|
for field_name in old_field_names - new_field_names:
|
||||||
self.add_to_migration(
|
self.add_to_migration(
|
||||||
|
@ -434,7 +445,8 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
|
||||||
"Adding a NOT NULL field to a model"
|
"Adding a NOT NULL field to a model"
|
||||||
choice = self._choice_input(
|
choice = self._choice_input(
|
||||||
"You are trying to add a non-nullable field '%s' to %s without a default;\n" % (field_name, model_name) +
|
"You are trying to add a non-nullable field '%s' to %s without a default;\n" % (field_name, model_name) +
|
||||||
"this is not possible. Please select a fix:",
|
"we can't do that (the database needs something to populate existing rows).\n" +
|
||||||
|
"Please select a fix:",
|
||||||
[
|
[
|
||||||
"Provide a one-off default now (will be set on all existing rows)",
|
"Provide a one-off default now (will be set on all existing rows)",
|
||||||
"Quit, and let me add a default in models.py",
|
"Quit, and let me add a default in models.py",
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from django.db import router
|
from django.db import router
|
||||||
|
from django.db.models.fields import NOT_PROVIDED
|
||||||
from .base import Operation
|
from .base import Operation
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,13 +8,20 @@ class AddField(Operation):
|
||||||
Adds a field to a model.
|
Adds a field to a model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name, name, field):
|
def __init__(self, model_name, name, field, preserve_default=True):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.name = name
|
self.name = name
|
||||||
self.field = field
|
self.field = field
|
||||||
|
self.preserve_default = preserve_default
|
||||||
|
|
||||||
def state_forwards(self, app_label, state):
|
def state_forwards(self, app_label, state):
|
||||||
state.models[app_label, self.model_name.lower()].fields.append((self.name, self.field))
|
# If preserve default is off, don't use the default for future state
|
||||||
|
if not self.preserve_default:
|
||||||
|
field = self.field.clone()
|
||||||
|
field.default = NOT_PROVIDED
|
||||||
|
else:
|
||||||
|
field = self.field
|
||||||
|
state.models[app_label, self.model_name.lower()].fields.append((self.name, field))
|
||||||
|
|
||||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||||
from_model = from_state.render().get_model(app_label, self.model_name)
|
from_model = from_state.render().get_model(app_label, self.model_name)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from django.db import connection, models, migrations, router
|
from django.db import connection, models, migrations, router
|
||||||
|
from django.db.models.fields import NOT_PROVIDED
|
||||||
from django.db.transaction import atomic
|
from django.db.transaction import atomic
|
||||||
from django.db.utils import IntegrityError
|
from django.db.utils import IntegrityError
|
||||||
from django.db.migrations.state import ProjectState
|
from django.db.migrations.state import ProjectState
|
||||||
|
@ -130,10 +131,19 @@ class OperationTests(MigrationTestBase):
|
||||||
"""
|
"""
|
||||||
project_state = self.set_up_test_model("test_adfl")
|
project_state = self.set_up_test_model("test_adfl")
|
||||||
# Test the state alteration
|
# Test the state alteration
|
||||||
operation = migrations.AddField("Pony", "height", models.FloatField(null=True))
|
operation = migrations.AddField(
|
||||||
|
"Pony",
|
||||||
|
"height",
|
||||||
|
models.FloatField(null=True, default=5),
|
||||||
|
)
|
||||||
new_state = project_state.clone()
|
new_state = project_state.clone()
|
||||||
operation.state_forwards("test_adfl", new_state)
|
operation.state_forwards("test_adfl", new_state)
|
||||||
self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 4)
|
self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 4)
|
||||||
|
field = [
|
||||||
|
f for n, f in new_state.models["test_adfl", "pony"].fields
|
||||||
|
if n == "height"
|
||||||
|
][0]
|
||||||
|
self.assertEqual(field.default, 5)
|
||||||
# Test the database alteration
|
# Test the database alteration
|
||||||
self.assertColumnNotExists("test_adfl_pony", "height")
|
self.assertColumnNotExists("test_adfl_pony", "height")
|
||||||
with connection.schema_editor() as editor:
|
with connection.schema_editor() as editor:
|
||||||
|
@ -144,6 +154,28 @@ class OperationTests(MigrationTestBase):
|
||||||
operation.database_backwards("test_adfl", editor, new_state, project_state)
|
operation.database_backwards("test_adfl", editor, new_state, project_state)
|
||||||
self.assertColumnNotExists("test_adfl_pony", "height")
|
self.assertColumnNotExists("test_adfl_pony", "height")
|
||||||
|
|
||||||
|
def test_add_field_preserve_default(self):
|
||||||
|
"""
|
||||||
|
Tests the AddField operation's state alteration
|
||||||
|
when preserve_default = False.
|
||||||
|
"""
|
||||||
|
project_state = self.set_up_test_model("test_adflpd")
|
||||||
|
# Test the state alteration
|
||||||
|
operation = migrations.AddField(
|
||||||
|
"Pony",
|
||||||
|
"height",
|
||||||
|
models.FloatField(null=True, default=4),
|
||||||
|
preserve_default = False,
|
||||||
|
)
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards("test_adflpd", new_state)
|
||||||
|
self.assertEqual(len(new_state.models["test_adflpd", "pony"].fields), 4)
|
||||||
|
field = [
|
||||||
|
f for n, f in new_state.models["test_adflpd", "pony"].fields
|
||||||
|
if n == "height"
|
||||||
|
][0]
|
||||||
|
self.assertEqual(field.default, NOT_PROVIDED)
|
||||||
|
|
||||||
def test_add_field_m2m(self):
|
def test_add_field_m2m(self):
|
||||||
"""
|
"""
|
||||||
Tests the AddField operation with a ManyToManyField.
|
Tests the AddField operation with a ManyToManyField.
|
||||||
|
|
Loading…
Reference in New Issue