Fixed #27631 -- Prevented execution of transactional DDL statements when unsupported.

Executing a DDL statement during a transaction on backends that don't support
it silently commits, leaving atomic() in an incoherent state.

While schema_editor.execute() could technically be used to execute DML
statements such usage should be uncommon as these are usually performed through
the ORM. In other cases schema_editor.connection.execute() can be used to
circumvent this check.

Thanks Adam and Tim for the review.
This commit is contained in:
Simon Charette 2016-12-23 22:34:52 -05:00
parent 755406f5ff
commit 813805833a
3 changed files with 63 additions and 60 deletions

View File

@ -11,7 +11,7 @@ from django.core.management.base import BaseCommand, CommandError
from django.core.management.sql import ( from django.core.management.sql import (
emit_post_migrate_signal, emit_pre_migrate_signal, emit_post_migrate_signal, emit_pre_migrate_signal,
) )
from django.db import DEFAULT_DB_ALIAS, connections, router, transaction from django.db import DEFAULT_DB_ALIAS, connections, router
from django.db.migrations.autodetector import MigrationAutodetector from django.db.migrations.autodetector import MigrationAutodetector
from django.db.migrations.executor import MigrationExecutor from django.db.migrations.executor import MigrationExecutor
from django.db.migrations.loader import AmbiguityError from django.db.migrations.loader import AmbiguityError
@ -259,63 +259,50 @@ class Command(BaseCommand):
self.stdout.write(self.style.SUCCESS(" DONE" + elapsed)) self.stdout.write(self.style.SUCCESS(" DONE" + elapsed))
def sync_apps(self, connection, app_labels): def sync_apps(self, connection, app_labels):
"Runs the old syncdb-style operation on a list of app_labels." """Run the old syncdb-style operation on a list of app_labels."""
cursor = connection.cursor() with connection.cursor() as cursor:
try:
# Get a list of already installed *models* so that references work right.
tables = connection.introspection.table_names(cursor) tables = connection.introspection.table_names(cursor)
created_models = set()
# Build the manifest of apps and models that are to be synchronized # Build the manifest of apps and models that are to be synchronized.
all_models = [ all_models = [
(app_config.label, (
router.get_migratable_models(app_config, connection.alias, include_auto_created=False)) app_config.label,
for app_config in apps.get_app_configs() router.get_migratable_models(app_config, connection.alias, include_auto_created=False),
if app_config.models_module is not None and app_config.label in app_labels )
] for app_config in apps.get_app_configs()
if app_config.models_module is not None and app_config.label in app_labels
]
def model_installed(model): def model_installed(model):
opts = model._meta opts = model._meta
converter = connection.introspection.table_name_converter converter = connection.introspection.table_name_converter
# Note that if a model is unmanaged we short-circuit and never try to install it return not (
return not ( (converter(opts.db_table) in tables) or
(converter(opts.db_table) in tables) or (opts.auto_created and converter(opts.auto_created._meta.db_table) in tables)
(opts.auto_created and converter(opts.auto_created._meta.db_table) in tables)
)
manifest = OrderedDict(
(app_name, list(filter(model_installed, model_list)))
for app_name, model_list in all_models
) )
# Create the tables for each model manifest = OrderedDict(
(app_name, list(filter(model_installed, model_list)))
for app_name, model_list in all_models
)
# Create the tables for each model
if self.verbosity >= 1:
self.stdout.write(" Creating tables...\n")
with connection.schema_editor() as editor:
for app_name, model_list in manifest.items():
for model in model_list:
# Never install unmanaged models, etc.
if not model._meta.can_migrate(connection):
continue
if self.verbosity >= 3:
self.stdout.write(
" Processing %s.%s model\n" % (app_name, model._meta.object_name)
)
if self.verbosity >= 1:
self.stdout.write(" Creating table %s\n" % model._meta.db_table)
editor.create_model(model)
# Deferred SQL is executed when exiting the editor's context.
if self.verbosity >= 1: if self.verbosity >= 1:
self.stdout.write(" Creating tables...\n") self.stdout.write(" Running deferred SQL...\n")
with transaction.atomic(using=connection.alias, savepoint=connection.features.can_rollback_ddl):
deferred_sql = []
for app_name, model_list in manifest.items():
for model in model_list:
if not model._meta.can_migrate(connection):
continue
if self.verbosity >= 3:
self.stdout.write(
" Processing %s.%s model\n" % (app_name, model._meta.object_name)
)
with connection.schema_editor() as editor:
if self.verbosity >= 1:
self.stdout.write(" Creating table %s\n" % model._meta.db_table)
editor.create_model(model)
deferred_sql.extend(editor.deferred_sql)
editor.deferred_sql = []
created_models.add(model)
if self.verbosity >= 1:
self.stdout.write(" Running deferred SQL...\n")
with connection.schema_editor() as editor:
for statement in deferred_sql:
editor.execute(statement)
finally:
cursor.close()
return created_models

