Fix altering of SERIAL columns and InnoDB being picky about FK changes
This commit is contained in:
parent
cee4fe7307
commit
5db028affb
|
@ -2,4 +2,57 @@ from django.db.backends.schema import BaseDatabaseSchemaEditor
|
||||||
|
|
||||||
|
|
||||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||||
pass
|
|
||||||
|
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
|
||||||
|
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
|
||||||
|
sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
|
||||||
|
|
||||||
|
def _alter_column_type_sql(self, table, column, type):
|
||||||
|
"""
|
||||||
|
Makes ALTER TYPE with SERIAL make sense.
|
||||||
|
"""
|
||||||
|
if type.lower() == "serial":
|
||||||
|
sequence_name = "%s_%s_seq" % (table, column)
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
self.sql_alter_column_type % {
|
||||||
|
"column": self.quote_name(column),
|
||||||
|
"type": "integer",
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
self.sql_delete_sequence % {
|
||||||
|
"sequence": sequence_name,
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
self.sql_create_sequence % {
|
||||||
|
"sequence": sequence_name,
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
self.sql_alter_column % {
|
||||||
|
"table": table,
|
||||||
|
"changes": self.sql_alter_column_default % {
|
||||||
|
"column": column,
|
||||||
|
"default": "nextval('%s')" % sequence_name,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
self.sql_set_sequence_max % {
|
||||||
|
"table": table,
|
||||||
|
"column": column,
|
||||||
|
"sequence": sequence_name,
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super(DatabaseSchemaEditor, self)._alter_column_type_sql(table, column, type)
|
||||||
|
|
|
@ -498,6 +498,18 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
"name": fk_name,
|
"name": fk_name,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
# Drop incoming FK constraints if we're a primary key and things are going
|
||||||
|
# to change.
|
||||||
|
if old_field.primary_key and new_field.primary_key and old_type != new_type:
|
||||||
|
for rel in new_field.model._meta.get_all_related_objects():
|
||||||
|
rel_fk_names = self._constraint_names(rel.model, [rel.field.column], foreign_key=True)
|
||||||
|
for fk_name in rel_fk_names:
|
||||||
|
self.execute(
|
||||||
|
self.sql_delete_fk % {
|
||||||
|
"table": self.quote_name(rel.model._meta.db_table),
|
||||||
|
"name": fk_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
# Change check constraints?
|
# Change check constraints?
|
||||||
if old_db_params['check'] != new_db_params['check'] and old_db_params['check']:
|
if old_db_params['check'] != new_db_params['check'] and old_db_params['check']:
|
||||||
constraint_names = self._constraint_names(model, [old_field.column], check=True)
|
constraint_names = self._constraint_names(model, [old_field.column], check=True)
|
||||||
|
@ -524,15 +536,12 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
})
|
})
|
||||||
# Next, start accumulating actions to do
|
# Next, start accumulating actions to do
|
||||||
actions = []
|
actions = []
|
||||||
|
post_actions = []
|
||||||
# Type change?
|
# Type change?
|
||||||
if old_type != new_type:
|
if old_type != new_type:
|
||||||
actions.append((
|
fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type)
|
||||||
self.sql_alter_column_type % {
|
actions.append(fragment)
|
||||||
"column": self.quote_name(new_field.column),
|
post_actions.extend(other_actions)
|
||||||
"type": new_type,
|
|
||||||
},
|
|
||||||
[],
|
|
||||||
))
|
|
||||||
# Default change?
|
# Default change?
|
||||||
old_default = self.effective_default(old_field)
|
old_default = self.effective_default(old_field)
|
||||||
new_default = self.effective_default(new_field)
|
new_default = self.effective_default(new_field)
|
||||||
|
@ -596,6 +605,9 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
},
|
},
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
|
if post_actions:
|
||||||
|
for sql, params in post_actions:
|
||||||
|
self.execute(sql, params)
|
||||||
# Added a unique?
|
# Added a unique?
|
||||||
if not old_field.unique and new_field.unique:
|
if not old_field.unique and new_field.unique:
|
||||||
self.execute(
|
self.execute(
|
||||||
|
@ -619,7 +631,7 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
# referring to us.
|
# referring to us.
|
||||||
rels_to_update = []
|
rels_to_update = []
|
||||||
if old_field.primary_key and new_field.primary_key and old_type != new_type:
|
if old_field.primary_key and new_field.primary_key and old_type != new_type:
|
||||||
rels_to_update.extend(model._meta.get_all_related_objects())
|
rels_to_update.extend(new_field.model._meta.get_all_related_objects())
|
||||||
# Changed to become primary key?
|
# Changed to become primary key?
|
||||||
# Note that we don't detect unsetting of a PK, as we assume another field
|
# Note that we don't detect unsetting of a PK, as we assume another field
|
||||||
# will always come along and replace it.
|
# will always come along and replace it.
|
||||||
|
@ -647,8 +659,8 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# Update all referencing columns
|
# Update all referencing columns
|
||||||
rels_to_update.extend(model._meta.get_all_related_objects())
|
rels_to_update.extend(new_field.model._meta.get_all_related_objects())
|
||||||
# Handle out type alters on the other end of rels from the PK stuff above
|
# Handle our type alters on the other end of rels from the PK stuff above
|
||||||
for rel in rels_to_update:
|
for rel in rels_to_update:
|
||||||
rel_db_params = rel.field.db_parameters(connection=self.connection)
|
rel_db_params = rel.field.db_parameters(connection=self.connection)
|
||||||
rel_type = rel_db_params['type']
|
rel_type = rel_db_params['type']
|
||||||
|
@ -672,6 +684,18 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
"to_column": self.quote_name(new_field.rel.get_related_field().column),
|
"to_column": self.quote_name(new_field.rel.get_related_field().column),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
# Rebuild FKs that pointed to us if we previously had to drop them
|
||||||
|
if old_field.primary_key and new_field.primary_key and old_type != new_type:
|
||||||
|
for rel in new_field.model._meta.get_all_related_objects():
|
||||||
|
self.execute(
|
||||||
|
self.sql_create_fk % {
|
||||||
|
"table": self.quote_name(rel.model._meta.db_table),
|
||||||
|
"name": self._create_index_name(rel.model, [rel.field.column], suffix="_fk"),
|
||||||
|
"column": self.quote_name(rel.field.column),
|
||||||
|
"to_table": self.quote_name(model._meta.db_table),
|
||||||
|
"to_column": self.quote_name(new_field.column),
|
||||||
|
}
|
||||||
|
)
|
||||||
# Does it have check constraints we need to add?
|
# Does it have check constraints we need to add?
|
||||||
if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
|
if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
|
||||||
self.execute(
|
self.execute(
|
||||||
|
@ -686,6 +710,27 @@ class BaseDatabaseSchemaEditor(object):
|
||||||
if self.connection.features.connection_persists_old_columns:
|
if self.connection.features.connection_persists_old_columns:
|
||||||
self.connection.close()
|
self.connection.close()
|
||||||
|
|
||||||
|
def _alter_column_type_sql(self, table, column, type):
|
||||||
|
"""
|
||||||
|
Hook to specialise column type alteration for different backends,
|
||||||
|
for cases when a creation type is different to an alteration type
|
||||||
|
(e.g. SERIAL in PostgreSQL, PostGIS fields).
|
||||||
|
|
||||||
|
Should return two things; an SQL fragment of (sql, params) to insert
|
||||||
|
into an ALTER TABLE statement, and a list of extra (sql, params) tuples
|
||||||
|
to run once the field is altered.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
self.sql_alter_column_type % {
|
||||||
|
"column": self.quote_name(column),
|
||||||
|
"type": type,
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
def _alter_many_to_many(self, model, old_field, new_field, strict):
|
def _alter_many_to_many(self, model, old_field, new_field, strict):
|
||||||
"""
|
"""
|
||||||
Alters M2Ms to repoint their to= endpoints.
|
Alters M2Ms to repoint their to= endpoints.
|
||||||
|
|
|
@ -24,9 +24,10 @@ class AddField(Operation):
|
||||||
state.models[app_label, self.model_name.lower()].fields.append((self.name, field))
|
state.models[app_label, self.model_name.lower()].fields.append((self.name, field))
|
||||||
|
|
||||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||||
|
from_model = from_state.render().get_model(app_label, self.model_name)
|
||||||
to_model = to_state.render().get_model(app_label, self.model_name)
|
to_model = to_state.render().get_model(app_label, self.model_name)
|
||||||
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
||||||
schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0])
|
schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0])
|
||||||
|
|
||||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||||
from_model = from_state.render().get_model(app_label, self.model_name)
|
from_model = from_state.render().get_model(app_label, self.model_name)
|
||||||
|
@ -73,9 +74,10 @@ class RemoveField(Operation):
|
||||||
schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0])
|
schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0])
|
||||||
|
|
||||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||||
|
from_model = from_state.render().get_model(app_label, self.model_name)
|
||||||
to_model = to_state.render().get_model(app_label, self.model_name)
|
to_model = to_state.render().get_model(app_label, self.model_name)
|
||||||
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
||||||
schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0])
|
schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0])
|
||||||
|
|
||||||
def describe(self):
|
def describe(self):
|
||||||
return "Remove field %s from %s" % (self.name, self.model_name)
|
return "Remove field %s from %s" % (self.name, self.model_name)
|
||||||
|
@ -107,7 +109,7 @@ class AlterField(Operation):
|
||||||
to_model = to_state.render().get_model(app_label, self.model_name)
|
to_model = to_state.render().get_model(app_label, self.model_name)
|
||||||
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
||||||
schema_editor.alter_field(
|
schema_editor.alter_field(
|
||||||
to_model,
|
from_model,
|
||||||
from_model._meta.get_field_by_name(self.name)[0],
|
from_model._meta.get_field_by_name(self.name)[0],
|
||||||
to_model._meta.get_field_by_name(self.name)[0],
|
to_model._meta.get_field_by_name(self.name)[0],
|
||||||
)
|
)
|
||||||
|
@ -153,7 +155,7 @@ class RenameField(Operation):
|
||||||
to_model = to_state.render().get_model(app_label, self.model_name)
|
to_model = to_state.render().get_model(app_label, self.model_name)
|
||||||
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
||||||
schema_editor.alter_field(
|
schema_editor.alter_field(
|
||||||
to_model,
|
from_model,
|
||||||
from_model._meta.get_field_by_name(self.old_name)[0],
|
from_model._meta.get_field_by_name(self.old_name)[0],
|
||||||
to_model._meta.get_field_by_name(self.new_name)[0],
|
to_model._meta.get_field_by_name(self.new_name)[0],
|
||||||
)
|
)
|
||||||
|
@ -163,7 +165,7 @@ class RenameField(Operation):
|
||||||
to_model = to_state.render().get_model(app_label, self.model_name)
|
to_model = to_state.render().get_model(app_label, self.model_name)
|
||||||
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
if router.allow_migrate(schema_editor.connection.alias, to_model):
|
||||||
schema_editor.alter_field(
|
schema_editor.alter_field(
|
||||||
to_model,
|
from_model,
|
||||||
from_model._meta.get_field_by_name(self.new_name)[0],
|
from_model._meta.get_field_by_name(self.new_name)[0],
|
||||||
to_model._meta.get_field_by_name(self.old_name)[0],
|
to_model._meta.get_field_by_name(self.old_name)[0],
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import unittest
|
||||||
from django.db import connection, models, migrations, router
|
from django.db import connection, models, migrations, router
|
||||||
from django.db.models.fields import NOT_PROVIDED
|
from django.db.models.fields import NOT_PROVIDED
|
||||||
from django.db.transaction import atomic
|
from django.db.transaction import atomic
|
||||||
|
@ -13,7 +14,7 @@ class OperationTests(MigrationTestBase):
|
||||||
both forwards and backwards.
|
both forwards and backwards.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def set_up_test_model(self, app_label, second_model=False):
|
def set_up_test_model(self, app_label, second_model=False, related_model=False):
|
||||||
"""
|
"""
|
||||||
Creates a test model state and database table.
|
Creates a test model state and database table.
|
||||||
"""
|
"""
|
||||||
|
@ -38,6 +39,14 @@ class OperationTests(MigrationTestBase):
|
||||||
)]
|
)]
|
||||||
if second_model:
|
if second_model:
|
||||||
operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))]))
|
operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))]))
|
||||||
|
if related_model:
|
||||||
|
operations.append(migrations.CreateModel(
|
||||||
|
"Rider",
|
||||||
|
[
|
||||||
|
("id", models.AutoField(primary_key=True)),
|
||||||
|
("pony", models.ForeignKey("Pony")),
|
||||||
|
],
|
||||||
|
))
|
||||||
project_state = ProjectState()
|
project_state = ProjectState()
|
||||||
for operation in operations:
|
for operation in operations:
|
||||||
operation.state_forwards(app_label, project_state)
|
operation.state_forwards(app_label, project_state)
|
||||||
|
@ -269,6 +278,52 @@ class OperationTests(MigrationTestBase):
|
||||||
operation.database_backwards("test_alfl", editor, new_state, project_state)
|
operation.database_backwards("test_alfl", editor, new_state, project_state)
|
||||||
self.assertColumnNotNull("test_alfl_pony", "pink")
|
self.assertColumnNotNull("test_alfl_pony", "pink")
|
||||||
|
|
||||||
|
def test_alter_field_pk(self):
|
||||||
|
"""
|
||||||
|
Tests the AlterField operation on primary keys (for things like PostgreSQL's SERIAL weirdness)
|
||||||
|
"""
|
||||||
|
project_state = self.set_up_test_model("test_alflpk")
|
||||||
|
# Test the state alteration
|
||||||
|
operation = migrations.AlterField("Pony", "id", models.IntegerField(primary_key=True))
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards("test_alflpk", new_state)
|
||||||
|
self.assertIsInstance(project_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.AutoField)
|
||||||
|
self.assertIsInstance(new_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.IntegerField)
|
||||||
|
# Test the database alteration
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards("test_alflpk", editor, project_state, new_state)
|
||||||
|
# And test reversal
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_backwards("test_alflpk", editor, new_state, project_state)
|
||||||
|
|
||||||
|
@unittest.skipUnless(connection.features.supports_foreign_keys, "No FK support")
|
||||||
|
def test_alter_field_pk_fk(self):
|
||||||
|
"""
|
||||||
|
Tests the AlterField operation on primary keys changes any FKs pointing to it.
|
||||||
|
"""
|
||||||
|
project_state = self.set_up_test_model("test_alflpkfk", related_model=True)
|
||||||
|
# Test the state alteration
|
||||||
|
operation = migrations.AlterField("Pony", "id", models.FloatField(primary_key=True))
|
||||||
|
new_state = project_state.clone()
|
||||||
|
operation.state_forwards("test_alflpkfk", new_state)
|
||||||
|
self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField)
|
||||||
|
self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField)
|
||||||
|
# Test the database alteration
|
||||||
|
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
|
||||||
|
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
|
||||||
|
self.assertEqual(id_type, fk_type)
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_forwards("test_alflpkfk", editor, project_state, new_state)
|
||||||
|
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
|
||||||
|
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
|
||||||
|
self.assertEqual(id_type, fk_type)
|
||||||
|
# And test reversal
|
||||||
|
with connection.schema_editor() as editor:
|
||||||
|
operation.database_backwards("test_alflpkfk", editor, new_state, project_state)
|
||||||
|
id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0]
|
||||||
|
fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0]
|
||||||
|
self.assertEqual(id_type, fk_type)
|
||||||
|
|
||||||
def test_rename_field(self):
|
def test_rename_field(self):
|
||||||
"""
|
"""
|
||||||
Tests the RenameField operation.
|
Tests the RenameField operation.
|
||||||
|
|
|
@ -636,6 +636,7 @@ class SchemaTests(TransactionTestCase):
|
||||||
# Alter to change the PK
|
# Alter to change the PK
|
||||||
new_field = SlugField(primary_key=True)
|
new_field = SlugField(primary_key=True)
|
||||||
new_field.set_attributes_from_name("slug")
|
new_field.set_attributes_from_name("slug")
|
||||||
|
new_field.model = Tag
|
||||||
with connection.schema_editor() as editor:
|
with connection.schema_editor() as editor:
|
||||||
editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0])
|
editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0])
|
||||||
editor.alter_field(
|
editor.alter_field(
|
||||||
|
|
Loading…
Reference in New Issue