Fixed #27631 -- Prevented execution of transactional DDL statements when unsupported.

Executing a DDL statement during a transaction on backends that don't support
it silently commits, leaving atomic() in an incoherent state.

While schema_editor.execute() could technically be used to execute DML
statements such usage should be uncommon as these are usually performed through
the ORM. In other cases schema_editor.connection.execute() can be used to
circumvent this check.

Thanks Adam and Tim for the review.
This commit is contained in:
Simon Charette 2016-12-23 22:34:52 -05:00
parent 755406f5ff
commit 813805833a
3 changed files with 63 additions and 60 deletions

View File

@ -11,7 +11,7 @@ from django.core.management.base import BaseCommand, CommandError
from django.core.management.sql import (
emit_post_migrate_signal, emit_pre_migrate_signal,
)
from django.db import DEFAULT_DB_ALIAS, connections, router, transaction
from django.db import DEFAULT_DB_ALIAS, connections, router
from django.db.migrations.autodetector import MigrationAutodetector
from django.db.migrations.executor import MigrationExecutor
from django.db.migrations.loader import AmbiguityError
@ -259,18 +259,16 @@ class Command(BaseCommand):
self.stdout.write(self.style.SUCCESS(" DONE" + elapsed))
def sync_apps(self, connection, app_labels):
"Runs the old syncdb-style operation on a list of app_labels."
cursor = connection.cursor()
try:
# Get a list of already installed *models* so that references work right.
"""Run the old syncdb-style operation on a list of app_labels."""
with connection.cursor() as cursor:
tables = connection.introspection.table_names(cursor)
created_models = set()
# Build the manifest of apps and models that are to be synchronized
# Build the manifest of apps and models that are to be synchronized.
all_models = [
(app_config.label,
router.get_migratable_models(app_config, connection.alias, include_auto_created=False))
(
app_config.label,
router.get_migratable_models(app_config, connection.alias, include_auto_created=False),
)
for app_config in apps.get_app_configs()
if app_config.models_module is not None and app_config.label in app_labels
]
@ -278,7 +276,6 @@ class Command(BaseCommand):
def model_installed(model):
opts = model._meta
converter = connection.introspection.table_name_converter
# Note that if a model is unmanaged we short-circuit and never try to install it
return not (
(converter(opts.db_table) in tables) or
(opts.auto_created and converter(opts.auto_created._meta.db_table) in tables)
@ -292,30 +289,20 @@ class Command(BaseCommand):
# Create the tables for each model
if self.verbosity >= 1:
self.stdout.write(" Creating tables...\n")
with transaction.atomic(using=connection.alias, savepoint=connection.features.can_rollback_ddl):
deferred_sql = []
with connection.schema_editor() as editor:
for app_name, model_list in manifest.items():
for model in model_list:
# Never install unmanaged models, etc.
if not model._meta.can_migrate(connection):
continue
if self.verbosity >= 3:
self.stdout.write(
" Processing %s.%s model\n" % (app_name, model._meta.object_name)
)
with connection.schema_editor() as editor:
if self.verbosity >= 1:
self.stdout.write(" Creating table %s\n" % model._meta.db_table)
editor.create_model(model)
deferred_sql.extend(editor.deferred_sql)
editor.deferred_sql = []
created_models.add(model)
# Deferred SQL is executed when exiting the editor's context.
if self.verbosity >= 1:
self.stdout.write(" Running deferred SQL...\n")
with connection.schema_editor() as editor:
for statement in deferred_sql:
editor.execute(statement)
finally:
cursor.close()
return created_models

View File

@ -2,7 +2,7 @@ import hashlib
import logging
from datetime import datetime
from django.db.transaction import atomic
from django.db.transaction import TransactionManagementError, atomic
from django.utils import six, timezone
from django.utils.encoding import force_bytes
@ -98,6 +98,13 @@ class BaseDatabaseSchemaEditor(object):
"""
Executes the given SQL statement, with optional parameters.
"""
# Don't perform the transactional DDL check if SQL is being collected
# as it's not going to be executed anyway.
if not self.collect_sql and self.connection.in_atomic_block and not self.connection.features.can_rollback_ddl:
raise TransactionManagementError(
"Executing DDL statements while in a transaction on databases "
"that can't perform a rollback is prohibited."
)
# Log the command we're running, then run it
logger.debug("%s; (params %r)", sql, params, extra={'params': params, 'sql': sql})
if self.collect_sql:

View File

@ -17,7 +17,7 @@ from django.db.models.fields.related import (
ForeignKey, ForeignObject, ManyToManyField, OneToOneField,
)
from django.db.models.indexes import Index
from django.db.transaction import atomic
from django.db.transaction import TransactionManagementError, atomic
from django.test import (
TransactionTestCase, mock, skipIfDBFeature, skipUnlessDBFeature,
)
@ -76,13 +76,12 @@ class SchemaTests(TransactionTestCase):
def delete_tables(self):
"Deletes all model tables for our models for a clean test environment"
converter = connection.introspection.table_name_converter
with atomic():
with connection.schema_editor() as editor:
connection.disable_constraint_checking()
table_names = connection.introspection.table_names()
for model in itertools.chain(SchemaTests.models, self.local_models):
tbl = converter(model._meta.db_table)
if tbl in table_names:
with connection.schema_editor() as editor:
editor.delete_model(model)
table_names.remove(tbl)
connection.enable_constraint_checking()
@ -1740,6 +1739,16 @@ class SchemaTests(TransactionTestCase):
except SomeError:
self.assertFalse(connection.in_atomic_block)
@skipIfDBFeature('can_rollback_ddl')
def test_unsupported_transactional_ddl_disallowed(self):
message = (
"Executing DDL statements while in a transaction on databases "
"that can't perform a rollback is prohibited."
)
with atomic(), connection.schema_editor() as editor:
with self.assertRaisesMessage(TransactionManagementError, message):
editor.execute(editor.sql_create_table % {'table': 'foo', 'definition': ''})
@skipUnlessDBFeature('supports_foreign_keys')
def test_foreign_key_index_long_names_regression(self):
"""