Fixed #23609 -- Fixed IntegrityError that prevented altering a NULL column into a NOT NULL one due to existing rows

Thanks to Simon Charette, Loic Bistuer and Tim Graham for the review.
This commit is contained in:
Markus Holtermann 2014-10-07 01:53:21 +02:00 committed by Loic Bistuer
parent 15d350fbce
commit f633ba778d
9 changed files with 256 additions and 36 deletions

View File

@ -44,6 +44,7 @@ class BaseDatabaseSchemaEditor(object):
sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT" sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT"
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE" sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
sql_update_with_default = "UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)" sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
@ -533,12 +534,19 @@ class BaseDatabaseSchemaEditor(object):
}) })
# Next, start accumulating actions to do # Next, start accumulating actions to do
actions = [] actions = []
null_actions = []
post_actions = [] post_actions = []
# Type change? # Type change?
if old_type != new_type: if old_type != new_type:
fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type) fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type)
actions.append(fragment) actions.append(fragment)
post_actions.extend(other_actions) post_actions.extend(other_actions)
# When changing a column NULL constraint to NOT NULL with a given
# default value, we need to perform 4 steps:
# 1. Add a default for new incoming writes
# 2. Update existing NULL rows with new default
# 3. Replace NULL constraint with NOT NULL
# 4. Drop the default again.
# Default change? # Default change?
old_default = self.effective_default(old_field) old_default = self.effective_default(old_field)
new_default = self.effective_default(new_field) new_default = self.effective_default(new_field)
@ -573,7 +581,7 @@ class BaseDatabaseSchemaEditor(object):
# Nullability change? # Nullability change?
if old_field.null != new_field.null: if old_field.null != new_field.null:
if new_field.null: if new_field.null:
actions.append(( null_actions.append((
self.sql_alter_column_null % { self.sql_alter_column_null % {
"column": self.quote_name(new_field.column), "column": self.quote_name(new_field.column),
"type": new_type, "type": new_type,
@ -581,14 +589,23 @@ class BaseDatabaseSchemaEditor(object):
[], [],
)) ))
else: else:
actions.append(( null_actions.append((
self.sql_alter_column_not_null % { self.sql_alter_column_not_null % {
"column": self.quote_name(new_field.column), "column": self.quote_name(new_field.column),
"type": new_type, "type": new_type,
}, },
[], [],
)) ))
if actions: # Only if we have a default and there is a change from NULL to NOT NULL
four_way_default_alteration = (
new_field.has_default() and
(old_field.null and not new_field.null)
)
if actions or null_actions:
if not four_way_default_alteration:
# If we don't have to do a 4-way default alteration we can
# directly run a (NOT) NULL alteration
actions = actions + null_actions
# Combine actions together if we can (e.g. postgres) # Combine actions together if we can (e.g. postgres)
if self.connection.features.supports_combined_alters: if self.connection.features.supports_combined_alters:
sql, params = tuple(zip(*actions)) sql, params = tuple(zip(*actions))
@ -602,6 +619,26 @@ class BaseDatabaseSchemaEditor(object):
}, },
params, params,
) )
if four_way_default_alteration:
# Update existing rows with default value
self.execute(
self.sql_update_with_default % {
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(new_field.column),
"default": "%s",
},
[new_default],
)
# Since we didn't run a NOT NULL change before we need to do it
# now
for sql, params in null_actions:
self.execute(
self.sql_alter_column % {
"table": self.quote_name(model._meta.db_table),
"changes": sql,
},
params,
)
if post_actions: if post_actions:
for sql, params in post_actions: for sql, params in post_actions:
self.execute(sql, params) self.execute(sql, params)

View File

@ -78,7 +78,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
del body[old_field.name] del body[old_field.name]
del mapping[old_field.column] del mapping[old_field.column]
body[new_field.name] = new_field body[new_field.name] = new_field
mapping[new_field.column] = self.quote_name(old_field.column) if old_field.null and not new_field.null:
case_sql = "coalesce(%(col)s, %(default)s)" % {
'col': self.quote_name(old_field.column),
'default': self.quote_value(self.effective_default(new_field))
}
mapping[new_field.column] = case_sql
else:
mapping[new_field.column] = self.quote_name(old_field.column)
rename_mapping[old_field.name] = new_field.name rename_mapping[old_field.name] = new_field.name
# Remove any deleted fields # Remove any deleted fields
for field in delete_fields: for field in delete_fields:

