diff --git a/django/db/backends/postgresql_psycopg2/schema.py b/django/db/backends/postgresql_psycopg2/schema.py index b86e0857bb..946f39d586 100644 --- a/django/db/backends/postgresql_psycopg2/schema.py +++ b/django/db/backends/postgresql_psycopg2/schema.py @@ -2,4 +2,57 @@ from django.db.backends.schema import BaseDatabaseSchemaEditor class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): - pass + + sql_create_sequence = "CREATE SEQUENCE %(sequence)s" + sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE" + sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s" + + def _alter_column_type_sql(self, table, column, type): + """ + Makes ALTER TYPE with SERIAL make sense. + """ + if type.lower() == "serial": + sequence_name = "%s_%s_seq" % (table, column) + return ( + ( + self.sql_alter_column_type % { + "column": self.quote_name(column), + "type": "integer", + }, + [], + ), + [ + ( + self.sql_delete_sequence % { + "sequence": sequence_name, + }, + [], + ), + ( + self.sql_create_sequence % { + "sequence": sequence_name, + }, + [], + ), + ( + self.sql_alter_column % { + "table": table, + "changes": self.sql_alter_column_default % { + "column": column, + "default": "nextval('%s')" % sequence_name, + } + }, + [], + ), + ( + self.sql_set_sequence_max % { + "table": table, + "column": column, + "sequence": sequence_name, + }, + [], + ), + ], + ) + else: + return super(DatabaseSchemaEditor, self)._alter_column_type_sql(table, column, type) diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index 5e3ad4dd80..e33956763c 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -498,6 +498,18 @@ class BaseDatabaseSchemaEditor(object): "name": fk_name, } ) + # Drop incoming FK constraints if we're a primary key and things are going + # to change. + if old_field.primary_key and new_field.primary_key and old_type != new_type: + for rel in new_field.model._meta.get_all_related_objects(): + rel_fk_names = self._constraint_names(rel.model, [rel.field.column], foreign_key=True) + for fk_name in rel_fk_names: + self.execute( + self.sql_delete_fk % { + "table": self.quote_name(rel.model._meta.db_table), + "name": fk_name, + } + ) # Change check constraints? if old_db_params['check'] != new_db_params['check'] and old_db_params['check']: constraint_names = self._constraint_names(model, [old_field.column], check=True) @@ -524,15 +536,12 @@ class BaseDatabaseSchemaEditor(object): }) # Next, start accumulating actions to do actions = [] + post_actions = [] # Type change? if old_type != new_type: - actions.append(( - self.sql_alter_column_type % { - "column": self.quote_name(new_field.column), - "type": new_type, - }, - [], - )) + fragment, other_actions = self._alter_column_type_sql(model._meta.db_table, new_field.column, new_type) + actions.append(fragment) + post_actions.extend(other_actions) # Default change? old_default = self.effective_default(old_field) new_default = self.effective_default(new_field) @@ -596,6 +605,9 @@ class BaseDatabaseSchemaEditor(object): }, params, ) + if post_actions: + for sql, params in post_actions: + self.execute(sql, params) # Added a unique? if not old_field.unique and new_field.unique: self.execute( @@ -619,7 +631,7 @@ class BaseDatabaseSchemaEditor(object): # referring to us. rels_to_update = [] if old_field.primary_key and new_field.primary_key and old_type != new_type: - rels_to_update.extend(model._meta.get_all_related_objects()) + rels_to_update.extend(new_field.model._meta.get_all_related_objects()) # Changed to become primary key? # Note that we don't detect unsetting of a PK, as we assume another field # will always come along and replace it. @@ -647,8 +659,8 @@ class BaseDatabaseSchemaEditor(object): } ) # Update all referencing columns - rels_to_update.extend(model._meta.get_all_related_objects()) - # Handle out type alters on the other end of rels from the PK stuff above + rels_to_update.extend(new_field.model._meta.get_all_related_objects()) + # Handle our type alters on the other end of rels from the PK stuff above for rel in rels_to_update: rel_db_params = rel.field.db_parameters(connection=self.connection) rel_type = rel_db_params['type'] @@ -672,6 +684,18 @@ class BaseDatabaseSchemaEditor(object): "to_column": self.quote_name(new_field.rel.get_related_field().column), } ) + # Rebuild FKs that pointed to us if we previously had to drop them + if old_field.primary_key and new_field.primary_key and old_type != new_type: + for rel in new_field.model._meta.get_all_related_objects(): + self.execute( + self.sql_create_fk % { + "table": self.quote_name(rel.model._meta.db_table), + "name": self._create_index_name(rel.model, [rel.field.column], suffix="_fk"), + "column": self.quote_name(rel.field.column), + "to_table": self.quote_name(model._meta.db_table), + "to_column": self.quote_name(new_field.column), + } + ) # Does it have check constraints we need to add? if old_db_params['check'] != new_db_params['check'] and new_db_params['check']: self.execute( @@ -686,6 +710,27 @@ class BaseDatabaseSchemaEditor(object): if self.connection.features.connection_persists_old_columns: self.connection.close() + def _alter_column_type_sql(self, table, column, type): + """ + Hook to specialise column type alteration for different backends, + for cases when a creation type is different to an alteration type + (e.g. SERIAL in PostgreSQL, PostGIS fields). + + Should return two things; an SQL fragment of (sql, params) to insert + into an ALTER TABLE statement, and a list of extra (sql, params) tuples + to run once the field is altered. + """ + return ( + ( + self.sql_alter_column_type % { + "column": self.quote_name(column), + "type": type, + }, + [], + ), + [], + ) + def _alter_many_to_many(self, model, old_field, new_field, strict): """ Alters M2Ms to repoint their to= endpoints. diff --git a/django/db/migrations/operations/fields.py b/django/db/migrations/operations/fields.py index 73efef4691..c5f0bd1e2b 100644 --- a/django/db/migrations/operations/fields.py +++ b/django/db/migrations/operations/fields.py @@ -24,9 +24,10 @@ class AddField(Operation): state.models[app_label, self.model_name.lower()].fields.append((self.name, field)) def database_forwards(self, app_label, schema_editor, from_state, to_state): + from_model = from_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): - schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0]) + schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0]) def database_backwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.render().get_model(app_label, self.model_name) @@ -73,9 +74,10 @@ class RemoveField(Operation): schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0]) def database_backwards(self, app_label, schema_editor, from_state, to_state): + from_model = from_state.render().get_model(app_label, self.model_name) to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): - schema_editor.add_field(to_model, to_model._meta.get_field_by_name(self.name)[0]) + schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0]) def describe(self): return "Remove field %s from %s" % (self.name, self.model_name) @@ -107,7 +109,7 @@ class AlterField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( - to_model, + from_model, from_model._meta.get_field_by_name(self.name)[0], to_model._meta.get_field_by_name(self.name)[0], ) @@ -153,7 +155,7 @@ class RenameField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( - to_model, + from_model, from_model._meta.get_field_by_name(self.old_name)[0], to_model._meta.get_field_by_name(self.new_name)[0], ) @@ -163,7 +165,7 @@ class RenameField(Operation): to_model = to_state.render().get_model(app_label, self.model_name) if router.allow_migrate(schema_editor.connection.alias, to_model): schema_editor.alter_field( - to_model, + from_model, from_model._meta.get_field_by_name(self.new_name)[0], to_model._meta.get_field_by_name(self.old_name)[0], ) diff --git a/tests/migrations/test_operations.py b/tests/migrations/test_operations.py index 93f0842bcd..20aff59de2 100644 --- a/tests/migrations/test_operations.py +++ b/tests/migrations/test_operations.py @@ -1,3 +1,4 @@ +import unittest from django.db import connection, models, migrations, router from django.db.models.fields import NOT_PROVIDED from django.db.transaction import atomic @@ -13,7 +14,7 @@ class OperationTests(MigrationTestBase): both forwards and backwards. """ - def set_up_test_model(self, app_label, second_model=False): + def set_up_test_model(self, app_label, second_model=False, related_model=False): """ Creates a test model state and database table. """ @@ -38,6 +39,14 @@ class OperationTests(MigrationTestBase): )] if second_model: operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))])) + if related_model: + operations.append(migrations.CreateModel( + "Rider", + [ + ("id", models.AutoField(primary_key=True)), + ("pony", models.ForeignKey("Pony")), + ], + )) project_state = ProjectState() for operation in operations: operation.state_forwards(app_label, project_state) @@ -269,6 +278,52 @@ class OperationTests(MigrationTestBase): operation.database_backwards("test_alfl", editor, new_state, project_state) self.assertColumnNotNull("test_alfl_pony", "pink") + def test_alter_field_pk(self): + """ + Tests the AlterField operation on primary keys (for things like PostgreSQL's SERIAL weirdness) + """ + project_state = self.set_up_test_model("test_alflpk") + # Test the state alteration + operation = migrations.AlterField("Pony", "id", models.IntegerField(primary_key=True)) + new_state = project_state.clone() + operation.state_forwards("test_alflpk", new_state) + self.assertIsInstance(project_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.AutoField) + self.assertIsInstance(new_state.models["test_alflpk", "pony"].get_field_by_name("id"), models.IntegerField) + # Test the database alteration + with connection.schema_editor() as editor: + operation.database_forwards("test_alflpk", editor, project_state, new_state) + # And test reversal + with connection.schema_editor() as editor: + operation.database_backwards("test_alflpk", editor, new_state, project_state) + + @unittest.skipUnless(connection.features.supports_foreign_keys, "No FK support") + def test_alter_field_pk_fk(self): + """ + Tests the AlterField operation on primary keys changes any FKs pointing to it. + """ + project_state = self.set_up_test_model("test_alflpkfk", related_model=True) + # Test the state alteration + operation = migrations.AlterField("Pony", "id", models.FloatField(primary_key=True)) + new_state = project_state.clone() + operation.state_forwards("test_alflpkfk", new_state) + self.assertIsInstance(project_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.AutoField) + self.assertIsInstance(new_state.models["test_alflpkfk", "pony"].get_field_by_name("id"), models.FloatField) + # Test the database alteration + id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] + fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] + self.assertEqual(id_type, fk_type) + with connection.schema_editor() as editor: + operation.database_forwards("test_alflpkfk", editor, project_state, new_state) + id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] + fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] + self.assertEqual(id_type, fk_type) + # And test reversal + with connection.schema_editor() as editor: + operation.database_backwards("test_alflpkfk", editor, new_state, project_state) + id_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_pony") if c.name == "id"][0] + fk_type = [c.type_code for c in connection.introspection.get_table_description(connection.cursor(), "test_alflpkfk_rider") if c.name == "pony_id"][0] + self.assertEqual(id_type, fk_type) + def test_rename_field(self): """ Tests the RenameField operation. diff --git a/tests/schema/tests.py b/tests/schema/tests.py index d69e5387f1..66450956ef 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -636,6 +636,7 @@ class SchemaTests(TransactionTestCase): # Alter to change the PK new_field = SlugField(primary_key=True) new_field.set_attributes_from_name("slug") + new_field.model = Tag with connection.schema_editor() as editor: editor.remove_field(Tag, Tag._meta.get_field_by_name("id")[0]) editor.alter_field(