diff --git a/django/db/__init__.py b/django/db/__init__.py index 8395468b8c..26c7add0af 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -22,9 +22,21 @@ router = ConnectionRouter(settings.DATABASE_ROUTERS) # we manually create the dictionary from the settings, passing only the # settings that the database backends care about. Note that TIME_ZONE is used # by the PostgreSQL backends. -# we load all these up for backwards compatibility, you should use +# We load all these up for backwards compatibility, you should use # connections['default'] instead. -connection = connections[DEFAULT_DB_ALIAS] +class DefaultConnectionProxy(object): + """ + Proxy for accessing the default DatabaseWrapper object's attributes. If you + need to access the DatabaseWrapper object itself, use + connections[DEFAULT_DB_ALIAS] instead. + """ + def __getattr__(self, item): + return getattr(connections[DEFAULT_DB_ALIAS], item) + + def __setattr__(self, name, value): + return setattr(connections[DEFAULT_DB_ALIAS], name, value) + +connection = DefaultConnectionProxy() backend = load_backend(connection.settings_dict['ENGINE']) # Register an event that closes the database connection diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index f2bde840d7..59c0992af5 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -1,8 +1,9 @@ +from django.db.utils import DatabaseError + try: import thread except ImportError: import dummy_thread as thread -from threading import local from contextlib import contextmanager from django.conf import settings @@ -13,14 +14,15 @@ from django.utils.importlib import import_module from django.utils.timezone import is_aware -class BaseDatabaseWrapper(local): +class BaseDatabaseWrapper(object): """ Represents a database connection. """ ops = None vendor = 'unknown' - def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): + def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS, + allow_thread_sharing=False): # `settings_dict` should be a dictionary containing keys such as # NAME, USER, etc. It's called `settings_dict` instead of `settings` # to disambiguate it from Django settings modules. @@ -34,6 +36,8 @@ class BaseDatabaseWrapper(local): self.transaction_state = [] self.savepoint_state = 0 self._dirty = None + self._thread_ident = thread.get_ident() + self.allow_thread_sharing = allow_thread_sharing def __eq__(self, other): return self.alias == other.alias @@ -116,6 +120,21 @@ class BaseDatabaseWrapper(local): "pending COMMIT/ROLLBACK") self._dirty = False + def validate_thread_sharing(self): + """ + Validates that the connection isn't accessed by another thread than the + one which originally created it, unless the connection was explicitly + authorized to be shared between threads (via the `allow_thread_sharing` + property). Raises an exception if the validation fails. + """ + if (not self.allow_thread_sharing + and self._thread_ident != thread.get_ident()): + raise DatabaseError("DatabaseWrapper objects created in a " + "thread can only be used in that same thread. The object" + "with alias '%s' was created in thread id %s and this is " + "thread id %s." + % (self.alias, self._thread_ident, thread.get_ident())) + def is_dirty(self): """ Returns True if the current transaction requires a commit for changes to @@ -179,6 +198,7 @@ class BaseDatabaseWrapper(local): """ Commits changes if the system is not in managed transaction mode. """ + self.validate_thread_sharing() if not self.is_managed(): self._commit() self.clean_savepoints() @@ -189,6 +209,7 @@ class BaseDatabaseWrapper(local): """ Rolls back changes if the system is not in managed transaction mode. """ + self.validate_thread_sharing() if not self.is_managed(): self._rollback() else: @@ -198,6 +219,7 @@ class BaseDatabaseWrapper(local): """ Does the commit itself and resets the dirty flag. """ + self.validate_thread_sharing() self._commit() self.set_clean() @@ -205,6 +227,7 @@ class BaseDatabaseWrapper(local): """ This function does the rollback itself and resets the dirty flag. """ + self.validate_thread_sharing() self._rollback() self.set_clean() @@ -228,6 +251,7 @@ class BaseDatabaseWrapper(local): Rolls back the most recent savepoint (if one exists). Does nothing if savepoints are not supported. """ + self.validate_thread_sharing() if self.savepoint_state: self._savepoint_rollback(sid) @@ -236,6 +260,7 @@ class BaseDatabaseWrapper(local): Commits the most recent savepoint (if one exists). Does nothing if savepoints are not supported. """ + self.validate_thread_sharing() if self.savepoint_state: self._savepoint_commit(sid) @@ -269,11 +294,13 @@ class BaseDatabaseWrapper(local): pass def close(self): + self.validate_thread_sharing() if self.connection is not None: self.connection.close() self.connection = None def cursor(self): + self.validate_thread_sharing() if (self.use_debug_cursor or (self.use_debug_cursor is None and settings.DEBUG)): cursor = self.make_debug_cursor(self._cursor()) diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index a610606635..75e8fa027d 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -7,10 +7,10 @@ standard library. import datetime import decimal +import warnings import re import sys -from django.conf import settings from django.db import utils from django.db.backends import * from django.db.backends.signals import connection_created @@ -241,6 +241,21 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES, } kwargs.update(settings_dict['OPTIONS']) + # Always allow the underlying SQLite connection to be shareable + # between multiple threads. The safe-guarding will be handled at a + # higher level by the `BaseDatabaseWrapper.allow_thread_sharing` + # property. This is necessary as the shareability is disabled by + # default in pysqlite and it cannot be changed once a connection is + # opened. + if 'check_same_thread' in kwargs and kwargs['check_same_thread']: + warnings.warn( + 'The `check_same_thread` option was provided and set to ' + 'True. It will be overriden with False. Use the ' + '`DatabaseWrapper.allow_thread_sharing` property instead ' + 'for controlling thread shareability.', + RuntimeWarning + ) + kwargs.update({'check_same_thread': False}) self.connection = Database.connect(**kwargs) # Register extract, date_trunc, and regexp functions. self.connection.create_function("django_extract", 2, _sqlite_extract) diff --git a/django/db/utils.py b/django/db/utils.py index f0c13e3aac..41ad6df728 100644 --- a/django/db/utils.py +++ b/django/db/utils.py @@ -1,4 +1,5 @@ import os +from threading import local from django.conf import settings from django.core.exceptions import ImproperlyConfigured @@ -50,7 +51,7 @@ class ConnectionDoesNotExist(Exception): class ConnectionHandler(object): def __init__(self, databases): self.databases = databases - self._connections = {} + self._connections = local() def ensure_defaults(self, alias): """ @@ -73,16 +74,19 @@ class ConnectionHandler(object): conn.setdefault(setting, None) def __getitem__(self, alias): - if alias in self._connections: - return self._connections[alias] + if hasattr(self._connections, alias): + return getattr(self._connections, alias) self.ensure_defaults(alias) db = self.databases[alias] backend = load_backend(db['ENGINE']) conn = backend.DatabaseWrapper(db, alias) - self._connections[alias] = conn + setattr(self._connections, alias, conn) return conn + def __setitem__(self, key, value): + setattr(self._connections, key, value) + def __iter__(self): return iter(self.databases) diff --git a/docs/releases/1.4.txt b/docs/releases/1.4.txt index 491556bfa2..f614dee5f8 100644 --- a/docs/releases/1.4.txt +++ b/docs/releases/1.4.txt @@ -673,6 +673,32 @@ datetimes are now stored without time zone information in SQLite. When :setting:`USE_TZ` is ``False``, if you attempt to save an aware datetime object, Django raises an exception. +Database connection's thread-locality +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``DatabaseWrapper`` objects (i.e. the connection objects referenced by +``django.db.connection`` and ``django.db.connections["some_alias"]``) used to +be thread-local. They are now global objects in order to be potentially shared +between multiple threads. While the individual connection objects are now +global, the ``django.db.connections`` dictionary referencing those objects is +still thread-local. Therefore if you just use the ORM or +``DatabaseWrapper.cursor()`` then the behavior is still the same as before. +Note, however, that ``django.db.connection`` does not directly reference the +default ``DatabaseWrapper`` object any more and is now a proxy to access that +object's attributes. If you need to access the actual ``DatabaseWrapper`` +object, use ``django.db.connections[DEFAULT_DB_ALIAS]`` instead. + +As part of this change, all underlying SQLite connections are now enabled for +potential thread-sharing (by passing the ``check_same_thread=False`` attribute +to pysqlite). ``DatabaseWrapper`` however preserves the previous behavior by +disabling thread-sharing by default, so this does not affect any existing +code that purely relies on the ORM or on ``DatabaseWrapper.cursor()``. + +Finally, while it is now possible to pass connections between threads, Django +does not make any effort to synchronize access to the underlying backend. +Concurrency behavior is defined by the underlying backend implementation. +Check their documentation for details. + `COMMENTS_BANNED_USERS_GROUP` setting ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/regressiontests/backends/tests.py b/tests/regressiontests/backends/tests.py index 936f010186..82c21c8c7b 100644 --- a/tests/regressiontests/backends/tests.py +++ b/tests/regressiontests/backends/tests.py @@ -3,6 +3,7 @@ from __future__ import with_statement, absolute_import import datetime +import threading from django.conf import settings from django.core.management.color import no_style @@ -283,7 +284,7 @@ class ConnectionCreatedSignalTest(TestCase): connection_created.connect(receiver) connection.close() cursor = connection.cursor() - self.assertTrue(data["connection"] is connection) + self.assertTrue(data["connection"].connection is connection.connection) connection_created.disconnect(receiver) data.clear() @@ -446,3 +447,94 @@ class FkConstraintsTests(TransactionTestCase): connection.check_constraints() finally: transaction.rollback() + + +class ThreadTests(TestCase): + + def test_default_connection_thread_local(self): + """ + Ensure that the default connection (i.e. django.db.connection) is + different for each thread. + Refs #17258. + """ + connections_set = set() + connection.cursor() + connections_set.add(connection.connection) + def runner(): + from django.db import connection + connection.cursor() + connections_set.add(connection.connection) + for x in xrange(2): + t = threading.Thread(target=runner) + t.start() + t.join() + self.assertEquals(len(connections_set), 3) + # Finish by closing the connections opened by the other threads (the + # connection opened in the main thread will automatically be closed on + # teardown). + for conn in connections_set: + if conn != connection.connection: + conn.close() + + def test_connections_thread_local(self): + """ + Ensure that the connections are different for each thread. + Refs #17258. + """ + connections_set = set() + for conn in connections.all(): + connections_set.add(conn) + def runner(): + from django.db import connections + for conn in connections.all(): + connections_set.add(conn) + for x in xrange(2): + t = threading.Thread(target=runner) + t.start() + t.join() + self.assertEquals(len(connections_set), 6) + # Finish by closing the connections opened by the other threads (the + # connection opened in the main thread will automatically be closed on + # teardown). + for conn in connections_set: + if conn != connection: + conn.close() + + def test_pass_connection_between_threads(self): + """ + Ensure that a connection can be passed from one thread to the other. + Refs #17258. + """ + models.Person.objects.create(first_name="John", last_name="Doe") + + def do_thread(): + def runner(main_thread_connection): + from django.db import connections + connections['default'] = main_thread_connection + try: + models.Person.objects.get(first_name="John", last_name="Doe") + except DatabaseError, e: + exceptions.append(e) + t = threading.Thread(target=runner, args=[connections['default']]) + t.start() + t.join() + + # Without touching allow_thread_sharing, which should be False by default. + exceptions = [] + do_thread() + # Forbidden! + self.assertTrue(isinstance(exceptions[0], DatabaseError)) + + # If explicitly setting allow_thread_sharing to False + connections['default'].allow_thread_sharing = False + exceptions = [] + do_thread() + # Forbidden! + self.assertTrue(isinstance(exceptions[0], DatabaseError)) + + # If explicitly setting allow_thread_sharing to True + connections['default'].allow_thread_sharing = True + exceptions = [] + do_thread() + # All good + self.assertEqual(len(exceptions), 0) \ No newline at end of file