View File

@ -6,8 +6,8 @@ import datetime
from itertools import chain from itertools import chain
from django.utils import six from django.utils import six
from django.db import models
from django.conf import settings from django.conf import settings
from django.db import models
from django.db.migrations import operations from django.db.migrations import operations
from django.db.migrations.migration import Migration from django.db.migrations.migration import Migration
from django.db.migrations.questioner import MigrationQuestioner from django.db.migrations.questioner import MigrationQuestioner
@ -838,7 +838,6 @@ class MigrationAutodetector(object):
for app_label, model_name, field_name in sorted(self.old_field_keys.intersection(self.new_field_keys)): for app_label, model_name, field_name in sorted(self.old_field_keys.intersection(self.new_field_keys)):
# Did the field change? # Did the field change?
old_model_name = self.renamed_models.get((app_label, model_name), model_name) old_model_name = self.renamed_models.get((app_label, model_name), model_name)
new_model_state = self.to_state.models[app_label, model_name]
old_field_name = self.renamed_fields.get((app_label, model_name, field_name), field_name) old_field_name = self.renamed_fields.get((app_label, model_name, field_name), field_name)
old_field = self.old_apps.get_model(app_label, old_model_name)._meta.get_field_by_name(old_field_name)[0] old_field = self.old_apps.get_model(app_label, old_model_name)._meta.get_field_by_name(old_field_name)[0]
new_field = self.new_apps.get_model(app_label, model_name)._meta.get_field_by_name(field_name)[0] new_field = self.new_apps.get_model(app_label, model_name)._meta.get_field_by_name(field_name)[0]
@ -854,12 +853,23 @@ class MigrationAutodetector(object):
old_field_dec = self.deep_deconstruct(old_field) old_field_dec = self.deep_deconstruct(old_field)
new_field_dec = self.deep_deconstruct(new_field) new_field_dec = self.deep_deconstruct(new_field)
if old_field_dec != new_field_dec: if old_field_dec != new_field_dec:
preserve_default = True
if (old_field.null and not new_field.null and not new_field.has_default() and
not isinstance(new_field, models.ManyToManyField)):
field = new_field.clone()
new_default = self.questioner.ask_not_null_alteration(field_name, model_name)
if new_default is not models.NOT_PROVIDED:
field.default = new_default
preserve_default = False
else:
field = new_field
self.add_operation( self.add_operation(
app_label, app_label,
operations.AlterField( operations.AlterField(
model_name=model_name, model_name=model_name,
name=field_name, name=field_name,
field=new_model_state.get_field_by_name(field_name), field=field,
preserve_default=preserve_default,
) )
) )

View File

