Fixed #30171 -- Fixed DatabaseError in servers tests.

Made DatabaseWrapper thread sharing logic reentrant. Used a reference
counting like scheme to allow nested uses.

The error appeared after 8c775391b7.
This commit is contained in:
Jon Dufresne 2019-02-14 07:04:55 -08:00 committed by Tim Graham
parent 21f9d43737
commit 76990cbbda
7 changed files with 100 additions and 66 deletions

View File

@ -1,4 +1,5 @@
import copy import copy
import threading
import time import time
import warnings import warnings
from collections import deque from collections import deque
@ -43,8 +44,7 @@ class BaseDatabaseWrapper:
queries_limit = 9000 queries_limit = 9000
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS, def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
allow_thread_sharing=False):
# Connection related attributes. # Connection related attributes.
# The underlying database connection. # The underlying database connection.
self.connection = None self.connection = None
@ -80,7 +80,8 @@ class BaseDatabaseWrapper:
self.errors_occurred = False self.errors_occurred = False
# Thread-safety related attributes. # Thread-safety related attributes.
self.allow_thread_sharing = allow_thread_sharing self._thread_sharing_lock = threading.Lock()
self._thread_sharing_count = 0
self._thread_ident = _thread.get_ident() self._thread_ident = _thread.get_ident()
# A list of no-argument functions to run when the transaction commits. # A list of no-argument functions to run when the transaction commits.
@ -515,12 +516,27 @@ class BaseDatabaseWrapper:
# ##### Thread safety handling ##### # ##### Thread safety handling #####
@property
def allow_thread_sharing(self):
with self._thread_sharing_lock:
return self._thread_sharing_count > 0
def inc_thread_sharing(self):
with self._thread_sharing_lock:
self._thread_sharing_count += 1
def dec_thread_sharing(self):
with self._thread_sharing_lock:
if self._thread_sharing_count <= 0:
raise RuntimeError('Cannot decrement the thread sharing count below zero.')
self._thread_sharing_count -= 1
def validate_thread_sharing(self): def validate_thread_sharing(self):
""" """
Validate that the connection isn't accessed by another thread than the Validate that the connection isn't accessed by another thread than the
one which originally created it, unless the connection was explicitly one which originally created it, unless the connection was explicitly
authorized to be shared between threads (via the `allow_thread_sharing` authorized to be shared between threads (via the `inc_thread_sharing()`
property). Raise an exception if the validation fails. method). Raise an exception if the validation fails.
""" """
if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()): if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
raise DatabaseError( raise DatabaseError(
@ -589,11 +605,7 @@ class BaseDatabaseWrapper:
potential child threads while (or after) the test database is destroyed. potential child threads while (or after) the test database is destroyed.
Refs #10868, #17786, #16969. Refs #10868, #17786, #16969.
""" """
return self.__class__( return self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS)
{**self.settings_dict, 'NAME': None},
alias=NO_DB_ALIAS,
allow_thread_sharing=False,
)
def schema_editor(self, *args, **kwargs): def schema_editor(self, *args, **kwargs):
""" """
@ -635,7 +647,7 @@ class BaseDatabaseWrapper:
finally: finally:
self.execute_wrappers.pop() self.execute_wrappers.pop()
def copy(self, alias=None, allow_thread_sharing=None): def copy(self, alias=None):
""" """
Return a copy of this connection. Return a copy of this connection.
@ -644,6 +656,4 @@ class BaseDatabaseWrapper:
settings_dict = copy.deepcopy(self.settings_dict) settings_dict = copy.deepcopy(self.settings_dict)
if alias is None: if alias is None:
alias = self.alias alias = self.alias
if allow_thread_sharing is None: return type(self)(settings_dict, alias)
allow_thread_sharing = self.allow_thread_sharing
return type(self)(settings_dict, alias, allow_thread_sharing)

View File

@ -277,7 +277,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return self.__class__( return self.__class__(
{**self.settings_dict, 'NAME': connection.settings_dict['NAME']}, {**self.settings_dict, 'NAME': connection.settings_dict['NAME']},
alias=self.alias, alias=self.alias,
allow_thread_sharing=False,
) )
return nodb_connection return nodb_connection

View File

@ -1442,7 +1442,7 @@ class LiveServerTestCase(TransactionTestCase):
# the server thread. # the server thread.
if conn.vendor == 'sqlite' and conn.is_in_memory_db(): if conn.vendor == 'sqlite' and conn.is_in_memory_db():
# Explicitly enable thread-shareability for this connection # Explicitly enable thread-shareability for this connection
conn.allow_thread_sharing = True conn.inc_thread_sharing()
connections_override[conn.alias] = conn connections_override[conn.alias] = conn
cls._live_server_modified_settings = modify_settings( cls._live_server_modified_settings = modify_settings(
@ -1478,10 +1478,9 @@ class LiveServerTestCase(TransactionTestCase):
# Terminate the live server's thread # Terminate the live server's thread
cls.server_thread.terminate() cls.server_thread.terminate()
# Restore sqlite in-memory database connections' non-shareability # Restore sqlite in-memory database connections' non-shareability.
for conn in connections.all(): for conn in cls.server_thread.connections_override.values():
if conn.vendor == 'sqlite' and conn.is_in_memory_db(): conn.dec_thread_sharing()
conn.allow_thread_sharing = False
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):

View File

@ -286,6 +286,9 @@ backends.
* ``_delete_fk_sql()`` (to pair with ``_create_fk_sql()``) * ``_delete_fk_sql()`` (to pair with ``_create_fk_sql()``)
* ``_create_check_sql()`` and ``_delete_check_sql()`` * ``_create_check_sql()`` and ``_delete_check_sql()``
* The third argument of ``DatabaseWrapper.__init__()``,
``allow_thread_sharing``, is removed.
Admin actions are no longer collected from base ``ModelAdmin`` classes Admin actions are no longer collected from base ``ModelAdmin`` classes
---------------------------------------------------------------------- ----------------------------------------------------------------------

View File

@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase):
connection = connections[DEFAULT_DB_ALIAS] connection = connections[DEFAULT_DB_ALIAS]
# Allow thread sharing so the connection can be closed by the # Allow thread sharing so the connection can be closed by the
# main thread. # main thread.
connection.allow_thread_sharing = True connection.inc_thread_sharing()
connection.cursor() connection.cursor()
connections_dict[id(connection)] = connection connections_dict[id(connection)] = connection
try:
for x in range(2): for x in range(2):
t = threading.Thread(target=runner) t = threading.Thread(target=runner)
t.start() t.start()
t.join() t.join()
# Each created connection got different inner connection. # Each created connection got different inner connection.
self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3) self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
# Finish by closing the connections opened by the other threads (the finally:
# connection opened in the main thread will automatically be closed on # Finish by closing the connections opened by the other threads
# teardown). # (the connection opened in the main thread will automatically be
# closed on teardown).
for conn in connections_dict.values(): for conn in connections_dict.values():
if conn is not connection: if conn is not connection:
if conn.allow_thread_sharing:
conn.close() conn.close()
conn.dec_thread_sharing()
def test_connections_thread_local(self): def test_connections_thread_local(self):
""" """
@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase):
for conn in connections.all(): for conn in connections.all():
# Allow thread sharing so the connection can be closed by the # Allow thread sharing so the connection can be closed by the
# main thread. # main thread.
conn.allow_thread_sharing = True conn.inc_thread_sharing()
connections_dict[id(conn)] = conn connections_dict[id(conn)] = conn
try:
for x in range(2): for x in range(2):
t = threading.Thread(target=runner) t = threading.Thread(target=runner)
t.start() t.start()
t.join() t.join()
self.assertEqual(len(connections_dict), 6) self.assertEqual(len(connections_dict), 6)
# Finish by closing the connections opened by the other threads (the finally:
# connection opened in the main thread will automatically be closed on # Finish by closing the connections opened by the other threads
# teardown). # (the connection opened in the main thread will automatically be
# closed on teardown).
for conn in connections_dict.values(): for conn in connections_dict.values():
if conn is not connection: if conn is not connection:
if conn.allow_thread_sharing:
conn.close() conn.close()
conn.dec_thread_sharing()
def test_pass_connection_between_threads(self): def test_pass_connection_between_threads(self):
""" """
@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase):
t.start() t.start()
t.join() t.join()
# Without touching allow_thread_sharing, which should be False by default. # Without touching thread sharing, which should be False by default.
exceptions = [] exceptions = []
do_thread() do_thread()
# Forbidden! # Forbidden!
self.assertIsInstance(exceptions[0], DatabaseError) self.assertIsInstance(exceptions[0], DatabaseError)
# If explicitly setting allow_thread_sharing to False # After calling inc_thread_sharing() on the connection.
connections['default'].allow_thread_sharing = False connections['default'].inc_thread_sharing()
exceptions = [] try:
do_thread()
# Forbidden!
self.assertIsInstance(exceptions[0], DatabaseError)
# If explicitly setting allow_thread_sharing to True
connections['default'].allow_thread_sharing = True
exceptions = [] exceptions = []
do_thread() do_thread()
# All good # All good
self.assertEqual(exceptions, []) self.assertEqual(exceptions, [])
finally:
connections['default'].dec_thread_sharing()
def test_closing_non_shared_connections(self): def test_closing_non_shared_connections(self):
""" """
@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase):
except DatabaseError as e: except DatabaseError as e:
exceptions.add(e) exceptions.add(e)
# Enable thread sharing # Enable thread sharing
connections['default'].allow_thread_sharing = True connections['default'].inc_thread_sharing()
try:
t2 = threading.Thread(target=runner2, args=[connections['default']]) t2 = threading.Thread(target=runner2, args=[connections['default']])
t2.start() t2.start()
t2.join() t2.join()
finally:
connections['default'].dec_thread_sharing()
t1 = threading.Thread(target=runner1) t1 = threading.Thread(target=runner1)
t1.start() t1.start()
t1.join() t1.join()
# No exception was raised # No exception was raised
self.assertEqual(len(exceptions), 0) self.assertEqual(len(exceptions), 0)
def test_thread_sharing_count(self):
self.assertIs(connection.allow_thread_sharing, False)
connection.inc_thread_sharing()
self.assertIs(connection.allow_thread_sharing, True)
connection.inc_thread_sharing()
self.assertIs(connection.allow_thread_sharing, True)
connection.dec_thread_sharing()
self.assertIs(connection.allow_thread_sharing, True)
connection.dec_thread_sharing()
self.assertIs(connection.allow_thread_sharing, False)
msg = 'Cannot decrement the thread sharing count below zero.'
with self.assertRaisesMessage(RuntimeError, msg):
connection.dec_thread_sharing()
class MySQLPKZeroTests(TestCase): class MySQLPKZeroTests(TestCase):
""" """

View File

@ -18,11 +18,10 @@ class LiveServerThreadTest(TestCase):
# Pass a connection to the thread to check they are being closed. # Pass a connection to the thread to check they are being closed.
connections_override = {DEFAULT_DB_ALIAS: conn} connections_override = {DEFAULT_DB_ALIAS: conn}
saved_sharing = conn.allow_thread_sharing conn.inc_thread_sharing()
try: try:
conn.allow_thread_sharing = True
self.assertTrue(conn.is_usable()) self.assertTrue(conn.is_usable())
self.run_live_server_thread(connections_override) self.run_live_server_thread(connections_override)
self.assertFalse(conn.is_usable()) self.assertFalse(conn.is_usable())
finally: finally:
conn.allow_thread_sharing = saved_sharing conn.dec_thread_sharing()

View File

@ -64,6 +64,9 @@ class StaticLiveServerChecks(LiveServerBase):
# app without having set the required STATIC_URL setting.") # app without having set the required STATIC_URL setting.")
pass pass
finally: finally:
# Use del to avoid decrementing the database thread sharing count a
# second time.
del cls.server_thread
super().tearDownClass() super().tearDownClass()
def test_test_test(self): def test_test_test(self):