View File

@ -2,7 +2,7 @@ import hashlib
import logging import logging
from datetime import datetime from datetime import datetime
from django.db.transaction import atomic from django.db.transaction import TransactionManagementError, atomic
from django.utils import six, timezone from django.utils import six, timezone
from django.utils.encoding import force_bytes from django.utils.encoding import force_bytes
@ -98,6 +98,13 @@ class BaseDatabaseSchemaEditor(object):
""" """
Executes the given SQL statement, with optional parameters. Executes the given SQL statement, with optional parameters.
""" """
# Don't perform the transactional DDL check if SQL is being collected
# as it's not going to be executed anyway.
if not self.collect_sql and self.connection.in_atomic_block and not self.connection.features.can_rollback_ddl:
raise TransactionManagementError(
"Executing DDL statements while in a transaction on databases "
"that can't perform a rollback is prohibited."
)
# Log the command we're running, then run it # Log the command we're running, then run it
logger.debug("%s; (params %r)", sql, params, extra={'params': params, 'sql': sql}) logger.debug("%s; (params %r)", sql, params, extra={'params': params, 'sql': sql})
if self.collect_sql: if self.collect_sql:

View File

@ -17,7 +17,7 @@ from django.db.models.fields.related import (
ForeignKey, ForeignObject, ManyToManyField, OneToOneField, ForeignKey, ForeignObject, ManyToManyField, OneToOneField,
) )
from django.db.models.indexes import Index from django.db.models.indexes import Index
from django.db.transaction import atomic from django.db.transaction import TransactionManagementError, atomic
from django.test import ( from django.test import (
TransactionTestCase, mock, skipIfDBFeature, skipUnlessDBFeature, TransactionTestCase, mock, skipIfDBFeature, skipUnlessDBFeature,
) )
@ -76,14 +76,13 @@ class SchemaTests(TransactionTestCase):
def delete_tables(self): def delete_tables(self):
"Deletes all model tables for our models for a clean test environment" "Deletes all model tables for our models for a clean test environment"
converter = connection.introspection.table_name_converter converter = connection.introspection.table_name_converter
with atomic(): with connection.schema_editor() as editor:
connection.disable_constraint_checking() connection.disable_constraint_checking()
table_names = connection.introspection.table_names() table_names = connection.introspection.table_names()
for model in itertools.chain(SchemaTests.models, self.local_models): for model in itertools.chain(SchemaTests.models, self.local_models):
tbl = converter(model._meta.db_table) tbl = converter(model._meta.db_table)
if tbl in table_names: if tbl in table_names:
with connection.schema_editor() as editor: editor.delete_model(model)
editor.delete_model(model)
table_names.remove(tbl) table_names.remove(tbl)
connection.enable_constraint_checking() connection.enable_constraint_checking()
@ -1740,6 +1739,16 @@ class SchemaTests(TransactionTestCase):
except SomeError: except SomeError:
self.assertFalse(connection.in_atomic_block) self.assertFalse(connection.in_atomic_block)
@skipIfDBFeature('can_rollback_ddl')
def test_unsupported_transactional_ddl_disallowed(self):
message = (
"Executing DDL statements while in a transaction on databases "
"that can't perform a rollback is prohibited."
)
with atomic(), connection.schema_editor() as editor:
with self.assertRaisesMessage(TransactionManagementError, message):
editor.execute(editor.sql_create_table % {'table': 'foo', 'definition': ''})
@skipUnlessDBFeature('supports_foreign_keys') @skipUnlessDBFeature('supports_foreign_keys')
def test_foreign_key_index_long_names_regression(self): def test_foreign_key_index_long_names_regression(self):
""" """