Fixed #19274 -- Made db connection creation overridable in subclasses

Connection creation was done in db backend ._cursor() call. This
included taking a new connection if needed, initializing the session
state for the new connection and finally creating the connection.

To allow easier modifying of these steps in subclasses (for example to
support connection pools) the _cursor() now calls get_new_connection()
and init_connection_state() if there isn't an existing connection. This
was done for all non-gis core backends. In addition the parameters used
for taking a connection are now created by get_connection_params().

We should also do the same for gis backends and encourage 3rd party
backends to use the same pattern. The pattern is not enforced in code,
and as the backends are private API this will not be required by
documentation either.
This commit is contained in:
Anssi Kääriäinen 2012-11-26 22:42:27 +02:00
parent 2ea80b94d7
commit 1893467784
4 changed files with 183 additions and 140 deletions

View File

@ -372,47 +372,56 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.connection.ping()
return True
except DatabaseError:
self.connection.close()
self.connection = None
self.close()
return False
def get_connection_params(self):
kwargs = {
'conv': django_conversions,
'charset': 'utf8',
'use_unicode': True,
}
settings_dict = self.settings_dict
if settings_dict['USER']:
kwargs['user'] = settings_dict['USER']
if settings_dict['NAME']:
kwargs['db'] = settings_dict['NAME']
if settings_dict['PASSWORD']:
kwargs['passwd'] = force_str(settings_dict['PASSWORD'])
if settings_dict['HOST'].startswith('/'):
kwargs['unix_socket'] = settings_dict['HOST']
elif settings_dict['HOST']:
kwargs['host'] = settings_dict['HOST']
if settings_dict['PORT']:
kwargs['port'] = int(settings_dict['PORT'])
# We need the number of potentially affected rows after an
# "UPDATE", not the number of changed rows.
kwargs['client_flag'] = CLIENT.FOUND_ROWS
kwargs.update(settings_dict['OPTIONS'])
return kwargs
def get_new_connection(self, conn_params):
conn = Database.connect(**conn_params)
conn.encoders[SafeText] = conn.encoders[six.text_type]
conn.encoders[SafeBytes] = conn.encoders[bytes]
return conn
def init_connection_state(self):
cursor = self.connection.cursor()
# SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
# on a recently-inserted row will return when the field is tested for
# NULL. Disabling this value brings this aspect of MySQL in line with
# SQL standards.
cursor.execute('SET SQL_AUTO_IS_NULL = 0')
cursor.close()
def _cursor(self):
new_connection = False
if not self._valid_connection():
new_connection = True
kwargs = {
'conv': django_conversions,
'charset': 'utf8',
'use_unicode': True,
}
settings_dict = self.settings_dict
if settings_dict['USER']:
kwargs['user'] = settings_dict['USER']
if settings_dict['NAME']:
kwargs['db'] = settings_dict['NAME']
if settings_dict['PASSWORD']:
kwargs['passwd'] = force_str(settings_dict['PASSWORD'])
if settings_dict['HOST'].startswith('/'):
kwargs['unix_socket'] = settings_dict['HOST']
elif settings_dict['HOST']:
kwargs['host'] = settings_dict['HOST']
if settings_dict['PORT']:
kwargs['port'] = int(settings_dict['PORT'])
# We need the number of potentially affected rows after an
# "UPDATE", not the number of changed rows.
kwargs['client_flag'] = CLIENT.FOUND_ROWS
kwargs.update(settings_dict['OPTIONS'])
self.connection = Database.connect(**kwargs)
self.connection.encoders[SafeText] = self.connection.encoders[six.text_type]
self.connection.encoders[SafeBytes] = self.connection.encoders[bytes]
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
cursor = self.connection.cursor()
if new_connection:
# SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
# on a recently-inserted row will return when the field is tested for
# NULL. Disabling this value brings this aspect of MySQL in line with
# SQL standards.
cursor.execute('SET SQL_AUTO_IS_NULL = 0')
return CursorWrapper(cursor)
def _rollback(self):
@ -433,8 +442,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
server_info = self.connection.get_server_info()
if new_connection:
# Make sure we close the connection
self.connection.close()
self.connection = None
self.close()
m = server_version_re.match(server_info)
if not m:
raise Exception('Unable to determine MySQL version from version string %r' % server_info)

View File

