From 315145f7ca682f8361d956e985f533a7fb421cde Mon Sep 17 00:00:00 2001 From: Adrian Holovaty Date: Wed, 11 Mar 2009 03:39:34 +0000 Subject: [PATCH] Fixed #10459 -- Refactored the internals of database connection objects so that connections know their own settings and pass around settings as dictionaries instead of passing around the Django settings module itself. This will make it easier for multiple database support. Thanks to Alex Gaynor for the initial patch. This is backwards-compatible but will likely break third-party database backends. Specific API changes are: * BaseDatabaseWrapper.__init__() now takes a settings_dict instead of a settings module. It's called settings_dict to disambiguate, and for easy grepability. This should be a dictionary containing DATABASE_NAME, etc. * BaseDatabaseWrapper has a settings_dict attribute instead of an options attribute. BaseDatabaseWrapper.options is now BaseDatabaseWrapper['DATABASE_OPTIONS'] * BaseDatabaseWrapper._cursor() no longer takes a settings argument. * BaseDatabaseClient.__init__() now takes a connection argument (a DatabaseWrapper instance) instead of no arguments. git-svn-id: http://code.djangoproject.com/svn/django/trunk@10026 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/__init__.py | 18 ++++++++-- django/db/backends/__init__.py | 14 ++++++-- django/db/backends/dummy/base.py | 2 +- django/db/backends/mysql/base.py | 33 ++++++++++--------- django/db/backends/mysql/client.py | 14 ++++---- django/db/backends/oracle/base.py | 29 ++++++++-------- django/db/backends/oracle/client.py | 4 +-- django/db/backends/postgresql/base.py | 29 ++++++++-------- django/db/backends/postgresql/client.py | 18 +++++----- .../db/backends/postgresql_psycopg2/base.py | 29 ++++++++-------- django/db/backends/sqlite3/base.py | 14 ++++---- django/db/backends/sqlite3/client.py | 3 +- 12 files changed, 115 insertions(+), 92 deletions(-) diff --git a/django/db/__init__.py b/django/db/__init__.py index 8025721e723..517fec6e84d 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -36,8 +36,22 @@ except ImportError, e: else: raise # If there's some other error, this must be an error in Django itself. -# Convenient aliases for backend bits. -connection = backend.DatabaseWrapper(**settings.DATABASE_OPTIONS) +# `connection`, `DatabaseError` and `IntegrityError` are convenient aliases +# for backend bits. + +# DatabaseWrapper.__init__() takes a dictionary, not a settings module, so +# we manually create the dictionary from the settings, passing only the +# settings that the database backends care about. Note that TIME_ZONE is used +# by the PostgreSQL backends. +connection = backend.DatabaseWrapper({ + 'DATABASE_HOST': settings.DATABASE_HOST, + 'DATABASE_NAME': settings.DATABASE_NAME, + 'DATABASE_OPTIONS': settings.DATABASE_OPTIONS, + 'DATABASE_PASSWORD': settings.DATABASE_PASSWORD, + 'DATABASE_PORT': settings.DATABASE_PORT, + 'DATABASE_USER': settings.DATABASE_USER, + 'TIME_ZONE': settings.TIME_ZONE, +}) DatabaseError = backend.DatabaseError IntegrityError = backend.IntegrityError diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 187ff6cfe6b..6ca7b871160 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -24,10 +24,14 @@ class BaseDatabaseWrapper(local): Represents a database connection. """ ops = None - def __init__(self, **kwargs): + def __init__(self, settings_dict): + # `settings_dict` should be a dictionary containing keys such as + # DATABASE_NAME, DATABASE_USER, etc. It's called `settings_dict` + # instead of `settings` to disambiguate it from Django settings + # modules. self.connection = None self.queries = [] - self.options = kwargs + self.settings_dict = settings_dict def _commit(self): if self.connection is not None: @@ -59,7 +63,7 @@ class BaseDatabaseWrapper(local): def cursor(self): from django.conf import settings - cursor = self._cursor(settings) + cursor = self._cursor() if settings.DEBUG: return self.make_debug_cursor(cursor) return cursor @@ -498,6 +502,10 @@ class BaseDatabaseClient(object): # (e.g., "psql"). Subclasses must override this. executable_name = None + def __init__(self, connection): + # connection is an instance of BaseDatabaseWrapper. + self.connection = connection + def runshell(self): raise NotImplementedError() diff --git a/django/db/backends/dummy/base.py b/django/db/backends/dummy/base.py index 530ea9c5191..a18c6094bfd 100644 --- a/django/db/backends/dummy/base.py +++ b/django/db/backends/dummy/base.py @@ -46,7 +46,7 @@ class DatabaseWrapper(object): self.features = BaseDatabaseFeatures() self.ops = DatabaseOperations() - self.client = DatabaseClient() + self.client = DatabaseClient(self) self.creation = BaseDatabaseCreation(self) self.introspection = DatabaseIntrospection(self) self.validation = BaseDatabaseValidation() diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 00da726ac59..7ac64e44f31 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -234,11 +234,11 @@ class DatabaseWrapper(BaseDatabaseWrapper): def __init__(self, **kwargs): super(DatabaseWrapper, self).__init__(**kwargs) - self.server_version = None + self.server_version = None self.features = DatabaseFeatures() self.ops = DatabaseOperations() - self.client = DatabaseClient() + self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) self.validation = DatabaseValidation() @@ -253,26 +253,27 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.connection = None return False - def _cursor(self, settings): + def _cursor(self): if not self._valid_connection(): kwargs = { 'conv': django_conversions, 'charset': 'utf8', 'use_unicode': True, } - if settings.DATABASE_USER: - kwargs['user'] = settings.DATABASE_USER - if settings.DATABASE_NAME: - kwargs['db'] = settings.DATABASE_NAME - if settings.DATABASE_PASSWORD: - kwargs['passwd'] = settings.DATABASE_PASSWORD - if settings.DATABASE_HOST.startswith('/'): - kwargs['unix_socket'] = settings.DATABASE_HOST - elif settings.DATABASE_HOST: - kwargs['host'] = settings.DATABASE_HOST - if settings.DATABASE_PORT: - kwargs['port'] = int(settings.DATABASE_PORT) - kwargs.update(self.options) + settings_dict = self.settings_dict + if settings_dict['DATABASE_USER']: + kwargs['user'] = settings_dict['DATABASE_USER'] + if settings_dict['DATABASE_NAME']: + kwargs['db'] = settings_dict['DATABASE_NAME'] + if settings_dict['DATABASE_PASSWORD']: + kwargs['passwd'] = settings_dict['DATABASE_PASSWORD'] + if settings_dict['DATABASE_HOST'].startswith('/'): + kwargs['unix_socket'] = settings_dict['DATABASE_HOST'] + elif settings_dict['DATABASE_HOST']: + kwargs['host'] = settings_dict['DATABASE_HOST'] + if settings_dict['DATABASE_PORT']: + kwargs['port'] = int(settings_dict['DATABASE_PORT']) + kwargs.update(settings_dict['DATABASE_OPTIONS']) self.connection = Database.connect(**kwargs) self.connection.encoders[SafeUnicode] = self.connection.encoders[unicode] self.connection.encoders[SafeString] = self.connection.encoders[str] diff --git a/django/db/backends/mysql/client.py b/django/db/backends/mysql/client.py index 17daca9fd69..129f86a951b 100644 --- a/django/db/backends/mysql/client.py +++ b/django/db/backends/mysql/client.py @@ -1,18 +1,18 @@ from django.db.backends import BaseDatabaseClient -from django.conf import settings import os class DatabaseClient(BaseDatabaseClient): executable_name = 'mysql' def runshell(self): + settings_dict = self.connection.settings_dict args = [''] - db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME) - user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER) - passwd = settings.DATABASE_OPTIONS.get('passwd', settings.DATABASE_PASSWORD) - host = settings.DATABASE_OPTIONS.get('host', settings.DATABASE_HOST) - port = settings.DATABASE_OPTIONS.get('port', settings.DATABASE_PORT) - defaults_file = settings.DATABASE_OPTIONS.get('read_default_file') + db = settings_dict['DATABASE_OPTIONS'].get('db', settings_dict['DATABASE_NAME']) + user = settings_dict['DATABASE_OPTIONS'].get('user', settings_dict['DATABASE_USER']) + passwd = settings_dict['DATABASE_OPTIONS'].get('passwd', settings_dict['DATABASE_PASSWORD']) + host = settings_dict['DATABASE_OPTIONS'].get('host', settings_dict['DATABASE_HOST']) + port = settings_dict['DATABASE_OPTIONS'].get('port', settings_dict['DATABASE_PORT']) + defaults_file = settings_dict['DATABASE_OPTIONS'].get('read_default_file') # Seems to be no good way to set sql_mode with CLI. if defaults_file: diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index baa8486c50c..e8570c9fce2 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -262,7 +262,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.features = DatabaseFeatures() self.ops = DatabaseOperations() - self.client = DatabaseClient() + self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) self.validation = BaseDatabaseValidation() @@ -270,23 +270,24 @@ class DatabaseWrapper(BaseDatabaseWrapper): def _valid_connection(self): return self.connection is not None - def _connect_string(self, settings): - if len(settings.DATABASE_HOST.strip()) == 0: - settings.DATABASE_HOST = 'localhost' - if len(settings.DATABASE_PORT.strip()) != 0: - dsn = Database.makedsn(settings.DATABASE_HOST, - int(settings.DATABASE_PORT), - settings.DATABASE_NAME) + def _connect_string(self): + settings_dict = self.settings_dict + if len(settings_dict['DATABASE_HOST'].strip()) == 0: + settings_dict['DATABASE_HOST'] = 'localhost' + if len(settings_dict['DATABASE_PORT'].strip()) != 0: + dsn = Database.makedsn(settings_dict['DATABASE_HOST'], + int(settings_dict['DATABASE_PORT']), + settings_dict['DATABASE_NAME']) else: - dsn = settings.DATABASE_NAME - return "%s/%s@%s" % (settings.DATABASE_USER, - settings.DATABASE_PASSWORD, dsn) + dsn = settings_dict['DATABASE_NAME'] + return "%s/%s@%s" % (settings_dict['DATABASE_USER'], + settings_dict['DATABASE_PASSWORD'], dsn) - def _cursor(self, settings): + def _cursor(self): cursor = None if not self._valid_connection(): - conn_string = self._connect_string(settings) - self.connection = Database.connect(conn_string, **self.options) + conn_string = self._connect_string() + self.connection = Database.connect(conn_string, **self.settings_dict['DATABASE_OPTIONS']) cursor = FormatStylePlaceholderCursor(self.connection) # Set oracle date to ansi date format. This only needs to execute # once when we create a new connection. We also set the Territory diff --git a/django/db/backends/oracle/client.py b/django/db/backends/oracle/client.py index c95b8109ba5..84193eaedc1 100644 --- a/django/db/backends/oracle/client.py +++ b/django/db/backends/oracle/client.py @@ -1,12 +1,10 @@ from django.db.backends import BaseDatabaseClient -from django.conf import settings import os class DatabaseClient(BaseDatabaseClient): executable_name = 'sqlplus' def runshell(self): - from django.db import connection - conn_string = connection._connect_string(settings) + conn_string = self.connection._connect_string() args = [self.executable_name, "-L", conn_string] os.execvp(self.executable_name, args) diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index ad271f2e38e..9050cfd9e0d 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -90,32 +90,33 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.features = DatabaseFeatures() self.ops = DatabaseOperations() - self.client = DatabaseClient() + self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) self.validation = BaseDatabaseValidation() - def _cursor(self, settings): + def _cursor(self): set_tz = False + settings_dict = self.settings_dict if self.connection is None: set_tz = True - if settings.DATABASE_NAME == '': + if settings_dict['DATABASE_NAME'] == '': from django.core.exceptions import ImproperlyConfigured raise ImproperlyConfigured("You need to specify DATABASE_NAME in your Django settings file.") - conn_string = "dbname=%s" % settings.DATABASE_NAME - if settings.DATABASE_USER: - conn_string = "user=%s %s" % (settings.DATABASE_USER, conn_string) - if settings.DATABASE_PASSWORD: - conn_string += " password='%s'" % settings.DATABASE_PASSWORD - if settings.DATABASE_HOST: - conn_string += " host=%s" % settings.DATABASE_HOST - if settings.DATABASE_PORT: - conn_string += " port=%s" % settings.DATABASE_PORT - self.connection = Database.connect(conn_string, **self.options) + conn_string = "dbname=%s" % settings_dict['DATABASE_NAME'] + if settings_dict['DATABASE_USER']: + conn_string = "user=%s %s" % (settings_dict['DATABASE_USER'], conn_string) + if settings_dict['DATABASE_PASSWORD']: + conn_string += " password='%s'" % settings_dict['DATABASE_PASSWORD'] + if settings_dict['DATABASE_HOST']: + conn_string += " host=%s" % settings_dict['DATABASE_HOST'] + if settings_dict['DATABASE_PORT']: + conn_string += " port=%s" % settings_dict['DATABASE_PORT'] + self.connection = Database.connect(conn_string, **settings_dict['DATABASE_OPTIONS']) self.connection.set_isolation_level(1) # make transactions transparent to all cursors cursor = self.connection.cursor() if set_tz: - cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE]) + cursor.execute("SET TIME ZONE %s", [settings_dict['TIME_ZONE']]) if not hasattr(self, '_version'): self.__class__._version = get_version(cursor) if self._version < (8, 0): diff --git a/django/db/backends/postgresql/client.py b/django/db/backends/postgresql/client.py index 63f28a7b57e..506372bfc49 100644 --- a/django/db/backends/postgresql/client.py +++ b/django/db/backends/postgresql/client.py @@ -1,19 +1,19 @@ from django.db.backends import BaseDatabaseClient -from django.conf import settings import os class DatabaseClient(BaseDatabaseClient): executable_name = 'psql' def runshell(self): + settings_dict = self.connection.settings_dict args = [self.executable_name] - if settings.DATABASE_USER: - args += ["-U", settings.DATABASE_USER] - if settings.DATABASE_PASSWORD: + if settings_dict['DATABASE_USER']: + args += ["-U", settings_dict['DATABASE_USER']] + if settings_dict['DATABASE_PASSWORD']: args += ["-W"] - if settings.DATABASE_HOST: - args.extend(["-h", settings.DATABASE_HOST]) - if settings.DATABASE_PORT: - args.extend(["-p", str(settings.DATABASE_PORT)]) - args += [settings.DATABASE_NAME] + if settings_dict['DATABASE_HOST']: + args.extend(["-h", settings_dict['DATABASE_HOST']]) + if settings_dict['DATABASE_PORT']: + args.extend(["-p", str(settings_dict['DATABASE_PORT'])]) + args += [settings_dict['DATABASE_NAME']] os.execvp(self.executable_name, args) diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 27de942207e..b4a69dca89a 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -60,37 +60,38 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.features = DatabaseFeatures() self.ops = DatabaseOperations() - self.client = DatabaseClient() + self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) self.validation = BaseDatabaseValidation() - def _cursor(self, settings): + def _cursor(self): set_tz = False + settings_dict = self.settings_dict if self.connection is None: set_tz = True - if settings.DATABASE_NAME == '': + if settings_dict['DATABASE_NAME'] == '': from django.core.exceptions import ImproperlyConfigured raise ImproperlyConfigured("You need to specify DATABASE_NAME in your Django settings file.") conn_params = { - 'database': settings.DATABASE_NAME, + 'database': settings_dict['DATABASE_NAME'], } - conn_params.update(self.options) - if settings.DATABASE_USER: - conn_params['user'] = settings.DATABASE_USER - if settings.DATABASE_PASSWORD: - conn_params['password'] = settings.DATABASE_PASSWORD - if settings.DATABASE_HOST: - conn_params['host'] = settings.DATABASE_HOST - if settings.DATABASE_PORT: - conn_params['port'] = settings.DATABASE_PORT + conn_params.update(settings_dict['DATABASE_OPTIONS']) + if settings_dict['DATABASE_USER']: + conn_params['user'] = settings_dict['DATABASE_USER'] + if settings_dict['DATABASE_PASSWORD']: + conn_params['password'] = settings_dict['DATABASE_PASSWORD'] + if settings_dict['DATABASE_HOST']: + conn_params['host'] = settings_dict['DATABASE_HOST'] + if settings_dict['DATABASE_PORT']: + conn_params['port'] = settings_dict['DATABASE_PORT'] self.connection = Database.connect(**conn_params) self.connection.set_isolation_level(1) # make transactions transparent to all cursors self.connection.set_client_encoding('UTF8') cursor = self.connection.cursor() cursor.tzinfo_factory = None if set_tz: - cursor.execute("SET TIME ZONE %s", [settings.TIME_ZONE]) + cursor.execute("SET TIME ZONE %s", [settings_dict['TIME_ZONE']]) if not hasattr(self, '_version'): self.__class__._version = get_version(cursor) if self._version < (8, 0): diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index ba0ef16b61b..b0c087d1cd9 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -149,21 +149,22 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.features = DatabaseFeatures() self.ops = DatabaseOperations() - self.client = DatabaseClient() + self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) self.validation = BaseDatabaseValidation() - def _cursor(self, settings): + def _cursor(self): if self.connection is None: - if not settings.DATABASE_NAME: + settings_dict = self.settings_dict + if not settings_dict['DATABASE_NAME']: from django.core.exceptions import ImproperlyConfigured raise ImproperlyConfigured, "Please fill out DATABASE_NAME in the settings module before using the database." kwargs = { - 'database': settings.DATABASE_NAME, + 'database': settings_dict['DATABASE_NAME'], 'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES, } - kwargs.update(self.options) + kwargs.update(settings_dict['DATABASE_OPTIONS']) self.connection = Database.connect(**kwargs) # Register extract, date_trunc, and regexp functions. self.connection.create_function("django_extract", 2, _sqlite_extract) @@ -172,11 +173,10 @@ class DatabaseWrapper(BaseDatabaseWrapper): return self.connection.cursor(factory=SQLiteCursorWrapper) def close(self): - from django.conf import settings # If database is in memory, closing the connection destroys the # database. To prevent accidental data loss, ignore close requests on # an in-memory db. - if settings.DATABASE_NAME != ":memory:": + if self.settings_dict['DATABASE_NAME'] != ":memory:": BaseDatabaseWrapper.close(self) class SQLiteCursorWrapper(Database.Cursor): diff --git a/django/db/backends/sqlite3/client.py b/django/db/backends/sqlite3/client.py index 239e72f1e92..0b65444d74a 100644 --- a/django/db/backends/sqlite3/client.py +++ b/django/db/backends/sqlite3/client.py @@ -1,10 +1,9 @@ from django.db.backends import BaseDatabaseClient -from django.conf import settings import os class DatabaseClient(BaseDatabaseClient): executable_name = 'sqlite3' def runshell(self): - args = ['', settings.DATABASE_NAME] + args = ['', self.connection.settings_dict['DATABASE_NAME']] os.execvp(self.executable_name, args)