Fixed #35469 -- Removed deferred SQL to create index removed by AlterField operation.

This commit is contained in:
Jacob Walls 2024-05-25 17:17:15 -04:00 committed by Sarah Boyce
parent d3a7ed5bcc
commit 99f23eaabd
4 changed files with 70 additions and 9 deletions

View File

@ -1582,12 +1582,23 @@ class BaseDatabaseSchemaEditor:
) )
def _delete_index_sql(self, model, name, sql=None): def _delete_index_sql(self, model, name, sql=None):
return Statement( statement = Statement(
sql or self.sql_delete_index, sql or self.sql_delete_index,
table=Table(model._meta.db_table, self.quote_name), table=Table(model._meta.db_table, self.quote_name),
name=self.quote_name(name), name=self.quote_name(name),
) )
# Remove all deferred statements referencing the deleted index.
table_name = statement.parts["table"].table
index_name = statement.parts["name"]
for sql in list(self.deferred_sql):
if isinstance(sql, Statement) and sql.references_index(
table_name, index_name
):
self.deferred_sql.remove(sql)
return statement
def _rename_index_sql(self, model, old_name, new_name): def _rename_index_sql(self, model, old_name, new_name):
return Statement( return Statement(
self.sql_rename_index, self.sql_rename_index,

View File

@ -21,6 +21,12 @@ class Reference:
""" """
return False return False
def references_index(self, table, index):
"""
Return whether or not this instance references the specified index.
"""
return False
def rename_table_references(self, old_table, new_table): def rename_table_references(self, old_table, new_table):
""" """
Rename all references to the old_name to the new_table. Rename all references to the old_name to the new_table.
@ -52,6 +58,9 @@ class Table(Reference):
def references_table(self, table): def references_table(self, table):
return self.table == table return self.table == table
def references_index(self, table, index):
return self.references_table(table) and str(self) == index
def rename_table_references(self, old_table, new_table): def rename_table_references(self, old_table, new_table):
if self.table == old_table: if self.table == old_table:
self.table = new_table self.table = new_table
@ -207,6 +216,12 @@ class Statement(Reference):
for part in self.parts.values() for part in self.parts.values()
) )
def references_index(self, table, index):
return any(
hasattr(part, "references_index") and part.references_index(table, index)
for part in self.parts.values()
)
def rename_table_references(self, old_table, new_table): def rename_table_references(self, old_table, new_table):
for part in self.parts.values(): for part in self.parts.values():
if hasattr(part, "rename_table_references"): if hasattr(part, "rename_table_references"):

View File

