Fixed #10459 -- Refactored the internals of database connection objects so that connections know their own settings and pass around settings as dictionaries instead of passing around the Django settings module itself. This will make it easier for multiple database support. Thanks to Alex Gaynor for the initial patch.

This is backwards-compatible but will likely break third-party database backends. Specific API changes are:

* BaseDatabaseWrapper.__init__() now takes a settings_dict instead of a settings module. It's called settings_dict to disambiguate, and for easy grepability. This should be a dictionary containing DATABASE_NAME, etc.

* BaseDatabaseWrapper has a settings_dict attribute instead of an options attribute. BaseDatabaseWrapper.options is now BaseDatabaseWrapper['DATABASE_OPTIONS']

* BaseDatabaseWrapper._cursor() no longer takes a settings argument.

* BaseDatabaseClient.__init__() now takes a connection argument (a DatabaseWrapper instance) instead of no arguments.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@10026 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Adrian Holovaty 2009-03-11 03:39:34 +00:00
parent 7daf0b9407
commit 315145f7ca
12 changed files with 115 additions and 92 deletions

View File

@ -36,8 +36,22 @@ except ImportError, e:
else: else:
raise # If there's some other error, this must be an error in Django itself. raise # If there's some other error, this must be an error in Django itself.
# Convenient aliases for backend bits. # `connection`, `DatabaseError` and `IntegrityError` are convenient aliases
connection = backend.DatabaseWrapper(**settings.DATABASE_OPTIONS) # for backend bits.
# DatabaseWrapper.__init__() takes a dictionary, not a settings module, so
# we manually create the dictionary from the settings, passing only the
# settings that the database backends care about. Note that TIME_ZONE is used
# by the PostgreSQL backends.
connection = backend.DatabaseWrapper({
'DATABASE_HOST': settings.DATABASE_HOST,
'DATABASE_NAME': settings.DATABASE_NAME,
'DATABASE_OPTIONS': settings.DATABASE_OPTIONS,
'DATABASE_PASSWORD': settings.DATABASE_PASSWORD,
'DATABASE_PORT': settings.DATABASE_PORT,
'DATABASE_USER': settings.DATABASE_USER,
'TIME_ZONE': settings.TIME_ZONE,
})
DatabaseError = backend.DatabaseError DatabaseError = backend.DatabaseError
IntegrityError = backend.IntegrityError IntegrityError = backend.IntegrityError

View File

@ -24,10 +24,14 @@ class BaseDatabaseWrapper(local):
Represents a database connection. Represents a database connection.
""" """
ops = None ops = None
def __init__(self, **kwargs): def __init__(self, settings_dict):
# `settings_dict` should be a dictionary containing keys such as
# DATABASE_NAME, DATABASE_USER, etc. It's called `settings_dict`
# instead of `settings` to disambiguate it from Django settings
# modules.
self.connection = None self.connection = None
self.queries = [] self.queries = []
self.options = kwargs self.settings_dict = settings_dict
def _commit(self): def _commit(self):
if self.connection is not None: if self.connection is not None:
@ -59,7 +63,7 @@ class BaseDatabaseWrapper(local):
def cursor(self): def cursor(self):
from django.conf import settings from django.conf import settings
cursor = self._cursor(settings) cursor = self._cursor()
if settings.DEBUG: if settings.DEBUG:
return self.make_debug_cursor(cursor) return self.make_debug_cursor(cursor)
return cursor return cursor
@ -498,6 +502,10 @@ class BaseDatabaseClient(object):
# (e.g., "psql"). Subclasses must override this. # (e.g., "psql"). Subclasses must override this.
executable_name = None executable_name = None
def __init__(self, connection):
# connection is an instance of BaseDatabaseWrapper.
self.connection = connection
def runshell(self): def runshell(self):
raise NotImplementedError() raise NotImplementedError()

View File

@ -46,7 +46,7 @@ class DatabaseWrapper(object):
self.features = BaseDatabaseFeatures() self.features = BaseDatabaseFeatures()
self.ops = DatabaseOperations() self.ops = DatabaseOperations()
self.client = DatabaseClient() self.client = DatabaseClient(self)
self.creation = BaseDatabaseCreation(self) self.creation = BaseDatabaseCreation(self)
self.introspection = DatabaseIntrospection(self) self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation() self.validation = BaseDatabaseValidation()

