Fixed #17258 -- Moved `threading.local` from `DatabaseWrapper` to the `django.db.connections` dictionary. This allows connections to be explicitly shared between multiple threads and is particularly useful for enabling the sharing of in-memory SQLite connections. Many thanks to Anssi Kääriäinen for the excellent suggestions and feedback, and to Alex Gaynor for the reviews. Refs #2879.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@17205 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Julien Phalip 2011-12-16 13:40:19 +00:00
parent 5df31c0164
commit 34e248efec
6 changed files with 187 additions and 11 deletions

View File

@ -22,9 +22,21 @@ router = ConnectionRouter(settings.DATABASE_ROUTERS)
# we manually create the dictionary from the settings, passing only the
# settings that the database backends care about. Note that TIME_ZONE is used
# by the PostgreSQL backends.
# we load all these up for backwards compatibility, you should use
# We load all these up for backwards compatibility, you should use
# connections['default'] instead.
connection = connections[DEFAULT_DB_ALIAS]
class DefaultConnectionProxy(object):
"""
Proxy for accessing the default DatabaseWrapper object's attributes. If you
need to access the DatabaseWrapper object itself, use
connections[DEFAULT_DB_ALIAS] instead.
"""
def __getattr__(self, item):
return getattr(connections[DEFAULT_DB_ALIAS], item)
def __setattr__(self, name, value):
return setattr(connections[DEFAULT_DB_ALIAS], name, value)
connection = DefaultConnectionProxy()
backend = load_backend(connection.settings_dict['ENGINE'])
# Register an event that closes the database connection

View File

@ -1,8 +1,9 @@
from django.db.utils import DatabaseError
try:
import thread
except ImportError:
import dummy_thread as thread
from threading import local
from contextlib import contextmanager
from django.conf import settings
@ -13,14 +14,15 @@ from django.utils.importlib import import_module
from django.utils.timezone import is_aware
class BaseDatabaseWrapper(local):
class BaseDatabaseWrapper(object):
"""
Represents a database connection.
"""
ops = None
vendor = 'unknown'
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS,
allow_thread_sharing=False):
# `settings_dict` should be a dictionary containing keys such as
# NAME, USER, etc. It's called `settings_dict` instead of `settings`
# to disambiguate it from Django settings modules.
@ -34,6 +36,8 @@ class BaseDatabaseWrapper(local):
self.transaction_state = []
self.savepoint_state = 0
self._dirty = None
self._thread_ident = thread.get_ident()
self.allow_thread_sharing = allow_thread_sharing
def __eq__(self, other):
return self.alias == other.alias
@ -116,6 +120,21 @@ class BaseDatabaseWrapper(local):
"pending COMMIT/ROLLBACK")
self._dirty = False
def validate_thread_sharing(self):
"""
Validates that the connection isn't accessed by another thread than the
one which originally created it, unless the connection was explicitly
authorized to be shared between threads (via the `allow_thread_sharing`
property). Raises an exception if the validation fails.
"""
if (not self.allow_thread_sharing
and self._thread_ident != thread.get_ident()):
raise DatabaseError("DatabaseWrapper objects created in a "
"thread can only be used in that same thread. The object"
"with alias '%s' was created in thread id %s and this is "
"thread id %s."
% (self.alias, self._thread_ident, thread.get_ident()))
def is_dirty(self):
"""
Returns True if the current transaction requires a commit for changes to
@ -179,6 +198,7 @@ class BaseDatabaseWrapper(local):
"""
Commits changes if the system is not in managed transaction mode.
"""
self.validate_thread_sharing()
if not self.is_managed():
self._commit()
self.clean_savepoints()
@ -189,6 +209,7 @@ class BaseDatabaseWrapper(local):
"""
Rolls back changes if the system is not in managed transaction mode.
"""
self.validate_thread_sharing()
if not self.is_managed():
self._rollback()
else:
@ -198,6 +219,7 @@ class BaseDatabaseWrapper(local):
"""
Does the commit itself and resets the dirty flag.
"""
self.validate_thread_sharing()
self._commit()
self.set_clean()
@ -205,6 +227,7 @@ class BaseDatabaseWrapper(local):
"""
This function does the rollback itself and resets the dirty flag.
"""
self.validate_thread_sharing()
self._rollback()
self.set_clean()
@ -228,6 +251,7 @@ class BaseDatabaseWrapper(local):
Rolls back the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_rollback(sid)
@ -236,6 +260,7 @@ class BaseDatabaseWrapper(local):
Commits the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_commit(sid)
@ -269,11 +294,13 @@ class BaseDatabaseWrapper(local):
pass
def close(self):
self.validate_thread_sharing()
if self.connection is not None:
self.connection.close()
self.connection = None
def cursor(self):
self.validate_thread_sharing()
if (self.use_debug_cursor or
(self.use_debug_cursor is None and settings.DEBUG)):
cursor = self.make_debug_cursor(self._cursor())

View File

@ -7,10 +7,10 @@ standard library.
import datetime
import decimal
import warnings
import re
import sys
from django.conf import settings
from django.db import utils
from django.db.backends import *
from django.db.backends.signals import connection_created
@ -241,6 +241,21 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
}
kwargs.update(settings_dict['OPTIONS'])
# Always allow the underlying SQLite connection to be shareable
# between multiple threads. The safe-guarding will be handled at a
# higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
# property. This is necessary as the shareability is disabled by
# default in pysqlite and it cannot be changed once a connection is
# opened.
if 'check_same_thread' in kwargs and kwargs['check_same_thread']:
warnings.warn(
'The `check_same_thread` option was provided and set to '
'True. It will be overriden with False. Use the '
'`DatabaseWrapper.allow_thread_sharing` property instead '
'for controlling thread shareability.',
RuntimeWarning
)
kwargs.update({'check_same_thread': False})
self.connection = Database.connect(**kwargs)
# Register extract, date_trunc, and regexp functions.
self.connection.create_function("django_extract", 2, _sqlite_extract)

View File

@ -1,4 +1,5 @@
import os
from threading import local
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
@ -50,7 +51,7 @@ class ConnectionDoesNotExist(Exception):
class ConnectionHandler(object):
def __init__(self, databases):
self.databases = databases
self._connections = {}
self._connections = local()
def ensure_defaults(self, alias):
"""
@ -73,16 +74,19 @@ class ConnectionHandler(object):
conn.setdefault(setting, None)
def __getitem__(self, alias):
if alias in self._connections:
return self._connections[alias]
if hasattr(self._connections, alias):
return getattr(self._connections, alias)
self.ensure_defaults(alias)
db = self.databases[alias]
backend = load_backend(db['ENGINE'])
conn = backend.DatabaseWrapper(db, alias)
self._connections[alias] = conn
setattr(self._connections, alias, conn)
return conn
def __setitem__(self, key, value):
setattr(self._connections, key, value)
def __iter__(self):
return iter(self.databases)