@ -166,10 +166,13 @@ class ForeignKeyNameTests(IndexNameTests):
class MockReference: class MockReference:
def __init__(self, representation, referenced_tables, referenced_columns): def __init__(
self, representation, referenced_tables, referenced_columns, referenced_indexes
):
self.representation = representation self.representation = representation
self.referenced_tables = referenced_tables self.referenced_tables = referenced_tables
self.referenced_columns = referenced_columns self.referenced_columns = referenced_columns
self.referenced_indexes = referenced_indexes
def references_table(self, table): def references_table(self, table):
return table in self.referenced_tables return table in self.referenced_tables
@ -177,6 +180,9 @@ class MockReference:
def references_column(self, table, column): def references_column(self, table, column):
return (table, column) in self.referenced_columns return (table, column) in self.referenced_columns
def references_index(self, table, index):
return (table, index) in self.referenced_indexes
def rename_table_references(self, old_table, new_table): def rename_table_references(self, old_table, new_table):
if old_table in self.referenced_tables: if old_table in self.referenced_tables:
self.referenced_tables.remove(old_table) self.referenced_tables.remove(old_table)
@ -195,32 +201,43 @@ class MockReference:
class StatementTests(SimpleTestCase): class StatementTests(SimpleTestCase):
def test_references_table(self): def test_references_table(self):
statement = Statement( statement = Statement(
"", reference=MockReference("", {"table"}, {}), non_reference="" "", reference=MockReference("", {"table"}, {}, {}), non_reference=""
) )
self.assertIs(statement.references_table("table"), True) self.assertIs(statement.references_table("table"), True)
self.assertIs(statement.references_table("other"), False) self.assertIs(statement.references_table("other"), False)
def test_references_column(self): def test_references_column(self):
statement = Statement( statement = Statement(
"", reference=MockReference("", {}, {("table", "column")}), non_reference="" "",
reference=MockReference("", {}, {("table", "column")}, {}),
non_reference="",
) )
self.assertIs(statement.references_column("table", "column"), True) self.assertIs(statement.references_column("table", "column"), True)
self.assertIs(statement.references_column("other", "column"), False) self.assertIs(statement.references_column("other", "column"), False)
def test_references_index(self):
statement = Statement(
"",
reference=MockReference("", {}, {}, {("table", "index")}),
non_reference="",
)
self.assertIs(statement.references_index("table", "index"), True)
self.assertIs(statement.references_index("other", "index"), False)
def test_rename_table_references(self): def test_rename_table_references(self):
reference = MockReference("", {"table"}, {}) reference = MockReference("", {"table"}, {}, {})
statement = Statement("", reference=reference, non_reference="") statement = Statement("", reference=reference, non_reference="")
statement.rename_table_references("table", "other") statement.rename_table_references("table", "other")
self.assertEqual(reference.referenced_tables, {"other"}) self.assertEqual(reference.referenced_tables, {"other"})
def test_rename_column_references(self): def test_rename_column_references(self):
reference = MockReference("", {}, {("table", "column")}) reference = MockReference("", {}, {("table", "column")}, {})
statement = Statement("", reference=reference, non_reference="") statement = Statement("", reference=reference, non_reference="")
statement.rename_column_references("table", "column", "other") statement.rename_column_references("table", "column", "other")
self.assertEqual(reference.referenced_columns, {("table", "other")}) self.assertEqual(reference.referenced_columns, {("table", "other")})
def test_repr(self): def test_repr(self):
reference = MockReference("reference", {}, {}) reference = MockReference("reference", {}, {}, {})
statement = Statement( statement = Statement(
"%(reference)s - %(non_reference)s", "%(reference)s - %(non_reference)s",
reference=reference, reference=reference,
@ -229,7 +246,7 @@ class StatementTests(SimpleTestCase):
self.assertEqual(repr(statement), "<Statement 'reference - non_reference'>") self.assertEqual(repr(statement), "<Statement 'reference - non_reference'>")
def test_str(self): def test_str(self):
reference = MockReference("reference", {}, {}) reference = MockReference("reference", {}, {}, {})
statement = Statement( statement = Statement(
"%(reference)s - %(non_reference)s", "%(reference)s - %(non_reference)s",
reference=reference, reference=reference,

View File

@ -3,7 +3,7 @@ from unittest import skipUnless
from django.conf import settings from django.conf import settings
from django.db import connection from django.db import connection
from django.db.models import CASCADE, ForeignKey, Index, Q from django.db.models import CASCADE, CharField, ForeignKey, Index, Q
from django.db.models.functions import Lower from django.db.models.functions import Lower
from django.test import ( from django.test import (
TestCase, TestCase,
@ -87,6 +87,24 @@ class SchemaIndexesTests(TestCase):
str(index.create_sql(Article, editor)), str(index.create_sql(Article, editor)),
) )
@skipUnlessDBFeature("can_create_inline_fk", "can_rollback_ddl")
def test_alter_field_unique_false_removes_deferred_sql(self):
field_added = CharField(max_length=127, unique=True)
field_added.set_attributes_from_name("charfield_added")
field_to_alter = CharField(max_length=127, unique=True)
field_to_alter.set_attributes_from_name("charfield_altered")
altered_field = CharField(max_length=127, unique=False)
altered_field.set_attributes_from_name("charfield_altered")
with connection.schema_editor() as editor:
editor.add_field(ArticleTranslation, field_added)
editor.add_field(ArticleTranslation, field_to_alter)
self.assertEqual(len(editor.deferred_sql), 2)
editor.alter_field(ArticleTranslation, field_to_alter, altered_field)
self.assertEqual(len(editor.deferred_sql), 1)
self.assertIn("charfield_added", str(editor.deferred_sql[0].parts["name"]))
class SchemaIndexesNotPostgreSQLTests(TransactionTestCase): class SchemaIndexesNotPostgreSQLTests(TransactionTestCase):
available_apps = ["indexes"] available_apps = ["indexes"]