From ce5bd42259bc95d372ab0d65dbae793e6251ea80 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Sat, 18 May 2013 11:06:30 +0200 Subject: [PATCH] Turn SchemaEditor into a context manager --- django/db/backends/schema.py | 22 +++++++++++++++------- django/db/migrations/recorder.py | 6 ++---- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index e4a923f2beb..78ea80022ff 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -1,8 +1,10 @@ +import sys import hashlib from django.db.backends.creation import BaseDatabaseCreation from django.db.backends.util import truncate_name from django.utils.log import getLogger from django.db.models.fields.related import ManyToManyField +from django.db.transaction import atomic logger = getLogger('django.db.backends.schema') @@ -64,9 +66,7 @@ class BaseDatabaseSchemaEditor(object): Marks the start of a schema-altering run. """ self.deferred_sql = [] - self.old_autocommit = self.connection.autocommit - if self.connection.autocommit: - self.connection.set_autocommit(False) + atomic(self.connection.alias).__enter__() def commit(self): """ @@ -74,8 +74,7 @@ class BaseDatabaseSchemaEditor(object): """ for sql in self.deferred_sql: self.execute(sql) - self.connection.commit() - self.connection.set_autocommit(self.old_autocommit) + atomic(self.connection.alias).__exit__(None, None, None) def rollback(self): """ @@ -83,8 +82,17 @@ class BaseDatabaseSchemaEditor(object): """ if not self.connection.features.can_rollback_ddl: raise RuntimeError("Cannot rollback schema changes on this backend") - self.connection.rollback() - self.connection.set_autocommit(self.old_autocommit) + atomic(self.connection.alias).__exit__(*sys.exc_info()) + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + self.commit() + else: + self.rollback() # Core utility functions diff --git a/django/db/migrations/recorder.py b/django/db/migrations/recorder.py index 6fb927adef4..a1f111f2bcf 100644 --- a/django/db/migrations/recorder.py +++ b/django/db/migrations/recorder.py @@ -37,10 +37,8 @@ class MigrationRecorder(object): if self.Migration._meta.db_table in self.connection.introspection.get_table_list(self.connection.cursor()): return # Make the table - editor = self.connection.schema_editor() - editor.start() - editor.create_model(self.Migration) - editor.commit() + with self.connection.schema_editor() as editor: + editor.create_model(self.Migration) def applied_migrations(self): """