From 3b429c96736b8328c40e5d77282b0d30de563c3c Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Tue, 24 May 2016 15:25:05 -0400 Subject: [PATCH] Refs #25530 -- Tracked references of deferred SQL statements. --- django/db/backends/base/schema.py | 61 ++++++----- django/db/backends/ddl_references.py | 128 ++++++++++++++++++++++++ django/db/backends/sqlite3/schema.py | 10 +- tests/backends/test_ddl_references.py | 125 +++++++++++++++++++++++ tests/indexes/tests.py | 8 +- tests/model_options/test_tablespaces.py | 2 +- 6 files changed, 302 insertions(+), 32 deletions(-) create mode 100644 django/db/backends/ddl_references.py create mode 100644 tests/backends/test_ddl_references.py diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index 20243ae070..bf22711131 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -2,6 +2,9 @@ import hashlib import logging from datetime import datetime +from django.db.backends.ddl_references import ( + Columns, ForeignKeyName, IndexName, Statement, Table, +) from django.db.backends.utils import strip_quotes from django.db.models import Index from django.db.transaction import TransactionManagementError, atomic @@ -97,6 +100,8 @@ class BaseDatabaseSchemaEditor: "Executing DDL statements while in a transaction on databases " "that can't perform a rollback is prohibited." ) + # Account for non-string statement objects. + sql = str(sql) # Log the command we're running, then run it logger.debug("%s; (params %r)", sql, params, extra={'params': params, 'sql': sql}) if self.collect_sql: @@ -878,13 +883,19 @@ class BaseDatabaseSchemaEditor: tablespace_sql = self._get_index_tablespace_sql(model, fields) columns = [field.column for field in fields] sql_create_index = sql or self.sql_create_index - return sql_create_index % { - "table": self.quote_name(model._meta.db_table), - "name": self.quote_name(self._create_index_name(model._meta.db_table, columns, suffix=suffix)), - "using": "", - "columns": ", ".join(self.quote_name(column) for column in columns), - "extra": tablespace_sql, - } + table = model._meta.db_table + + def create_index_name(*args, **kwargs): + return self.quote_name(self._create_index_name(*args, **kwargs)) + + return Statement( + sql_create_index, + table=Table(table, self.quote_name), + name=IndexName(table, columns, suffix, create_index_name), + using='', + columns=Columns(table, columns, self.quote_name), + extra=tablespace_sql, + ) def _model_indexes_sql(self, model): """ @@ -930,26 +941,28 @@ class BaseDatabaseSchemaEditor: from_column = field.column to_table = field.target_field.model._meta.db_table to_column = field.target_field.column - suffix = suffix % { - "to_table": to_table, - "to_column": to_column, - } - return self.sql_create_fk % { - "table": self.quote_name(from_table), - "name": self.quote_name(self._create_index_name(model._meta.db_table, [from_column], suffix=suffix)), - "column": self.quote_name(from_column), - "to_table": self.quote_name(to_table), - "to_column": self.quote_name(to_column), - "deferrable": self.connection.ops.deferrable_sql(), - } + def create_fk_name(*args, **kwargs): + return self.quote_name(self._create_index_name(*args, **kwargs)) + + return Statement( + self.sql_create_fk, + table=Table(from_table, self.quote_name), + name=ForeignKeyName(from_table, [from_column], to_table, [to_column], suffix, create_fk_name), + column=Columns(from_table, [from_column], self.quote_name), + to_table=Table(to_table, self.quote_name), + to_column=Columns(to_table, [to_column], self.quote_name), + deferrable=self.connection.ops.deferrable_sql(), + ) def _create_unique_sql(self, model, columns): - return self.sql_create_unique % { - "table": self.quote_name(model._meta.db_table), - "name": self.quote_name(self._create_index_name(model._meta.db_table, columns, suffix="_uniq")), - "columns": ", ".join(self.quote_name(column) for column in columns), - } + table = model._meta.db_table + return Statement( + self.sql_create_unique, + table=Table(table, self.quote_name), + name=IndexName(table, columns, '_uniq', self._create_index_name), + columns=Columns(table, columns, self.quote_name), + ) def _delete_constraint_sql(self, template, model, name): return template % { diff --git a/django/db/backends/ddl_references.py b/django/db/backends/ddl_references.py new file mode 100644 index 0000000000..dd4d1aa415 --- /dev/null +++ b/django/db/backends/ddl_references.py @@ -0,0 +1,128 @@ +""" +Helpers to manipulate deferred DDL statements that might need to be adjusted or +discarded within when executing a migration. +""" + + +class Reference: + """Base class that defines the reference interface.""" + + def references_table(self, table): + """ + Return whether or not this instance references the specified table. + """ + return False + + def references_column(self, table, column): + """ + Return whether or not this instance references the specified column. + """ + return False + + def __repr__(self): + return '<%s %r>' % (self.__class__.__name__, str(self)) + + def __str__(self): + raise NotImplementedError('Subclasses must define how they should be converted to string.') + + +class Table(Reference): + """Hold a reference to a table.""" + + def __init__(self, table, quote_name): + self.table = table + self.quote_name = quote_name + + def references_table(self, table): + return self.table == table + + def __str__(self): + return self.quote_name(self.table) + + +class TableColumns(Table): + """Base class for references to multiple columns of a table.""" + + def __init__(self, table, columns): + self.table = table + self.columns = columns + + def references_column(self, table, column): + return self.table == table and column in self.columns + + +class Columns(TableColumns): + """Hold a reference to one or many columns.""" + + def __init__(self, table, columns, quote_name): + self.quote_name = quote_name + super().__init__(table, columns) + + def __str__(self): + return ', '.join(self.quote_name(column) for column in self.columns) + + +class IndexName(TableColumns): + """Hold a reference to an index name.""" + + def __init__(self, table, columns, suffix, create_index_name): + self.suffix = suffix + self.create_index_name = create_index_name + super().__init__(table, columns) + + def __str__(self): + return self.create_index_name(self.table, self.columns, self.suffix) + + +class ForeignKeyName(TableColumns): + """Hold a reference to a foreign key name.""" + + def __init__(self, from_table, from_columns, to_table, to_columns, suffix_template, create_fk_name): + self.to_reference = TableColumns(to_table, to_columns) + self.suffix_template = suffix_template + self.create_fk_name = create_fk_name + super().__init__(from_table, from_columns,) + + def references_table(self, table): + return super().references_table(table) or self.to_reference.references_table(table) + + def references_column(self, table, column): + return ( + super().references_column(table, column) or + self.to_reference.references_column(table, column) + ) + + def __str__(self): + suffix = self.suffix_template % { + 'to_table': self.to_reference.table, + 'to_column': self.to_reference.columns[0], + } + return self.create_fk_name(self.table, self.columns, suffix) + + +class Statement(Reference): + """ + Statement template and formatting parameters container. + + Allows keeping a reference to a statement without interpolating identifiers + that might have to be adjusted if they're referencing a table or column + that is removed + """ + def __init__(self, template, **parts): + self.template = template + self.parts = parts + + def references_table(self, table): + return any( + hasattr(part, 'references_table') and part.references_table(table) + for part in self.parts.values() + ) + + def references_column(self, table, column): + return any( + hasattr(part, 'references_column') and part.references_column(table, column) + for part in self.parts.values() + ) + + def __str__(self): + return self.template % self.parts diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index 10d7c623f8..5517d78a97 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -5,6 +5,7 @@ from decimal import Decimal from django.apps.registry import Apps from django.db.backends.base.schema import BaseDatabaseSchemaEditor +from django.db.backends.ddl_references import Statement class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): @@ -189,9 +190,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): # Rename the old table to make way for the new self.alter_db_table(model, temp_model._meta.db_table, model._meta.db_table) - # Create a new table with the updated schema. We remove things - # from the deferred SQL that match our table name, too - self.deferred_sql = [x for x in self.deferred_sql if temp_model._meta.db_table not in x] + # Remove all deferred statements referencing the temporary table. + for sql in list(self.deferred_sql): + if isinstance(sql, Statement) and sql.references_table(temp_model._meta.db_table): + self.deferred_sql.remove(sql) + + # Create a new table with the updated schema. self.create_model(temp_model) # Copy data from the old table into the new table diff --git a/tests/backends/test_ddl_references.py b/tests/backends/test_ddl_references.py new file mode 100644 index 0000000000..268eed988b --- /dev/null +++ b/tests/backends/test_ddl_references.py @@ -0,0 +1,125 @@ +from django.db.backends.ddl_references import ( + Columns, ForeignKeyName, IndexName, Statement, Table, +) +from django.test import SimpleTestCase + + +class TableTests(SimpleTestCase): + def setUp(self): + self.reference = Table('table', lambda table: table.upper()) + + def test_references_table(self): + self.assertIs(self.reference.references_table('table'), True) + self.assertIs(self.reference.references_table('other'), False) + + def test_repr(self): + self.assertEqual(repr(self.reference), "") + + def test_str(self): + self.assertEqual(str(self.reference), 'TABLE') + + +class ColumnsTests(TableTests): + def setUp(self): + self.reference = Columns( + 'table', ['first_column', 'second_column'], lambda column: column.upper() + ) + + def test_references_column(self): + self.assertIs(self.reference.references_column('other', 'first_column'), False) + self.assertIs(self.reference.references_column('table', 'third_column'), False) + self.assertIs(self.reference.references_column('table', 'first_column'), True) + + def test_repr(self): + self.assertEqual(repr(self.reference), "") + + def test_str(self): + self.assertEqual(str(self.reference), 'FIRST_COLUMN, SECOND_COLUMN') + + +class IndexNameTests(ColumnsTests): + def setUp(self): + def create_index_name(table_name, column_names, suffix): + return ', '.join("%s_%s_%s" % (table_name, column_name, suffix) for column_name in column_names) + self.reference = IndexName( + 'table', ['first_column', 'second_column'], 'suffix', create_index_name + ) + + def test_repr(self): + self.assertEqual(repr(self.reference), "") + + def test_str(self): + self.assertEqual(str(self.reference), 'table_first_column_suffix, table_second_column_suffix') + + +class ForeignKeyNameTests(IndexNameTests): + def setUp(self): + def create_foreign_key_name(table_name, column_names, suffix): + return ', '.join("%s_%s_%s" % (table_name, column_name, suffix) for column_name in column_names) + self.reference = ForeignKeyName( + 'table', ['first_column', 'second_column'], + 'to_table', ['to_first_column', 'to_second_column'], + '%(to_table)s_%(to_column)s_fk', + create_foreign_key_name, + ) + + def test_references_table(self): + super().test_references_table() + self.assertIs(self.reference.references_table('to_table'), True) + + def test_references_column(self): + super().test_references_column() + self.assertIs(self.reference.references_column('to_table', 'second_column'), False) + self.assertIs(self.reference.references_column('to_table', 'to_second_column'), True) + + def test_repr(self): + self.assertEqual( + repr(self.reference), + "" + ) + + def test_str(self): + self.assertEqual( + str(self.reference), + 'table_first_column_to_table_to_first_column_fk, ' + 'table_second_column_to_table_to_first_column_fk' + ) + + +class MockReference(object): + def __init__(self, representation, referenced_tables, referenced_columns): + self.representation = representation + self.referenced_tables = referenced_tables + self.referenced_columns = referenced_columns + + def references_table(self, table): + return table in self.referenced_tables + + def references_column(self, table, column): + return (table, column) in self.referenced_columns + + def __str__(self): + return self.representation + + +class StatementTests(SimpleTestCase): + def test_references_table(self): + statement = Statement('', reference=MockReference('', {'table'}, {}), non_reference='') + self.assertIs(statement.references_table('table'), True) + self.assertIs(statement.references_table('other'), False) + + def test_references_column(self): + statement = Statement('', reference=MockReference('', {}, {('table', 'column')}), non_reference='') + self.assertIs(statement.references_column('table', 'column'), True) + self.assertIs(statement.references_column('other', 'column'), False) + + def test_repr(self): + reference = MockReference('reference', {}, {}) + statement = Statement("%(reference)s - %(non_reference)s", reference=reference, non_reference='non_reference') + self.assertEqual(repr(statement), "") + + def test_str(self): + reference = MockReference('reference', {}, {}) + statement = Statement("%(reference)s - %(non_reference)s", reference=reference, non_reference='non_reference') + self.assertEqual(str(statement), 'reference - non_reference') diff --git a/tests/indexes/tests.py b/tests/indexes/tests.py index c2d76feeb9..ee2cbd1564 100644 --- a/tests/indexes/tests.py +++ b/tests/indexes/tests.py @@ -51,7 +51,7 @@ class SchemaIndexesTests(TestCase): def test_index_together(self): editor = connection.schema_editor() - index_sql = editor._model_indexes_sql(Article) + index_sql = [str(statement) for statement in editor._model_indexes_sql(Article)] self.assertEqual(len(index_sql), 1) # Ensure the index name is properly quoted self.assertIn( @@ -70,7 +70,7 @@ class SchemaIndexesTests(TestCase): def test_postgresql_text_indexes(self): """Test creation of PostgreSQL-specific text indexes (#12234)""" from .models import IndexedArticle - index_sql = connection.schema_editor()._model_indexes_sql(IndexedArticle) + index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(IndexedArticle)] self.assertEqual(len(index_sql), 5) self.assertIn('("headline" varchar_pattern_ops)', index_sql[1]) self.assertIn('("body" text_pattern_ops)', index_sql[3]) @@ -99,7 +99,7 @@ class SchemaIndexesMySQLTests(TransactionTestCase): ) if storage != "InnoDB": self.skip("This test only applies to the InnoDB storage engine") - index_sql = connection.schema_editor()._model_indexes_sql(ArticleTranslation) + index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(ArticleTranslation)] self.assertEqual(index_sql, [ 'CREATE INDEX `indexes_articletranslation_article_no_constraint_id_d6c0806b` ' 'ON `indexes_articletranslation` (`article_no_constraint_id`)' @@ -114,7 +114,7 @@ class SchemaIndexesMySQLTests(TransactionTestCase): new_field.set_attributes_from_name('new_foreign_key') editor.add_field(ArticleTranslation, new_field) field_created = True - self.assertEqual(editor.deferred_sql, [ + self.assertEqual([str(statement) for statement in editor.deferred_sql], [ 'ALTER TABLE `indexes_articletranslation` ' 'ADD CONSTRAINT `indexes_articletrans_new_foreign_key_id_d27a9146_fk_indexes_a` ' 'FOREIGN KEY (`new_foreign_key_id`) REFERENCES `indexes_article` (`id`)' diff --git a/tests/model_options/test_tablespaces.py b/tests/model_options/test_tablespaces.py index 03a137603b..79b0a8bb75 100644 --- a/tests/model_options/test_tablespaces.py +++ b/tests/model_options/test_tablespaces.py @@ -15,7 +15,7 @@ def sql_for_table(model): def sql_for_index(model): - return '\n'.join(connection.schema_editor()._model_indexes_sql(model)) + return '\n'.join(str(sql) for sql in connection.schema_editor()._model_indexes_sql(model)) # We can't test the DEFAULT_TABLESPACE and DEFAULT_INDEX_TABLESPACE settings