Reordered methods in database wrappers.

* Grouped related methods together -- with banner comments :/
* Described which methods are intended to be implemented in backends.
* Added docstrings.
* Used the same order in all wrappers.
This commit is contained in:
Aymeric Augustin 2013-03-02 12:12:51 +01:00
parent c5a25c2771
commit d63e55039d
6 changed files with 273 additions and 217 deletions

View File

@ -40,20 +40,24 @@ class BaseDatabaseWrapper(object):
self.alias = alias
self.use_debug_cursor = None
# Transaction related attributes
self.transaction_state = []
# Savepoint management related attributes
self.savepoint_state = 0
# Transaction management related attributes
self.transaction_state = []
# Tracks if the connection is believed to be in transaction. This is
# set somewhat aggressively, as the DBAPI doesn't make it easy to
# deduce if the connection is in transaction or not.
self._dirty = False
self._thread_ident = thread.get_ident()
self.allow_thread_sharing = allow_thread_sharing
# Connection termination related attributes
self.close_at = None
self.errors_occurred = False
# Thread-safety related attributes
self.allow_thread_sharing = allow_thread_sharing
self._thread_ident = thread.get_ident()
def __eq__(self, other):
return self.alias == other.alias
@ -63,21 +67,26 @@ class BaseDatabaseWrapper(object):
def __hash__(self):
return hash(self.alias)
def wrap_database_errors(self):
return DatabaseErrorWrapper(self)
##### Backend-specific methods for creating connections and cursors #####
def get_connection_params(self):
"""Returns a dict of parameters suitable for get_new_connection."""
raise NotImplementedError
def get_new_connection(self, conn_params):
"""Opens a connection to the database."""
raise NotImplementedError
def init_connection_state(self):
"""Initializes the database connection settings."""
raise NotImplementedError
def create_cursor(self):
"""Creates a cursor. Assumes that a connection is established."""
raise NotImplementedError
##### Backend-specific wrappers for PEP-249 connection methods #####
def _cursor(self):
with self.wrap_database_errors():
if self.connection is None:
@ -107,20 +116,48 @@ class BaseDatabaseWrapper(object):
with self.wrap_database_errors():
return self.connection.close()
def _enter_transaction_management(self, managed):
"""
A hook for backend-specific changes required when entering manual
transaction handling.
"""
pass
##### Generic wrappers for PEP-249 connection methods #####
def _leave_transaction_management(self, managed):
def cursor(self):
"""
A hook for backend-specific changes required when leaving manual
transaction handling. Will usually be implemented only when
_enter_transaction_management() is also required.
Creates a cursor, opening a connection if necessary.
"""
pass
self.validate_thread_sharing()
if (self.use_debug_cursor or
(self.use_debug_cursor is None and settings.DEBUG)):
cursor = self.make_debug_cursor(self._cursor())
else:
cursor = util.CursorWrapper(self._cursor(), self)
return cursor
def commit(self):
"""
Does the commit itself and resets the dirty flag.
"""
self.validate_thread_sharing()
self._commit()
self.set_clean()
def rollback(self):
"""
Does the rollback itself and resets the dirty flag.
"""
self.validate_thread_sharing()
self._rollback()
self.set_clean()
def close(self):
"""
Closes the connection to the database.
"""
self.validate_thread_sharing()
try:
self._close()
finally:
self.connection = None
self.set_clean()
##### Backend-specific savepoint management methods #####
def _savepoint(self, sid):
if not self.features.uses_savepoints:
@ -137,15 +174,65 @@ class BaseDatabaseWrapper(object):
return
self.cursor().execute(self.ops.savepoint_commit_sql(sid))
def abort(self):
##### Generic savepoint management methods #####
def savepoint(self):
"""
Roll back any ongoing transaction and clean the transaction state
stack.
Creates a savepoint (if supported and required by the backend) inside the
current transaction. Returns an identifier for the savepoint that will be
used for the subsequent rollback or commit.
"""
if self._dirty:
self.rollback()
while self.transaction_state:
self.leave_transaction_management()
thread_ident = thread.get_ident()
self.savepoint_state += 1
tid = str(thread_ident).replace('-', '')
sid = "s%s_x%d" % (tid, self.savepoint_state)
self._savepoint(sid)
return sid
def savepoint_rollback(self, sid):
"""
Rolls back the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_rollback(sid)
def savepoint_commit(self, sid):
"""
Commits the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_commit(sid)
def clean_savepoints(self):
"""
Resets the counter used to generate unique savepoint ids in this thread.
"""
self.savepoint_state = 0
##### Backend-specific transaction management methods #####
def _enter_transaction_management(self, managed):
"""
A hook for backend-specific changes required when entering manual
transaction handling.
"""
pass
def _leave_transaction_management(self, managed):
"""
A hook for backend-specific changes required when leaving manual
transaction handling. Will usually be implemented only when
_enter_transaction_management() is also required.
"""
pass
##### Generic transaction management methods #####
def enter_transaction_management(self, managed=True):
"""
@ -185,20 +272,15 @@ class BaseDatabaseWrapper(object):
raise TransactionManagementError(
"Transaction managed block ended with pending COMMIT/ROLLBACK")
def validate_thread_sharing(self):
def abort(self):
"""
Validates that the connection isn't accessed by another thread than the
one which originally created it, unless the connection was explicitly
authorized to be shared between threads (via the `allow_thread_sharing`
property). Raises an exception if the validation fails.
Roll back any ongoing transaction and clean the transaction state
stack.
"""
if (not self.allow_thread_sharing
and self._thread_ident != thread.get_ident()):
raise DatabaseError("DatabaseWrapper objects created in a "
"thread can only be used in that same thread. The object "
"with alias '%s' was created in thread id %s and this is "
"thread id %s."
% (self.alias, self._thread_ident, thread.get_ident()))
if self._dirty:
self.rollback()
while self.transaction_state:
self.leave_transaction_management()
def is_dirty(self):
"""
@ -224,12 +306,6 @@ class BaseDatabaseWrapper(object):
self._dirty = False
self.clean_savepoints()
def clean_savepoints(self):
"""
Resets the counter used to generate unique savepoint ids in this thread.
"""
self.savepoint_state = 0
def is_managed(self):
"""
Checks whether the transaction manager is in manual or in auto state.
@ -275,57 +351,13 @@ class BaseDatabaseWrapper(object):
else:
self.set_dirty()
def commit(self):
"""
Does the commit itself and resets the dirty flag.
"""
self.validate_thread_sharing()
self._commit()
self.set_clean()
def rollback(self):
"""
This function does the rollback itself and resets the dirty flag.
"""
self.validate_thread_sharing()
self._rollback()
self.set_clean()
def savepoint(self):
"""
Creates a savepoint (if supported and required by the backend) inside the
current transaction. Returns an identifier for the savepoint that will be
used for the subsequent rollback or commit.
"""
thread_ident = thread.get_ident()
self.savepoint_state += 1
tid = str(thread_ident).replace('-', '')
sid = "s%s_x%d" % (tid, self.savepoint_state)
self._savepoint(sid)
return sid
def savepoint_rollback(self, sid):
"""
Rolls back the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_rollback(sid)
def savepoint_commit(self, sid):
"""
Commits the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_commit(sid)
##### Foreign key constraints checks handling #####
@contextmanager
def constraint_checks_disabled(self):
"""
Context manager that disables foreign key constraint checking.
"""
disabled = self.disable_constraint_checking()
try:
yield
@ -335,33 +367,40 @@ class BaseDatabaseWrapper(object):
def disable_constraint_checking(self):
"""
Backends can implement as needed to temporarily disable foreign key constraint
checking.
Backends can implement as needed to temporarily disable foreign key
constraint checking.
"""
pass
def enable_constraint_checking(self):
"""
Backends can implement as needed to re-enable foreign key constraint checking.
Backends can implement as needed to re-enable foreign key constraint
checking.
"""
pass
def check_constraints(self, table_names=None):
"""
Backends can override this method if they can apply constraint checking (e.g. via "SET CONSTRAINTS
ALL IMMEDIATE"). Should raise an IntegrityError if any invalid foreign key references are encountered.
Backends can override this method if they can apply constraint
checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
IntegrityError if any invalid foreign key references are encountered.
"""
pass
def close(self):
self.validate_thread_sharing()
try:
self._close()
finally:
self.connection = None
self.set_clean()
##### Connection termination handling #####
def is_usable(self):
"""
Tests if the database connection is usable.
This function may assume that self.connection is not None.
"""
raise NotImplementedError
def close_if_unusable_or_obsolete(self):
"""
Closes the current connection if unrecoverable errors have occurred,
or if it outlived its maximum age.
"""
if self.connection is not None:
if self.errors_occurred:
if self.is_usable():
@ -373,30 +412,45 @@ class BaseDatabaseWrapper(object):
self.close()
return
def is_usable(self):
"""
Test if the database connection is usable.
##### Thread safety handling #####
This function may assume that self.connection is not None.
def validate_thread_sharing(self):
"""
raise NotImplementedError
Validates that the connection isn't accessed by another thread than the
one which originally created it, unless the connection was explicitly
authorized to be shared between threads (via the `allow_thread_sharing`
property). Raises an exception if the validation fails.
"""
if not (self.allow_thread_sharing
or self._thread_ident == thread.get_ident()):
raise DatabaseError("DatabaseWrapper objects created in a "
"thread can only be used in that same thread. The object "
"with alias '%s' was created in thread id %s and this is "
"thread id %s."
% (self.alias, self._thread_ident, thread.get_ident()))
def cursor(self):
self.validate_thread_sharing()
if (self.use_debug_cursor or
(self.use_debug_cursor is None and settings.DEBUG)):
cursor = self.make_debug_cursor(self._cursor())
else:
cursor = util.CursorWrapper(self._cursor(), self)
return cursor
##### Miscellaneous #####
def wrap_database_errors(self):
"""
Context manager and decorator that re-throws backend-specific database
exceptions using Django's common wrappers.
"""
return DatabaseErrorWrapper(self)
def make_debug_cursor(self, cursor):
"""
Creates a cursor that logs all queries in self.queries.
"""
return util.CursorDebugWrapper(cursor, self)
@contextmanager
def temporary_connection(self):
# Ensure a connection is established, and avoid leaving a dangling
# connection, for operations outside of the request-response cycle.
"""
Context manager that ensures that a connection is established, and
if it opened one, closes it to avoid leaving a dangling connection.
This is useful for operations outside of the request-response cycle.
"""
must_close = self.connection is None
cursor = self.cursor()
try:
@ -406,6 +460,7 @@ class BaseDatabaseWrapper(object):
if must_close:
self.close()
class BaseDatabaseFeatures(object):
allows_group_by_pk = False
# True if django.db.backend.utils.typecast_timestamp is used on values

View File

@ -48,19 +48,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# implementations. Anything that tries to actually
# do something raises complain; anything that tries
# to rollback or undo something raises ignore.
_cursor = complain
_commit = complain
_rollback = ignore
enter_transaction_management = complain
leave_transaction_management = ignore
_close = ignore
_savepoint = ignore
_savepoint_commit = complain
_savepoint_rollback = ignore
_enter_transaction_management = complain
_leave_transaction_management = ignore
set_dirty = complain
set_clean = complain
commit_unless_managed = complain
rollback_unless_managed = ignore
savepoint = ignore
savepoint_commit = complain
savepoint_rollback = ignore
close = ignore
cursor = complain
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)