View File

@ -234,11 +234,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DatabaseWrapper, self).__init__(**kwargs) super(DatabaseWrapper, self).__init__(**kwargs)
self.server_version = None
self.server_version = None
self.features = DatabaseFeatures() self.features = DatabaseFeatures()
self.ops = DatabaseOperations() self.ops = DatabaseOperations()
self.client = DatabaseClient() self.client = DatabaseClient(self)
self.creation = DatabaseCreation(self) self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self) self.introspection = DatabaseIntrospection(self)
self.validation = DatabaseValidation() self.validation = DatabaseValidation()
@ -253,26 +253,27 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.connection = None self.connection = None
return False return False
def _cursor(self, settings): def _cursor(self):
if not self._valid_connection(): if not self._valid_connection():
kwargs = { kwargs = {
'conv': django_conversions, 'conv': django_conversions,
'charset': 'utf8', 'charset': 'utf8',
'use_unicode': True, 'use_unicode': True,
} }
if settings.DATABASE_USER: settings_dict = self.settings_dict
kwargs['user'] = settings.DATABASE_USER if settings_dict['DATABASE_USER']:
if settings.DATABASE_NAME: kwargs['user'] = settings_dict['DATABASE_USER']
kwargs['db'] = settings.DATABASE_NAME if settings_dict['DATABASE_NAME']:
if settings.DATABASE_PASSWORD: kwargs['db'] = settings_dict['DATABASE_NAME']
kwargs['passwd'] = settings.DATABASE_PASSWORD if settings_dict['DATABASE_PASSWORD']:
if settings.DATABASE_HOST.startswith('/'): kwargs['passwd'] = settings_dict['DATABASE_PASSWORD']
kwargs['unix_socket'] = settings.DATABASE_HOST if settings_dict['DATABASE_HOST'].startswith('/'):
elif settings.DATABASE_HOST: kwargs['unix_socket'] = settings_dict['DATABASE_HOST']
kwargs['host'] = settings.DATABASE_HOST elif settings_dict['DATABASE_HOST']:
if settings.DATABASE_PORT: kwargs['host'] = settings_dict['DATABASE_HOST']
kwargs['port'] = int(settings.DATABASE_PORT) if settings_dict['DATABASE_PORT']:
kwargs.update(self.options) kwargs['port'] = int(settings_dict['DATABASE_PORT'])
kwargs.update(settings_dict['DATABASE_OPTIONS'])
self.connection = Database.connect(**kwargs) self.connection = Database.connect(**kwargs)
self.connection.encoders[SafeUnicode] = self.connection.encoders[unicode] self.connection.encoders[SafeUnicode] = self.connection.encoders[unicode]
self.connection.encoders[SafeString] = self.connection.encoders[str] self.connection.encoders[SafeString] = self.connection.encoders[str]

View File

@ -1,18 +1,18 @@
from django.db.backends import BaseDatabaseClient from django.db.backends import BaseDatabaseClient
from django.conf import settings
import os import os
class DatabaseClient(BaseDatabaseClient): class DatabaseClient(BaseDatabaseClient):
executable_name = 'mysql' executable_name = 'mysql'
def runshell(self): def runshell(self):
settings_dict = self.connection.settings_dict
args = [''] args = ['']
db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME) db = settings_dict['DATABASE_OPTIONS'].get('db', settings_dict['DATABASE_NAME'])
user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER) user = settings_dict['DATABASE_OPTIONS'].get('user', settings_dict['DATABASE_USER'])
passwd = settings.DATABASE_OPTIONS.get('passwd', settings.DATABASE_PASSWORD) passwd = settings_dict['DATABASE_OPTIONS'].get('passwd', settings_dict['DATABASE_PASSWORD'])
host = settings.DATABASE_OPTIONS.get('host', settings.DATABASE_HOST) host = settings_dict['DATABASE_OPTIONS'].get('host', settings_dict['DATABASE_HOST'])
port = settings.DATABASE_OPTIONS.get('port', settings.DATABASE_PORT) port = settings_dict['DATABASE_OPTIONS'].get('port', settings_dict['DATABASE_PORT'])
defaults_file = settings.DATABASE_OPTIONS.get('read_default_file') defaults_file = settings_dict['DATABASE_OPTIONS'].get('read_default_file')
# Seems to be no good way to set sql_mode with CLI. # Seems to be no good way to set sql_mode with CLI.
if defaults_file: if defaults_file:

View File

@ -262,7 +262,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.features = DatabaseFeatures() self.features = DatabaseFeatures()
self.ops = DatabaseOperations() self.ops = DatabaseOperations()
self.client = DatabaseClient() self.client = DatabaseClient(self)
self.creation = DatabaseCreation(self) self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self) self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation() self.validation = BaseDatabaseValidation()
@ -270,23 +270,24 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def _valid_connection(self): def _valid_connection(self):
return self.connection is not None return self.connection is not None
def _connect_string(self, settings): def _connect_string(self):
if len(settings.DATABASE_HOST.strip()) == 0: settings_dict = self.settings_dict
settings.DATABASE_HOST = 'localhost' if len(settings_dict['DATABASE_HOST'].strip()) == 0:
if len(settings.DATABASE_PORT.strip()) != 0: settings_dict['DATABASE_HOST'] = 'localhost'
dsn = Database.makedsn(settings.DATABASE_HOST, if len(settings_dict['DATABASE_PORT'].strip()) != 0:
int(settings.DATABASE_PORT), dsn = Database.makedsn(settings_dict['DATABASE_HOST'],
settings.DATABASE_NAME) int(settings_dict['DATABASE_PORT']),
settings_dict['DATABASE_NAME'])
else: else:
dsn = settings.DATABASE_NAME dsn = settings_dict['DATABASE_NAME']
return "%s/%s@%s" % (settings.DATABASE_USER, return "%s/%s@%s" % (settings_dict['DATABASE_USER'],
settings.DATABASE_PASSWORD, dsn) settings_dict['DATABASE_PASSWORD'], dsn)
def _cursor(self, settings): def _cursor(self):
cursor = None cursor = None
if not self._valid_connection(): if not self._valid_connection():
conn_string = self._connect_string(settings) conn_string = self._connect_string()
self.connection = Database.connect(conn_string, **self.options) self.connection = Database.connect(conn_string, **self.settings_dict['DATABASE_OPTIONS'])
cursor = FormatStylePlaceholderCursor(self.connection) cursor = FormatStylePlaceholderCursor(self.connection)
# Set oracle date to ansi date format. This only needs to execute # Set oracle date to ansi date format. This only needs to execute
# once when we create a new connection. We also set the Territory # once when we create a new connection. We also set the Territory

View File

@ -1,12 +1,10 @@
from django.db.backends import BaseDatabaseClient from django.db.backends import BaseDatabaseClient
from django.conf import settings
import os import os
class DatabaseClient(BaseDatabaseClient): class DatabaseClient(BaseDatabaseClient):
executable_name = 'sqlplus' executable_name = 'sqlplus'
def runshell(self): def runshell(self):
from django.db import connection conn_string = self.connection._connect_string()
conn_string = connection._connect_string(settings)
args = [self.executable_name, "-L", conn_string] args = [self.executable_name, "-L", conn_string]
os.execvp(self.executable_name, args) os.execvp(self.executable_name, args)

View File

