2020-08-21 00:41:22 +08:00
|
|
|
import threading
|
|
|
|
|
|
|
|
from asgiref.sync import async_to_sync
|
|
|
|
|
2019-09-27 01:06:35 +08:00
|
|
|
from django.contrib.sessions.middleware import SessionMiddleware
|
2020-08-21 00:41:22 +08:00
|
|
|
from django.db import connection
|
|
|
|
from django.http.request import HttpRequest
|
|
|
|
from django.http.response import HttpResponse
|
2019-09-27 01:06:35 +08:00
|
|
|
from django.middleware.cache import (
|
|
|
|
CacheMiddleware, FetchFromCacheMiddleware, UpdateCacheMiddleware,
|
|
|
|
)
|
|
|
|
from django.middleware.common import CommonMiddleware
|
|
|
|
from django.middleware.security import SecurityMiddleware
|
|
|
|
from django.test import SimpleTestCase
|
2020-08-21 00:41:22 +08:00
|
|
|
from django.utils.deprecation import MiddlewareMixin, RemovedInDjango40Warning
|
2019-09-27 01:06:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
class MiddlewareMixinTests(SimpleTestCase):
|
|
|
|
"""
|
|
|
|
Deprecation warning is raised when using get_response=None.
|
|
|
|
"""
|
|
|
|
msg = 'Passing None for the middleware get_response argument is deprecated.'
|
|
|
|
|
|
|
|
def test_deprecation(self):
|
|
|
|
with self.assertRaisesMessage(RemovedInDjango40Warning, self.msg):
|
|
|
|
CommonMiddleware()
|
|
|
|
|
|
|
|
def test_passing_explicit_none(self):
|
|
|
|
with self.assertRaisesMessage(RemovedInDjango40Warning, self.msg):
|
|
|
|
CommonMiddleware(None)
|
|
|
|
|
|
|
|
def test_subclass_deprecation(self):
|
|
|
|
"""
|
|
|
|
Deprecation warning is raised in subclasses overriding __init__()
|
|
|
|
without calling super().
|
|
|
|
"""
|
|
|
|
for middleware in [
|
|
|
|
SessionMiddleware,
|
|
|
|
CacheMiddleware,
|
|
|
|
FetchFromCacheMiddleware,
|
|
|
|
UpdateCacheMiddleware,
|
|
|
|
SecurityMiddleware,
|
|
|
|
]:
|
|
|
|
with self.subTest(middleware=middleware):
|
|
|
|
with self.assertRaisesMessage(RemovedInDjango40Warning, self.msg):
|
|
|
|
middleware()
|
2020-08-21 00:41:22 +08:00
|
|
|
|
|
|
|
def test_sync_to_async_uses_base_thread_and_connection(self):
|
|
|
|
"""
|
|
|
|
The process_request() and process_response() hooks must be called with
|
|
|
|
the sync_to_async thread_sensitive flag enabled, so that database
|
|
|
|
operations use the correct thread and connection.
|
|
|
|
"""
|
|
|
|
def request_lifecycle():
|
|
|
|
"""Fake request_started/request_finished."""
|
|
|
|
return (threading.get_ident(), id(connection))
|
|
|
|
|
|
|
|
async def get_response(self):
|
|
|
|
return HttpResponse()
|
|
|
|
|
|
|
|
class SimpleMiddleWare(MiddlewareMixin):
|
|
|
|
def process_request(self, request):
|
|
|
|
request.thread_and_connection = request_lifecycle()
|
|
|
|
|
|
|
|
def process_response(self, request, response):
|
|
|
|
response.thread_and_connection = request_lifecycle()
|
|
|
|
return response
|
|
|
|
|
|
|
|
threads_and_connections = []
|
|
|
|
threads_and_connections.append(request_lifecycle())
|
|
|
|
|
|
|
|
request = HttpRequest()
|
|
|
|
response = async_to_sync(SimpleMiddleWare(get_response))(request)
|
|
|
|
threads_and_connections.append(request.thread_and_connection)
|
|
|
|
threads_and_connections.append(response.thread_and_connection)
|
|
|
|
|
|
|
|
threads_and_connections.append(request_lifecycle())
|
|
|
|
|
|
|
|
self.assertEqual(len(threads_and_connections), 4)
|
|
|
|
self.assertEqual(len(set(threads_and_connections)), 1)
|