@ -104,14 +104,20 @@ class AlterField(Operation):
Alters a field's database column (e.g. null, max_length) to the provided new field Alters a field's database column (e.g. null, max_length) to the provided new field
""" """
def __init__(self, model_name, name, field): def __init__(self, model_name, name, field, preserve_default=True):
self.model_name = model_name self.model_name = model_name
self.name = name self.name = name
self.field = field self.field = field
self.preserve_default = preserve_default
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
if not self.preserve_default:
field = self.field.clone()
field.default = NOT_PROVIDED
else:
field = self.field
state.models[app_label, self.model_name.lower()].fields = [ state.models[app_label, self.model_name.lower()].fields = [
(n, self.field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields (n, field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields
] ]
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
@ -128,7 +134,11 @@ class AlterField(Operation):
from_field.rel.to = to_field.rel.to from_field.rel.to = to_field.rel.to
elif to_field.rel and isinstance(to_field.rel.to, six.string_types): elif to_field.rel and isinstance(to_field.rel.to, six.string_types):
to_field.rel.to = from_field.rel.to to_field.rel.to = from_field.rel.to
if not self.preserve_default:
to_field.default = self.field.default
schema_editor.alter_field(from_model, from_field, to_field) schema_editor.alter_field(from_model, from_field, to_field)
if not self.preserve_default:
to_field.default = NOT_PROVIDED
def database_backwards(self, app_label, schema_editor, from_state, to_state): def database_backwards(self, app_label, schema_editor, from_state, to_state):
self.database_forwards(app_label, schema_editor, from_state, to_state) self.database_forwards(app_label, schema_editor, from_state, to_state)

View File

@ -1,10 +1,11 @@
from __future__ import unicode_literals from __future__ import print_function, unicode_literals
import importlib import importlib
import os import os
import sys import sys
from django.apps import apps from django.apps import apps
from django.db.models.fields import NOT_PROVIDED
from django.utils import datetime_safe, six, timezone from django.utils import datetime_safe, six, timezone
from django.utils.six.moves import input from django.utils.six.moves import input
@ -55,6 +56,11 @@ class MigrationQuestioner(object):
# None means quit # None means quit
return None return None
def ask_not_null_alteration(self, field_name, model_name):
"Changing a NULL field to NOT NULL"
# None means quit
return None
def ask_rename(self, model_name, old_name, new_name, field_instance): def ask_rename(self, model_name, old_name, new_name, field_instance):
"Was this field really renamed?" "Was this field really renamed?"
return self.defaults.get("ask_rename", False) return self.defaults.get("ask_rename", False)
@ -92,13 +98,34 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
pass pass
result = input("Please select a valid option: ") result = input("Please select a valid option: ")
def _ask_default(self):
print("Please enter the default value now, as valid Python")
print("The datetime and django.utils.timezone modules are available, so you can do e.g. timezone.now()")
while True:
if six.PY3:
# Six does not correctly abstract over the fact that
# py3 input returns a unicode string, while py2 raw_input
# returns a bytestring.
code = input(">>> ")
else:
code = input(">>> ").decode(sys.stdin.encoding)
if not code:
print("Please enter some code, or 'exit' (with no quotes) to exit.")
elif code == "exit":
sys.exit(1)
else:
try:
return eval(code, {}, {"datetime": datetime_safe, "timezone": timezone})
except (SyntaxError, NameError) as e:
print("Invalid input: %s" % e)
def ask_not_null_addition(self, field_name, model_name): def ask_not_null_addition(self, field_name, model_name):
"Adding a NOT NULL field to a model" "Adding a NOT NULL field to a model"
if not self.dry_run: if not self.dry_run:
choice = self._choice_input( choice = self._choice_input(
"You are trying to add a non-nullable field '%s' to %s without a default;\n" % (field_name, model_name) + "You are trying to add a non-nullable field '%s' to %s without a default; "
"we can't do that (the database needs something to populate existing rows).\n" + "we can't do that (the database needs something to populate existing rows).\n"
"Please select a fix:", "Please select a fix:" % (field_name, model_name),
[ [
"Provide a one-off default now (will be set on all existing rows)", "Provide a one-off default now (will be set on all existing rows)",
"Quit, and let me add a default in models.py", "Quit, and let me add a default in models.py",
@ -107,26 +134,31 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
if choice == 2: if choice == 2:
sys.exit(3) sys.exit(3)
else: else:
print("Please enter the default value now, as valid Python") return self._ask_default()
print("The datetime and django.utils.timezone modules are " return None
"available, so you can do e.g. timezone.now()")
while True: def ask_not_null_alteration(self, field_name, model_name):
if six.PY3: "Changing a NULL field to NOT NULL"
# Six does not correctly abstract over the fact that if not self.dry_run:
# py3 input returns a unicode string, while py2 raw_input choice = self._choice_input(
# returns a bytestring. "You are trying to change the nullable field '%s' on %s to non-nullable "
code = input(">>> ") "without a default; we can't do that (the database needs something to "
else: "populate existing rows).\n"
code = input(">>> ").decode(sys.stdin.encoding) "Please select a fix:" % (field_name, model_name),
if not code: [
print("Please enter some code, or 'exit' (with no quotes) to exit.") "Provide a one-off default now (will be set on all existing rows)",
elif code == "exit": ("Ignore for now, and let me handle existing rows with NULL myself "
sys.exit(1) "(e.g. adding a RunPython or RunSQL operation in the new migration "
else: "file before the AlterField operation)"),
try: "Quit, and let me add a default in models.py",
return eval(code, {}, {"datetime": datetime_safe, "timezone": timezone}) ]
except (SyntaxError, NameError) as e: )
print("Invalid input: %s" % e) if choice == 2:
return NOT_PROVIDED
elif choice == 3:
sys.exit(3)
else:
return self._ask_default()
return None return None
def ask_rename(self, model_name, old_name, new_name, field_instance): def ask_rename(self, model_name, old_name, new_name, field_instance):

View File

@ -137,7 +137,7 @@ or if it is temporary and just for this migration (``False``) - usually
because the migration is adding a non-nullable field to a table and needs because the migration is adding a non-nullable field to a table and needs
a default value to put into existing rows. It does not effect the behavior a default value to put into existing rows. It does not effect the behavior
of setting defaults in the database directly - Django never sets database of setting defaults in the database directly - Django never sets database
defaults, and always applies them in the Django ORM code. defaults and always applies them in the Django ORM code.
RemoveField RemoveField
----------- -----------
@ -153,16 +153,28 @@ from any data loss, which of course is irreversible).
AlterField AlterField
---------- ----------
.. class:: AlterField(model_name, name, field) .. class:: AlterField(model_name, name, field, preserve_default=True)
Alters a field's definition, including changes to its type, Alters a field's definition, including changes to its type,
:attr:`~django.db.models.Field.null`, :attr:`~django.db.models.Field.unique`, :attr:`~django.db.models.Field.null`, :attr:`~django.db.models.Field.unique`,
:attr:`~django.db.models.Field.db_column` and other field attributes. :attr:`~django.db.models.Field.db_column` and other field attributes.
The ``preserve_default`` argument indicates whether the field's default
value is permanent and should be baked into the project state (``True``),
or if it is temporary and just for this migration (``False``) - usually
because the migration is altering a nullable field to a non-nullable one and
needs a default value to put into existing rows. It does not effect the
behavior of setting defaults in the database directly - Django never sets
database defaults and always applies them in the Django ORM code.
Note that not all changes are possible on all databases - for example, you Note that not all changes are possible on all databases - for example, you
cannot change a text-type field like ``models.TextField()`` into a number-type cannot change a text-type field like ``models.TextField()`` into a number-type
field like ``models.IntegerField()`` on most databases. field like ``models.IntegerField()`` on most databases.
.. versionchanged:: 1.7.1
The ``preserve_default`` argument was added.
RenameField RenameField
----------- -----------

View File

@ -106,3 +106,7 @@ Bugfixes
* Made :func:`~django.utils.http.urlsafe_base64_decode` return the proper * Made :func:`~django.utils.http.urlsafe_base64_decode` return the proper
type (byte string) on Python 3 (:ticket:`23333`). type (byte string) on Python 3 (:ticket:`23333`).
* Added a prompt to the migrations questioner when removing the null constraint
from a field to prevent an IntegrityError on existing NULL rows
(:ticket:`23609`).

View File

@ -26,6 +26,7 @@ class AutodetectorTests(TestCase):
author_empty = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True))]) author_empty = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True))])
author_name = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200))]) author_name = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200))])
author_name_null = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200, null=True))])
author_name_longer = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=400))]) author_name_longer = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=400))])
author_name_renamed = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("names", models.CharField(max_length=200))]) author_name_renamed = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("names", models.CharField(max_length=200))])
author_name_default = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200, default='Ada Lovelace'))]) author_name_default = ModelState("testapp", "Author", [("id", models.AutoField(primary_key=True)), ("name", models.CharField(max_length=200, default='Ada Lovelace'))])
@ -302,6 +303,80 @@ class AutodetectorTests(TestCase):
action = migration.operations[0] action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "AlterField") self.assertEqual(action.__class__.__name__, "AlterField")
self.assertEqual(action.name, "name") self.assertEqual(action.name, "name")
self.assertTrue(action.preserve_default)
def test_alter_field_to_not_null_with_default(self):
"#23609 - Tests autodetection of nullable to non-nullable alterations"
class CustomQuestioner(MigrationQuestioner):
def ask_not_null_alteration(self, field_name, model_name):
raise Exception("Should not have prompted for not null addition")
# Make state
before = self.make_project_state([self.author_name_null])
after = self.make_project_state([self.author_name_default])
autodetector = MigrationAutodetector(before, after, CustomQuestioner())
changes = autodetector._detect_changes()
# Right number of migrations?
self.assertEqual(len(changes['testapp']), 1)
# Right number of actions?
migration = changes['testapp'][0]
self.assertEqual(len(migration.operations), 1)
# Right action?
action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "AlterField")
self.assertEqual(action.name, "name")
self.assertTrue(action.preserve_default)
self.assertEqual(action.field.default, 'Ada Lovelace')
def test_alter_field_to_not_null_without_default(self):
"#23609 - Tests autodetection of nullable to non-nullable alterations"
class CustomQuestioner(MigrationQuestioner):
def ask_not_null_alteration(self, field_name, model_name):
# Ignore for now, and let me handle existing rows with NULL
# myself (e.g. adding a RunPython or RunSQL operation in the new
# migration file before the AlterField operation)
return models.NOT_PROVIDED
# Make state
before = self.make_project_state([self.author_name_null])
after = self.make_project_state([self.author_name])
autodetector = MigrationAutodetector(before, after, CustomQuestioner())
changes = autodetector._detect_changes()
# Right number of migrations?
self.assertEqual(len(changes['testapp']), 1)
# Right number of actions?
migration = changes['testapp'][0]
self.assertEqual(len(migration.operations), 1)
# Right action?
action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "AlterField")
self.assertEqual(action.name, "name")
self.assertTrue(action.preserve_default)
self.assertIs(action.field.default, models.NOT_PROVIDED)
def test_alter_field_to_not_null_oneoff_default(self):
"#23609 - Tests autodetection of nullable to non-nullable alterations"
class CustomQuestioner(MigrationQuestioner):
def ask_not_null_alteration(self, field_name, model_name):
# Provide a one-off default now (will be set on all existing rows)
return 'Some Name'
# Make state
before = self.make_project_state([self.author_name_null])
after = self.make_project_state([self.author_name])
autodetector = MigrationAutodetector(before, after, CustomQuestioner())
changes = autodetector._detect_changes()
# Right number of migrations?
self.assertEqual(len(changes['testapp']), 1)
# Right number of actions?
migration = changes['testapp'][0]
self.assertEqual(len(migration.operations), 1)
# Right action?
action = migration.operations[0]
self.assertEqual(action.__class__.__name__, "AlterField")
self.assertEqual(action.name, "name")
self.assertFalse(action.preserve_default)
self.assertEqual(action.field.default, "Some Name")
def test_rename_field(self): def test_rename_field(self):
"Tests autodetection of renamed fields" "Tests autodetection of renamed fields"

