mirror of https://github.com/django/django.git
Fixed #19274 -- Made db connection creation overridable in subclasses
Connection creation was done in db backend ._cursor() call. This included taking a new connection if needed, initializing the session state for the new connection and finally creating the connection. To allow easier modifying of these steps in subclasses (for example to support connection pools) the _cursor() now calls get_new_connection() and init_connection_state() if there isn't an existing connection. This was done for all non-gis core backends. In addition the parameters used for taking a connection are now created by get_connection_params(). We should also do the same for gis backends and encourage 3rd party backends to use the same pattern. The pattern is not enforced in code, and as the backends are private API this will not be required by documentation either.
This commit is contained in:
parent
2ea80b94d7
commit
1893467784
|
@ -372,14 +372,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
self.connection.ping()
|
self.connection.ping()
|
||||||
return True
|
return True
|
||||||
except DatabaseError:
|
except DatabaseError:
|
||||||
self.connection.close()
|
self.close()
|
||||||
self.connection = None
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _cursor(self):
|
def get_connection_params(self):
|
||||||
new_connection = False
|
|
||||||
if not self._valid_connection():
|
|
||||||
new_connection = True
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'conv': django_conversions,
|
'conv': django_conversions,
|
||||||
'charset': 'utf8',
|
'charset': 'utf8',
|
||||||
|
@ -402,17 +398,30 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
# "UPDATE", not the number of changed rows.
|
# "UPDATE", not the number of changed rows.
|
||||||
kwargs['client_flag'] = CLIENT.FOUND_ROWS
|
kwargs['client_flag'] = CLIENT.FOUND_ROWS
|
||||||
kwargs.update(settings_dict['OPTIONS'])
|
kwargs.update(settings_dict['OPTIONS'])
|
||||||
self.connection = Database.connect(**kwargs)
|
return kwargs
|
||||||
self.connection.encoders[SafeText] = self.connection.encoders[six.text_type]
|
|
||||||
self.connection.encoders[SafeBytes] = self.connection.encoders[bytes]
|
def get_new_connection(self, conn_params):
|
||||||
connection_created.send(sender=self.__class__, connection=self)
|
conn = Database.connect(**conn_params)
|
||||||
|
conn.encoders[SafeText] = conn.encoders[six.text_type]
|
||||||
|
conn.encoders[SafeBytes] = conn.encoders[bytes]
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def init_connection_state(self):
|
||||||
cursor = self.connection.cursor()
|
cursor = self.connection.cursor()
|
||||||
if new_connection:
|
|
||||||
# SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
|
# SQL_AUTO_IS_NULL in MySQL controls whether an AUTO_INCREMENT column
|
||||||
# on a recently-inserted row will return when the field is tested for
|
# on a recently-inserted row will return when the field is tested for
|
||||||
# NULL. Disabling this value brings this aspect of MySQL in line with
|
# NULL. Disabling this value brings this aspect of MySQL in line with
|
||||||
# SQL standards.
|
# SQL standards.
|
||||||
cursor.execute('SET SQL_AUTO_IS_NULL = 0')
|
cursor.execute('SET SQL_AUTO_IS_NULL = 0')
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
def _cursor(self):
|
||||||
|
if not self._valid_connection():
|
||||||
|
conn_params = self.get_connection_params()
|
||||||
|
self.connection = self.get_new_connection(conn_params)
|
||||||
|
self.init_connection_state()
|
||||||
|
connection_created.send(sender=self.__class__, connection=self)
|
||||||
|
cursor = self.connection.cursor()
|
||||||
return CursorWrapper(cursor)
|
return CursorWrapper(cursor)
|
||||||
|
|
||||||
def _rollback(self):
|
def _rollback(self):
|
||||||
|
@ -433,8 +442,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
server_info = self.connection.get_server_info()
|
server_info = self.connection.get_server_info()
|
||||||
if new_connection:
|
if new_connection:
|
||||||
# Make sure we close the connection
|
# Make sure we close the connection
|
||||||
self.connection.close()
|
self.close()
|
||||||
self.connection = None
|
|
||||||
m = server_version_re.match(server_info)
|
m = server_version_re.match(server_info)
|
||||||
if not m:
|
if not m:
|
||||||
raise Exception('Unable to determine MySQL version from version string %r' % server_info)
|
raise Exception('Unable to determine MySQL version from version string %r' % server_info)
|
||||||
|
|
|
@ -489,15 +489,21 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
return "%s/%s@%s" % (settings_dict['USER'],
|
return "%s/%s@%s" % (settings_dict['USER'],
|
||||||
settings_dict['PASSWORD'], dsn)
|
settings_dict['PASSWORD'], dsn)
|
||||||
|
|
||||||
def _cursor(self):
|
def create_cursor(self, conn):
|
||||||
cursor = None
|
return FormatStylePlaceholderCursor(conn)
|
||||||
if not self._valid_connection():
|
|
||||||
conn_string = convert_unicode(self._connect_string())
|
def get_connection_params(self):
|
||||||
conn_params = self.settings_dict['OPTIONS'].copy()
|
conn_params = self.settings_dict['OPTIONS'].copy()
|
||||||
if 'use_returning_into' in conn_params:
|
if 'use_returning_into' in conn_params:
|
||||||
del conn_params['use_returning_into']
|
del conn_params['use_returning_into']
|
||||||
self.connection = Database.connect(conn_string, **conn_params)
|
return conn_params
|
||||||
cursor = FormatStylePlaceholderCursor(self.connection)
|
|
||||||
|
def get_new_connection(self, conn_params):
|
||||||
|
conn_string = convert_unicode(self._connect_string())
|
||||||
|
return Database.connect(conn_string, **conn_params)
|
||||||
|
|
||||||
|
def init_connection_state(self):
|
||||||
|
cursor = self.create_cursor(self.connection)
|
||||||
# Set the territory first. The territory overrides NLS_DATE_FORMAT
|
# Set the territory first. The territory overrides NLS_DATE_FORMAT
|
||||||
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of
|
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of
|
||||||
# these are set in single statement it isn't clear what is supposed
|
# these are set in single statement it isn't clear what is supposed
|
||||||
|
@ -511,13 +517,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
"ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'"
|
"ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'"
|
||||||
" NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'"
|
" NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'"
|
||||||
+ (" TIME_ZONE = 'UTC'" if settings.USE_TZ else ''))
|
+ (" TIME_ZONE = 'UTC'" if settings.USE_TZ else ''))
|
||||||
|
cursor.close()
|
||||||
if 'operators' not in self.__dict__:
|
if 'operators' not in self.__dict__:
|
||||||
# Ticket #14149: Check whether our LIKE implementation will
|
# Ticket #14149: Check whether our LIKE implementation will
|
||||||
# work for this connection or we need to fall back on LIKEC.
|
# work for this connection or we need to fall back on LIKEC.
|
||||||
# This check is performed only once per DatabaseWrapper
|
# This check is performed only once per DatabaseWrapper
|
||||||
# instance per thread, since subsequent connections will use
|
# instance per thread, since subsequent connections will use
|
||||||
# the same settings.
|
# the same settings.
|
||||||
|
cursor = self.create_cursor(self.connection)
|
||||||
try:
|
try:
|
||||||
cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
|
cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
|
||||||
% self._standard_operators['contains'],
|
% self._standard_operators['contains'],
|
||||||
|
@ -526,6 +533,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
self.operators = self._likec_operators
|
self.operators = self._likec_operators
|
||||||
else:
|
else:
|
||||||
self.operators = self._standard_operators
|
self.operators = self._standard_operators
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.oracle_version = int(self.connection.version.split('.')[0])
|
self.oracle_version = int(self.connection.version.split('.')[0])
|
||||||
|
@ -545,10 +553,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
# Django docs specify cx_Oracle version 4.3.1 or higher, but
|
# Django docs specify cx_Oracle version 4.3.1 or higher, but
|
||||||
# stmtcachesize is available only in 4.3.2 and up.
|
# stmtcachesize is available only in 4.3.2 and up.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _cursor(self):
|
||||||
|
if not self._valid_connection():
|
||||||
|
conn_params = self.get_connection_params()
|
||||||
|
self.connection = self.get_new_connection(conn_params)
|
||||||
|
self.init_connection_state()
|
||||||
connection_created.send(sender=self.__class__, connection=self)
|
connection_created.send(sender=self.__class__, connection=self)
|
||||||
if not cursor:
|
return self.create_cursor(self.connection)
|
||||||
cursor = FormatStylePlaceholderCursor(self.connection)
|
|
||||||
return cursor
|
|
||||||
|
|
||||||
# Oracle doesn't support savepoint commits. Ignore them.
|
# Oracle doesn't support savepoint commits. Ignore them.
|
||||||
def _savepoint_commit(self, sid):
|
def _savepoint_commit(self, sid):
|
||||||
|
|
|
@ -156,9 +156,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
return self._pg_version
|
return self._pg_version
|
||||||
pg_version = property(_get_pg_version)
|
pg_version = property(_get_pg_version)
|
||||||
|
|
||||||
def _cursor(self):
|
def get_connection_params(self):
|
||||||
settings_dict = self.settings_dict
|
settings_dict = self.settings_dict
|
||||||
if self.connection is None:
|
|
||||||
if not settings_dict['NAME']:
|
if not settings_dict['NAME']:
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
raise ImproperlyConfigured(
|
raise ImproperlyConfigured(
|
||||||
|
@ -178,7 +177,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
conn_params['host'] = settings_dict['HOST']
|
conn_params['host'] = settings_dict['HOST']
|
||||||
if settings_dict['PORT']:
|
if settings_dict['PORT']:
|
||||||
conn_params['port'] = settings_dict['PORT']
|
conn_params['port'] = settings_dict['PORT']
|
||||||
self.connection = Database.connect(**conn_params)
|
return conn_params
|
||||||
|
|
||||||
|
def get_new_connection(self, conn_params):
|
||||||
|
return Database.connect(**conn_params)
|
||||||
|
|
||||||
|
def init_connection_state(self):
|
||||||
|
settings_dict = self.settings_dict
|
||||||
self.connection.set_client_encoding('UTF8')
|
self.connection.set_client_encoding('UTF8')
|
||||||
tz = 'UTC' if settings.USE_TZ else settings_dict.get('TIME_ZONE')
|
tz = 'UTC' if settings.USE_TZ else settings_dict.get('TIME_ZONE')
|
||||||
if tz:
|
if tz:
|
||||||
|
@ -198,6 +203,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
self.ops.set_time_zone_sql(), [tz])
|
self.ops.set_time_zone_sql(), [tz])
|
||||||
self.connection.set_isolation_level(self.isolation_level)
|
self.connection.set_isolation_level(self.isolation_level)
|
||||||
self._get_pg_version()
|
self._get_pg_version()
|
||||||
|
|
||||||
|
def _cursor(self):
|
||||||
|
if self.connection is None:
|
||||||
|
conn_params = self.get_connection_params()
|
||||||
|
self.connection = self.get_new_connection(conn_params)
|
||||||
|
self.init_connection_state()
|
||||||
connection_created.send(sender=self.__class__, connection=self)
|
connection_created.send(sender=self.__class__, connection=self)
|
||||||
cursor = self.connection.cursor()
|
cursor = self.connection.cursor()
|
||||||
cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None
|
cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None
|
||||||
|
|
|
@ -266,7 +266,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
self.introspection = DatabaseIntrospection(self)
|
self.introspection = DatabaseIntrospection(self)
|
||||||
self.validation = BaseDatabaseValidation(self)
|
self.validation = BaseDatabaseValidation(self)
|
||||||
|
|
||||||
def _sqlite_create_connection(self):
|
def get_connection_params(self):
|
||||||
settings_dict = self.settings_dict
|
settings_dict = self.settings_dict
|
||||||
if not settings_dict['NAME']:
|
if not settings_dict['NAME']:
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
@ -293,12 +293,24 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
RuntimeWarning
|
RuntimeWarning
|
||||||
)
|
)
|
||||||
kwargs.update({'check_same_thread': False})
|
kwargs.update({'check_same_thread': False})
|
||||||
self.connection = Database.connect(**kwargs)
|
return kwargs
|
||||||
|
|
||||||
|
def get_new_connection(self, conn_params):
|
||||||
|
conn = Database.connect(**conn_params)
|
||||||
# Register extract, date_trunc, and regexp functions.
|
# Register extract, date_trunc, and regexp functions.
|
||||||
self.connection.create_function("django_extract", 2, _sqlite_extract)
|
conn.create_function("django_extract", 2, _sqlite_extract)
|
||||||
self.connection.create_function("django_date_trunc", 2, _sqlite_date_trunc)
|
conn.create_function("django_date_trunc", 2, _sqlite_date_trunc)
|
||||||
self.connection.create_function("regexp", 2, _sqlite_regexp)
|
conn.create_function("regexp", 2, _sqlite_regexp)
|
||||||
self.connection.create_function("django_format_dtdelta", 5, _sqlite_format_dtdelta)
|
conn.create_function("django_format_dtdelta", 5, _sqlite_format_dtdelta)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def init_connection_state(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _sqlite_create_connection(self):
|
||||||
|
conn_params = self.get_connection_params()
|
||||||
|
self.connection = self.get_new_connection(conn_params)
|
||||||
|
self.init_connection_state()
|
||||||
connection_created.send(sender=self.__class__, connection=self)
|
connection_created.send(sender=self.__class__, connection=self)
|
||||||
|
|
||||||
def _cursor(self):
|
def _cursor(self):
|
||||||
|
|
Loading…
Reference in New Issue