Ensured a connection is established when checking the database version.

Fixed a test broken by 21765c0a. Refs #18135.
This commit is contained in:
Aymeric Augustin 2013-02-19 10:51:24 +01:00
parent 9a3988ca5a
commit ebabd77291
5 changed files with 20 additions and 5 deletions

View File

@ -352,6 +352,18 @@ class BaseDatabaseWrapper(object):
def make_debug_cursor(self, cursor): def make_debug_cursor(self, cursor):
return util.CursorDebugWrapper(cursor, self) return util.CursorDebugWrapper(cursor, self)
@contextmanager
def temporary_connection(self):
# Ensure a connection is established, and avoid leaving a dangling
# connection, for operations outside of the request-response cycle.
must_close = self.connection is None
cursor = self.cursor()
try:
yield
finally:
cursor.close()
if must_close:
self.close()
class BaseDatabaseFeatures(object): class BaseDatabaseFeatures(object):
allows_group_by_pk = False allows_group_by_pk = False

View File

@ -453,7 +453,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@cached_property @cached_property
def mysql_version(self): def mysql_version(self):
server_info = self.connection.get_server_info() with self.temporary_connection():
server_info = self.connection.get_server_info()
match = server_version_re.match(server_info) match = server_version_re.match(server_info)
if not match: if not match:
raise Exception('Unable to determine MySQL version from version string %r' % server_info) raise Exception('Unable to determine MySQL version from version string %r' % server_info)

View File

@ -623,8 +623,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@cached_property @cached_property
def oracle_version(self): def oracle_version(self):
with self.temporary_connection():
version = self.connection.version
try: try:
return int(self.connection.version.split('.')[0]) return int(version.split('.')[0])
except ValueError: except ValueError:
return None return None

View File

@ -152,7 +152,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@cached_property @cached_property
def pg_version(self): def pg_version(self):
return get_version(self.connection) with self.temporary_connection():
return get_version(self.connection)
def get_connection_params(self): def get_connection_params(self):
settings_dict = self.settings_dict settings_dict = self.settings_dict

View File

@ -195,8 +195,7 @@ class DatabaseOperations(BaseDatabaseOperations):
NotImplementedError if this is the database in use. NotImplementedError if this is the database in use.
""" """
if aggregate.sql_function in ('STDDEV_POP', 'VAR_POP'): if aggregate.sql_function in ('STDDEV_POP', 'VAR_POP'):
pg_version = self.connection.pg_version if 80200 <= self.connection.pg_version <= 80204:
if pg_version >= 80200 and pg_version <= 80204:
raise NotImplementedError('PostgreSQL 8.2 to 8.2.4 is known to have a faulty implementation of %s. Please upgrade your version of PostgreSQL.' % aggregate.sql_function) raise NotImplementedError('PostgreSQL 8.2 to 8.2.4 is known to have a faulty implementation of %s. Please upgrade your version of PostgreSQL.' % aggregate.sql_function)
def max_name_length(self): def max_name_length(self):