diff --git a/django/core/handlers/asgi.py b/django/core/handlers/asgi.py index 7fbabe45104..2b8cc8b76e9 100644 --- a/django/core/handlers/asgi.py +++ b/django/core/handlers/asgi.py @@ -3,7 +3,7 @@ import sys import tempfile import traceback -from asgiref.sync import sync_to_async +from asgiref.sync import ThreadSensitiveContext, sync_to_async from django.conf import settings from django.core import signals @@ -144,6 +144,14 @@ class ASGIHandler(base.BaseHandler): 'Django can only handle ASGI/HTTP connections, not %s.' % scope['type'] ) + + async with ThreadSensitiveContext(): + await self.handle(scope, receive, send) + + async def handle(self, scope, receive, send): + """ + Handles the ASGI request. Called via the __call__ method. + """ # Receive the HTTP request body as a stream object. try: body_file = await self.read_body(receive) diff --git a/tests/asgi/tests.py b/tests/asgi/tests.py index 3509bb0aa7e..7eb35724dfc 100644 --- a/tests/asgi/tests.py +++ b/tests/asgi/tests.py @@ -4,7 +4,6 @@ import threading from pathlib import Path from unittest import skipIf -from asgiref.sync import SyncToAsync from asgiref.testing import ApplicationCommunicator from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler @@ -16,7 +15,7 @@ from django.test import ( ) from django.utils.http import http_date -from .urls import test_filename +from .urls import sync_waiter, test_filename TEST_STATIC_ROOT = Path(__file__).parent / 'project' / 'static' @@ -235,11 +234,39 @@ class ASGITest(SimpleTestCase): # Give response.close() time to finish. await communicator.wait() - # At this point, AsyncToSync does not have a current executor. Thus - # SyncToAsync falls-back to .single_thread_executor. - target_thread = next(iter(SyncToAsync.single_thread_executor._threads)) + # AsyncToSync should have executed the signals in the same thread. request_started_thread, request_finished_thread = signal_handler.threads - self.assertEqual(request_started_thread, target_thread) - self.assertEqual(request_finished_thread, target_thread) + self.assertEqual(request_started_thread, request_finished_thread) request_started.disconnect(signal_handler) request_finished.disconnect(signal_handler) + + async def test_concurrent_async_uses_multiple_thread_pools(self): + sync_waiter.active_threads.clear() + + # Send 2 requests concurrently + application = get_asgi_application() + scope = self.async_request_factory._base_scope(path='/wait/') + communicators = [] + for _ in range(2): + communicators.append(ApplicationCommunicator(application, scope)) + await communicators[-1].send_input({'type': 'http.request'}) + + # Each request must complete with a status code of 200 + # If requests aren't scheduled concurrently, the barrier in the + # sync_wait view will time out, resulting in a 500 status code. + for communicator in communicators: + response_start = await communicator.receive_output() + self.assertEqual(response_start['type'], 'http.response.start') + self.assertEqual(response_start['status'], 200) + response_body = await communicator.receive_output() + self.assertEqual(response_body['type'], 'http.response.body') + self.assertEqual(response_body['body'], b'Hello World!') + # Give response.close() time to finish. + await communicator.wait() + + # The requests should have scheduled on different threads. Note + # active_threads is a set (a thread can only appear once), therefore + # length is a sufficient check. + self.assertEqual(len(sync_waiter.active_threads), 2) + + sync_waiter.active_threads.clear() diff --git a/tests/asgi/urls.py b/tests/asgi/urls.py index ff8d21ea7cd..22d85604d14 100644 --- a/tests/asgi/urls.py +++ b/tests/asgi/urls.py @@ -1,3 +1,5 @@ +import threading + from django.http import FileResponse, HttpResponse from django.urls import path @@ -14,6 +16,18 @@ def hello_meta(request): ) +def sync_waiter(request): + with sync_waiter.lock: + sync_waiter.active_threads.add(threading.current_thread()) + sync_waiter.barrier.wait(timeout=0.5) + return hello(request) + + +sync_waiter.active_threads = set() +sync_waiter.lock = threading.Lock() +sync_waiter.barrier = threading.Barrier(2) + + test_filename = __file__ @@ -21,4 +35,5 @@ urlpatterns = [ path('', hello), path('file/', lambda x: FileResponse(open(test_filename, 'rb'))), path('meta/', hello_meta), + path('wait/', sync_waiter), ]