Fixed #23609 -- Fixed IntegrityError that prevented altering a NULL column into a NOT NULL one due to existing rows

Thanks to Simon Charette, Loic Bistuer and Tim Graham for the review.
This commit is contained in:
Markus Holtermann 2014-10-07 01:53:21 +02:00 committed by Loic Bistuer
parent 15d350fbce
commit f633ba778d
9 changed files with 256 additions and 36 deletions

View File

@ -44,6 +44,7 @@ class BaseDatabaseSchemaEditor(object):
sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT"
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
@ -533,12 +534,19 @@ class BaseDatabaseSchemaEditor(object):
})
# Next, start accumulating actions to do
actions = []
null_actions = []
post_actions = []
# Type change?
if old_type != new_type:
fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type)
actions.append(fragment)
post_actions.extend(other_actions)
# When changing a column NULL constraint to NOT NULL with a given
# default value, we need to perform 4 steps:
# 1. Add a default for new incoming writes
# 2. Update existing NULL rows with new default
# 3. Replace NULL constraint with NOT NULL
# 4. Drop the default again.
# Default change?
old_default = self.effective_default(old_field)
new_default = self.effective_default(new_field)
@ -573,7 +581,7 @@ class BaseDatabaseSchemaEditor(object):
# Nullability change?
if old_field.null != new_field.null:
if new_field.null:
actions.append((
null_actions.append((
self.sql_alter_column_null % {
"column": self.quote_name(new_field.column),
"type": new_type,
@ -581,14 +589,23 @@ class BaseDatabaseSchemaEditor(object):
[],
))
else:
actions.append((
null_actions.append((
self.sql_alter_column_not_null % {
"column": self.quote_name(new_field.column),
"type": new_type,
},
[],
))
if actions:
# Only if we have a default and there is a change from NULL to NOT NULL
four_way_default_alteration = (
new_field.has_default() and
(old_field.null and not new_field.null)
)
if actions or null_actions:
if not four_way_default_alteration:
# If we don't have to do a 4-way default alteration we can
# directly run a (NOT) NULL alteration
actions = actions + null_actions
# Combine actions together if we can (e.g. postgres)
if self.connection.features.supports_combined_alters:
sql, params = tuple(zip(*actions))
@ -602,6 +619,26 @@ class BaseDatabaseSchemaEditor(object):
},
params,
)
if four_way_default_alteration:
# Update existing rows with default value
self.execute(
self.sql_update_with_default % {
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(new_field.column),
"default": "%s",
},
[new_default],
)
# Since we didn't run a NOT NULL change before we need to do it
# now
for sql, params in null_actions:
self.execute(
self.sql_alter_column % {
"table": self.quote_name(model._meta.db_table),
"changes": sql,
},
params,
)
if post_actions:
for sql, params in post_actions:
self.execute(sql, params)

View File

@ -78,6 +78,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
del body[old_field.name]
del mapping[old_field.column]
body[new_field.name] = new_field
if old_field.null and not new_field.null:
case_sql = "coalesce(%(col)s, %(default)s)" % {
'col': self.quote_name(old_field.column),
'default': self.quote_value(self.effective_default(new_field))
}
mapping[new_field.column] = case_sql
else:
mapping[new_field.column] = self.quote_name(old_field.column)
rename_mapping[old_field.name] = new_field.name
# Remove any deleted fields

View File

@ -6,8 +6,8 @@ import datetime
from itertools import chain
from django.utils import six
from django.db import models
from django.conf import settings
from django.db import models
from django.db.migrations import operations
from django.db.migrations.migration import Migration
from django.db.migrations.questioner import MigrationQuestioner
@ -838,7 +838,6 @@ class MigrationAutodetector(object):
for app_label, model_name, field_name in sorted(self.old_field_keys.intersection(self.new_field_keys)):
# Did the field change?
old_model_name = self.renamed_models.get((app_label, model_name), model_name)
new_model_state = self.to_state.models[app_label, model_name]
old_field_name = self.renamed_fields.get((app_label, model_name, field_name), field_name)
old_field = self.old_apps.get_model(app_label, old_model_name)._meta.get_field_by_name(old_field_name)[0]
new_field = self.new_apps.get_model(app_label, model_name)._meta.get_field_by_name(field_name)[0]
@ -854,12 +853,23 @@ class MigrationAutodetector(object):
old_field_dec = self.deep_deconstruct(old_field)
new_field_dec = self.deep_deconstruct(new_field)
if old_field_dec != new_field_dec:
preserve_default = True
if (old_field.null and not new_field.null and not new_field.has_default() and
not isinstance(new_field, models.ManyToManyField)):
field = new_field.clone()
new_default = self.questioner.ask_not_null_alteration(field_name, model_name)
if new_default is not models.NOT_PROVIDED:
field.default = new_default
preserve_default = False
else:
field = new_field
self.add_operation(
app_label,
operations.AlterField(
model_name=model_name,
name=field_name,
field=new_model_state.get_field_by_name(field_name),
field=field,
preserve_default=preserve_default,
)
)

View File

@ -104,14 +104,20 @@ class AlterField(Operation):
Alters a field's database column (e.g. null, max_length) to the provided new field
"""
def __init__(self, model_name, name, field):
def __init__(self, model_name, name, field, preserve_default=True):
self.model_name = model_name
self.name = name
self.field = field
self.preserve_default = preserve_default
def state_forwards(self, app_label, state):
if not self.preserve_default:
field = self.field.clone()
field.default = NOT_PROVIDED
else:
field = self.field
state.models[app_label, self.model_name.lower()].fields = [
(n, self.field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields
(n, field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields
]
def database_forwards(self, app_label, schema_editor, from_state, to_state):
@ -128,7 +134,11 @@ class AlterField(Operation):
from_field.rel.to = to_field.rel.to
elif to_field.rel and isinstance(to_field.rel.to, six.string_types):
to_field.rel.to = from_field.rel.to
if not self.preserve_default:
to_field.default = self.field.default
schema_editor.alter_field(from_model, from_field, to_field)
if not self.preserve_default:
to_field.default = NOT_PROVIDED
def database_backwards(self, app_label, schema_editor, from_state, to_state):
self.database_forwards(app_label, schema_editor, from_state, to_state)

View File

@ -1,10 +1,11 @@
from __future__ import unicode_literals
from __future__ import print_function, unicode_literals
import importlib
import os
import sys
from django.apps import apps
from django.db.models.fields import NOT_PROVIDED
from django.utils import datetime_safe, six, timezone
from django.utils.six.moves import input
@ -55,6 +56,11 @@ class MigrationQuestioner(object):
# None means quit
return None
def ask_not_null_alteration(self, field_name, model_name):
"Changing a NULL field to NOT NULL"
# None means quit
return None
def ask_rename(self, model_name, old_name, new_name, field_instance):
"Was this field really renamed?"
return self.defaults.get("ask_rename", False)
@ -92,24 +98,9 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
pass
result = input("Please select a valid option: ")
def ask_not_null_addition(self, field_name, model_name):
"Adding a NOT NULL field to a model"
if not self.dry_run:
choice = self._choice_input(
"You are trying to add a non-nullable field '%s' to %s without a default;\n" % (field_name, model_name) +
"we can't do that (the database needs something to populate existing rows).\n" +
"Please select a fix:",
[
"Provide a one-off default now (will be set on all existing rows)",
"Quit, and let me add a default in models.py",
]
)
if choice == 2:
sys.exit(3)
else:
def _ask_default(self):
print("Please enter the default value now, as valid Python")
print("The datetime and django.utils.timezone modules are "
"available, so you can do e.g. timezone.now()")
print("The datetime and django.utils.timezone modules are available, so you can do e.g. timezone.now()")
while True:
if six.PY3:
# Six does not correctly abstract over the fact that
@ -127,6 +118,47 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
return eval(code, {}, {"datetime": datetime_safe, "timezone": timezone})
except (SyntaxError, NameError) as e:
print("Invalid input: %s" % e)
def ask_not_null_addition(self, field_name, model_name):
"Adding a NOT NULL field to a model"
if not self.dry_run:
choice = self._choice_input(
"You are trying to add a non-nullable field '%s' to %s without a default; "
"we can't do that (the database needs something to populate existing rows).\n"
"Please select a fix:" % (field_name, model_name),
[
"Provide a one-off default now (will be set on all existing rows)",
"Quit, and let me add a default in models.py",
]
)
if choice == 2:
sys.exit(3)
else:
return self._ask_default()
return None
def ask_not_null_alteration(self, field_name, model_name):
"Changing a NULL field to NOT NULL"
if not self.dry_run:
choice = self._choice_input(
"You are trying to change the nullable field '%s' on %s to non-nullable "
"without a default; we can't do that (the database needs something to "
"populate existing rows).\n"
"Please select a fix:" % (field_name, model_name),
[
"Provide a one-off default now (will be set on all existing rows)",
("Ignore for now, and let me handle existing rows with NULL myself "
"(e.g. adding a RunPython or RunSQL operation in the new migration "
"file before the AlterField operation)"),
"Quit, and let me add a default in models.py",
]
)
if choice == 2:
return NOT_PROVIDED
elif choice == 3:
sys.exit(3)
else:
return self._ask_default()
return None
def ask_rename(self, model_name, old_name, new_name, field_instance):

View File

@ -137,7 +137,7 @@ or if it is temporary and just for this migration (``False``) - usually
because the migration is adding a non-nullable field to a table and needs
a default value to put into existing rows. It does not effect the behavior
of setting defaults in the database directly - Django never sets database
defaults, and always applies them in the Django ORM code.
defaults and always applies them in the Django ORM code.
RemoveField
-----------
@ -153,16 +153,28 @@ from any data loss, which of course is irreversible).
AlterField
----------
.. class:: AlterField(model_name, name, field)
.. class:: AlterField(model_name, name, field, preserve_default=True)
Alters a field's definition, including changes to its type,
:attr:`~django.db.models.Field.null`, :attr:`~django.db.models.Field.unique`,
:attr:`~django.db.models.Field.db_column` and other field attributes.
The ``preserve_default`` argument indicates whether the field's default
value is permanent and should be baked into the project state (``True``),
or if it is temporary and just for this migration (``False``) - usually
because the migration is altering a nullable field to a non-nullable one and
needs a default value to put into existing rows. It does not effect the
behavior of setting defaults in the database directly - Django never sets
database defaults and always applies them in the Django ORM code.
Note that not all changes are possible on all databases - for example, you
cannot change a text-type field like ``models.TextField()`` into a number-type
field like ``models.IntegerField()`` on most databases.
.. versionchanged:: 1.7.1
The ``preserve_default`` argument was added.
RenameField
-----------

View File

@ -106,3 +106,7 @@ Bugfixes
* Made :func:`~django.utils.http.urlsafe_base64_decode` return the proper
type (byte string) on Python 3 (:ticket:`23333`).
* Added a prompt to the migrations questioner when removing the null constraint
from a field to prevent an IntegrityError on existing NULL rows
(:ticket:`23609`).

View File

@ -26,6 +26,7 @@ class AutodetectorTests(TestCase):
author_empty = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True))])
author_name = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200))])
author_name_null = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200, null=True))])
author_name_longer = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=400))])
author_name_renamed = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("names", models.CharField(max_length=200))])
author_name_default = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200, default='Ada Lovelace'))])
@ -302,6 +303,80 @@ class AutodetectorTests(TestCase):
action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "AlterField")
self.assertEqual(action.name, "name")
self.assertTrue(action.preserve_default)
def test_alter_field_to_not_null_with_default(self):
"#23609 - Tests autodetection of nullable to non-nullable alterations"
class CustomQuestioner(MigrationQuestioner):
def ask_not_null_alteration(self, field_name, model_name):
raise Exception("Should not have prompted for not null addition")
# Make state
before = self.make_project_state([self.author_name_null])
after = self.make_project_state([self.author_name_default])
autodetector = MigrationAutodetector(before, after, CustomQuestioner())
changes = autodetector._detect_changes()
# Right number of migrations?
self.assertEqual(len(changes['testapp']), 1)
# Right number of actions?
migration = changes['testapp'][0]
self.assertEqual(len(migration.operations), 1)
# Right action?
action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "AlterField")
self.assertEqual(action.name, "name")
self.assertTrue(action.preserve_default)
self.assertEqual(action.field.default, 'Ada Lovelace')
def test_alter_field_to_not_null_without_default(self):
"#23609 - Tests autodetection of nullable to non-nullable alterations"
class CustomQuestioner(MigrationQuestioner):
def ask_not_null_alteration(self, field_name, model_name):
# Ignore for now, and let me handle existing rows with NULL
# myself (e.g. adding a RunPython or RunSQL operation in the new
# migration file before the AlterField operation)
return models.NOT_PROVIDED
# Make state
before = self.make_project_state([self.author_name_null])
after = self.make_project_state([self.author_name])
autodetector = MigrationAutodetector(before, after, CustomQuestioner())
changes = autodetector._detect_changes()
# Right number of migrations?
self.assertEqual(len(changes['testapp']), 1)
# Right number of actions?
migration = changes['testapp'][0]
self.assertEqual(len(migration.operations), 1)
# Right action?
action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "AlterField")
self.assertEqual(action.name, "name")
self.assertTrue(action.preserve_default)
self.assertIs(action.field.default, models.NOT_PROVIDED)
def test_alter_field_to_not_null_oneoff_default(self):
"#23609 - Tests autodetection of nullable to non-nullable alterations"
class CustomQuestioner(MigrationQuestioner):
def ask_not_null_alteration(self, field_name, model_name):
# Provide a one-off default now (will be set on all existing rows)
return 'Some Name'
# Make state
before = self.make_project_state([self.author_name_null])
after = self.make_project_state([self.author_name])
autodetector = MigrationAutodetector(before, after, CustomQuestioner())
changes = autodetector._detect_changes()
# Right number of migrations?
self.assertEqual(len(changes['testapp']), 1)
# Right number of actions?
migration = changes['testapp'][0]
self.assertEqual(len(migration.operations), 1)
# Right action?
action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "AlterField")
self.assertEqual(action.name, "name")
self.assertFalse(action.preserve_default)
self.assertEqual(action.field.default, "Some Name")
def test_rename_field(self):
"Tests autodetection of renamed fields"

View File

@ -3,7 +3,8 @@ import unittest
from django.test import TransactionTestCase
from django.db import connection, DatabaseError, IntegrityError, OperationalError
from django.db.models.fields import IntegerField, TextField, CharField, SlugField, BooleanField, BinaryField
from django.db.models.fields import (BinaryField, BooleanField, CharField, IntegerField,
PositiveIntegerField, SlugField, TextField)
from django.db.models.fields.related import ManyToManyField, ForeignKey
from django.db.transaction import atomic
from .models import (Author, AuthorWithM2M, Book, BookWithLongName,
@ -415,6 +416,38 @@ class SchemaTests(TransactionTestCase):
self.assertEqual(columns['name'][0], "TextField")
self.assertEqual(bool(columns['name'][1][6]), False)
def test_alter_null_to_not_null(self):
"""
#23609 - Tests handling of default values when altering from NULL to NOT NULL.
"""
# Create the table
with connection.schema_editor() as editor:
editor.create_model(Author)
# Ensure the field is right to begin with
columns = self.column_classes(Author)
self.assertTrue(columns['height'][1][6])
# Create some test data
Author.objects.create(name='Not null author', height=12)
Author.objects.create(name='Null author')
# Verify null value
self.assertEqual(Author.objects.get(name='Not null author').height, 12)
self.assertIsNone(Author.objects.get(name='Null author').height)
# Alter the height field to NOT NULL with default
new_field = PositiveIntegerField(default=42)
new_field.set_attributes_from_name("height")
with connection.schema_editor() as editor:
editor.alter_field(
Author,
Author._meta.get_field_by_name("height")[0],
new_field
)
# Ensure the field is right afterwards
columns = self.column_classes(Author)
self.assertFalse(columns['height'][1][6])
# Verify default value
self.assertEqual(Author.objects.get(name='Not null author').height, 12)
self.assertEqual(Author.objects.get(name='Null author').height, 42)
@unittest.skipUnless(connection.features.supports_foreign_keys, "No FK support")
def test_alter_fk(self):
"""