diff --git a/django/db/migrations/autodetector.py b/django/db/migrations/autodetector.py index 083351a6f10..f8ec49ab0a7 100644 --- a/django/db/migrations/autodetector.py +++ b/django/db/migrations/autodetector.py @@ -188,15 +188,26 @@ class MigrationAutodetector(object): continue # You can't just add NOT NULL fields with no default if not field.null and not field.has_default(): + field = field.clone() field.default = self.questioner.ask_not_null_addition(field_name, model_name) - self.add_to_migration( - app_label, - operations.AddField( - model_name=model_name, - name=field_name, - field=field, + self.add_to_migration( + app_label, + operations.AddField( + model_name=model_name, + name=field_name, + field=field, + preserve_default=False, + ) + ) + else: + self.add_to_migration( + app_label, + operations.AddField( + model_name=model_name, + name=field_name, + field=field, + ) ) - ) # Old fields for field_name in old_field_names - new_field_names: self.add_to_migration( @@ -434,7 +445,8 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): "Adding a NOT NULL field to a model" choice = self._choice_input( "You are trying to add a non-nullable field '%s' to %s without a default;\n" % (field_name, model_name) + - "this is not possible. Please select a fix:", + "we can't do that (the database needs something to populate existing rows).\n" + + "Please select a fix:", [ "Provide a one-off default now (will be set on all existing rows)", "Quit, and let me add a default in models.py", diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index b609110f853..c5f0bd1e2b2 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -1,4 +1,5 @@ from django.db import router +from django.db.models.fields import NOT_PROVIDED from .base import Operation @@ -7,13 +8,20 @@ class AddField(Operation): Adds a field to a model. """ - def __init__(self, model_name, name, field): + def __init__(self, model_name, name, field, preserve_default=True): self.model_name = model_name self.name = name self.field = field + self.preserve_default = preserve_default def state_forwards(self, app_label, state): - state.models[app_label, self.model_name.lower()].fields.append((self.name, self.field)) + # If preserve default is off, don't use the default for future state + if not self.preserve_default: + field = self.field.clone() + field.default = NOT_PROVIDED + else: + field = self.field + state.models[app_label, self.model_name.lower()].fields.append((self.name, field)) def database_forwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.render().get_model(app_label, self.model_name) diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index fb386050303..2dea14b445a 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -1,4 +1,5 @@ from django.db import connection, models, migrations, router +from django.db.models.fields import NOT_PROVIDED from django.db.transaction import atomic from django.db.utils import IntegrityError from django.db.migrations.state import ProjectState @@ -130,10 +131,19 @@ class OperationTests(MigrationTestBase): """ project_state = self.set_up_test_model("test_adfl") # Test the state alteration - operation = migrations.AddField("Pony", "height", models.FloatField(null=True)) + operation = migrations.AddField( + "Pony", + "height", + models.FloatField(null=True, default=5), + ) new_state = project_state.clone() operation.state_forwards("test_adfl", new_state) self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 4) + field = [ + f for n, f in new_state.models["test_adfl", "pony"].fields + if n == "height" + ][0] + self.assertEqual(field.default, 5) # Test the database alteration self.assertColumnNotExists("test_adfl_pony", "height") with connection.schema_editor() as editor: @@ -144,6 +154,28 @@ class OperationTests(MigrationTestBase): operation.database_backwards("test_adfl", editor, new_state, project_state) self.assertColumnNotExists("test_adfl_pony", "height") + def test_add_field_preserve_default(self): + """ + Tests the AddField operation's state alteration + when preserve_default = False. + """ + project_state = self.set_up_test_model("test_adflpd") + # Test the state alteration + operation = migrations.AddField( + "Pony", + "height", + models.FloatField(null=True, default=4), + preserve_default = False, + ) + new_state = project_state.clone() + operation.state_forwards("test_adflpd", new_state) + self.assertEqual(len(new_state.models["test_adflpd", "pony"].fields), 4) + field = [ + f for n, f in new_state.models["test_adflpd", "pony"].fields + if n == "height" + ][0] + self.assertEqual(field.default, NOT_PROVIDED) + def test_add_field_m2m(self): """ Tests the AddField operation with a ManyToManyField.