From b6d2419120598c6643e16a9b309258a6e9c2a736 Mon Sep 17 00:00:00 2001 From: Michael Galler Date: Thu, 20 Aug 2020 18:41:22 +0200 Subject: [PATCH] [3.1.x] Fixed #31905 -- Made MiddlewareMixin call process_request()/process_response() with thread sensitive. Co-authored-by: Carlton Gibson Backport of 547a07fa7ec4364ea9ecd2aabfdd16ee4c63003c from master --- django/utils/deprecation.py | 10 ++++- docs/releases/3.1.1.txt | 4 ++ tests/deprecation/test_middleware_mixin.py | 43 +++++++++++++++++++++- 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/django/utils/deprecation.py b/django/utils/deprecation.py index 6336558a817..36943d3ae8a 100644 --- a/django/utils/deprecation.py +++ b/django/utils/deprecation.py @@ -123,10 +123,16 @@ class MiddlewareMixin: """ response = None if hasattr(self, 'process_request'): - response = await sync_to_async(self.process_request)(request) + response = await sync_to_async( + self.process_request, + thread_sensitive=True, + )(request) response = response or await self.get_response(request) if hasattr(self, 'process_response'): - response = await sync_to_async(self.process_response)(request, response) + response = await sync_to_async( + self.process_response, + thread_sensitive=True, + )(request, response) return response def _get_response_none_deprecation(self, get_response): diff --git a/docs/releases/3.1.1.txt b/docs/releases/3.1.1.txt index fd0843c2ae2..6359f181fa3 100644 --- a/docs/releases/3.1.1.txt +++ b/docs/releases/3.1.1.txt @@ -35,3 +35,7 @@ Bugfixes * Reverted a deprecation in Django 3.1 that caused a crash when passing deprecated keyword arguments to a queryset in ``TemplateView.get_context_data()`` (:ticket:`31877`). + +* Enforced thread sensitivity of the :class:`MiddlewareMixin.process_request() + ` and ``process_response()`` hooks + when in an async context (:ticket:`31905`). diff --git a/tests/deprecation/test_middleware_mixin.py b/tests/deprecation/test_middleware_mixin.py index f03d9168ec6..c90aeb83609 100644 --- a/tests/deprecation/test_middleware_mixin.py +++ b/tests/deprecation/test_middleware_mixin.py @@ -1,11 +1,18 @@ +import threading + +from asgiref.sync import async_to_sync + from django.contrib.sessions.middleware import SessionMiddleware +from django.db import connection +from django.http.request import HttpRequest +from django.http.response import HttpResponse from django.middleware.cache import ( CacheMiddleware, FetchFromCacheMiddleware, UpdateCacheMiddleware, ) from django.middleware.common import CommonMiddleware from django.middleware.security import SecurityMiddleware from django.test import SimpleTestCase -from django.utils.deprecation import RemovedInDjango40Warning +from django.utils.deprecation import MiddlewareMixin, RemovedInDjango40Warning class MiddlewareMixinTests(SimpleTestCase): @@ -37,3 +44,37 @@ class MiddlewareMixinTests(SimpleTestCase): with self.subTest(middleware=middleware): with self.assertRaisesMessage(RemovedInDjango40Warning, self.msg): middleware() + + def test_sync_to_async_uses_base_thread_and_connection(self): + """ + The process_request() and process_response() hooks must be called with + the sync_to_async thread_sensitive flag enabled, so that database + operations use the correct thread and connection. + """ + def request_lifecycle(): + """Fake request_started/request_finished.""" + return (threading.get_ident(), id(connection)) + + async def get_response(self): + return HttpResponse() + + class SimpleMiddleWare(MiddlewareMixin): + def process_request(self, request): + request.thread_and_connection = request_lifecycle() + + def process_response(self, request, response): + response.thread_and_connection = request_lifecycle() + return response + + threads_and_connections = [] + threads_and_connections.append(request_lifecycle()) + + request = HttpRequest() + response = async_to_sync(SimpleMiddleWare(get_response))(request) + threads_and_connections.append(request.thread_and_connection) + threads_and_connections.append(response.thread_and_connection) + + threads_and_connections.append(request_lifecycle()) + + self.assertEqual(len(threads_and_connections), 4) + self.assertEqual(len(set(threads_and_connections)), 1)