Ensured that thread-shareability gets validated when closing a PostgreSQL or SQLite connection. Refs #17258.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@17206 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
34e248efec
commit
a1d2f1f7b7
|
@ -130,7 +130,7 @@ class BaseDatabaseWrapper(object):
|
||||||
if (not self.allow_thread_sharing
|
if (not self.allow_thread_sharing
|
||||||
and self._thread_ident != thread.get_ident()):
|
and self._thread_ident != thread.get_ident()):
|
||||||
raise DatabaseError("DatabaseWrapper objects created in a "
|
raise DatabaseError("DatabaseWrapper objects created in a "
|
||||||
"thread can only be used in that same thread. The object"
|
"thread can only be used in that same thread. The object "
|
||||||
"with alias '%s' was created in thread id %s and this is "
|
"with alias '%s' was created in thread id %s and this is "
|
||||||
"thread id %s."
|
"thread id %s."
|
||||||
% (self.alias, self._thread_ident, thread.get_ident()))
|
% (self.alias, self._thread_ident, thread.get_ident()))
|
||||||
|
|
|
@ -129,6 +129,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
|
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
self.validate_thread_sharing()
|
||||||
if self.connection is None:
|
if self.connection is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -300,6 +300,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
referenced_table_name, referenced_column_name))
|
referenced_table_name, referenced_column_name))
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
self.validate_thread_sharing()
|
||||||
# If database is in memory, closing the connection destroys the
|
# If database is in memory, closing the connection destroys the
|
||||||
# database. To prevent accidental data loss, ignore close requests on
|
# database. To prevent accidental data loss, ignore close requests on
|
||||||
# an in-memory db.
|
# an in-memory db.
|
||||||
|
|
|
@ -487,6 +487,9 @@ class ThreadTests(TestCase):
|
||||||
def runner():
|
def runner():
|
||||||
from django.db import connections
|
from django.db import connections
|
||||||
for conn in connections.all():
|
for conn in connections.all():
|
||||||
|
# Allow thread sharing so the connection can be closed by the
|
||||||
|
# main thread.
|
||||||
|
conn.allow_thread_sharing = True
|
||||||
connections_set.add(conn)
|
connections_set.add(conn)
|
||||||
for x in xrange(2):
|
for x in xrange(2):
|
||||||
t = threading.Thread(target=runner)
|
t = threading.Thread(target=runner)
|
||||||
|
@ -537,4 +540,46 @@ class ThreadTests(TestCase):
|
||||||
exceptions = []
|
exceptions = []
|
||||||
do_thread()
|
do_thread()
|
||||||
# All good
|
# All good
|
||||||
|
self.assertEqual(len(exceptions), 0)
|
||||||
|
|
||||||
|
def test_closing_non_shared_connections(self):
|
||||||
|
"""
|
||||||
|
Ensure that a connection that is not explicitly shareable cannot be
|
||||||
|
closed by another thread.
|
||||||
|
Refs #17258.
|
||||||
|
"""
|
||||||
|
# First, without explicitly enabling the connection for sharing.
|
||||||
|
exceptions = set()
|
||||||
|
def runner1():
|
||||||
|
def runner2(other_thread_connection):
|
||||||
|
try:
|
||||||
|
other_thread_connection.close()
|
||||||
|
except DatabaseError, e:
|
||||||
|
exceptions.add(e)
|
||||||
|
t2 = threading.Thread(target=runner2, args=[connections['default']])
|
||||||
|
t2.start()
|
||||||
|
t2.join()
|
||||||
|
t1 = threading.Thread(target=runner1)
|
||||||
|
t1.start()
|
||||||
|
t1.join()
|
||||||
|
# The exception was raised
|
||||||
|
self.assertEqual(len(exceptions), 1)
|
||||||
|
|
||||||
|
# Then, with explicitly enabling the connection for sharing.
|
||||||
|
exceptions = set()
|
||||||
|
def runner1():
|
||||||
|
def runner2(other_thread_connection):
|
||||||
|
try:
|
||||||
|
other_thread_connection.close()
|
||||||
|
except DatabaseError, e:
|
||||||
|
exceptions.add(e)
|
||||||
|
# Enable thread sharing
|
||||||
|
connections['default'].allow_thread_sharing = True
|
||||||
|
t2 = threading.Thread(target=runner2, args=[connections['default']])
|
||||||
|
t2.start()
|
||||||
|
t2.join()
|
||||||
|
t1 = threading.Thread(target=runner1)
|
||||||
|
t1.start()
|
||||||
|
t1.join()
|
||||||
|
# No exception was raised
|
||||||
self.assertEqual(len(exceptions), 0)
|
self.assertEqual(len(exceptions), 0)
|
Loading…
Reference in New Issue