@ -489,66 +489,78 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return "%s/%s@%s" % (settings_dict['USER'],
settings_dict['PASSWORD'], dsn)
def create_cursor(self, conn):
return FormatStylePlaceholderCursor(conn)
def get_connection_params(self):
conn_params = self.settings_dict['OPTIONS'].copy()
if 'use_returning_into' in conn_params:
del conn_params['use_returning_into']
return conn_params
def get_new_connection(self, conn_params):
conn_string = convert_unicode(self._connect_string())
return Database.connect(conn_string, **conn_params)
def init_connection_state(self):
cursor = self.create_cursor(self.connection)
# Set the territory first. The territory overrides NLS_DATE_FORMAT
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of
# these are set in single statement it isn't clear what is supposed
# to happen.
cursor.execute("ALTER SESSION SET NLS_TERRITORY = 'AMERICA'")
# Set oracle date to ansi date format. This only needs to execute
# once when we create a new connection. We also set the Territory
# to 'AMERICA' which forces Sunday to evaluate to a '1' in
# TO_CHAR().
cursor.execute(
"ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'"
" NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'"
+ (" TIME_ZONE = 'UTC'" if settings.USE_TZ else ''))
cursor.close()
if 'operators' not in self.__dict__:
# Ticket #14149: Check whether our LIKE implementation will
# work for this connection or we need to fall back on LIKEC.
# This check is performed only once per DatabaseWrapper
# instance per thread, since subsequent connections will use
# the same settings.
cursor = self.create_cursor(self.connection)
try:
cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
% self._standard_operators['contains'],
['X'])
except utils.DatabaseError:
self.operators = self._likec_operators
else:
self.operators = self._standard_operators
cursor.close()
try:
self.oracle_version = int(self.connection.version.split('.')[0])
# There's no way for the DatabaseOperations class to know the
# currently active Oracle version, so we do some setups here.
# TODO: Multi-db support will need a better solution (a way to
# communicate the current version).
if self.oracle_version <= 9:
self.ops.regex_lookup = self.ops.regex_lookup_9
else:
self.ops.regex_lookup = self.ops.regex_lookup_10
except ValueError:
pass
try:
self.connection.stmtcachesize = 20
except:
# Django docs specify cx_Oracle version 4.3.1 or higher, but
# stmtcachesize is available only in 4.3.2 and up.
pass
def _cursor(self):
cursor = None
if not self._valid_connection():
conn_string = convert_unicode(self._connect_string())
conn_params = self.settings_dict['OPTIONS'].copy()
if 'use_returning_into' in conn_params:
del conn_params['use_returning_into']
self.connection = Database.connect(conn_string, **conn_params)
cursor = FormatStylePlaceholderCursor(self.connection)
# Set the territory first. The territory overrides NLS_DATE_FORMAT
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of
# these are set in single statement it isn't clear what is supposed
# to happen.
cursor.execute("ALTER SESSION SET NLS_TERRITORY = 'AMERICA'")
# Set oracle date to ansi date format. This only needs to execute
# once when we create a new connection. We also set the Territory
# to 'AMERICA' which forces Sunday to evaluate to a '1' in
# TO_CHAR().
cursor.execute(
"ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'"
" NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'"
+ (" TIME_ZONE = 'UTC'" if settings.USE_TZ else ''))
if 'operators' not in self.__dict__:
# Ticket #14149: Check whether our LIKE implementation will
# work for this connection or we need to fall back on LIKEC.
# This check is performed only once per DatabaseWrapper
# instance per thread, since subsequent connections will use
# the same settings.
try:
cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
% self._standard_operators['contains'],
['X'])
except utils.DatabaseError:
self.operators = self._likec_operators
else:
self.operators = self._standard_operators
try:
self.oracle_version = int(self.connection.version.split('.')[0])
# There's no way for the DatabaseOperations class to know the
# currently active Oracle version, so we do some setups here.
# TODO: Multi-db support will need a better solution (a way to
# communicate the current version).
if self.oracle_version <= 9:
self.ops.regex_lookup = self.ops.regex_lookup_9
else:
self.ops.regex_lookup = self.ops.regex_lookup_10
except ValueError:
pass
try:
self.connection.stmtcachesize = 20
except:
# Django docs specify cx_Oracle version 4.3.1 or higher, but
# stmtcachesize is available only in 4.3.2 and up.
pass
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
if not cursor:
cursor = FormatStylePlaceholderCursor(self.connection)
return cursor
return self.create_cursor(self.connection)
# Oracle doesn't support savepoint commits. Ignore them.
def _savepoint_commit(self, sid):

