Ensured that thread-shareability gets validated when closing a PostgreSQL or SQLite connection. Refs #17258.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@17206 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Julien Phalip 2011-12-16 17:02:41 +00:00
parent 34e248efec
commit a1d2f1f7b7
4 changed files with 48 additions and 1 deletions

View File

@ -130,7 +130,7 @@ class BaseDatabaseWrapper(object):
if (not self.allow_thread_sharing if (not self.allow_thread_sharing
and self._thread_ident != thread.get_ident()): and self._thread_ident != thread.get_ident()):
raise DatabaseError("DatabaseWrapper objects created in a " raise DatabaseError("DatabaseWrapper objects created in a "
"thread can only be used in that same thread. The object" "thread can only be used in that same thread. The object "
"with alias '%s' was created in thread id %s and this is " "with alias '%s' was created in thread id %s and this is "
"thread id %s." "thread id %s."
% (self.alias, self._thread_ident, thread.get_ident())) % (self.alias, self._thread_ident, thread.get_ident()))

View File

@ -129,6 +129,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def close(self): def close(self):
self.validate_thread_sharing()
if self.connection is None: if self.connection is None:
return return

View File

@ -300,6 +300,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
referenced_table_name, referenced_column_name)) referenced_table_name, referenced_column_name))
def close(self): def close(self):
self.validate_thread_sharing()
# If database is in memory, closing the connection destroys the # If database is in memory, closing the connection destroys the
# database. To prevent accidental data loss, ignore close requests on # database. To prevent accidental data loss, ignore close requests on
# an in-memory db. # an in-memory db.

View File

@ -487,6 +487,9 @@ class ThreadTests(TestCase):
def runner(): def runner():
from django.db import connections from django.db import connections
for conn in connections.all(): for conn in connections.all():
# Allow thread sharing so the connection can be closed by the
# main thread.
conn.allow_thread_sharing = True
connections_set.add(conn) connections_set.add(conn)
for x in xrange(2): for x in xrange(2):
t = threading.Thread(target=runner) t = threading.Thread(target=runner)
@ -537,4 +540,46 @@ class ThreadTests(TestCase):
exceptions = [] exceptions = []
do_thread() do_thread()
# All good # All good
self.assertEqual(len(exceptions), 0)
def test_closing_non_shared_connections(self):
"""
Ensure that a connection that is not explicitly shareable cannot be
closed by another thread.
Refs #17258.
"""
# First, without explicitly enabling the connection for sharing.
exceptions = set()
def runner1():
def runner2(other_thread_connection):
try:
other_thread_connection.close()
except DatabaseError, e:
exceptions.add(e)
t2 = threading.Thread(target=runner2, args=[connections['default']])
t2.start()
t2.join()
t1 = threading.Thread(target=runner1)
t1.start()
t1.join()
# The exception was raised
self.assertEqual(len(exceptions), 1)
# Then, with explicitly enabling the connection for sharing.
exceptions = set()
def runner1():
def runner2(other_thread_connection):
try:
other_thread_connection.close()
except DatabaseError, e:
exceptions.add(e)
# Enable thread sharing
connections['default'].allow_thread_sharing = True
t2 = threading.Thread(target=runner2, args=[connections['default']])
t2.start()
t2.join()
t1 = threading.Thread(target=runner1)
t1.start()
t1.join()
# No exception was raised
self.assertEqual(len(exceptions), 0) self.assertEqual(len(exceptions), 0)