Fixed #17258 -- Moved `threading.local` from `DatabaseWrapper` to the `django.db.connections` dictionary. This allows connections to be explicitly shared between multiple threads and is particularly useful for enabling the sharing of in-memory SQLite connections. Many thanks to Anssi Kääriäinen for the excellent suggestions and feedback, and to Alex Gaynor for the reviews. Refs #2879.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@17205 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Julien Phalip 2011-12-16 13:40:19 +00:00
parent 5df31c0164
commit 34e248efec
6 changed files with 187 additions and 11 deletions

View File

@ -22,9 +22,21 @@ router = ConnectionRouter(settings.DATABASE_ROUTERS)
# we manually create the dictionary from the settings, passing only the # we manually create the dictionary from the settings, passing only the
# settings that the database backends care about. Note that TIME_ZONE is used # settings that the database backends care about. Note that TIME_ZONE is used
# by the PostgreSQL backends. # by the PostgreSQL backends.
# we load all these up for backwards compatibility, you should use # We load all these up for backwards compatibility, you should use
# connections['default'] instead. # connections['default'] instead.
connection = connections[DEFAULT_DB_ALIAS] class DefaultConnectionProxy(object):
"""
Proxy for accessing the default DatabaseWrapper object's attributes. If you
need to access the DatabaseWrapper object itself, use
connections[DEFAULT_DB_ALIAS] instead.
"""
def __getattr__(self, item):
return getattr(connections[DEFAULT_DB_ALIAS], item)
def __setattr__(self, name, value):
return setattr(connections[DEFAULT_DB_ALIAS], name, value)
connection = DefaultConnectionProxy()
backend = load_backend(connection.settings_dict['ENGINE']) backend = load_backend(connection.settings_dict['ENGINE'])
# Register an event that closes the database connection # Register an event that closes the database connection

View File

@ -1,8 +1,9 @@
from django.db.utils import DatabaseError
try: try:
import thread import thread
except ImportError: except ImportError:
import dummy_thread as thread import dummy_thread as thread
from threading import local
from contextlib import contextmanager from contextlib import contextmanager
from django.conf import settings from django.conf import settings
@ -13,14 +14,15 @@ from django.utils.importlib import import_module
from django.utils.timezone import is_aware from django.utils.timezone import is_aware
class BaseDatabaseWrapper(local): class BaseDatabaseWrapper(object):
""" """
Represents a database connection. Represents a database connection.
""" """
ops = None ops = None
vendor = 'unknown' vendor = 'unknown'
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS,
allow_thread_sharing=False):
# `settings_dict` should be a dictionary containing keys such as # `settings_dict` should be a dictionary containing keys such as
# NAME, USER, etc. It's called `settings_dict` instead of `settings` # NAME, USER, etc. It's called `settings_dict` instead of `settings`
# to disambiguate it from Django settings modules. # to disambiguate it from Django settings modules.
@ -34,6 +36,8 @@ class BaseDatabaseWrapper(local):
self.transaction_state = [] self.transaction_state = []
self.savepoint_state = 0 self.savepoint_state = 0
self._dirty = None self._dirty = None
self._thread_ident = thread.get_ident()
self.allow_thread_sharing = allow_thread_sharing
def __eq__(self, other): def __eq__(self, other):
return self.alias == other.alias return self.alias == other.alias
@ -116,6 +120,21 @@ class BaseDatabaseWrapper(local):
"pending COMMIT/ROLLBACK") "pending COMMIT/ROLLBACK")
self._dirty = False self._dirty = False
def validate_thread_sharing(self):
"""
Validates that the connection isn't accessed by another thread than the
one which originally created it, unless the connection was explicitly
authorized to be shared between threads (via the `allow_thread_sharing`
property). Raises an exception if the validation fails.
"""
if (not self.allow_thread_sharing
and self._thread_ident != thread.get_ident()):
raise DatabaseError("DatabaseWrapper objects created in a "
"thread can only be used in that same thread. The object"
"with alias '%s' was created in thread id %s and this is "
"thread id %s."
% (self.alias, self._thread_ident, thread.get_ident()))
def is_dirty(self): def is_dirty(self):
""" """
Returns True if the current transaction requires a commit for changes to Returns True if the current transaction requires a commit for changes to
@ -179,6 +198,7 @@ class BaseDatabaseWrapper(local):
""" """
Commits changes if the system is not in managed transaction mode. Commits changes if the system is not in managed transaction mode.
""" """
self.validate_thread_sharing()
if not self.is_managed(): if not self.is_managed():
self._commit() self._commit()
self.clean_savepoints() self.clean_savepoints()
@ -189,6 +209,7 @@ class BaseDatabaseWrapper(local):
""" """
Rolls back changes if the system is not in managed transaction mode. Rolls back changes if the system is not in managed transaction mode.
""" """
self.validate_thread_sharing()
if not self.is_managed(): if not self.is_managed():
self._rollback() self._rollback()
else: else:
@ -198,6 +219,7 @@ class BaseDatabaseWrapper(local):
""" """
Does the commit itself and resets the dirty flag. Does the commit itself and resets the dirty flag.
""" """
self.validate_thread_sharing()
self._commit() self._commit()
self.set_clean() self.set_clean()
@ -205,6 +227,7 @@ class BaseDatabaseWrapper(local):
""" """
This function does the rollback itself and resets the dirty flag. This function does the rollback itself and resets the dirty flag.
""" """
self.validate_thread_sharing()
self._rollback() self._rollback()
self.set_clean() self.set_clean()
@ -228,6 +251,7 @@ class BaseDatabaseWrapper(local):
Rolls back the most recent savepoint (if one exists). Does nothing if Rolls back the most recent savepoint (if one exists). Does nothing if
savepoints are not supported. savepoints are not supported.
""" """
self.validate_thread_sharing()
if self.savepoint_state: if self.savepoint_state:
self._savepoint_rollback(sid) self._savepoint_rollback(sid)
@ -236,6 +260,7 @@ class BaseDatabaseWrapper(local):
Commits the most recent savepoint (if one exists). Does nothing if Commits the most recent savepoint (if one exists). Does nothing if
savepoints are not supported. savepoints are not supported.
""" """
self.validate_thread_sharing()
if self.savepoint_state: if self.savepoint_state:
self._savepoint_commit(sid) self._savepoint_commit(sid)
@ -269,11 +294,13 @@ class BaseDatabaseWrapper(local):
pass pass
def close(self): def close(self):
self.validate_thread_sharing()
if self.connection is not None: if self.connection is not None:
self.connection.close() self.connection.close()
self.connection = None self.connection = None
def cursor(self): def cursor(self):
self.validate_thread_sharing()
if (self.use_debug_cursor or if (self.use_debug_cursor or
(self.use_debug_cursor is None and settings.DEBUG)): (self.use_debug_cursor is None and settings.DEBUG)):
cursor = self.make_debug_cursor(self._cursor()) cursor = self.make_debug_cursor(self._cursor())