View File

@ -156,48 +156,59 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return self._pg_version
pg_version = property(_get_pg_version)
def _cursor(self):
def get_connection_params(self):
settings_dict = self.settings_dict
if self.connection is None:
if not settings_dict['NAME']:
from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME value.")
conn_params = {
'database': settings_dict['NAME'],
}
conn_params.update(settings_dict['OPTIONS'])
if 'autocommit' in conn_params:
del conn_params['autocommit']
if settings_dict['USER']:
conn_params['user'] = settings_dict['USER']
if settings_dict['PASSWORD']:
conn_params['password'] = force_str(settings_dict['PASSWORD'])
if settings_dict['HOST']:
conn_params['host'] = settings_dict['HOST']
if settings_dict['PORT']:
conn_params['port'] = settings_dict['PORT']
self.connection = Database.connect(**conn_params)
self.connection.set_client_encoding('UTF8')
tz = 'UTC' if settings.USE_TZ else settings_dict.get('TIME_ZONE')
if tz:
try:
get_parameter_status = self.connection.get_parameter_status
except AttributeError:
# psycopg2 < 2.0.12 doesn't have get_parameter_status
conn_tz = None
else:
conn_tz = get_parameter_status('TimeZone')
if not settings_dict['NAME']:
from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME value.")
conn_params = {
'database': settings_dict['NAME'],
}
conn_params.update(settings_dict['OPTIONS'])
if 'autocommit' in conn_params:
del conn_params['autocommit']
if settings_dict['USER']:
conn_params['user'] = settings_dict['USER']
if settings_dict['PASSWORD']:
conn_params['password'] = force_str(settings_dict['PASSWORD'])
if settings_dict['HOST']:
conn_params['host'] = settings_dict['HOST']
if settings_dict['PORT']:
conn_params['port'] = settings_dict['PORT']
return conn_params
if conn_tz != tz:
# Set the time zone in autocommit mode (see #17062)
self.connection.set_isolation_level(
psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
self.connection.cursor().execute(
self.ops.set_time_zone_sql(), [tz])
self.connection.set_isolation_level(self.isolation_level)
self._get_pg_version()
def get_new_connection(self, conn_params):
return Database.connect(**conn_params)
def init_connection_state(self):
settings_dict = self.settings_dict
self.connection.set_client_encoding('UTF8')
tz = 'UTC' if settings.USE_TZ else settings_dict.get('TIME_ZONE')
if tz:
try:
get_parameter_status = self.connection.get_parameter_status
except AttributeError:
# psycopg2 < 2.0.12 doesn't have get_parameter_status
conn_tz = None
else:
conn_tz = get_parameter_status('TimeZone')
if conn_tz != tz:
# Set the time zone in autocommit mode (see #17062)
self.connection.set_isolation_level(
psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
self.connection.cursor().execute(
self.ops.set_time_zone_sql(), [tz])
self.connection.set_isolation_level(self.isolation_level)
self._get_pg_version()
def _cursor(self):
if self.connection is None:
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
cursor = self.connection.cursor()
cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None

View File

@ -266,7 +266,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation(self)
def _sqlite_create_connection(self):
def get_connection_params(self):
settings_dict = self.settings_dict
if not settings_dict['NAME']:
from django.core.exceptions import ImproperlyConfigured
@ -293,12 +293,24 @@ class DatabaseWrapper(BaseDatabaseWrapper):
RuntimeWarning
)
kwargs.update({'check_same_thread': False})
self.connection = Database.connect(**kwargs)
return kwargs
def get_new_connection(self, conn_params):
conn = Database.connect(**conn_params)
# Register extract, date_trunc, and regexp functions.
self.connection.create_function("django_extract", 2, _sqlite_extract)
self.connection.create_function("django_date_trunc", 2, _sqlite_date_trunc)
self.connection.create_function("regexp", 2, _sqlite_regexp)
self.connection.create_function("django_format_dtdelta", 5, _sqlite_format_dtdelta)
conn.create_function("django_extract", 2, _sqlite_extract)
conn.create_function("django_date_trunc", 2, _sqlite_date_trunc)
conn.create_function("regexp", 2, _sqlite_regexp)
conn.create_function("django_format_dtdelta", 5, _sqlite_format_dtdelta)
return conn
def init_connection_state(self):
pass
def _sqlite_create_connection(self):
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
def _cursor(self):