Fixed #31905 -- Made MiddlewareMixin call process_request()/process_response() with thread sensitive.

Co-authored-by: Carlton Gibson <carlton.gibson@noumenal.es>
This commit is contained in:
Michael Galler 2020-08-20 18:41:22 +02:00 committed by Mariusz Felisiak
parent 0b0658111c
commit 547a07fa7e
3 changed files with 54 additions and 3 deletions

View File

@ -126,10 +126,16 @@ class MiddlewareMixin:
""" """
response = None response = None
if hasattr(self, 'process_request'): if hasattr(self, 'process_request'):
response = await sync_to_async(self.process_request)(request) response = await sync_to_async(
self.process_request,
thread_sensitive=True,
)(request)
response = response or await self.get_response(request) response = response or await self.get_response(request)
if hasattr(self, 'process_response'): if hasattr(self, 'process_response'):
response = await sync_to_async(self.process_response)(request, response) response = await sync_to_async(
self.process_response,
thread_sensitive=True,
)(request, response)
return response return response
def _get_response_none_deprecation(self, get_response): def _get_response_none_deprecation(self, get_response):

View File

@ -35,3 +35,7 @@ Bugfixes
* Reverted a deprecation in Django 3.1 that caused a crash when passing * Reverted a deprecation in Django 3.1 that caused a crash when passing
deprecated keyword arguments to a queryset in deprecated keyword arguments to a queryset in
``TemplateView.get_context_data()`` (:ticket:`31877`). ``TemplateView.get_context_data()`` (:ticket:`31877`).
* Enforced thread sensitivity of the :class:`MiddlewareMixin.process_request()
<django.utils.deprecation.MiddlewareMixin>` and ``process_response()`` hooks
when in an async context (:ticket:`31905`).

View File

@ -1,11 +1,18 @@
import threading
from asgiref.sync import async_to_sync
from django.contrib.sessions.middleware import SessionMiddleware from django.contrib.sessions.middleware import SessionMiddleware
from django.db import connection
from django.http.request import HttpRequest
from django.http.response import HttpResponse
from django.middleware.cache import ( from django.middleware.cache import (
CacheMiddleware, FetchFromCacheMiddleware, UpdateCacheMiddleware, CacheMiddleware, FetchFromCacheMiddleware, UpdateCacheMiddleware,
) )
from django.middleware.common import CommonMiddleware from django.middleware.common import CommonMiddleware
from django.middleware.security import SecurityMiddleware from django.middleware.security import SecurityMiddleware
from django.test import SimpleTestCase from django.test import SimpleTestCase
from django.utils.deprecation import RemovedInDjango40Warning from django.utils.deprecation import MiddlewareMixin, RemovedInDjango40Warning
class MiddlewareMixinTests(SimpleTestCase): class MiddlewareMixinTests(SimpleTestCase):
@ -37,3 +44,37 @@ class MiddlewareMixinTests(SimpleTestCase):
with self.subTest(middleware=middleware): with self.subTest(middleware=middleware):
with self.assertRaisesMessage(RemovedInDjango40Warning, self.msg): with self.assertRaisesMessage(RemovedInDjango40Warning, self.msg):
middleware() middleware()
def test_sync_to_async_uses_base_thread_and_connection(self):
"""
The process_request() and process_response() hooks must be called with
the sync_to_async thread_sensitive flag enabled, so that database
operations use the correct thread and connection.
"""
def request_lifecycle():
"""Fake request_started/request_finished."""
return (threading.get_ident(), id(connection))
async def get_response(self):
return HttpResponse()
class SimpleMiddleWare(MiddlewareMixin):
def process_request(self, request):
request.thread_and_connection = request_lifecycle()
def process_response(self, request, response):
response.thread_and_connection = request_lifecycle()
return response
threads_and_connections = []
threads_and_connections.append(request_lifecycle())
request = HttpRequest()
response = async_to_sync(SimpleMiddleWare(get_response))(request)
threads_and_connections.append(request.thread_and_connection)
threads_and_connections.append(response.thread_and_connection)
threads_and_connections.append(request_lifecycle())
self.assertEqual(len(threads_and_connections), 4)
self.assertEqual(len(set(threads_and_connections)), 1)