View File

@ -439,29 +439,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
cursor = self.connection.cursor()
return CursorWrapper(cursor)
def is_usable(self):
try:
self.connection.ping()
except DatabaseError:
return False
else:
return True
def _rollback(self):
try:
BaseDatabaseWrapper._rollback(self)
except Database.NotSupportedError:
pass
@cached_property
def mysql_version(self):
with self.temporary_connection():
server_info = self.connection.get_server_info()
match = server_version_re.match(server_info)
if not match:
raise Exception('Unable to determine MySQL version from version string %r' % server_info)
return tuple([int(x) for x in match.groups()])
def disable_constraint_checking(self):
"""
Disables foreign key checks, primarily for use in adding rows with forward references. Always returns True,
@ -510,3 +493,20 @@ class DatabaseWrapper(BaseDatabaseWrapper):
% (table_name, bad_row[0],
table_name, column_name, bad_row[1],
referenced_table_name, referenced_column_name))
def is_usable(self):
try:
self.connection.ping()
except DatabaseError:
return False
else:
return True
@cached_property
def mysql_version(self):
with self.temporary_connection():
server_info = self.connection.get_server_info()
match = server_version_re.match(server_info)
if not match:
raise Exception('Unable to determine MySQL version from version string %r' % server_info)
return tuple([int(x) for x in match.groups()])

View File

@ -515,14 +515,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation(self)
def check_constraints(self, table_names=None):
"""
To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
are returned to deferred.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def _connect_string(self):
settings_dict = self.settings_dict
if not settings_dict['HOST'].strip():
@ -536,9 +528,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return "%s/%s@%s" % (settings_dict['USER'],
settings_dict['PASSWORD'], dsn)
def create_cursor(self):
return FormatStylePlaceholderCursor(self.connection)
def get_connection_params(self):
conn_params = self.settings_dict['OPTIONS'].copy()
if 'use_returning_into' in conn_params:
@ -598,21 +587,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# stmtcachesize is available only in 4.3.2 and up.
pass
def is_usable(self):
try:
if hasattr(self.connection, 'ping'): # Oracle 10g R2 and higher
self.connection.ping()
else:
# Use a cx_Oracle cursor directly, bypassing Django's utilities.
self.connection.cursor().execute("SELECT 1 FROM DUAL")
except DatabaseError:
return False
else:
return True
# Oracle doesn't support savepoint commits. Ignore them.
def _savepoint_commit(self, sid):
pass
def create_cursor(self):
return FormatStylePlaceholderCursor(self.connection)
def _commit(self):
if self.connection is not None:
@ -632,6 +608,30 @@ class DatabaseWrapper(BaseDatabaseWrapper):
six.reraise(utils.IntegrityError, utils.IntegrityError(*tuple(e.args)), sys.exc_info()[2])
raise
# Oracle doesn't support savepoint commits. Ignore them.
def _savepoint_commit(self, sid):
pass
def check_constraints(self, table_names=None):
"""
To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
are returned to deferred.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
if hasattr(self.connection, 'ping'): # Oracle 10g R2 and higher
self.connection.ping()
else:
# Use a cx_Oracle cursor directly, bypassing Django's utilities.
self.connection.cursor().execute("SELECT 1 FROM DUAL")
except DatabaseError:
return False
else:
return True
@cached_property
def oracle_version(self):
with self.temporary_connection():

View File

@ -91,40 +91,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation(self)
def check_constraints(self, table_names=None):
"""
To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
are returned to deferred.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def close(self):
self.validate_thread_sharing()
if self.connection is None:
return
try:
self.connection.close()
self.connection = None
except Database.Error:
# In some cases (database restart, network connection lost etc...)
# the connection to the database is lost without giving Django a
# notification. If we don't set self.connection to None, the error
# will occur a every request.
self.connection = None
logger.warning('psycopg2 error while closing the connection.',
exc_info=sys.exc_info()
)
raise
finally:
self.set_clean()
@cached_property
def pg_version(self):
with self.temporary_connection():
return get_version(self.connection)
def get_connection_params(self):
settings_dict = self.settings_dict
if not settings_dict['NAME']:
@ -177,14 +143,26 @@ class DatabaseWrapper(BaseDatabaseWrapper):
cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None
return cursor
def is_usable(self):
def close(self):
self.validate_thread_sharing()
if self.connection is None:
return
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
self.connection.cursor().execute("SELECT 1")
except DatabaseError:
return False
else:
return True
self.connection.close()
self.connection = None
except Database.Error:
# In some cases (database restart, network connection lost etc...)
# the connection to the database is lost without giving Django a
# notification. If we don't set self.connection to None, the error
# will occur a every request.
self.connection = None
logger.warning('psycopg2 error while closing the connection.',
exc_info=sys.exc_info()
)
raise
finally:
self.set_clean()
def _enter_transaction_management(self, managed):
"""
@ -222,3 +200,25 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if ((self.transaction_state and self.transaction_state[-1]) or
not self.features.uses_autocommit):
super(DatabaseWrapper, self).set_dirty()
def check_constraints(self, table_names=None):
"""
To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
are returned to deferred.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
self.connection.cursor().execute("SELECT 1")
except DatabaseError:
return False
else:
return True
@cached_property
def pg_version(self):
with self.temporary_connection():
return get_version(self.connection)

View File

@ -347,8 +347,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def create_cursor(self):
return self.connection.cursor(factory=SQLiteCursorWrapper)
def is_usable(self):
return True
def close(self):
self.validate_thread_sharing()
# If database is in memory, closing the connection destroys the
# database. To prevent accidental data loss, ignore close requests on
# an in-memory db.
if self.settings_dict['NAME'] != ":memory:":
BaseDatabaseWrapper.close(self)
def check_constraints(self, table_names=None):
"""
@ -384,13 +389,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
% (table_name, bad_row[0], table_name, column_name, bad_row[1],
referenced_table_name, referenced_column_name))
def close(self):
self.validate_thread_sharing()
# If database is in memory, closing the connection destroys the
# database. To prevent accidental data loss, ignore close requests on
# an in-memory db.
if self.settings_dict['NAME'] != ":memory:":
BaseDatabaseWrapper.close(self)
def is_usable(self):
return True
FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s')