diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 2abc81ae5d..890249a77d 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -1,9 +1,14 @@ import decimal +try: + import thread +except ImportError: + import dummy_thread as thread from threading import local from django.conf import settings from django.db import DEFAULT_DB_ALIAS from django.db.backends import util +from django.db.transaction import TransactionManagementError from django.utils import datetime_safe from django.utils.importlib import import_module @@ -28,7 +33,7 @@ class BaseDatabaseWrapper(local): # Transaction related attributes self.transaction_state = [] self.savepoint_state = 0 - self.dirty = None + self._dirty = None def __eq__(self, other): return self.alias == other.alias @@ -74,6 +79,166 @@ class BaseDatabaseWrapper(local): return self.cursor().execute(self.ops.savepoint_commit_sql(sid)) + def enter_transaction_management(self, managed=True): + """ + Enters transaction management for a running thread. It must be balanced with + the appropriate leave_transaction_management call, since the actual state is + managed as a stack. + + The state and dirty flag are carried over from the surrounding block or + from the settings, if there is no surrounding block (dirty is always false + when no current block is running). + """ + if self.transaction_state: + self.transaction_state.append(self.transaction_state[-1]) + else: + self.transaction_state.append(settings.TRANSACTIONS_MANAGED) + + if self._dirty is None: + self._dirty = False + self._enter_transaction_management(managed) + + def leave_transaction_management(self): + """ + Leaves transaction management for a running thread. A dirty flag is carried + over to the surrounding block, as a commit will commit all changes, even + those from outside. (Commits are on connection level.) + """ + self._leave_transaction_management(self.is_managed()) + if self.transaction_state: + del self.transaction_state[-1] + else: + raise TransactionManagementError("This code isn't under transaction " + "management") + if self._dirty: + self.rollback() + raise TransactionManagementError("Transaction managed block ended with " + "pending COMMIT/ROLLBACK") + self._dirty = False + + def is_dirty(self): + """ + Returns True if the current transaction requires a commit for changes to + happen. + """ + return self._dirty + + def set_dirty(self): + """ + Sets a dirty flag for the current thread and code streak. This can be used + to decide in a managed block of code to decide whether there are open + changes waiting for commit. + """ + if self._dirty is not None: + self._dirty = True + else: + raise TransactionManagementError("This code isn't under transaction " + "management") + + def set_clean(self): + """ + Resets a dirty flag for the current thread and code streak. This can be used + to decide in a managed block of code to decide whether a commit or rollback + should happen. + """ + if self._dirty is not None: + self._dirty = False + else: + raise TransactionManagementError("This code isn't under transaction management") + self.clean_savepoints() + + def clean_savepoints(self): + self.savepoint_state = 0 + + def is_managed(self): + """ + Checks whether the transaction manager is in manual or in auto state. + """ + if self.transaction_state: + return self.transaction_state[-1] + return settings.TRANSACTIONS_MANAGED + + def managed(self, flag=True): + """ + Puts the transaction manager into a manual state: managed transactions have + to be committed explicitly by the user. If you switch off transaction + management and there is a pending commit/rollback, the data will be + commited. + """ + top = self.transaction_state + if top: + top[-1] = flag + if not flag and self.is_dirty(): + self._commit() + self.set_clean() + else: + raise TransactionManagementError("This code isn't under transaction " + "management") + + def commit_unless_managed(self): + """ + Commits changes if the system is not in managed transaction mode. + """ + if not self.is_managed(): + self._commit() + self.clean_savepoints() + else: + self.set_dirty() + + def rollback_unless_managed(self): + """ + Rolls back changes if the system is not in managed transaction mode. + """ + if not self.is_managed(): + self._rollback() + else: + self.set_dirty() + + def commit(self): + """ + Does the commit itself and resets the dirty flag. + """ + self._commit() + self.set_clean() + + def rollback(self): + """ + This function does the rollback itself and resets the dirty flag. + """ + self._rollback() + self.set_clean() + + def savepoint(self): + """ + Creates a savepoint (if supported and required by the backend) inside the + current transaction. Returns an identifier for the savepoint that will be + used for the subsequent rollback or commit. + """ + thread_ident = thread.get_ident() + + self.savepoint_state += 1 + + tid = str(thread_ident).replace('-', '') + sid = "s%s_x%d" % (tid, self.savepoint_state) + self._savepoint(sid) + return sid + + def savepoint_rollback(self, sid): + """ + Rolls back the most recent savepoint (if one exists). Does nothing if + savepoints are not supported. + """ + if self.savepoint_state: + self._savepoint_rollback(sid) + + def savepoint_commit(self, sid): + """ + Commits the most recent savepoint (if one exists). Does nothing if + savepoints are not supported. + """ + if self.savepoint_state: + self._savepoint_commit(sid) + def close(self): if self.connection is not None: self.connection.close() diff --git a/django/db/transaction.py b/django/db/transaction.py index a565da2466..b5584dd8b9 100644 --- a/django/db/transaction.py +++ b/django/db/transaction.py @@ -13,10 +13,6 @@ or implicit commits or rollbacks. """ import sys -try: - import thread -except ImportError: - import dummy_thread as thread try: from functools import wraps except ImportError: @@ -46,15 +42,7 @@ def enter_transaction_management(managed=True, using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - - if connection.transaction_state: - connection.transaction_state.append(connection.transaction_state[-1]) - else: - connection.transaction_state.append(settings.TRANSACTIONS_MANAGED) - - if connection.dirty is None: - connection.dirty = False - connection._enter_transaction_management(managed) + connection.enter_transaction_management(managed) def leave_transaction_management(using=None): """ @@ -65,18 +53,7 @@ def leave_transaction_management(using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - - connection._leave_transaction_management(is_managed(using=using)) - if connection.transaction_state: - del connection.transaction_state[-1] - else: - raise TransactionManagementError("This code isn't under transaction " - "management") - if connection.dirty: - rollback(using=using) - raise TransactionManagementError("Transaction managed block ended with " - "pending COMMIT/ROLLBACK") - connection.dirty = False + connection.leave_transaction_management() def is_dirty(using=None): """ @@ -86,8 +63,7 @@ def is_dirty(using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - - return connection.dirty + return connection.is_dirty() def set_dirty(using=None): """ @@ -98,12 +74,7 @@ def set_dirty(using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - - if connection.dirty is not None: - connection.dirty = True - else: - raise TransactionManagementError("This code isn't under transaction " - "management") + connection.set_dirty() def set_clean(using=None): """ @@ -114,18 +85,13 @@ def set_clean(using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - - if connection.dirty is not None: - connection.dirty = False - else: - raise TransactionManagementError("This code isn't under transaction management") - clean_savepoints(using=using) + connection.set_clean() def clean_savepoints(using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - connection.savepoint_state = 0 + connection.clean_savepoints() def is_managed(using=None): """ @@ -134,9 +100,7 @@ def is_managed(using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - if connection.transaction_state: - return connection.transaction_state[-1] - return settings.TRANSACTIONS_MANAGED + return connection.is_managed() def managed(flag=True, using=None): """ @@ -148,16 +112,7 @@ def managed(flag=True, using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - - top = connection.transaction_state - if top: - top[-1] = flag - if not flag and is_dirty(using=using): - connection._commit() - set_clean(using=using) - else: - raise TransactionManagementError("This code isn't under transaction " - "management") + connection.managed(flag) def commit_unless_managed(using=None): """ @@ -166,11 +121,7 @@ def commit_unless_managed(using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - if not is_managed(using=using): - connection._commit() - clean_savepoints(using=using) - else: - set_dirty(using=using) + connection.commit_unless_managed() def rollback_unless_managed(using=None): """ @@ -179,10 +130,7 @@ def rollback_unless_managed(using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - if not is_managed(using=using): - connection._rollback() - else: - set_dirty(using=using) + connection.rollback_unless_managed() def commit(using=None): """ @@ -191,8 +139,7 @@ def commit(using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - connection._commit() - set_clean(using=using) + connection.commit() def rollback(using=None): """ @@ -201,8 +148,7 @@ def rollback(using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - connection._rollback() - set_clean(using=using) + connection.rollback() def savepoint(using=None): """ @@ -213,14 +159,7 @@ def savepoint(using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - thread_ident = thread.get_ident() - - connection.savepoint_state += 1 - - tid = str(thread_ident).replace('-', '') - sid = "s%s_x%d" % (tid, connection.savepoint_state) - connection._savepoint(sid) - return sid + return connection.savepoint() def savepoint_rollback(sid, using=None): """ @@ -230,9 +169,7 @@ def savepoint_rollback(sid, using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - - if connection.savepoint_state: - connection._savepoint_rollback(sid) + connection.savepoint_rollback(sid) def savepoint_commit(sid, using=None): """ @@ -242,9 +179,7 @@ def savepoint_commit(sid, using=None): if using is None: using = DEFAULT_DB_ALIAS connection = connections[using] - - if connection.savepoint_state: - connection._savepoint_commit(sid) + connection.savepoint_commit(sid) ############## # DECORATORS #