View File

@ -7,10 +7,10 @@ standard library.
import datetime import datetime
import decimal import decimal
import warnings
import re import re
import sys import sys
from django.conf import settings
from django.db import utils from django.db import utils
from django.db.backends import * from django.db.backends import *
from django.db.backends.signals import connection_created from django.db.backends.signals import connection_created
@ -241,6 +241,21 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES, 'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
} }
kwargs.update(settings_dict['OPTIONS']) kwargs.update(settings_dict['OPTIONS'])
# Always allow the underlying SQLite connection to be shareable
# between multiple threads. The safe-guarding will be handled at a
# higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
# property. This is necessary as the shareability is disabled by
# default in pysqlite and it cannot be changed once a connection is
# opened.
if 'check_same_thread' in kwargs and kwargs['check_same_thread']:
warnings.warn(
'The `check_same_thread` option was provided and set to '
'True. It will be overriden with False. Use the '
'`DatabaseWrapper.allow_thread_sharing` property instead '
'for controlling thread shareability.',
RuntimeWarning
)
kwargs.update({'check_same_thread': False})
self.connection = Database.connect(**kwargs) self.connection = Database.connect(**kwargs)
# Register extract, date_trunc, and regexp functions. # Register extract, date_trunc, and regexp functions.
self.connection.create_function("django_extract", 2, _sqlite_extract) self.connection.create_function("django_extract", 2, _sqlite_extract)

View File

@ -1,4 +1,5 @@
import os import os
from threading import local
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
@ -50,7 +51,7 @@ class ConnectionDoesNotExist(Exception):
class ConnectionHandler(object): class ConnectionHandler(object):
def __init__(self, databases): def __init__(self, databases):
self.databases = databases self.databases = databases
self._connections = {} self._connections = local()
def ensure_defaults(self, alias): def ensure_defaults(self, alias):
""" """
@ -73,16 +74,19 @@ class ConnectionHandler(object):
conn.setdefault(setting, None) conn.setdefault(setting, None)
def __getitem__(self, alias): def __getitem__(self, alias):
if alias in self._connections: if hasattr(self._connections, alias):
return self._connections[alias] return getattr(self._connections, alias)
self.ensure_defaults(alias) self.ensure_defaults(alias)
db = self.databases[alias] db = self.databases[alias]
backend = load_backend(db['ENGINE']) backend = load_backend(db['ENGINE'])
conn = backend.DatabaseWrapper(db, alias) conn = backend.DatabaseWrapper(db, alias)
self._connections[alias] = conn setattr(self._connections, alias, conn)
return conn return conn
def __setitem__(self, key, value):
setattr(self._connections, key, value)
def __iter__(self): def __iter__(self):
return iter(self.databases) return iter(self.databases)

View File

