import threading from asgiref.sync import async_to_sync from django.contrib.sessions.middleware import SessionMiddleware from django.db import connection from django.http.request import HttpRequest from django.http.response import HttpResponse from django.middleware.cache import ( CacheMiddleware, FetchFromCacheMiddleware, UpdateCacheMiddleware, ) from django.middleware.common import CommonMiddleware from django.middleware.security import SecurityMiddleware from django.test import SimpleTestCase from django.utils.deprecation import MiddlewareMixin, RemovedInDjango40Warning class MiddlewareMixinTests(SimpleTestCase): """ Deprecation warning is raised when using get_response=None. """ msg = 'Passing None for the middleware get_response argument is deprecated.' def test_deprecation(self): with self.assertRaisesMessage(RemovedInDjango40Warning, self.msg): CommonMiddleware() def test_passing_explicit_none(self): with self.assertRaisesMessage(RemovedInDjango40Warning, self.msg): CommonMiddleware(None) def test_subclass_deprecation(self): """ Deprecation warning is raised in subclasses overriding __init__() without calling super(). """ for middleware in [ SessionMiddleware, CacheMiddleware, FetchFromCacheMiddleware, UpdateCacheMiddleware, SecurityMiddleware, ]: with self.subTest(middleware=middleware): with self.assertRaisesMessage(RemovedInDjango40Warning, self.msg): middleware() def test_sync_to_async_uses_base_thread_and_connection(self): """ The process_request() and process_response() hooks must be called with the sync_to_async thread_sensitive flag enabled, so that database operations use the correct thread and connection. """ def request_lifecycle(): """Fake request_started/request_finished.""" return (threading.get_ident(), id(connection)) async def get_response(self): return HttpResponse() class SimpleMiddleWare(MiddlewareMixin): def process_request(self, request): request.thread_and_connection = request_lifecycle() def process_response(self, request, response): response.thread_and_connection = request_lifecycle() return response threads_and_connections = [] threads_and_connections.append(request_lifecycle()) request = HttpRequest() response = async_to_sync(SimpleMiddleWare(get_response))(request) threads_and_connections.append(request.thread_and_connection) threads_and_connections.append(response.thread_and_connection) threads_and_connections.append(request_lifecycle()) self.assertEqual(len(threads_and_connections), 4) self.assertEqual(len(set(threads_and_connections)), 1)