View File

@ -673,6 +673,32 @@ datetimes are now stored without time zone information in SQLite. When
:setting:`USE_TZ` is ``False``, if you attempt to save an aware datetime
object, Django raises an exception.
Database connection's thread-locality
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``DatabaseWrapper`` objects (i.e. the connection objects referenced by
``django.db.connection`` and ``django.db.connections["some_alias"]``) used to
be thread-local. They are now global objects in order to be potentially shared
between multiple threads. While the individual connection objects are now
global, the ``django.db.connections`` dictionary referencing those objects is
still thread-local. Therefore if you just use the ORM or
``DatabaseWrapper.cursor()`` then the behavior is still the same as before.
Note, however, that ``django.db.connection`` does not directly reference the
default ``DatabaseWrapper`` object any more and is now a proxy to access that
object's attributes. If you need to access the actual ``DatabaseWrapper``
object, use ``django.db.connections[DEFAULT_DB_ALIAS]`` instead.
As part of this change, all underlying SQLite connections are now enabled for
potential thread-sharing (by passing the ``check_same_thread=False`` attribute
to pysqlite). ``DatabaseWrapper`` however preserves the previous behavior by
disabling thread-sharing by default, so this does not affect any existing
code that purely relies on the ORM or on ``DatabaseWrapper.cursor()``.
Finally, while it is now possible to pass connections between threads, Django
does not make any effort to synchronize access to the underlying backend.
Concurrency behavior is defined by the underlying backend implementation.
Check their documentation for details.
`COMMENTS_BANNED_USERS_GROUP` setting
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -3,6 +3,7 @@
from __future__ import with_statement, absolute_import
import datetime
import threading
from django.conf import settings
from django.core.management.color import no_style
@ -283,7 +284,7 @@ class ConnectionCreatedSignalTest(TestCase):
connection_created.connect(receiver)
connection.close()
cursor = connection.cursor()
self.assertTrue(data["connection"] is connection)
self.assertTrue(data["connection"].connection is connection.connection)
connection_created.disconnect(receiver)
data.clear()
@ -446,3 +447,94 @@ class FkConstraintsTests(TransactionTestCase):
connection.check_constraints()
finally:
transaction.rollback()
class ThreadTests(TestCase):
def test_default_connection_thread_local(self):
"""
Ensure that the default connection (i.e. django.db.connection) is
different for each thread.
Refs #17258.
"""
connections_set = set()
connection.cursor()
connections_set.add(connection.connection)
def runner():
from django.db import connection
connection.cursor()
connections_set.add(connection.connection)
for x in xrange(2):
t = threading.Thread(target=runner)
t.start()
t.join()
self.assertEquals(len(connections_set), 3)
# Finish by closing the connections opened by the other threads (the
# connection opened in the main thread will automatically be closed on
# teardown).
for conn in connections_set:
if conn != connection.connection:
conn.close()
def test_connections_thread_local(self):
"""
Ensure that the connections are different for each thread.
Refs #17258.
"""
connections_set = set()
for conn in connections.all():
connections_set.add(conn)
def runner():
from django.db import connections
for conn in connections.all():
connections_set.add(conn)
for x in xrange(2):
t = threading.Thread(target=runner)
t.start()
t.join()
self.assertEquals(len(connections_set), 6)
# Finish by closing the connections opened by the other threads (the
# connection opened in the main thread will automatically be closed on
# teardown).
for conn in connections_set:
if conn != connection:
conn.close()
def test_pass_connection_between_threads(self):
"""
Ensure that a connection can be passed from one thread to the other.
Refs #17258.
"""
models.Person.objects.create(first_name="John", last_name="Doe")
def do_thread():
def runner(main_thread_connection):
from django.db import connections
connections['default'] = main_thread_connection
try:
models.Person.objects.get(first_name="John", last_name="Doe")
except DatabaseError, e:
exceptions.append(e)
t = threading.Thread(target=runner, args=[connections['default']])
t.start()
t.join()
# Without touching allow_thread_sharing, which should be False by default.
exceptions = []
do_thread()
# Forbidden!
self.assertTrue(isinstance(exceptions[0], DatabaseError))
# If explicitly setting allow_thread_sharing to False
connections['default'].allow_thread_sharing = False
exceptions = []
do_thread()
# Forbidden!
self.assertTrue(isinstance(exceptions[0], DatabaseError))
# If explicitly setting allow_thread_sharing to True
connections['default'].allow_thread_sharing = True
exceptions = []
do_thread()
# All good
self.assertEqual(len(exceptions), 0)