Fix altering of SERIAL columns and InnoDB being picky about FK changes

This commit is contained in:
Andrew Godwin 2013-12-11 14:19:05 +00:00
parent cee4fe7307
commit 5db028affb
5 changed files with 173 additions and 17 deletions

View File

@ -2,4 +2,57 @@ from django.db.backends.schema import BaseDatabaseSchemaEditor
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
pass
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
def _alter_column_type_sql(self, table, column, type):
"""
Makes ALTER TYPE with SERIAL make sense.
"""
if type.lower() == "serial":
sequence_name = "%s_%s_seq" % (table, column)
return (
(
self.sql_alter_column_type % {
"column": self.quote_name(column),
"type": "integer",
},
[],
),
[
(
self.sql_delete_sequence % {
"sequence": sequence_name,
},
[],
),
(
self.sql_create_sequence % {
"sequence": sequence_name,
},
[],
),
(
self.sql_alter_column % {
"table": table,
"changes": self.sql_alter_column_default % {
"column": column,
"default": "nextval('%s')" % sequence_name,
}
},
[],
),
(
self.sql_set_sequence_max % {
"table": table,
"column": column,
"sequence": sequence_name,
},
[],
),
],
)
else:
return super(DatabaseSchemaEditor, self)._alter_column_type_sql(table, column, type)

View File