@ -90,32 +90,33 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.features = DatabaseFeatures() self.features = DatabaseFeatures()
self.ops = DatabaseOperations() self.ops = DatabaseOperations()
self.client = DatabaseClient() self.client = DatabaseClient(self)
self.creation = DatabaseCreation(self) self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self) self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation() self.validation = BaseDatabaseValidation()
def _cursor(self, settings): def _cursor(self):
set_tz = False set_tz = False
settings_dict = self.settings_dict
if self.connection is None: if self.connection is None:
set_tz = True set_tz = True
if settings.DATABASE_NAME == '': if settings_dict['DATABASE_NAME'] == '':
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured("You need to specify DATABASE_NAME in your Django settings file.") raise ImproperlyConfigured("You need to specify DATABASE_NAME in your Django settings file.")
conn_string = "dbname=%s" % settings.DATABASE_NAME conn_string = "dbname=%s" % settings_dict['DATABASE_NAME']
if settings.DATABASE_USER: if settings_dict['DATABASE_USER']:
conn_string = "user=%s %s" % (settings.DATABASE_USER, conn_string) conn_string = "user=%s %s" % (settings_dict['DATABASE_USER'], conn_string)
if settings.DATABASE_PASSWORD: if settings_dict['DATABASE_PASSWORD']:
conn_string += " password='%s'" % settings.DATABASE_PASSWORD conn_string += " password='%s'" % settings_dict['DATABASE_PASSWORD']
if settings.DATABASE_HOST: if settings_dict['DATABASE_HOST']:
conn_string += " host=%s" % settings.DATABASE_HOST conn_string += " host=%s" % settings_dict['DATABASE_HOST']
if settings.DATABASE_PORT: if settings_dict['DATABASE_PORT']:
conn_string += " port=%s" % settings.DATABASE_PORT conn_string += " port=%s" % settings_dict['DATABASE_PORT']
self.connection = Database.connect(conn_string, **self.options) self.connection = Database.connect(conn_string, **settings_dict['DATABASE_OPTIONS'])
self.connection.set_isolation_level(1) # make transactions transparent to all cursors self.connection.set_isolation_level(1) # make transactions transparent to all cursors
cursor = self.connection.cursor() cursor = self.connection.cursor()
if set_tz: if set_tz:
cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE]) cursor.execute("SET TIME ZONE %s", [settings_dict['TIME_ZONE']])
if not hasattr(self, '_version'): if not hasattr(self, '_version'):
self.__class__._version = get_version(cursor) self.__class__._version = get_version(cursor)
if self._version < (8, 0): if self._version < (8, 0):

View File

@ -1,19 +1,19 @@
from django.db.backends import BaseDatabaseClient from django.db.backends import BaseDatabaseClient
from django.conf import settings
import os import os
class DatabaseClient(BaseDatabaseClient): class DatabaseClient(BaseDatabaseClient):
executable_name = 'psql' executable_name = 'psql'
def runshell(self): def runshell(self):
settings_dict = self.connection.settings_dict
args = [self.executable_name] args = [self.executable_name]
if settings.DATABASE_USER: if settings_dict['DATABASE_USER']:
args += ["-U", settings.DATABASE_USER] args += ["-U", settings_dict['DATABASE_USER']]
if settings.DATABASE_PASSWORD: if settings_dict['DATABASE_PASSWORD']:
args += ["-W"] args += ["-W"]
if settings.DATABASE_HOST: if settings_dict['DATABASE_HOST']:
args.extend(["-h", settings.DATABASE_HOST]) args.extend(["-h", settings_dict['DATABASE_HOST']])
if settings.DATABASE_PORT: if settings_dict['DATABASE_PORT']:
args.extend(["-p", str(settings.DATABASE_PORT)]) args.extend(["-p", str(settings_dict['DATABASE_PORT'])])
args += [settings.DATABASE_NAME] args += [settings_dict['DATABASE_NAME']]
os.execvp(self.executable_name, args) os.execvp(self.executable_name, args)

View File

