Turn SchemaEditor into a context manager

This commit is contained in:
Andrew Godwin 2013-05-18 11:06:30 +02:00
parent b31eea069c
commit ce5bd42259
2 changed files with 17 additions and 11 deletions

View File

@ -1,8 +1,10 @@
import sys
import hashlib import hashlib
from django.db.backends.creation import BaseDatabaseCreation from django.db.backends.creation import BaseDatabaseCreation
from django.db.backends.util import truncate_name from django.db.backends.util import truncate_name
from django.utils.log import getLogger from django.utils.log import getLogger
from django.db.models.fields.related import ManyToManyField from django.db.models.fields.related import ManyToManyField
from django.db.transaction import atomic
logger = getLogger('django.db.backends.schema') logger = getLogger('django.db.backends.schema')
@ -64,9 +66,7 @@ class BaseDatabaseSchemaEditor(object):
Marks the start of a schema-altering run. Marks the start of a schema-altering run.
""" """
self.deferred_sql = [] self.deferred_sql = []
self.old_autocommit = self.connection.autocommit atomic(self.connection.alias).__enter__()
if self.connection.autocommit:
self.connection.set_autocommit(False)
def commit(self): def commit(self):
""" """
@ -74,8 +74,7 @@ class BaseDatabaseSchemaEditor(object):
""" """
for sql in self.deferred_sql: for sql in self.deferred_sql:
self.execute(sql) self.execute(sql)
self.connection.commit() atomic(self.connection.alias).__exit__(None, None, None)
self.connection.set_autocommit(self.old_autocommit)
def rollback(self): def rollback(self):
""" """
@ -83,8 +82,17 @@ class BaseDatabaseSchemaEditor(object):
""" """
if not self.connection.features.can_rollback_ddl: if not self.connection.features.can_rollback_ddl:
raise RuntimeError("Cannot rollback schema changes on this backend") raise RuntimeError("Cannot rollback schema changes on this backend")
self.connection.rollback() atomic(self.connection.alias).__exit__(*sys.exc_info())
self.connection.set_autocommit(self.old_autocommit)
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
self.commit()
else:
self.rollback()
# Core utility functions # Core utility functions

View File

@ -37,10 +37,8 @@ class MigrationRecorder(object):
if self.Migration._meta.db_table in self.connection.introspection.get_table_list(self.connection.cursor()): if self.Migration._meta.db_table in self.connection.introspection.get_table_list(self.connection.cursor()):
return return
# Make the table # Make the table
editor = self.connection.schema_editor() with self.connection.schema_editor() as editor:
editor.start() editor.create_model(self.Migration)
editor.create_model(self.Migration)
editor.commit()
def applied_migrations(self): def applied_migrations(self):
""" """