diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index bf22711131..e31251ae81 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -376,6 +376,10 @@ class BaseDatabaseSchemaEditor: "old_table": self.quote_name(old_db_table), "new_table": self.quote_name(new_db_table), }) + # Rename all references to the old table name. + for sql in self.deferred_sql: + if isinstance(sql, Statement): + sql.rename_table_references(old_db_table, new_db_table) def alter_db_tablespace(self, model, old_db_tablespace, new_db_tablespace): """Move a model's table between tablespaces.""" @@ -570,6 +574,10 @@ class BaseDatabaseSchemaEditor: # Have they renamed the column? if old_field.column != new_field.column: self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type)) + # Rename all references to the renamed column. + for sql in self.deferred_sql: + if isinstance(sql, Statement): + sql.rename_column_references(model._meta.db_table, old_field.column, new_field.column) # Next, start accumulating actions to do actions = [] null_actions = [] diff --git a/django/db/backends/ddl_references.py b/django/db/backends/ddl_references.py index dd4d1aa415..61b7b9eaf8 100644 --- a/django/db/backends/ddl_references.py +++ b/django/db/backends/ddl_references.py @@ -19,6 +19,18 @@ class Reference: """ return False + def rename_table_references(self, old_table, new_table): + """ + Rename all references to the old_name to the new_table. + """ + pass + + def rename_column_references(self, table, old_column, new_column): + """ + Rename all references to the old_column to the new_column. + """ + pass + def __repr__(self): return '<%s %r>' % (self.__class__.__name__, str(self)) @@ -36,6 +48,10 @@ class Table(Reference): def references_table(self, table): return self.table == table + def rename_table_references(self, old_table, new_table): + if self.table == old_table: + self.table = new_table + def __str__(self): return self.quote_name(self.table) @@ -50,6 +66,12 @@ class TableColumns(Table): def references_column(self, table, column): return self.table == table and column in self.columns + def rename_column_references(self, table, old_column, new_column): + if self.table == table: + for index, column in enumerate(self.columns): + if column == old_column: + self.columns[index] = new_column + class Columns(TableColumns): """Hold a reference to one or many columns.""" @@ -92,6 +114,14 @@ class ForeignKeyName(TableColumns): self.to_reference.references_column(table, column) ) + def rename_table_references(self, old_table, new_table): + super().rename_table_references(old_table, new_table) + self.to_reference.rename_table_references(old_table, new_table) + + def rename_column_references(self, table, old_column, new_column): + super().rename_column_references(table, old_column, new_column) + self.to_reference.rename_column_references(table, old_column, new_column) + def __str__(self): suffix = self.suffix_template % { 'to_table': self.to_reference.table, @@ -124,5 +154,15 @@ class Statement(Reference): for part in self.parts.values() ) + def rename_table_references(self, old_table, new_table): + for part in self.parts.values(): + if hasattr(part, 'rename_table_references'): + part.rename_table_references(old_table, new_table) + + def rename_column_references(self, table, old_column, new_column): + for part in self.parts.values(): + if hasattr(part, 'rename_column_references'): + part.rename_column_references(table, old_column, new_column) + def __str__(self): return self.template % self.parts diff --git a/tests/backends/test_ddl_references.py b/tests/backends/test_ddl_references.py index 268eed988b..dc5b8750a0 100644 --- a/tests/backends/test_ddl_references.py +++ b/tests/backends/test_ddl_references.py @@ -12,6 +12,14 @@ class TableTests(SimpleTestCase): self.assertIs(self.reference.references_table('table'), True) self.assertIs(self.reference.references_table('other'), False) + def test_rename_table_references(self): + self.reference.rename_table_references('other', 'table') + self.assertIs(self.reference.references_table('table'), True) + self.assertIs(self.reference.references_table('other'), False) + self.reference.rename_table_references('table', 'other') + self.assertIs(self.reference.references_table('table'), False) + self.assertIs(self.reference.references_table('other'), True) + def test_repr(self): self.assertEqual(repr(self.reference), "") @@ -30,6 +38,18 @@ class ColumnsTests(TableTests): self.assertIs(self.reference.references_column('table', 'third_column'), False) self.assertIs(self.reference.references_column('table', 'first_column'), True) + def test_rename_column_references(self): + self.reference.rename_column_references('other', 'first_column', 'third_column') + self.assertIs(self.reference.references_column('table', 'first_column'), True) + self.assertIs(self.reference.references_column('table', 'third_column'), False) + self.assertIs(self.reference.references_column('other', 'third_column'), False) + self.reference.rename_column_references('table', 'third_column', 'first_column') + self.assertIs(self.reference.references_column('table', 'first_column'), True) + self.assertIs(self.reference.references_column('table', 'third_column'), False) + self.reference.rename_column_references('table', 'first_column', 'third_column') + self.assertIs(self.reference.references_column('table', 'first_column'), False) + self.assertIs(self.reference.references_column('table', 'third_column'), True) + def test_repr(self): self.assertEqual(repr(self.reference), "") @@ -72,6 +92,21 @@ class ForeignKeyNameTests(IndexNameTests): self.assertIs(self.reference.references_column('to_table', 'second_column'), False) self.assertIs(self.reference.references_column('to_table', 'to_second_column'), True) + def test_rename_table_references(self): + super().test_rename_table_references() + self.reference.rename_table_references('to_table', 'other_to_table') + self.assertIs(self.reference.references_table('other_to_table'), True) + self.assertIs(self.reference.references_table('to_table'), False) + + def test_rename_column_references(self): + super().test_rename_column_references() + self.reference.rename_column_references('to_table', 'second_column', 'third_column') + self.assertIs(self.reference.references_column('table', 'second_column'), True) + self.assertIs(self.reference.references_column('to_table', 'to_second_column'), True) + self.reference.rename_column_references('to_table', 'to_first_column', 'to_third_column') + self.assertIs(self.reference.references_column('to_table', 'to_first_column'), False) + self.assertIs(self.reference.references_column('to_table', 'to_third_column'), True) + def test_repr(self): self.assertEqual( repr(self.reference), @@ -99,6 +134,17 @@ class MockReference(object): def references_column(self, table, column): return (table, column) in self.referenced_columns + def rename_table_references(self, old_table, new_table): + if old_table in self.referenced_tables: + self.referenced_tables.remove(old_table) + self.referenced_tables.add(new_table) + + def rename_column_references(self, table, old_column, new_column): + column = (table, old_column) + if column in self.referenced_columns: + self.referenced_columns.remove(column) + self.referenced_columns.add((table, new_column)) + def __str__(self): return self.representation @@ -114,6 +160,18 @@ class StatementTests(SimpleTestCase): self.assertIs(statement.references_column('table', 'column'), True) self.assertIs(statement.references_column('other', 'column'), False) + def test_rename_table_references(self): + reference = MockReference('', {'table'}, {}) + statement = Statement('', reference=reference, non_reference='') + statement.rename_table_references('table', 'other') + self.assertEqual(reference.referenced_tables, {'other'}) + + def test_rename_column_references(self): + reference = MockReference('', {}, {('table', 'column')}) + statement = Statement('', reference=reference, non_reference='') + statement.rename_column_references('table', 'column', 'other') + self.assertEqual(reference.referenced_columns, {('table', 'other')}) + def test_repr(self): reference = MockReference('reference', {}, {}) statement = Statement("%(reference)s - %(non_reference)s", reference=reference, non_reference='non_reference') diff --git a/tests/schema/tests.py b/tests/schema/tests.py index d2707ecea7..f4612a5dee 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -2359,3 +2359,32 @@ class SchemaTests(TransactionTestCase): doc = Document.objects.create(name='Test Name') student = Student.objects.create(name='Some man') doc.students.add(student) + + def test_rename_table_renames_deferred_sql_references(self): + with connection.schema_editor() as editor: + editor.create_model(Author) + editor.create_model(Book) + editor.alter_db_table(Author, 'schema_author', 'schema_renamed_author') + editor.alter_db_table(Author, 'schema_book', 'schema_renamed_book') + self.assertGreater(len(editor.deferred_sql), 0) + for statement in editor.deferred_sql: + self.assertIs(statement.references_table('schema_author'), False) + self.assertIs(statement.references_table('schema_book'), False) + + @unittest.skipIf(connection.vendor == 'sqlite', 'SQLite naively remakes the table on field alteration.') + def test_rename_column_renames_deferred_sql_references(self): + with connection.schema_editor() as editor: + editor.create_model(Author) + editor.create_model(Book) + old_title = Book._meta.get_field('title') + new_title = CharField(max_length=100, db_index=True) + new_title.set_attributes_from_name('renamed_title') + editor.alter_field(Book, old_title, new_title) + old_author = Book._meta.get_field('author') + new_author = ForeignKey(Author, CASCADE) + new_author.set_attributes_from_name('renamed_author') + editor.alter_field(Book, old_author, new_author) + self.assertGreater(len(editor.deferred_sql), 0) + for statement in editor.deferred_sql: + self.assertIs(statement.references_column('book', 'title'), False) + self.assertIs(statement.references_column('book', 'author_id'), False)