diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index 8bc16dab2d..282dfe6374 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -44,6 +44,7 @@ class BaseDatabaseSchemaEditor(object): sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT" sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE" sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" + sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL" sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)" sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" @@ -533,12 +534,19 @@ class BaseDatabaseSchemaEditor(object): }) # Next, start accumulating actions to do actions = [] + null_actions = [] post_actions = [] # Type change? if old_type != new_type: fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type) actions.append(fragment) post_actions.extend(other_actions) + # When changing a column NULL constraint to NOT NULL with a given + # default value, we need to perform 4 steps: + # 1. Add a default for new incoming writes + # 2. Update existing NULL rows with new default + # 3. Replace NULL constraint with NOT NULL + # 4. Drop the default again. # Default change? old_default = self.effective_default(old_field) new_default = self.effective_default(new_field) @@ -573,7 +581,7 @@ class BaseDatabaseSchemaEditor(object): # Nullability change? if old_field.null != new_field.null: if new_field.null: - actions.append(( + null_actions.append(( self.sql_alter_column_null % { "column": self.quote_name(new_field.column), "type": new_type, @@ -581,14 +589,23 @@ class BaseDatabaseSchemaEditor(object): [], )) else: - actions.append(( + null_actions.append(( self.sql_alter_column_not_null % { "column": self.quote_name(new_field.column), "type": new_type, }, [], )) - if actions: + # Only if we have a default and there is a change from NULL to NOT NULL + four_way_default_alteration = ( + new_field.has_default() and + (old_field.null and not new_field.null) + ) + if actions or null_actions: + if not four_way_default_alteration: + # If we don't have to do a 4-way default alteration we can + # directly run a (NOT) NULL alteration + actions = actions + null_actions # Combine actions together if we can (e.g. postgres) if self.connection.features.supports_combined_alters: sql, params = tuple(zip(*actions)) @@ -602,6 +619,26 @@ class BaseDatabaseSchemaEditor(object): }, params, ) + if four_way_default_alteration: + # Update existing rows with default value + self.execute( + self.sql_update_with_default % { + "table": self.quote_name(model._meta.db_table), + "column": self.quote_name(new_field.column), + "default": "%s", + }, + [new_default], + ) + # Since we didn't run a NOT NULL change before we need to do it + # now + for sql, params in null_actions: + self.execute( + self.sql_alter_column % { + "table": self.quote_name(model._meta.db_table), + "changes": sql, + }, + params, + ) if post_actions: for sql, params in post_actions: self.execute(sql, params) diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index cbc39cd3de..f229288ea9 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -78,7 +78,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): del body[old_field.name] del mapping[old_field.column] body[new_field.name] = new_field - mapping[new_field.column] = self.quote_name(old_field.column) + if old_field.null and not new_field.null: + case_sql = "coalesce(%(col)s, %(default)s)" % { + 'col': self.quote_name(old_field.column), + 'default': self.quote_value(self.effective_default(new_field)) + } + mapping[new_field.column] = case_sql + else: + mapping[new_field.column] = self.quote_name(old_field.column) rename_mapping[old_field.name] = new_field.name # Remove any deleted fields for field in delete_fields: diff --git a/django/db/migrations/autodetector.py b/django/db/migrations/autodetector.py index 3e3fe83eb8..9674222aee 100644 --- a/django/db/migrations/autodetector.py +++ b/django/db/migrations/autodetector.py @@ -6,8 +6,8 @@ import datetime from itertools import chain from django.utils import six -from django.db import models from django.conf import settings +from django.db import models from django.db.migrations import operations from django.db.migrations.migration import Migration from django.db.migrations.questioner import MigrationQuestioner @@ -838,7 +838,6 @@ class MigrationAutodetector(object): for app_label, model_name, field_name in sorted(self.old_field_keys.intersection(self.new_field_keys)): # Did the field change? old_model_name = self.renamed_models.get((app_label, model_name), model_name) - new_model_state = self.to_state.models[app_label, model_name] old_field_name = self.renamed_fields.get((app_label, model_name, field_name), field_name) old_field = self.old_apps.get_model(app_label, old_model_name)._meta.get_field_by_name(old_field_name)[0] new_field = self.new_apps.get_model(app_label, model_name)._meta.get_field_by_name(field_name)[0] @@ -854,12 +853,23 @@ class MigrationAutodetector(object): old_field_dec = self.deep_deconstruct(old_field) new_field_dec = self.deep_deconstruct(new_field) if old_field_dec != new_field_dec: + preserve_default = True + if (old_field.null and not new_field.null and not new_field.has_default() and + not isinstance(new_field, models.ManyToManyField)): + field = new_field.clone() + new_default = self.questioner.ask_not_null_alteration(field_name, model_name) + if new_default is not models.NOT_PROVIDED: + field.default = new_default + preserve_default = False + else: + field = new_field self.add_operation( app_label, operations.AlterField( model_name=model_name, name=field_name, - field=new_model_state.get_field_by_name(field_name), + field=field, + preserve_default=preserve_default, ) ) diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index 55f27c1b6e..500a8397cb 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -104,14 +104,20 @@ class AlterField(Operation): Alters a field's database column (e.g. null, max_length) to the provided new field """ - def __init__(self, model_name, name, field): + def __init__(self, model_name, name, field, preserve_default=True): self.model_name = model_name self.name = name self.field = field + self.preserve_default = preserve_default def state_forwards(self, app_label, state): + if not self.preserve_default: + field = self.field.clone() + field.default = NOT_PROVIDED + else: + field = self.field state.models[app_label, self.model_name.lower()].fields = [ - (n, self.field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields + (n, field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields ] def database_forwards(self, app_label, schema_editor, from_state, to_state): @@ -128,7 +134,11 @@ class AlterField(Operation): from_field.rel.to = to_field.rel.to elif to_field.rel and isinstance(to_field.rel.to, six.string_types): to_field.rel.to = from_field.rel.to + if not self.preserve_default: + to_field.default = self.field.default schema_editor.alter_field(from_model, from_field, to_field) + if not self.preserve_default: + to_field.default = NOT_PROVIDED def database_backwards(self, app_label, schema_editor, from_state, to_state): self.database_forwards(app_label, schema_editor, from_state, to_state) diff --git a/django/db/migrations/questioner.py b/django/db/migrations/questioner.py index 3b31796cbc..54a4c80635 100644 --- a/django/db/migrations/questioner.py +++ b/django/db/migrations/questioner.py @@ -1,10 +1,11 @@ -from __future__ import unicode_literals +from __future__ import print_function, unicode_literals import importlib import os import sys from django.apps import apps +from django.db.models.fields import NOT_PROVIDED from django.utils import datetime_safe, six, timezone from django.utils.six.moves import input @@ -55,6 +56,11 @@ class MigrationQuestioner(object): # None means quit return None + def ask_not_null_alteration(self, field_name, model_name): + "Changing a NULL field to NOT NULL" + # None means quit + return None + def ask_rename(self, model_name, old_name, new_name, field_instance): "Was this field really renamed?" return self.defaults.get("ask_rename", False) @@ -92,13 +98,34 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): pass result = input("Please select a valid option: ") + def _ask_default(self): + print("Please enter the default value now, as valid Python") + print("The datetime and django.utils.timezone modules are available, so you can do e.g. timezone.now()") + while True: + if six.PY3: + # Six does not correctly abstract over the fact that + # py3 input returns a unicode string, while py2 raw_input + # returns a bytestring. + code = input(">>> ") + else: + code = input(">>> ").decode(sys.stdin.encoding) + if not code: + print("Please enter some code, or 'exit' (with no quotes) to exit.") + elif code == "exit": + sys.exit(1) + else: + try: + return eval(code, {}, {"datetime": datetime_safe, "timezone": timezone}) + except (SyntaxError, NameError) as e: + print("Invalid input: %s" % e) + def ask_not_null_addition(self, field_name, model_name): "Adding a NOT NULL field to a model" if not self.dry_run: choice = self._choice_input( - "You are trying to add a non-nullable field '%s' to %s without a default;\n" % (field_name, model_name) + - "we can't do that (the database needs something to populate existing rows).\n" + - "Please select a fix:", + "You are trying to add a non-nullable field '%s' to %s without a default; " + "we can't do that (the database needs something to populate existing rows).\n" + "Please select a fix:" % (field_name, model_name), [ "Provide a one-off default now (will be set on all existing rows)", "Quit, and let me add a default in models.py", @@ -107,26 +134,31 @@ class InteractiveMigrationQuestioner(MigrationQuestioner): if choice == 2: sys.exit(3) else: - print("Please enter the default value now, as valid Python") - print("The datetime and django.utils.timezone modules are " - "available, so you can do e.g. timezone.now()") - while True: - if six.PY3: - # Six does not correctly abstract over the fact that - # py3 input returns a unicode string, while py2 raw_input - # returns a bytestring. - code = input(">>> ") - else: - code = input(">>> ").decode(sys.stdin.encoding) - if not code: - print("Please enter some code, or 'exit' (with no quotes) to exit.") - elif code == "exit": - sys.exit(1) - else: - try: - return eval(code, {}, {"datetime": datetime_safe, "timezone": timezone}) - except (SyntaxError, NameError) as e: - print("Invalid input: %s" % e) + return self._ask_default() + return None + + def ask_not_null_alteration(self, field_name, model_name): + "Changing a NULL field to NOT NULL" + if not self.dry_run: + choice = self._choice_input( + "You are trying to change the nullable field '%s' on %s to non-nullable " + "without a default; we can't do that (the database needs something to " + "populate existing rows).\n" + "Please select a fix:" % (field_name, model_name), + [ + "Provide a one-off default now (will be set on all existing rows)", + ("Ignore for now, and let me handle existing rows with NULL myself " + "(e.g. adding a RunPython or RunSQL operation in the new migration " + "file before the AlterField operation)"), + "Quit, and let me add a default in models.py", + ] + ) + if choice == 2: + return NOT_PROVIDED + elif choice == 3: + sys.exit(3) + else: + return self._ask_default() return None def ask_rename(self, model_name, old_name, new_name, field_instance): diff --git a/docs/ref/migration-operations.txt b/docs/ref/migration-operations.txt index 6998bdb574..ef5f2a44d9 100644 --- a/docs/ref/migration-operations.txt +++ b/docs/ref/migration-operations.txt @@ -137,7 +137,7 @@ or if it is temporary and just for this migration (``False``) - usually because the migration is adding a non-nullable field to a table and needs a default value to put into existing rows. It does not effect the behavior of setting defaults in the database directly - Django never sets database -defaults, and always applies them in the Django ORM code. +defaults and always applies them in the Django ORM code. RemoveField ----------- @@ -153,16 +153,28 @@ from any data loss, which of course is irreversible). AlterField ---------- -.. class:: AlterField(model_name, name, field) +.. class:: AlterField(model_name, name, field, preserve_default=True) Alters a field's definition, including changes to its type, :attr:`~django.db.models.Field.null`, :attr:`~django.db.models.Field.unique`, :attr:`~django.db.models.Field.db_column` and other field attributes. +The ``preserve_default`` argument indicates whether the field's default +value is permanent and should be baked into the project state (``True``), +or if it is temporary and just for this migration (``False``) - usually +because the migration is altering a nullable field to a non-nullable one and +needs a default value to put into existing rows. It does not effect the +behavior of setting defaults in the database directly - Django never sets +database defaults and always applies them in the Django ORM code. + Note that not all changes are possible on all databases - for example, you cannot change a text-type field like ``models.TextField()`` into a number-type field like ``models.IntegerField()`` on most databases. +.. versionchanged:: 1.7.1 + + The ``preserve_default`` argument was added. + RenameField ----------- diff --git a/docs/releases/1.7.1.txt b/docs/releases/1.7.1.txt index 01edc4c8f5..0d6cd3d3b7 100644 --- a/docs/releases/1.7.1.txt +++ b/docs/releases/1.7.1.txt @@ -106,3 +106,7 @@ Bugfixes * Made :func:`~django.utils.http.urlsafe_base64_decode` return the proper type (byte string) on Python 3 (:ticket:`23333`). + +* Added a prompt to the migrations questioner when removing the null constraint + from a field to prevent an IntegrityError on existing NULL rows + (:ticket:`23609`). diff --git a/tests/migrations/test_autodetector.py b/tests/migrations/test_autodetector.py index 4138d859d1..25b11e4f08 100644 --- a/tests/migrations/test_autodetector.py +++ b/tests/migrations/test_autodetector.py @@ -26,6 +26,7 @@ class AutodetectorTests(TestCase): author_empty = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True))]) author_name = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200))]) + author_name_null = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200, null=True))]) author_name_longer = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=400))]) author_name_renamed = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("names", models.CharField(max_length=200))]) author_name_default = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200, default='Ada Lovelace'))]) @@ -302,6 +303,80 @@ class AutodetectorTests(TestCase): action = migration.operations[0] self.assertEqual(action.__class__.__name__, "AlterField") self.assertEqual(action.name, "name") + self.assertTrue(action.preserve_default) + + def test_alter_field_to_not_null_with_default(self): + "#23609 - Tests autodetection of nullable to non-nullable alterations" + class CustomQuestioner(MigrationQuestioner): + def ask_not_null_alteration(self, field_name, model_name): + raise Exception("Should not have prompted for not null addition") + + # Make state + before = self.make_project_state([self.author_name_null]) + after = self.make_project_state([self.author_name_default]) + autodetector = MigrationAutodetector(before, after, CustomQuestioner()) + changes = autodetector._detect_changes() + # Right number of migrations? + self.assertEqual(len(changes['testapp']), 1) + # Right number of actions? + migration = changes['testapp'][0] + self.assertEqual(len(migration.operations), 1) + # Right action? + action = migration.operations[0] + self.assertEqual(action.__class__.__name__, "AlterField") + self.assertEqual(action.name, "name") + self.assertTrue(action.preserve_default) + self.assertEqual(action.field.default, 'Ada Lovelace') + + def test_alter_field_to_not_null_without_default(self): + "#23609 - Tests autodetection of nullable to non-nullable alterations" + class CustomQuestioner(MigrationQuestioner): + def ask_not_null_alteration(self, field_name, model_name): + # Ignore for now, and let me handle existing rows with NULL + # myself (e.g. adding a RunPython or RunSQL operation in the new + # migration file before the AlterField operation) + return models.NOT_PROVIDED + + # Make state + before = self.make_project_state([self.author_name_null]) + after = self.make_project_state([self.author_name]) + autodetector = MigrationAutodetector(before, after, CustomQuestioner()) + changes = autodetector._detect_changes() + # Right number of migrations? + self.assertEqual(len(changes['testapp']), 1) + # Right number of actions? + migration = changes['testapp'][0] + self.assertEqual(len(migration.operations), 1) + # Right action? + action = migration.operations[0] + self.assertEqual(action.__class__.__name__, "AlterField") + self.assertEqual(action.name, "name") + self.assertTrue(action.preserve_default) + self.assertIs(action.field.default, models.NOT_PROVIDED) + + def test_alter_field_to_not_null_oneoff_default(self): + "#23609 - Tests autodetection of nullable to non-nullable alterations" + class CustomQuestioner(MigrationQuestioner): + def ask_not_null_alteration(self, field_name, model_name): + # Provide a one-off default now (will be set on all existing rows) + return 'Some Name' + + # Make state + before = self.make_project_state([self.author_name_null]) + after = self.make_project_state([self.author_name]) + autodetector = MigrationAutodetector(before, after, CustomQuestioner()) + changes = autodetector._detect_changes() + # Right number of migrations? + self.assertEqual(len(changes['testapp']), 1) + # Right number of actions? + migration = changes['testapp'][0] + self.assertEqual(len(migration.operations), 1) + # Right action? + action = migration.operations[0] + self.assertEqual(action.__class__.__name__, "AlterField") + self.assertEqual(action.name, "name") + self.assertFalse(action.preserve_default) + self.assertEqual(action.field.default, "Some Name") def test_rename_field(self): "Tests autodetection of renamed fields" diff --git a/tests/schema/tests.py b/tests/schema/tests.py index 88290338a1..2042f00bcd 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -3,7 +3,8 @@ import unittest from django.test import TransactionTestCase from django.db import connection, DatabaseError, IntegrityError, OperationalError -from django.db.models.fields import IntegerField, TextField, CharField, SlugField, BooleanField, BinaryField +from django.db.models.fields import (BinaryField, BooleanField, CharField, IntegerField, + PositiveIntegerField, SlugField, TextField) from django.db.models.fields.related import ManyToManyField, ForeignKey from django.db.transaction import atomic from .models import (Author, AuthorWithM2M, Book, BookWithLongName, @@ -415,6 +416,38 @@ class SchemaTests(TransactionTestCase): self.assertEqual(columns['name'][0], "TextField") self.assertEqual(bool(columns['name'][1][6]), False) + def test_alter_null_to_not_null(self): + """ + #23609 - Tests handling of default values when altering from NULL to NOT NULL. + """ + # Create the table + with connection.schema_editor() as editor: + editor.create_model(Author) + # Ensure the field is right to begin with + columns = self.column_classes(Author) + self.assertTrue(columns['height'][1][6]) + # Create some test data + Author.objects.create(name='Not null author', height=12) + Author.objects.create(name='Null author') + # Verify null value + self.assertEqual(Author.objects.get(name='Not null author').height, 12) + self.assertIsNone(Author.objects.get(name='Null author').height) + # Alter the height field to NOT NULL with default + new_field = PositiveIntegerField(default=42) + new_field.set_attributes_from_name("height") + with connection.schema_editor() as editor: + editor.alter_field( + Author, + Author._meta.get_field_by_name("height")[0], + new_field + ) + # Ensure the field is right afterwards + columns = self.column_classes(Author) + self.assertFalse(columns['height'][1][6]) + # Verify default value + self.assertEqual(Author.objects.get(name='Not null author').height, 12) + self.assertEqual(Author.objects.get(name='Null author').height, 42) + @unittest.skipUnless(connection.features.supports_foreign_keys, "No FK support") def test_alter_fk(self): """