diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py index f97d171c96..9fa03cc0ee 100644 --- a/django/db/backends/base/base.py +++ b/django/db/backends/base/base.py @@ -1,4 +1,5 @@ import copy +import threading import time import warnings from collections import deque @@ -43,8 +44,7 @@ class BaseDatabaseWrapper: queries_limit = 9000 - def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS, - allow_thread_sharing=False): + def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): # Connection related attributes. # The underlying database connection. self.connection = None @@ -80,7 +80,8 @@ class BaseDatabaseWrapper: self.errors_occurred = False # Thread-safety related attributes. - self.allow_thread_sharing = allow_thread_sharing + self._thread_sharing_lock = threading.Lock() + self._thread_sharing_count = 0 self._thread_ident = _thread.get_ident() # A list of no-argument functions to run when the transaction commits. @@ -515,12 +516,27 @@ class BaseDatabaseWrapper: # ##### Thread safety handling ##### + @property + def allow_thread_sharing(self): + with self._thread_sharing_lock: + return self._thread_sharing_count > 0 + + def inc_thread_sharing(self): + with self._thread_sharing_lock: + self._thread_sharing_count += 1 + + def dec_thread_sharing(self): + with self._thread_sharing_lock: + if self._thread_sharing_count <= 0: + raise RuntimeError('Cannot decrement the thread sharing count below zero.') + self._thread_sharing_count -= 1 + def validate_thread_sharing(self): """ Validate that the connection isn't accessed by another thread than the one which originally created it, unless the connection was explicitly - authorized to be shared between threads (via the `allow_thread_sharing` - property). Raise an exception if the validation fails. + authorized to be shared between threads (via the `inc_thread_sharing()` + method). Raise an exception if the validation fails. """ if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()): raise DatabaseError( @@ -589,11 +605,7 @@ class BaseDatabaseWrapper: potential child threads while (or after) the test database is destroyed. Refs #10868, #17786, #16969. """ - return self.__class__( - {**self.settings_dict, 'NAME': None}, - alias=NO_DB_ALIAS, - allow_thread_sharing=False, - ) + return self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS) def schema_editor(self, *args, **kwargs): """ @@ -635,7 +647,7 @@ class BaseDatabaseWrapper: finally: self.execute_wrappers.pop() - def copy(self, alias=None, allow_thread_sharing=None): + def copy(self, alias=None): """ Return a copy of this connection. @@ -644,6 +656,4 @@ class BaseDatabaseWrapper: settings_dict = copy.deepcopy(self.settings_dict) if alias is None: alias = self.alias - if allow_thread_sharing is None: - allow_thread_sharing = self.allow_thread_sharing - return type(self)(settings_dict, alias, allow_thread_sharing) + return type(self)(settings_dict, alias) diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 4be5a193bb..b3a3202c90 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -277,7 +277,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): return self.__class__( {**self.settings_dict, 'NAME': connection.settings_dict['NAME']}, alias=self.alias, - allow_thread_sharing=False, ) return nodb_connection diff --git a/django/test/testcases.py b/django/test/testcases.py index 991165c04d..dea7fedbcc 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -1442,7 +1442,7 @@ class LiveServerTestCase(TransactionTestCase): # the server thread. if conn.vendor == 'sqlite' and conn.is_in_memory_db(): # Explicitly enable thread-shareability for this connection - conn.allow_thread_sharing = True + conn.inc_thread_sharing() connections_override[conn.alias] = conn cls._live_server_modified_settings = modify_settings( @@ -1478,10 +1478,9 @@ class LiveServerTestCase(TransactionTestCase): # Terminate the live server's thread cls.server_thread.terminate() - # Restore sqlite in-memory database connections' non-shareability - for conn in connections.all(): - if conn.vendor == 'sqlite' and conn.is_in_memory_db(): - conn.allow_thread_sharing = False + # Restore sqlite in-memory database connections' non-shareability. + for conn in cls.server_thread.connections_override.values(): + conn.dec_thread_sharing() @classmethod def tearDownClass(cls): diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt index 1fe5af93fe..ec6639280a 100644 --- a/docs/releases/2.2.txt +++ b/docs/releases/2.2.txt @@ -286,6 +286,9 @@ backends. * ``_delete_fk_sql()`` (to pair with ``_create_fk_sql()``) * ``_create_check_sql()`` and ``_delete_check_sql()`` +* The third argument of ``DatabaseWrapper.__init__()``, + ``allow_thread_sharing``, is removed. + Admin actions are no longer collected from base ``ModelAdmin`` classes ---------------------------------------------------------------------- diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 7e4e665758..6138a3626c 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase): connection = connections[DEFAULT_DB_ALIAS] # Allow thread sharing so the connection can be closed by the # main thread. - connection.allow_thread_sharing = True + connection.inc_thread_sharing() connection.cursor() connections_dict[id(connection)] = connection - for x in range(2): - t = threading.Thread(target=runner) - t.start() - t.join() - # Each created connection got different inner connection. - self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3) - # Finish by closing the connections opened by the other threads (the - # connection opened in the main thread will automatically be closed on - # teardown). - for conn in connections_dict.values(): - if conn is not connection: - conn.close() + try: + for x in range(2): + t = threading.Thread(target=runner) + t.start() + t.join() + # Each created connection got different inner connection. + self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3) + finally: + # Finish by closing the connections opened by the other threads + # (the connection opened in the main thread will automatically be + # closed on teardown). + for conn in connections_dict.values(): + if conn is not connection: + if conn.allow_thread_sharing: + conn.close() + conn.dec_thread_sharing() def test_connections_thread_local(self): """ @@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase): for conn in connections.all(): # Allow thread sharing so the connection can be closed by the # main thread. - conn.allow_thread_sharing = True + conn.inc_thread_sharing() connections_dict[id(conn)] = conn - for x in range(2): - t = threading.Thread(target=runner) - t.start() - t.join() - self.assertEqual(len(connections_dict), 6) - # Finish by closing the connections opened by the other threads (the - # connection opened in the main thread will automatically be closed on - # teardown). - for conn in connections_dict.values(): - if conn is not connection: - conn.close() + try: + for x in range(2): + t = threading.Thread(target=runner) + t.start() + t.join() + self.assertEqual(len(connections_dict), 6) + finally: + # Finish by closing the connections opened by the other threads + # (the connection opened in the main thread will automatically be + # closed on teardown). + for conn in connections_dict.values(): + if conn is not connection: + if conn.allow_thread_sharing: + conn.close() + conn.dec_thread_sharing() def test_pass_connection_between_threads(self): """ @@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase): t.start() t.join() - # Without touching allow_thread_sharing, which should be False by default. + # Without touching thread sharing, which should be False by default. exceptions = [] do_thread() # Forbidden! self.assertIsInstance(exceptions[0], DatabaseError) - # If explicitly setting allow_thread_sharing to False - connections['default'].allow_thread_sharing = False - exceptions = [] - do_thread() - # Forbidden! - self.assertIsInstance(exceptions[0], DatabaseError) - - # If explicitly setting allow_thread_sharing to True - connections['default'].allow_thread_sharing = True - exceptions = [] - do_thread() - # All good - self.assertEqual(exceptions, []) + # After calling inc_thread_sharing() on the connection. + connections['default'].inc_thread_sharing() + try: + exceptions = [] + do_thread() + # All good + self.assertEqual(exceptions, []) + finally: + connections['default'].dec_thread_sharing() def test_closing_non_shared_connections(self): """ @@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase): except DatabaseError as e: exceptions.add(e) # Enable thread sharing - connections['default'].allow_thread_sharing = True - t2 = threading.Thread(target=runner2, args=[connections['default']]) - t2.start() - t2.join() + connections['default'].inc_thread_sharing() + try: + t2 = threading.Thread(target=runner2, args=[connections['default']]) + t2.start() + t2.join() + finally: + connections['default'].dec_thread_sharing() t1 = threading.Thread(target=runner1) t1.start() t1.join() # No exception was raised self.assertEqual(len(exceptions), 0) + def test_thread_sharing_count(self): + self.assertIs(connection.allow_thread_sharing, False) + connection.inc_thread_sharing() + self.assertIs(connection.allow_thread_sharing, True) + connection.inc_thread_sharing() + self.assertIs(connection.allow_thread_sharing, True) + connection.dec_thread_sharing() + self.assertIs(connection.allow_thread_sharing, True) + connection.dec_thread_sharing() + self.assertIs(connection.allow_thread_sharing, False) + msg = 'Cannot decrement the thread sharing count below zero.' + with self.assertRaisesMessage(RuntimeError, msg): + connection.dec_thread_sharing() + class MySQLPKZeroTests(TestCase): """ diff --git a/tests/servers/test_liveserverthread.py b/tests/servers/test_liveserverthread.py index d39aac8183..9762b53791 100644 --- a/tests/servers/test_liveserverthread.py +++ b/tests/servers/test_liveserverthread.py @@ -18,11 +18,10 @@ class LiveServerThreadTest(TestCase): # Pass a connection to the thread to check they are being closed. connections_override = {DEFAULT_DB_ALIAS: conn} - saved_sharing = conn.allow_thread_sharing + conn.inc_thread_sharing() try: - conn.allow_thread_sharing = True self.assertTrue(conn.is_usable()) self.run_live_server_thread(connections_override) self.assertFalse(conn.is_usable()) finally: - conn.allow_thread_sharing = saved_sharing + conn.dec_thread_sharing() diff --git a/tests/staticfiles_tests/test_liveserver.py b/tests/staticfiles_tests/test_liveserver.py index 264242bbae..820fa5bc89 100644 --- a/tests/staticfiles_tests/test_liveserver.py +++ b/tests/staticfiles_tests/test_liveserver.py @@ -64,6 +64,9 @@ class StaticLiveServerChecks(LiveServerBase): # app without having set the required STATIC_URL setting.") pass finally: + # Use del to avoid decrementing the database thread sharing count a + # second time. + del cls.server_thread super().tearDownClass() def test_test_test(self):