@ -498,6 +498,18 @@ class BaseDatabaseSchemaEditor(object):
"name": fk_name, "name": fk_name,
} }
) )
# Drop incoming FK constraints if we're a primary key and things are going
# to change.
if old_field.primary_key and new_field.primary_key and old_type != new_type:
for rel in new_field.model._meta.get_all_related_objects():
rel_fk_names = self._constraint_names(rel.model, [rel.field.column], foreign_key=True)
for fk_name in rel_fk_names:
self.execute(
self.sql_delete_fk % {
"table": self.quote_name(rel.model._meta.db_table),
"name": fk_name,
}
)
# Change check constraints? # Change check constraints?
if old_db_params['check'] != new_db_params['check'] and old_db_params['check']: if old_db_params['check'] != new_db_params['check'] and old_db_params['check']:
constraint_names = self._constraint_names(model, [old_field.column], check=True) constraint_names = self._constraint_names(model, [old_field.column], check=True)
@ -524,15 +536,12 @@ class BaseDatabaseSchemaEditor(object):
}) })
# Next, start accumulating actions to do # Next, start accumulating actions to do
actions = [] actions = []
post_actions = []
# Type change? # Type change?
if old_type != new_type: if old_type != new_type:
actions.append(( fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type)
self.sql_alter_column_type % { actions.append(fragment)
"column": self.quote_name(new_field.column), post_actions.extend(other_actions)
"type": new_type,
},
[],
))
# Default change? # Default change?
old_default = self.effective_default(old_field) old_default = self.effective_default(old_field)
new_default = self.effective_default(new_field) new_default = self.effective_default(new_field)
@ -596,6 +605,9 @@ class BaseDatabaseSchemaEditor(object):
}, },
params, params,
) )
if post_actions:
for sql, params in post_actions:
self.execute(sql, params)
# Added a unique? # Added a unique?
if not old_field.unique and new_field.unique: if not old_field.unique and new_field.unique:
self.execute( self.execute(
@ -619,7 +631,7 @@ class BaseDatabaseSchemaEditor(object):
# referring to us. # referring to us.
rels_to_update = [] rels_to_update = []
if old_field.primary_key and new_field.primary_key and old_type != new_type: if old_field.primary_key and new_field.primary_key and old_type != new_type:
rels_to_update.extend(model._meta.get_all_related_objects()) rels_to_update.extend(new_field.model._meta.get_all_related_objects())
# Changed to become primary key? # Changed to become primary key?
# Note that we don't detect unsetting of a PK, as we assume another field # Note that we don't detect unsetting of a PK, as we assume another field
# will always come along and replace it. # will always come along and replace it.
@ -647,8 +659,8 @@ class BaseDatabaseSchemaEditor(object):
} }
) )
# Update all referencing columns # Update all referencing columns
rels_to_update.extend(model._meta.get_all_related_objects()) rels_to_update.extend(new_field.model._meta.get_all_related_objects())
# Handle out type alters on the other end of rels from the PK stuff above # Handle our type alters on the other end of rels from the PK stuff above
for rel in rels_to_update: for rel in rels_to_update:
rel_db_params = rel.field.db_parameters(connection=self.connection) rel_db_params = rel.field.db_parameters(connection=self.connection)
rel_type = rel_db_params['type'] rel_type = rel_db_params['type']
@ -672,6 +684,18 @@ class BaseDatabaseSchemaEditor(object):
"to_column": self.quote_name(new_field.rel.get_related_field().column), "to_column": self.quote_name(new_field.rel.get_related_field().column),
} }
) )
# Rebuild FKs that pointed to us if we previously had to drop them
if old_field.primary_key and new_field.primary_key and old_type != new_type:
for rel in new_field.model._meta.get_all_related_objects():
self.execute(
self.sql_create_fk % {
"table": self.quote_name(rel.model._meta.db_table),
"name": self._create_index_name(rel.model, [rel.field.column], suffix="_fk"),
"column": self.quote_name(rel.field.column),
"to_table": self.quote_name(model._meta.db_table),
"to_column": self.quote_name(new_field.column),
}
)
# Does it have check constraints we need to add? # Does it have check constraints we need to add?
if old_db_params['check'] != new_db_params['check'] and new_db_params['check']: if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
self.execute( self.execute(
@ -686,6 +710,27 @@ class BaseDatabaseSchemaEditor(object):
if self.connection.features.connection_persists_old_columns: if self.connection.features.connection_persists_old_columns:
self.connection.close() self.connection.close()
def _alter_column_type_sql(self, table, column, type):
"""
Hook to specialise column type alteration for different backends,
for cases when a creation type is different to an alteration type
(e.g. SERIAL in PostgreSQL, PostGIS fields).
Should return two things; an SQL fragment of (sql, params) to insert
into an ALTER TABLE statement, and a list of extra (sql, params) tuples
to run once the field is altered.
"""
return (
(
self.sql_alter_column_type % {
"column": self.quote_name(column),
"type": type,
},
[],
),
[],
)
def _alter_many_to_many(self, model, old_field, new_field, strict): def _alter_many_to_many(self, model, old_field, new_field, strict):
""" """
Alters M2Ms to repoint their to= endpoints. Alters M2Ms to repoint their to= endpoints.

View File

@ -24,9 +24,10 @@ class AddField(Operation):
state.models[app_label, self.model_name.lower()].fields.append((self.name, field)) state.models[app_label, self.model_name.lower()].fields.append((self.name, field))
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.render().get_model(app_label, self.model_name)
to_model = to_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name)
if router.allow_migrate(schema_editor.connection.alias, to_model): if router.allow_migrate(schema_editor.connection.alias, to_model):
schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0]) schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0])
def database_backwards(self, app_label, schema_editor, from_state, to_state): def database_backwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.render().get_model(app_label, self.model_name) from_model = from_state.render().get_model(app_label, self.model_name)
@ -73,9 +74,10 @@ class RemoveField(Operation):
schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0]) schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0])
def database_backwards(self, app_label, schema_editor, from_state, to_state): def database_backwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.render().get_model(app_label, self.model_name)
to_model = to_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name)
if router.allow_migrate(schema_editor.connection.alias, to_model): if router.allow_migrate(schema_editor.connection.alias, to_model):
schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0]) schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0])
def describe(self): def describe(self):
return "Remove field %s from %s" % (self.name, self.model_name) return "Remove field %s from %s" % (self.name, self.model_name)
@ -107,7 +109,7 @@ class AlterField(Operation):
to_model = to_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name)
if router.allow_migrate(schema_editor.connection.alias, to_model): if router.allow_migrate(schema_editor.connection.alias, to_model):
schema_editor.alter_field( schema_editor.alter_field(
to_model, from_model,
from_model._meta.get_field_by_name(self.name)[0], from_model._meta.get_field_by_name(self.name)[0],
to_model._meta.get_field_by_name(self.name)[0], to_model._meta.get_field_by_name(self.name)[0],
) )
@ -153,7 +155,7 @@ class RenameField(Operation):
to_model = to_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name)
if router.allow_migrate(schema_editor.connection.alias, to_model): if router.allow_migrate(schema_editor.connection.alias, to_model):
schema_editor.alter_field( schema_editor.alter_field(
to_model, from_model,
from_model._meta.get_field_by_name(self.old_name)[0], from_model._meta.get_field_by_name(self.old_name)[0],
to_model._meta.get_field_by_name(self.new_name)[0], to_model._meta.get_field_by_name(self.new_name)[0],
) )
@ -163,7 +165,7 @@ class RenameField(Operation):
to_model = to_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name)
if router.allow_migrate(schema_editor.connection.alias, to_model): if router.allow_migrate(schema_editor.connection.alias, to_model):
schema_editor.alter_field( schema_editor.alter_field(
to_model, from_model,
from_model._meta.get_field_by_name(self.new_name)[0], from_model._meta.get_field_by_name(self.new_name)[0],
to_model._meta.get_field_by_name(self.old_name)[0], to_model._meta.get_field_by_name(self.old_name)[0],
) )