@ -60,37 +60,38 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.features = DatabaseFeatures() self.features = DatabaseFeatures()
self.ops = DatabaseOperations() self.ops = DatabaseOperations()
self.client = DatabaseClient() self.client = DatabaseClient(self)
self.creation = DatabaseCreation(self) self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self) self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation() self.validation = BaseDatabaseValidation()
def _cursor(self, settings): def _cursor(self):
set_tz = False set_tz = False
settings_dict = self.settings_dict
if self.connection is None: if self.connection is None:
set_tz = True set_tz = True
if settings.DATABASE_NAME == '': if settings_dict['DATABASE_NAME'] == '':
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured("You need to specify DATABASE_NAME in your Django settings file.") raise ImproperlyConfigured("You need to specify DATABASE_NAME in your Django settings file.")
conn_params = { conn_params = {
'database': settings.DATABASE_NAME, 'database': settings_dict['DATABASE_NAME'],
} }
conn_params.update(self.options) conn_params.update(settings_dict['DATABASE_OPTIONS'])
if settings.DATABASE_USER: if settings_dict['DATABASE_USER']:
conn_params['user'] = settings.DATABASE_USER conn_params['user'] = settings_dict['DATABASE_USER']
if settings.DATABASE_PASSWORD: if settings_dict['DATABASE_PASSWORD']:
conn_params['password'] = settings.DATABASE_PASSWORD conn_params['password'] = settings_dict['DATABASE_PASSWORD']
if settings.DATABASE_HOST: if settings_dict['DATABASE_HOST']:
conn_params['host'] = settings.DATABASE_HOST conn_params['host'] = settings_dict['DATABASE_HOST']
if settings.DATABASE_PORT: if settings_dict['DATABASE_PORT']:
conn_params['port'] = settings.DATABASE_PORT conn_params['port'] = settings_dict['DATABASE_PORT']
self.connection = Database.connect(**conn_params) self.connection = Database.connect(**conn_params)
self.connection.set_isolation_level(1) # make transactions transparent to all cursors self.connection.set_isolation_level(1) # make transactions transparent to all cursors
self.connection.set_client_encoding('UTF8') self.connection.set_client_encoding('UTF8')
cursor = self.connection.cursor() cursor = self.connection.cursor()
cursor.tzinfo_factory = None cursor.tzinfo_factory = None
if set_tz: if set_tz:
cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE]) cursor.execute("SET TIME ZONE %s", [settings_dict['TIME_ZONE']])
if not hasattr(self, '_version'): if not hasattr(self, '_version'):
self.__class__._version = get_version(cursor) self.__class__._version = get_version(cursor)
if self._version < (8, 0): if self._version < (8, 0):

View File

@ -149,21 +149,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.features = DatabaseFeatures() self.features = DatabaseFeatures()
self.ops = DatabaseOperations() self.ops = DatabaseOperations()
self.client = DatabaseClient() self.client = DatabaseClient(self)
self.creation = DatabaseCreation(self) self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self) self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation() self.validation = BaseDatabaseValidation()
def _cursor(self, settings): def _cursor(self):
if self.connection is None: if self.connection is None:
if not settings.DATABASE_NAME: settings_dict = self.settings_dict
if not settings_dict['DATABASE_NAME']:
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured, "Please fill out DATABASE_NAME in the settings module before using the database." raise ImproperlyConfigured, "Please fill out DATABASE_NAME in the settings module before using the database."
kwargs = { kwargs = {
'database': settings.DATABASE_NAME, 'database': settings_dict['DATABASE_NAME'],
'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES, 'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
} }
kwargs.update(self.options) kwargs.update(settings_dict['DATABASE_OPTIONS'])
self.connection = Database.connect(**kwargs) self.connection = Database.connect(**kwargs)
# Register extract, date_trunc, and regexp functions. # Register extract, date_trunc, and regexp functions.
self.connection.create_function("django_extract", 2, _sqlite_extract) self.connection.create_function("django_extract", 2, _sqlite_extract)
@ -172,11 +173,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return self.connection.cursor(factory=SQLiteCursorWrapper) return self.connection.cursor(factory=SQLiteCursorWrapper)
def close(self): def close(self):
from django.conf import settings
# If database is in memory, closing the connection destroys the # If database is in memory, closing the connection destroys the
# database. To prevent accidental data loss, ignore close requests on # database. To prevent accidental data loss, ignore close requests on
# an in-memory db. # an in-memory db.
if settings.DATABASE_NAME != ":memory:": if self.settings_dict['DATABASE_NAME'] != ":memory:":
BaseDatabaseWrapper.close(self) BaseDatabaseWrapper.close(self)
class SQLiteCursorWrapper(Database.Cursor): class SQLiteCursorWrapper(Database.Cursor):

View File

@ -1,10 +1,9 @@
from django.db.backends import BaseDatabaseClient from django.db.backends import BaseDatabaseClient
from django.conf import settings
import os import os
class DatabaseClient(BaseDatabaseClient): class DatabaseClient(BaseDatabaseClient):
executable_name = 'sqlite3' executable_name = 'sqlite3'
def runshell(self): def runshell(self):
args = ['', settings.DATABASE_NAME] args = ['', self.connection.settings_dict['DATABASE_NAME']]
os.execvp(self.executable_name, args) os.execvp(self.executable_name, args)