View File

@ -3,7 +3,8 @@ import unittest
from django.test import TransactionTestCase from django.test import TransactionTestCase
from django.db import connection, DatabaseError, IntegrityError, OperationalError from django.db import connection, DatabaseError, IntegrityError, OperationalError
from django.db.models.fields import IntegerField, TextField, CharField, SlugField, BooleanField, BinaryField from django.db.models.fields import (BinaryField, BooleanField, CharField, IntegerField,
PositiveIntegerField, SlugField, TextField)
from django.db.models.fields.related import ManyToManyField, ForeignKey from django.db.models.fields.related import ManyToManyField, ForeignKey
from django.db.transaction import atomic from django.db.transaction import atomic
from .models import (Author, AuthorWithM2M, Book, BookWithLongName, from .models import (Author, AuthorWithM2M, Book, BookWithLongName,
@ -415,6 +416,38 @@ class SchemaTests(TransactionTestCase):
self.assertEqual(columns['name'][0], "TextField") self.assertEqual(columns['name'][0], "TextField")
self.assertEqual(bool(columns['name'][1][6]), False) self.assertEqual(bool(columns['name'][1][6]), False)
def test_alter_null_to_not_null(self):
"""
#23609 - Tests handling of default values when altering from NULL to NOT NULL.
"""
# Create the table
with connection.schema_editor() as editor:
editor.create_model(Author)
# Ensure the field is right to begin with
columns = self.column_classes(Author)
self.assertTrue(columns['height'][1][6])
# Create some test data
Author.objects.create(name='Not null author', height=12)
Author.objects.create(name='Null author')
# Verify null value
self.assertEqual(Author.objects.get(name='Not null author').height, 12)
self.assertIsNone(Author.objects.get(name='Null author').height)
# Alter the height field to NOT NULL with default
new_field = PositiveIntegerField(default=42)
new_field.set_attributes_from_name("height")
with connection.schema_editor() as editor:
editor.alter_field(
Author,
Author._meta.get_field_by_name("height")[0],
new_field
)
# Ensure the field is right afterwards
columns = self.column_classes(Author)
self.assertFalse(columns['height'][1][6])
# Verify default value
self.assertEqual(Author.objects.get(name='Not null author').height, 12)
self.assertEqual(Author.objects.get(name='Null author').height, 42)
@unittest.skipUnless(connection.features.supports_foreign_keys, "No FK support") @unittest.skipUnless(connection.features.supports_foreign_keys, "No FK support")
def test_alter_fk(self): def test_alter_fk(self):
""" """