View File

@ -1,3 +1,4 @@
import unittest
from django.db import connection, models, migrations, router from django.db import connection, models, migrations, router
from django.db.models.fields import NOT_PROVIDED from django.db.models.fields import NOT_PROVIDED
from django.db.transaction import atomic from django.db.transaction import atomic
@ -13,7 +14,7 @@ class OperationTests(MigrationTestBase):
both forwards and backwards. both forwards and backwards.
""" """
def set_up_test_model(self, app_label, second_model=False): def set_up_test_model(self, app_label, second_model=False, related_model=False):
""" """
Creates a test model state and database table. Creates a test model state and database table.
""" """
@ -38,6 +39,14 @@ class OperationTests(MigrationTestBase):
)] )]
if second_model: if second_model:
operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))])) operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))]))
if related_model:
operations.append(migrations.CreateModel(
"Rider",
[
("id", models.AutoField(primary_key=True)),
("pony", models.ForeignKey("Pony")),
],
))
project_state = ProjectState() project_state = ProjectState()
for operation in operations: for operation in operations:
operation.state_forwards(app_label, project_state) operation.state_forwards(app_label, project_state)
@ -269,6 +278,52 @@ class OperationTests(MigrationTestBase):
operation.database_backwards("test_alfl", editor, new_state, project_state) operation.database_backwards("test_alfl", editor, new_state, project_state)
self.assertColumnNotNull("test_alfl_pony", "pink") self.assertColumnNotNull("test_alfl_pony", "pink")
def test_alter_field_pk(self):
"""
Tests the AlterField operation on primary keys (for things like PostgreSQL's SERIAL weirdness)
"""
project_state = self.set_up_test_model("test_alflpk")
# Test the state alteration
operation = migrations.AlterField("Pony", "id", models.IntegerField(primary_key=True))
new_state = project_state.clone()
operation.state_forwards("test_alflpk", new_state)
self.assertIsInstance(project_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.AutoField)
self.assertIsInstance(new_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.IntegerField)
# Test the database alteration
with connection.schema_editor() as editor:
operation.database_forwards("test_alflpk", editor, project_state, new_state)
# And test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_alflpk", editor, new_state, project_state)
@unittest.skipUnless(connection.features.supports_foreign_keys, "No FK support")
def test_alter_field_pk_fk(self):
"""
Tests the AlterField operation on primary keys changes any FKs pointing to it.
"""
project_state = self.set_up_test_model("test_alflpkfk", related_model=True)
# Test the state alteration
operation = migrations.AlterField("Pony", "id", models.FloatField(primary_key=True))
new_state = project_state.clone()
operation.state_forwards("test_alflpkfk", new_state)
self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField)
self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField)
# Test the database alteration
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
with connection.schema_editor() as editor:
operation.database_forwards("test_alflpkfk", editor, project_state, new_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
# And test reversal
with connection.schema_editor() as editor:
operation.database_backwards("test_alflpkfk", editor, new_state, project_state)
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
self.assertEqual(id_type, fk_type)
def test_rename_field(self): def test_rename_field(self):
""" """
Tests the RenameField operation. Tests the RenameField operation.

View File

@ -636,6 +636,7 @@ class SchemaTests(TransactionTestCase):
# Alter to change the PK # Alter to change the PK
new_field = SlugField(primary_key=True) new_field = SlugField(primary_key=True)
new_field.set_attributes_from_name("slug") new_field.set_attributes_from_name("slug")
new_field.model = Tag
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0]) editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0])
editor.alter_field( editor.alter_field(