@ -673,6 +673,32 @@ datetimes are now stored without time zone information in SQLite. When
:setting:`USE_TZ` is ``False``, if you attempt to save an aware datetime :setting:`USE_TZ` is ``False``, if you attempt to save an aware datetime
object, Django raises an exception. object, Django raises an exception.
Database connection's thread-locality
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``DatabaseWrapper`` objects (i.e. the connection objects referenced by
``django.db.connection`` and ``django.db.connections["some_alias"]``) used to
be thread-local. They are now global objects in order to be potentially shared
between multiple threads. While the individual connection objects are now
global, the ``django.db.connections`` dictionary referencing those objects is
still thread-local. Therefore if you just use the ORM or
``DatabaseWrapper.cursor()`` then the behavior is still the same as before.
Note, however, that ``django.db.connection`` does not directly reference the
default ``DatabaseWrapper`` object any more and is now a proxy to access that
object's attributes. If you need to access the actual ``DatabaseWrapper``
object, use ``django.db.connections[DEFAULT_DB_ALIAS]`` instead.
As part of this change, all underlying SQLite connections are now enabled for
potential thread-sharing (by passing the ``check_same_thread=False`` attribute
to pysqlite). ``DatabaseWrapper`` however preserves the previous behavior by
disabling thread-sharing by default, so this does not affect any existing
code that purely relies on the ORM or on ``DatabaseWrapper.cursor()``.
Finally, while it is now possible to pass connections between threads, Django
does not make any effort to synchronize access to the underlying backend.
Concurrency behavior is defined by the underlying backend implementation.
Check their documentation for details.
`COMMENTS_BANNED_USERS_GROUP` setting `COMMENTS_BANNED_USERS_GROUP` setting
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -3,6 +3,7 @@
from __future__ import with_statement, absolute_import from __future__ import with_statement, absolute_import
import datetime import datetime
import threading
from django.conf import settings from django.conf import settings
from django.core.management.color import no_style from django.core.management.color import no_style
@ -283,7 +284,7 @@ class ConnectionCreatedSignalTest(TestCase):
connection_created.connect(receiver) connection_created.connect(receiver)
connection.close() connection.close()
cursor = connection.cursor() cursor = connection.cursor()
self.assertTrue(data["connection"] is connection) self.assertTrue(data["connection"].connection is connection.connection)
connection_created.disconnect(receiver) connection_created.disconnect(receiver)
data.clear() data.clear()
@ -446,3 +447,94 @@ class FkConstraintsTests(TransactionTestCase):
connection.check_constraints() connection.check_constraints()
finally: finally:
transaction.rollback() transaction.rollback()
class ThreadTests(TestCase):
def test_default_connection_thread_local(self):
"""
Ensure that the default connection (i.e. django.db.connection) is
different for each thread.
Refs #17258.
"""
connections_set = set()
connection.cursor()
connections_set.add(connection.connection)
def runner():
from django.db import connection
connection.cursor()
connections_set.add(connection.connection)
for x in xrange(2):
t = threading.Thread(target=runner)
t.start()
t.join()
self.assertEquals(len(connections_set), 3)
# Finish by closing the connections opened by the other threads (the
# connection opened in the main thread will automatically be closed on
# teardown).
for conn in connections_set:
if conn != connection.connection:
conn.close()
def test_connections_thread_local(self):
"""
Ensure that the connections are different for each thread.
Refs #17258.
"""
connections_set = set()
for conn in connections.all():
connections_set.add(conn)
def runner():
from django.db import connections
for conn in connections.all():
connections_set.add(conn)
for x in xrange(2):
t = threading.Thread(target=runner)
t.start()
t.join()
self.assertEquals(len(connections_set), 6)
# Finish by closing the connections opened by the other threads (the
# connection opened in the main thread will automatically be closed on
# teardown).
for conn in connections_set:
if conn != connection:
conn.close()
def test_pass_connection_between_threads(self):
"""
Ensure that a connection can be passed from one thread to the other.
Refs #17258.
"""
models.Person.objects.create(first_name="John", last_name="Doe")
def do_thread():
def runner(main_thread_connection):
from django.db import connections
connections['default'] = main_thread_connection
try:
models.Person.objects.get(first_name="John", last_name="Doe")
except DatabaseError, e:
exceptions.append(e)
t = threading.Thread(target=runner, args=[connections['default']])
t.start()
t.join()
# Without touching allow_thread_sharing, which should be False by default.
exceptions = []
do_thread()
# Forbidden!
self.assertTrue(isinstance(exceptions[0], DatabaseError))
# If explicitly setting allow_thread_sharing to False
connections['default'].allow_thread_sharing = False
exceptions = []
do_thread()
# Forbidden!
self.assertTrue(isinstance(exceptions[0], DatabaseError))
# If explicitly setting allow_thread_sharing to True
connections['default'].allow_thread_sharing = True
exceptions = []
do_thread()
# All good
self.assertEqual(len(exceptions), 0)