Reordered methods in database wrappers.

* Grouped related methods together -- with banner comments :/
* Described which methods are intended to be implemented in backends.
* Added docstrings.
* Used the same order in all wrappers.
This commit is contained in:
Aymeric Augustin 2013-03-02 12:12:51 +01:00
parent c5a25c2771
commit d63e55039d
6 changed files with 273 additions and 217 deletions

View File

@ -40,20 +40,24 @@ class BaseDatabaseWrapper(object):
self.alias = alias self.alias = alias
self.use_debug_cursor = None self.use_debug_cursor = None
# Transaction related attributes # Savepoint management related attributes
self.transaction_state = []
self.savepoint_state = 0 self.savepoint_state = 0
# Transaction management related attributes
self.transaction_state = []
# Tracks if the connection is believed to be in transaction. This is # Tracks if the connection is believed to be in transaction. This is
# set somewhat aggressively, as the DBAPI doesn't make it easy to # set somewhat aggressively, as the DBAPI doesn't make it easy to
# deduce if the connection is in transaction or not. # deduce if the connection is in transaction or not.
self._dirty = False self._dirty = False
self._thread_ident = thread.get_ident()
self.allow_thread_sharing = allow_thread_sharing
# Connection termination related attributes # Connection termination related attributes
self.close_at = None self.close_at = None
self.errors_occurred = False self.errors_occurred = False
# Thread-safety related attributes
self.allow_thread_sharing = allow_thread_sharing
self._thread_ident = thread.get_ident()
def __eq__(self, other): def __eq__(self, other):
return self.alias == other.alias return self.alias == other.alias
@ -63,21 +67,26 @@ class BaseDatabaseWrapper(object):
def __hash__(self): def __hash__(self):
return hash(self.alias) return hash(self.alias)
def wrap_database_errors(self): ##### Backend-specific methods for creating connections and cursors #####
return DatabaseErrorWrapper(self)
def get_connection_params(self): def get_connection_params(self):
"""Returns a dict of parameters suitable for get_new_connection."""
raise NotImplementedError raise NotImplementedError
def get_new_connection(self, conn_params): def get_new_connection(self, conn_params):
"""Opens a connection to the database."""
raise NotImplementedError raise NotImplementedError
def init_connection_state(self): def init_connection_state(self):
"""Initializes the database connection settings."""
raise NotImplementedError raise NotImplementedError
def create_cursor(self): def create_cursor(self):
"""Creates a cursor. Assumes that a connection is established."""
raise NotImplementedError raise NotImplementedError
##### Backend-specific wrappers for PEP-249 connection methods #####
def _cursor(self): def _cursor(self):
with self.wrap_database_errors(): with self.wrap_database_errors():
if self.connection is None: if self.connection is None:
@ -107,20 +116,48 @@ class BaseDatabaseWrapper(object):
with self.wrap_database_errors(): with self.wrap_database_errors():
return self.connection.close() return self.connection.close()
def _enter_transaction_management(self, managed): ##### Generic wrappers for PEP-249 connection methods #####
"""
A hook for backend-specific changes required when entering manual
transaction handling.
"""
pass
def _leave_transaction_management(self, managed): def cursor(self):
""" """
A hook for backend-specific changes required when leaving manual Creates a cursor, opening a connection if necessary.
transaction handling. Will usually be implemented only when
_enter_transaction_management() is also required.
""" """
pass self.validate_thread_sharing()
if (self.use_debug_cursor or
(self.use_debug_cursor is None and settings.DEBUG)):
cursor = self.make_debug_cursor(self._cursor())
else:
cursor = util.CursorWrapper(self._cursor(), self)
return cursor
def commit(self):
"""
Does the commit itself and resets the dirty flag.
"""
self.validate_thread_sharing()
self._commit()
self.set_clean()
def rollback(self):
"""
Does the rollback itself and resets the dirty flag.
"""
self.validate_thread_sharing()
self._rollback()
self.set_clean()
def close(self):
"""
Closes the connection to the database.
"""
self.validate_thread_sharing()
try:
self._close()
finally:
self.connection = None
self.set_clean()
##### Backend-specific savepoint management methods #####
def _savepoint(self, sid): def _savepoint(self, sid):
if not self.features.uses_savepoints: if not self.features.uses_savepoints:
@ -137,15 +174,65 @@ class BaseDatabaseWrapper(object):
return return
self.cursor().execute(self.ops.savepoint_commit_sql(sid)) self.cursor().execute(self.ops.savepoint_commit_sql(sid))
def abort(self): ##### Generic savepoint management methods #####
def savepoint(self):
""" """
Roll back any ongoing transaction and clean the transaction state Creates a savepoint (if supported and required by the backend) inside the
stack. current transaction. Returns an identifier for the savepoint that will be
used for the subsequent rollback or commit.
""" """
if self._dirty: thread_ident = thread.get_ident()
self.rollback()
while self.transaction_state: self.savepoint_state += 1
self.leave_transaction_management()
tid = str(thread_ident).replace('-', '')
sid = "s%s_x%d" % (tid, self.savepoint_state)
self._savepoint(sid)
return sid
def savepoint_rollback(self, sid):
"""
Rolls back the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_rollback(sid)
def savepoint_commit(self, sid):
"""
Commits the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_commit(sid)
def clean_savepoints(self):
"""
Resets the counter used to generate unique savepoint ids in this thread.
"""
self.savepoint_state = 0
##### Backend-specific transaction management methods #####
def _enter_transaction_management(self, managed):
"""
A hook for backend-specific changes required when entering manual
transaction handling.
"""
pass
def _leave_transaction_management(self, managed):
"""
A hook for backend-specific changes required when leaving manual
transaction handling. Will usually be implemented only when
_enter_transaction_management() is also required.
"""
pass
##### Generic transaction management methods #####
def enter_transaction_management(self, managed=True): def enter_transaction_management(self, managed=True):
""" """
@ -185,20 +272,15 @@ class BaseDatabaseWrapper(object):
raise TransactionManagementError( raise TransactionManagementError(
"Transaction managed block ended with pending COMMIT/ROLLBACK") "Transaction managed block ended with pending COMMIT/ROLLBACK")
def validate_thread_sharing(self): def abort(self):
""" """
Validates that the connection isn't accessed by another thread than the Roll back any ongoing transaction and clean the transaction state
one which originally created it, unless the connection was explicitly stack.
authorized to be shared between threads (via the `allow_thread_sharing`
property). Raises an exception if the validation fails.
""" """
if (not self.allow_thread_sharing if self._dirty:
and self._thread_ident != thread.get_ident()): self.rollback()
raise DatabaseError("DatabaseWrapper objects created in a " while self.transaction_state:
"thread can only be used in that same thread. The object " self.leave_transaction_management()
"with alias '%s' was created in thread id %s and this is "
"thread id %s."
% (self.alias, self._thread_ident, thread.get_ident()))
def is_dirty(self): def is_dirty(self):
""" """
@ -224,12 +306,6 @@ class BaseDatabaseWrapper(object):
self._dirty = False self._dirty = False
self.clean_savepoints() self.clean_savepoints()
def clean_savepoints(self):
"""
Resets the counter used to generate unique savepoint ids in this thread.
"""
self.savepoint_state = 0
def is_managed(self): def is_managed(self):
""" """
Checks whether the transaction manager is in manual or in auto state. Checks whether the transaction manager is in manual or in auto state.
@ -275,57 +351,13 @@ class BaseDatabaseWrapper(object):
else: else:
self.set_dirty() self.set_dirty()
def commit(self): ##### Foreign key constraints checks handling #####
"""
Does the commit itself and resets the dirty flag.
"""
self.validate_thread_sharing()
self._commit()
self.set_clean()
def rollback(self):
"""
This function does the rollback itself and resets the dirty flag.
"""
self.validate_thread_sharing()
self._rollback()
self.set_clean()
def savepoint(self):
"""
Creates a savepoint (if supported and required by the backend) inside the
current transaction. Returns an identifier for the savepoint that will be
used for the subsequent rollback or commit.
"""
thread_ident = thread.get_ident()
self.savepoint_state += 1
tid = str(thread_ident).replace('-', '')
sid = "s%s_x%d" % (tid, self.savepoint_state)
self._savepoint(sid)
return sid
def savepoint_rollback(self, sid):
"""
Rolls back the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_rollback(sid)
def savepoint_commit(self, sid):
"""
Commits the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_commit(sid)
@contextmanager @contextmanager
def constraint_checks_disabled(self): def constraint_checks_disabled(self):
"""
Context manager that disables foreign key constraint checking.
"""
disabled = self.disable_constraint_checking() disabled = self.disable_constraint_checking()
try: try:
yield yield
@ -335,33 +367,40 @@ class BaseDatabaseWrapper(object):
def disable_constraint_checking(self): def disable_constraint_checking(self):
""" """
Backends can implement as needed to temporarily disable foreign key constraint Backends can implement as needed to temporarily disable foreign key
checking. constraint checking.
""" """
pass pass
def enable_constraint_checking(self): def enable_constraint_checking(self):
""" """
Backends can implement as needed to re-enable foreign key constraint checking. Backends can implement as needed to re-enable foreign key constraint
checking.
""" """
pass pass
def check_constraints(self, table_names=None): def check_constraints(self, table_names=None):
""" """
Backends can override this method if they can apply constraint checking (e.g. via "SET CONSTRAINTS Backends can override this method if they can apply constraint
ALL IMMEDIATE"). Should raise an IntegrityError if any invalid foreign key references are encountered. checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
IntegrityError if any invalid foreign key references are encountered.
""" """
pass pass
def close(self): ##### Connection termination handling #####
self.validate_thread_sharing()
try: def is_usable(self):
self._close() """
finally: Tests if the database connection is usable.
self.connection = None This function may assume that self.connection is not None.
self.set_clean() """
raise NotImplementedError
def close_if_unusable_or_obsolete(self): def close_if_unusable_or_obsolete(self):
"""
Closes the current connection if unrecoverable errors have occurred,
or if it outlived its maximum age.
"""
if self.connection is not None: if self.connection is not None:
if self.errors_occurred: if self.errors_occurred:
if self.is_usable(): if self.is_usable():
@ -373,30 +412,45 @@ class BaseDatabaseWrapper(object):
self.close() self.close()
return return
def is_usable(self): ##### Thread safety handling #####
"""
Test if the database connection is usable.
This function may assume that self.connection is not None. def validate_thread_sharing(self):
""" """
raise NotImplementedError Validates that the connection isn't accessed by another thread than the
one which originally created it, unless the connection was explicitly
authorized to be shared between threads (via the `allow_thread_sharing`
property). Raises an exception if the validation fails.
"""
if not (self.allow_thread_sharing
or self._thread_ident == thread.get_ident()):
raise DatabaseError("DatabaseWrapper objects created in a "
"thread can only be used in that same thread. The object "
"with alias '%s' was created in thread id %s and this is "
"thread id %s."
% (self.alias, self._thread_ident, thread.get_ident()))
def cursor(self): ##### Miscellaneous #####
self.validate_thread_sharing()
if (self.use_debug_cursor or def wrap_database_errors(self):
(self.use_debug_cursor is None and settings.DEBUG)): """
cursor = self.make_debug_cursor(self._cursor()) Context manager and decorator that re-throws backend-specific database
else: exceptions using Django's common wrappers.
cursor = util.CursorWrapper(self._cursor(), self) """
return cursor return DatabaseErrorWrapper(self)
def make_debug_cursor(self, cursor): def make_debug_cursor(self, cursor):
"""
Creates a cursor that logs all queries in self.queries.
"""
return util.CursorDebugWrapper(cursor, self) return util.CursorDebugWrapper(cursor, self)
@contextmanager @contextmanager
def temporary_connection(self): def temporary_connection(self):
# Ensure a connection is established, and avoid leaving a dangling """
# connection, for operations outside of the request-response cycle. Context manager that ensures that a connection is established, and
if it opened one, closes it to avoid leaving a dangling connection.
This is useful for operations outside of the request-response cycle.
"""
must_close = self.connection is None must_close = self.connection is None
cursor = self.cursor() cursor = self.cursor()
try: try:
@ -406,6 +460,7 @@ class BaseDatabaseWrapper(object):
if must_close: if must_close:
self.close() self.close()
class BaseDatabaseFeatures(object): class BaseDatabaseFeatures(object):
allows_group_by_pk = False allows_group_by_pk = False
# True if django.db.backend.utils.typecast_timestamp is used on values # True if django.db.backend.utils.typecast_timestamp is used on values

View File

@ -48,19 +48,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# implementations. Anything that tries to actually # implementations. Anything that tries to actually
# do something raises complain; anything that tries # do something raises complain; anything that tries
# to rollback or undo something raises ignore. # to rollback or undo something raises ignore.
_cursor = complain
_commit = complain _commit = complain
_rollback = ignore _rollback = ignore
enter_transaction_management = complain _close = ignore
leave_transaction_management = ignore _savepoint = ignore
_savepoint_commit = complain
_savepoint_rollback = ignore
_enter_transaction_management = complain
_leave_transaction_management = ignore
set_dirty = complain set_dirty = complain
set_clean = complain set_clean = complain
commit_unless_managed = complain commit_unless_managed = complain
rollback_unless_managed = ignore rollback_unless_managed = ignore
savepoint = ignore
savepoint_commit = complain
savepoint_rollback = ignore
close = ignore
cursor = complain
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs) super(DatabaseWrapper, self).__init__(*args, **kwargs)

View File

@ -439,29 +439,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
cursor = self.connection.cursor() cursor = self.connection.cursor()
return CursorWrapper(cursor) return CursorWrapper(cursor)
def is_usable(self):
try:
self.connection.ping()
except DatabaseError:
return False
else:
return True
def _rollback(self): def _rollback(self):
try: try:
BaseDatabaseWrapper._rollback(self) BaseDatabaseWrapper._rollback(self)
except Database.NotSupportedError: except Database.NotSupportedError:
pass pass
@cached_property
def mysql_version(self):
with self.temporary_connection():
server_info = self.connection.get_server_info()
match = server_version_re.match(server_info)
if not match:
raise Exception('Unable to determine MySQL version from version string %r' % server_info)
return tuple([int(x) for x in match.groups()])
def disable_constraint_checking(self): def disable_constraint_checking(self):
""" """
Disables foreign key checks, primarily for use in adding rows with forward references. Always returns True, Disables foreign key checks, primarily for use in adding rows with forward references. Always returns True,
@ -510,3 +493,20 @@ class DatabaseWrapper(BaseDatabaseWrapper):
% (table_name, bad_row[0], % (table_name, bad_row[0],
table_name, column_name, bad_row[1], table_name, column_name, bad_row[1],
referenced_table_name, referenced_column_name)) referenced_table_name, referenced_column_name))
def is_usable(self):
try:
self.connection.ping()
except DatabaseError:
return False
else:
return True
@cached_property
def mysql_version(self):
with self.temporary_connection():
server_info = self.connection.get_server_info()
match = server_version_re.match(server_info)
if not match:
raise Exception('Unable to determine MySQL version from version string %r' % server_info)
return tuple([int(x) for x in match.groups()])

View File

@ -515,14 +515,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.introspection = DatabaseIntrospection(self) self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation(self) self.validation = BaseDatabaseValidation(self)
def check_constraints(self, table_names=None):
"""
To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
are returned to deferred.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def _connect_string(self): def _connect_string(self):
settings_dict = self.settings_dict settings_dict = self.settings_dict
if not settings_dict['HOST'].strip(): if not settings_dict['HOST'].strip():
@ -536,9 +528,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return "%s/%s@%s" % (settings_dict['USER'], return "%s/%s@%s" % (settings_dict['USER'],
settings_dict['PASSWORD'], dsn) settings_dict['PASSWORD'], dsn)
def create_cursor(self):
return FormatStylePlaceholderCursor(self.connection)
def get_connection_params(self): def get_connection_params(self):
conn_params = self.settings_dict['OPTIONS'].copy() conn_params = self.settings_dict['OPTIONS'].copy()
if 'use_returning_into' in conn_params: if 'use_returning_into' in conn_params:
@ -598,21 +587,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# stmtcachesize is available only in 4.3.2 and up. # stmtcachesize is available only in 4.3.2 and up.
pass pass
def is_usable(self): def create_cursor(self):
try: return FormatStylePlaceholderCursor(self.connection)
if hasattr(self.connection, 'ping'): # Oracle 10g R2 and higher
self.connection.ping()
else:
# Use a cx_Oracle cursor directly, bypassing Django's utilities.
self.connection.cursor().execute("SELECT 1 FROM DUAL")
except DatabaseError:
return False
else:
return True
# Oracle doesn't support savepoint commits. Ignore them.
def _savepoint_commit(self, sid):
pass
def _commit(self): def _commit(self):
if self.connection is not None: if self.connection is not None:
@ -632,6 +608,30 @@ class DatabaseWrapper(BaseDatabaseWrapper):
six.reraise(utils.IntegrityError, utils.IntegrityError(*tuple(e.args)), sys.exc_info()[2]) six.reraise(utils.IntegrityError, utils.IntegrityError(*tuple(e.args)), sys.exc_info()[2])
raise raise
# Oracle doesn't support savepoint commits. Ignore them.
def _savepoint_commit(self, sid):
pass
def check_constraints(self, table_names=None):
"""
To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
are returned to deferred.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
if hasattr(self.connection, 'ping'): # Oracle 10g R2 and higher
self.connection.ping()
else:
# Use a cx_Oracle cursor directly, bypassing Django's utilities.
self.connection.cursor().execute("SELECT 1 FROM DUAL")
except DatabaseError:
return False
else:
return True
@cached_property @cached_property
def oracle_version(self): def oracle_version(self):
with self.temporary_connection(): with self.temporary_connection():

View File

@ -91,40 +91,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.introspection = DatabaseIntrospection(self) self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation(self) self.validation = BaseDatabaseValidation(self)
def check_constraints(self, table_names=None):
"""
To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
are returned to deferred.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def close(self):
self.validate_thread_sharing()
if self.connection is None:
return
try:
self.connection.close()
self.connection = None
except Database.Error:
# In some cases (database restart, network connection lost etc...)
# the connection to the database is lost without giving Django a
# notification. If we don't set self.connection to None, the error
# will occur a every request.
self.connection = None
logger.warning('psycopg2 error while closing the connection.',
exc_info=sys.exc_info()
)
raise
finally:
self.set_clean()
@cached_property
def pg_version(self):
with self.temporary_connection():
return get_version(self.connection)
def get_connection_params(self): def get_connection_params(self):
settings_dict = self.settings_dict settings_dict = self.settings_dict
if not settings_dict['NAME']: if not settings_dict['NAME']:
@ -177,14 +143,26 @@ class DatabaseWrapper(BaseDatabaseWrapper):
cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None
return cursor return cursor
def is_usable(self): def close(self):
self.validate_thread_sharing()
if self.connection is None:
return
try: try:
# Use a psycopg cursor directly, bypassing Django's utilities. self.connection.close()
self.connection.cursor().execute("SELECT 1") self.connection = None
except DatabaseError: except Database.Error:
return False # In some cases (database restart, network connection lost etc...)
else: # the connection to the database is lost without giving Django a
return True # notification. If we don't set self.connection to None, the error
# will occur a every request.
self.connection = None
logger.warning('psycopg2 error while closing the connection.',
exc_info=sys.exc_info()
)
raise
finally:
self.set_clean()
def _enter_transaction_management(self, managed): def _enter_transaction_management(self, managed):
""" """
@ -222,3 +200,25 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if ((self.transaction_state and self.transaction_state[-1]) or if ((self.transaction_state and self.transaction_state[-1]) or
not self.features.uses_autocommit): not self.features.uses_autocommit):
super(DatabaseWrapper, self).set_dirty() super(DatabaseWrapper, self).set_dirty()
def check_constraints(self, table_names=None):
"""
To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
are returned to deferred.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
self.connection.cursor().execute("SELECT 1")
except DatabaseError:
return False
else:
return True
@cached_property
def pg_version(self):
with self.temporary_connection():
return get_version(self.connection)

View File

@ -347,8 +347,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def create_cursor(self): def create_cursor(self):
return self.connection.cursor(factory=SQLiteCursorWrapper) return self.connection.cursor(factory=SQLiteCursorWrapper)
def is_usable(self): def close(self):
return True self.validate_thread_sharing()
# If database is in memory, closing the connection destroys the
# database. To prevent accidental data loss, ignore close requests on
# an in-memory db.
if self.settings_dict['NAME'] != ":memory:":
BaseDatabaseWrapper.close(self)
def check_constraints(self, table_names=None): def check_constraints(self, table_names=None):
""" """
@ -384,13 +389,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
% (table_name, bad_row[0], table_name, column_name, bad_row[1], % (table_name, bad_row[0], table_name, column_name, bad_row[1],
referenced_table_name, referenced_column_name)) referenced_table_name, referenced_column_name))
def close(self): def is_usable(self):
self.validate_thread_sharing() return True
# If database is in memory, closing the connection destroys the
# database. To prevent accidental data loss, ignore close requests on
# an in-memory db.
if self.settings_dict['NAME'] != ":memory:":
BaseDatabaseWrapper.close(self)
FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s') FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s')