diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 9fb2b236442..fe26c98baf9 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -40,20 +40,24 @@ class BaseDatabaseWrapper(object): self.alias = alias self.use_debug_cursor = None - # Transaction related attributes - self.transaction_state = [] + # Savepoint management related attributes self.savepoint_state = 0 + + # Transaction management related attributes + self.transaction_state = [] # Tracks if the connection is believed to be in transaction. This is # set somewhat aggressively, as the DBAPI doesn't make it easy to # deduce if the connection is in transaction or not. self._dirty = False - self._thread_ident = thread.get_ident() - self.allow_thread_sharing = allow_thread_sharing # Connection termination related attributes self.close_at = None self.errors_occurred = False + # Thread-safety related attributes + self.allow_thread_sharing = allow_thread_sharing + self._thread_ident = thread.get_ident() + def __eq__(self, other): return self.alias == other.alias @@ -63,21 +67,26 @@ class BaseDatabaseWrapper(object): def __hash__(self): return hash(self.alias) - def wrap_database_errors(self): - return DatabaseErrorWrapper(self) + ##### Backend-specific methods for creating connections and cursors ##### def get_connection_params(self): + """Returns a dict of parameters suitable for get_new_connection.""" raise NotImplementedError def get_new_connection(self, conn_params): + """Opens a connection to the database.""" raise NotImplementedError def init_connection_state(self): + """Initializes the database connection settings.""" raise NotImplementedError def create_cursor(self): + """Creates a cursor. Assumes that a connection is established.""" raise NotImplementedError + ##### Backend-specific wrappers for PEP-249 connection methods ##### + def _cursor(self): with self.wrap_database_errors(): if self.connection is None: @@ -107,20 +116,48 @@ class BaseDatabaseWrapper(object): with self.wrap_database_errors(): return self.connection.close() - def _enter_transaction_management(self, managed): - """ - A hook for backend-specific changes required when entering manual - transaction handling. - """ - pass + ##### Generic wrappers for PEP-249 connection methods ##### - def _leave_transaction_management(self, managed): + def cursor(self): """ - A hook for backend-specific changes required when leaving manual - transaction handling. Will usually be implemented only when - _enter_transaction_management() is also required. + Creates a cursor, opening a connection if necessary. """ - pass + self.validate_thread_sharing() + if (self.use_debug_cursor or + (self.use_debug_cursor is None and settings.DEBUG)): + cursor = self.make_debug_cursor(self._cursor()) + else: + cursor = util.CursorWrapper(self._cursor(), self) + return cursor + + def commit(self): + """ + Does the commit itself and resets the dirty flag. + """ + self.validate_thread_sharing() + self._commit() + self.set_clean() + + def rollback(self): + """ + Does the rollback itself and resets the dirty flag. + """ + self.validate_thread_sharing() + self._rollback() + self.set_clean() + + def close(self): + """ + Closes the connection to the database. + """ + self.validate_thread_sharing() + try: + self._close() + finally: + self.connection = None + self.set_clean() + + ##### Backend-specific savepoint management methods ##### def _savepoint(self, sid): if not self.features.uses_savepoints: @@ -137,15 +174,65 @@ class BaseDatabaseWrapper(object): return self.cursor().execute(self.ops.savepoint_commit_sql(sid)) - def abort(self): + ##### Generic savepoint management methods ##### + + def savepoint(self): """ - Roll back any ongoing transaction and clean the transaction state - stack. + Creates a savepoint (if supported and required by the backend) inside the + current transaction. Returns an identifier for the savepoint that will be + used for the subsequent rollback or commit. """ - if self._dirty: - self.rollback() - while self.transaction_state: - self.leave_transaction_management() + thread_ident = thread.get_ident() + + self.savepoint_state += 1 + + tid = str(thread_ident).replace('-', '') + sid = "s%s_x%d" % (tid, self.savepoint_state) + self._savepoint(sid) + return sid + + def savepoint_rollback(self, sid): + """ + Rolls back the most recent savepoint (if one exists). Does nothing if + savepoints are not supported. + """ + self.validate_thread_sharing() + if self.savepoint_state: + self._savepoint_rollback(sid) + + def savepoint_commit(self, sid): + """ + Commits the most recent savepoint (if one exists). Does nothing if + savepoints are not supported. + """ + self.validate_thread_sharing() + if self.savepoint_state: + self._savepoint_commit(sid) + + def clean_savepoints(self): + """ + Resets the counter used to generate unique savepoint ids in this thread. + """ + self.savepoint_state = 0 + + ##### Backend-specific transaction management methods ##### + + def _enter_transaction_management(self, managed): + """ + A hook for backend-specific changes required when entering manual + transaction handling. + """ + pass + + def _leave_transaction_management(self, managed): + """ + A hook for backend-specific changes required when leaving manual + transaction handling. Will usually be implemented only when + _enter_transaction_management() is also required. + """ + pass + + ##### Generic transaction management methods ##### def enter_transaction_management(self, managed=True): """ @@ -185,20 +272,15 @@ class BaseDatabaseWrapper(object): raise TransactionManagementError( "Transaction managed block ended with pending COMMIT/ROLLBACK") - def validate_thread_sharing(self): + def abort(self): """ - Validates that the connection isn't accessed by another thread than the - one which originally created it, unless the connection was explicitly - authorized to be shared between threads (via the `allow_thread_sharing` - property). Raises an exception if the validation fails. + Roll back any ongoing transaction and clean the transaction state + stack. """ - if (not self.allow_thread_sharing - and self._thread_ident != thread.get_ident()): - raise DatabaseError("DatabaseWrapper objects created in a " - "thread can only be used in that same thread. The object " - "with alias '%s' was created in thread id %s and this is " - "thread id %s." - % (self.alias, self._thread_ident, thread.get_ident())) + if self._dirty: + self.rollback() + while self.transaction_state: + self.leave_transaction_management() def is_dirty(self): """ @@ -224,12 +306,6 @@ class BaseDatabaseWrapper(object): self._dirty = False self.clean_savepoints() - def clean_savepoints(self): - """ - Resets the counter used to generate unique savepoint ids in this thread. - """ - self.savepoint_state = 0 - def is_managed(self): """ Checks whether the transaction manager is in manual or in auto state. @@ -275,57 +351,13 @@ class BaseDatabaseWrapper(object): else: self.set_dirty() - def commit(self): - """ - Does the commit itself and resets the dirty flag. - """ - self.validate_thread_sharing() - self._commit() - self.set_clean() - - def rollback(self): - """ - This function does the rollback itself and resets the dirty flag. - """ - self.validate_thread_sharing() - self._rollback() - self.set_clean() - - def savepoint(self): - """ - Creates a savepoint (if supported and required by the backend) inside the - current transaction. Returns an identifier for the savepoint that will be - used for the subsequent rollback or commit. - """ - thread_ident = thread.get_ident() - - self.savepoint_state += 1 - - tid = str(thread_ident).replace('-', '') - sid = "s%s_x%d" % (tid, self.savepoint_state) - self._savepoint(sid) - return sid - - def savepoint_rollback(self, sid): - """ - Rolls back the most recent savepoint (if one exists). Does nothing if - savepoints are not supported. - """ - self.validate_thread_sharing() - if self.savepoint_state: - self._savepoint_rollback(sid) - - def savepoint_commit(self, sid): - """ - Commits the most recent savepoint (if one exists). Does nothing if - savepoints are not supported. - """ - self.validate_thread_sharing() - if self.savepoint_state: - self._savepoint_commit(sid) + ##### Foreign key constraints checks handling ##### @contextmanager def constraint_checks_disabled(self): + """ + Context manager that disables foreign key constraint checking. + """ disabled = self.disable_constraint_checking() try: yield @@ -335,33 +367,40 @@ class BaseDatabaseWrapper(object): def disable_constraint_checking(self): """ - Backends can implement as needed to temporarily disable foreign key constraint - checking. + Backends can implement as needed to temporarily disable foreign key + constraint checking. """ pass def enable_constraint_checking(self): """ - Backends can implement as needed to re-enable foreign key constraint checking. + Backends can implement as needed to re-enable foreign key constraint + checking. """ pass def check_constraints(self, table_names=None): """ - Backends can override this method if they can apply constraint checking (e.g. via "SET CONSTRAINTS - ALL IMMEDIATE"). Should raise an IntegrityError if any invalid foreign key references are encountered. + Backends can override this method if they can apply constraint + checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an + IntegrityError if any invalid foreign key references are encountered. """ pass - def close(self): - self.validate_thread_sharing() - try: - self._close() - finally: - self.connection = None - self.set_clean() + ##### Connection termination handling ##### + + def is_usable(self): + """ + Tests if the database connection is usable. + This function may assume that self.connection is not None. + """ + raise NotImplementedError def close_if_unusable_or_obsolete(self): + """ + Closes the current connection if unrecoverable errors have occurred, + or if it outlived its maximum age. + """ if self.connection is not None: if self.errors_occurred: if self.is_usable(): @@ -373,30 +412,45 @@ class BaseDatabaseWrapper(object): self.close() return - def is_usable(self): - """ - Test if the database connection is usable. + ##### Thread safety handling ##### - This function may assume that self.connection is not None. + def validate_thread_sharing(self): """ - raise NotImplementedError + Validates that the connection isn't accessed by another thread than the + one which originally created it, unless the connection was explicitly + authorized to be shared between threads (via the `allow_thread_sharing` + property). Raises an exception if the validation fails. + """ + if not (self.allow_thread_sharing + or self._thread_ident == thread.get_ident()): + raise DatabaseError("DatabaseWrapper objects created in a " + "thread can only be used in that same thread. The object " + "with alias '%s' was created in thread id %s and this is " + "thread id %s." + % (self.alias, self._thread_ident, thread.get_ident())) - def cursor(self): - self.validate_thread_sharing() - if (self.use_debug_cursor or - (self.use_debug_cursor is None and settings.DEBUG)): - cursor = self.make_debug_cursor(self._cursor()) - else: - cursor = util.CursorWrapper(self._cursor(), self) - return cursor + ##### Miscellaneous ##### + + def wrap_database_errors(self): + """ + Context manager and decorator that re-throws backend-specific database + exceptions using Django's common wrappers. + """ + return DatabaseErrorWrapper(self) def make_debug_cursor(self, cursor): + """ + Creates a cursor that logs all queries in self.queries. + """ return util.CursorDebugWrapper(cursor, self) @contextmanager def temporary_connection(self): - # Ensure a connection is established, and avoid leaving a dangling - # connection, for operations outside of the request-response cycle. + """ + Context manager that ensures that a connection is established, and + if it opened one, closes it to avoid leaving a dangling connection. + This is useful for operations outside of the request-response cycle. + """ must_close = self.connection is None cursor = self.cursor() try: @@ -406,6 +460,7 @@ class BaseDatabaseWrapper(object): if must_close: self.close() + class BaseDatabaseFeatures(object): allows_group_by_pk = False # True if django.db.backend.utils.typecast_timestamp is used on values diff --git a/django/db/backends/dummy/base.py b/django/db/backends/dummy/base.py index b648aae9c9c..720b0176d63 100644 --- a/django/db/backends/dummy/base.py +++ b/django/db/backends/dummy/base.py @@ -48,19 +48,19 @@ class DatabaseWrapper(BaseDatabaseWrapper): # implementations. Anything that tries to actually # do something raises complain; anything that tries # to rollback or undo something raises ignore. + _cursor = complain _commit = complain _rollback = ignore - enter_transaction_management = complain - leave_transaction_management = ignore + _close = ignore + _savepoint = ignore + _savepoint_commit = complain + _savepoint_rollback = ignore + _enter_transaction_management = complain + _leave_transaction_management = ignore set_dirty = complain set_clean = complain commit_unless_managed = complain rollback_unless_managed = ignore - savepoint = ignore - savepoint_commit = complain - savepoint_rollback = ignore - close = ignore - cursor = complain def __init__(self, *args, **kwargs): super(DatabaseWrapper, self).__init__(*args, **kwargs) diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 6b2ecaead19..400fe6cdac0 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -439,29 +439,12 @@ class DatabaseWrapper(BaseDatabaseWrapper): cursor = self.connection.cursor() return CursorWrapper(cursor) - def is_usable(self): - try: - self.connection.ping() - except DatabaseError: - return False - else: - return True - def _rollback(self): try: BaseDatabaseWrapper._rollback(self) except Database.NotSupportedError: pass - @cached_property - def mysql_version(self): - with self.temporary_connection(): - server_info = self.connection.get_server_info() - match = server_version_re.match(server_info) - if not match: - raise Exception('Unable to determine MySQL version from version string %r' % server_info) - return tuple([int(x) for x in match.groups()]) - def disable_constraint_checking(self): """ Disables foreign key checks, primarily for use in adding rows with forward references. Always returns True, @@ -510,3 +493,20 @@ class DatabaseWrapper(BaseDatabaseWrapper): % (table_name, bad_row[0], table_name, column_name, bad_row[1], referenced_table_name, referenced_column_name)) + + def is_usable(self): + try: + self.connection.ping() + except DatabaseError: + return False + else: + return True + + @cached_property + def mysql_version(self): + with self.temporary_connection(): + server_info = self.connection.get_server_info() + match = server_version_re.match(server_info) + if not match: + raise Exception('Unable to determine MySQL version from version string %r' % server_info) + return tuple([int(x) for x in match.groups()]) diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index 478124f5dfb..a56813e28e0 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -515,14 +515,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.introspection = DatabaseIntrospection(self) self.validation = BaseDatabaseValidation(self) - def check_constraints(self, table_names=None): - """ - To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they - are returned to deferred. - """ - self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') - self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') - def _connect_string(self): settings_dict = self.settings_dict if not settings_dict['HOST'].strip(): @@ -536,9 +528,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): return "%s/%s@%s" % (settings_dict['USER'], settings_dict['PASSWORD'], dsn) - def create_cursor(self): - return FormatStylePlaceholderCursor(self.connection) - def get_connection_params(self): conn_params = self.settings_dict['OPTIONS'].copy() if 'use_returning_into' in conn_params: @@ -598,21 +587,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): # stmtcachesize is available only in 4.3.2 and up. pass - def is_usable(self): - try: - if hasattr(self.connection, 'ping'): # Oracle 10g R2 and higher - self.connection.ping() - else: - # Use a cx_Oracle cursor directly, bypassing Django's utilities. - self.connection.cursor().execute("SELECT 1 FROM DUAL") - except DatabaseError: - return False - else: - return True - - # Oracle doesn't support savepoint commits. Ignore them. - def _savepoint_commit(self, sid): - pass + def create_cursor(self): + return FormatStylePlaceholderCursor(self.connection) def _commit(self): if self.connection is not None: @@ -632,6 +608,30 @@ class DatabaseWrapper(BaseDatabaseWrapper): six.reraise(utils.IntegrityError, utils.IntegrityError(*tuple(e.args)), sys.exc_info()[2]) raise + # Oracle doesn't support savepoint commits. Ignore them. + def _savepoint_commit(self, sid): + pass + + def check_constraints(self, table_names=None): + """ + To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they + are returned to deferred. + """ + self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') + self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') + + def is_usable(self): + try: + if hasattr(self.connection, 'ping'): # Oracle 10g R2 and higher + self.connection.ping() + else: + # Use a cx_Oracle cursor directly, bypassing Django's utilities. + self.connection.cursor().execute("SELECT 1 FROM DUAL") + except DatabaseError: + return False + else: + return True + @cached_property def oracle_version(self): with self.temporary_connection(): diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index db4b5ade053..d5b6f136961 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -91,40 +91,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.introspection = DatabaseIntrospection(self) self.validation = BaseDatabaseValidation(self) - def check_constraints(self, table_names=None): - """ - To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they - are returned to deferred. - """ - self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') - self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') - - def close(self): - self.validate_thread_sharing() - if self.connection is None: - return - - try: - self.connection.close() - self.connection = None - except Database.Error: - # In some cases (database restart, network connection lost etc...) - # the connection to the database is lost without giving Django a - # notification. If we don't set self.connection to None, the error - # will occur a every request. - self.connection = None - logger.warning('psycopg2 error while closing the connection.', - exc_info=sys.exc_info() - ) - raise - finally: - self.set_clean() - - @cached_property - def pg_version(self): - with self.temporary_connection(): - return get_version(self.connection) - def get_connection_params(self): settings_dict = self.settings_dict if not settings_dict['NAME']: @@ -177,14 +143,26 @@ class DatabaseWrapper(BaseDatabaseWrapper): cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None return cursor - def is_usable(self): + def close(self): + self.validate_thread_sharing() + if self.connection is None: + return + try: - # Use a psycopg cursor directly, bypassing Django's utilities. - self.connection.cursor().execute("SELECT 1") - except DatabaseError: - return False - else: - return True + self.connection.close() + self.connection = None + except Database.Error: + # In some cases (database restart, network connection lost etc...) + # the connection to the database is lost without giving Django a + # notification. If we don't set self.connection to None, the error + # will occur a every request. + self.connection = None + logger.warning('psycopg2 error while closing the connection.', + exc_info=sys.exc_info() + ) + raise + finally: + self.set_clean() def _enter_transaction_management(self, managed): """ @@ -222,3 +200,25 @@ class DatabaseWrapper(BaseDatabaseWrapper): if ((self.transaction_state and self.transaction_state[-1]) or not self.features.uses_autocommit): super(DatabaseWrapper, self).set_dirty() + + def check_constraints(self, table_names=None): + """ + To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they + are returned to deferred. + """ + self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') + self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') + + def is_usable(self): + try: + # Use a psycopg cursor directly, bypassing Django's utilities. + self.connection.cursor().execute("SELECT 1") + except DatabaseError: + return False + else: + return True + + @cached_property + def pg_version(self): + with self.temporary_connection(): + return get_version(self.connection) diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index ad54af46add..416a6293f55 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -347,8 +347,13 @@ class DatabaseWrapper(BaseDatabaseWrapper): def create_cursor(self): return self.connection.cursor(factory=SQLiteCursorWrapper) - def is_usable(self): - return True + def close(self): + self.validate_thread_sharing() + # If database is in memory, closing the connection destroys the + # database. To prevent accidental data loss, ignore close requests on + # an in-memory db. + if self.settings_dict['NAME'] != ":memory:": + BaseDatabaseWrapper.close(self) def check_constraints(self, table_names=None): """ @@ -384,13 +389,9 @@ class DatabaseWrapper(BaseDatabaseWrapper): % (table_name, bad_row[0], table_name, column_name, bad_row[1], referenced_table_name, referenced_column_name)) - def close(self): - self.validate_thread_sharing() - # If database is in memory, closing the connection destroys the - # database. To prevent accidental data loss, ignore close requests on - # an in-memory db. - if self.settings_dict['NAME'] != ":memory:": - BaseDatabaseWrapper.close(self) + def is_usable(self): + return True + FORMAT_QMARK_REGEX = re.compile(r'(?