Refs #31949 -- Made make_middleware_decorator to work with async functions.

This commit is contained in:
Ben Lomax 2023-07-08 22:51:19 +01:00 committed by Mariusz Felisiak
parent 059cb0dbc9
commit 74f7deec9e
8 changed files with 214 additions and 11 deletions

View File

@ -2,6 +2,8 @@
from functools import partial, update_wrapper, wraps
from asgiref.sync import iscoroutinefunction
class classonlymethod(classmethod):
def __get__(self, instance, cls=None):
@ -120,8 +122,7 @@ def make_middleware_decorator(middleware_class):
def _decorator(view_func):
middleware = middleware_class(view_func, *m_args, **m_kwargs)
@wraps(view_func)
def _wrapper_view(request, *args, **kwargs):
def _pre_process_request(request, *args, **kwargs):
if hasattr(middleware, "process_request"):
result = middleware.process_request(request)
if result is not None:
@ -130,14 +131,16 @@ def make_middleware_decorator(middleware_class):
result = middleware.process_view(request, view_func, args, kwargs)
if result is not None:
return result
try:
response = view_func(request, *args, **kwargs)
except Exception as e:
if hasattr(middleware, "process_exception"):
result = middleware.process_exception(request, e)
if result is not None:
return result
raise
return None
def _process_exception(request, exception):
if hasattr(middleware, "process_exception"):
result = middleware.process_exception(request, exception)
if result is not None:
return result
raise
def _post_process_request(request, response):
if hasattr(response, "render") and callable(response.render):
if hasattr(middleware, "process_template_response"):
response = middleware.process_template_response(
@ -156,7 +159,39 @@ def make_middleware_decorator(middleware_class):
return middleware.process_response(request, response)
return response
return _wrapper_view
if iscoroutinefunction(view_func):
async def _view_wrapper(request, *args, **kwargs):
result = _pre_process_request(request, *args, **kwargs)
if result is not None:
return result
try:
response = await view_func(request, *args, **kwargs)
except Exception as e:
result = _process_exception(request, e)
if result is not None:
return result
return _post_process_request(request, response)
else:
def _view_wrapper(request, *args, **kwargs):
result = _pre_process_request(request, *args, **kwargs)
if result is not None:
return result
try:
response = view_func(request, *args, **kwargs)
except Exception as e:
result = _process_exception(request, e)
if result is not None:
return result
return _post_process_request(request, response)
return wraps(view_func)(_view_wrapper)
return _decorator

View File

@ -170,6 +170,10 @@ class-based views<decorating-class-based-views>`.
# ...
return render(request, "a_template.html", c)
.. versionchanged:: 5.0
Support for wrapping asynchronous view functions was added.
.. function:: requires_csrf_token(view)
Normally the :ttag:`csrf_token` template tag will not work if
@ -190,10 +194,18 @@ class-based views<decorating-class-based-views>`.
# ...
return render(request, "a_template.html", c)
.. versionchanged:: 5.0
Support for wrapping asynchronous view functions was added.
.. function:: ensure_csrf_cookie(view)
This decorator forces a view to send the CSRF cookie.
.. versionchanged:: 5.0
Support for wrapping asynchronous view functions was added.
Settings
========

View File

@ -322,9 +322,14 @@ Decorators
* :func:`~django.views.decorators.cache.never_cache`
* :func:`~django.views.decorators.common.no_append_slash`
* :func:`~django.views.decorators.csrf.csrf_exempt`
* :func:`~django.views.decorators.csrf.csrf_protect`
* :func:`~django.views.decorators.csrf.ensure_csrf_cookie`
* :func:`~django.views.decorators.csrf.requires_csrf_token`
* :func:`~django.views.decorators.debug.sensitive_variables`
* :func:`~django.views.decorators.debug.sensitive_post_parameters`
* :func:`~django.views.decorators.gzip.gzip_page`
* :func:`~django.views.decorators.http.condition`
* ``conditional_page()``
* :func:`~django.views.decorators.http.etag`
* :func:`~django.views.decorators.http.last_modified`
* :func:`~django.views.decorators.http.require_http_methods`

View File

@ -85,9 +85,14 @@ view functions:
* :func:`~django.views.decorators.cache.never_cache`
* :func:`~django.views.decorators.common.no_append_slash`
* :func:`~django.views.decorators.csrf.csrf_exempt`
* :func:`~django.views.decorators.csrf.csrf_protect`
* :func:`~django.views.decorators.csrf.ensure_csrf_cookie`
* :func:`~django.views.decorators.csrf.requires_csrf_token`
* :func:`~django.views.decorators.debug.sensitive_variables`
* :func:`~django.views.decorators.debug.sensitive_post_parameters`
* :func:`~django.views.decorators.gzip.gzip_page`
* :func:`~django.views.decorators.http.condition`
* ``conditional_page()``
* :func:`~django.views.decorators.http.etag`
* :func:`~django.views.decorators.http.last_modified`
* :func:`~django.views.decorators.http.require_http_methods`

View File

@ -105,6 +105,10 @@ compression on a per-view basis.
It sets the ``Vary`` header accordingly, so that caches will base their
storage on the ``Accept-Encoding`` header.
.. versionchanged:: 5.0
Support for wrapping asynchronous view functions was added.
.. module:: django.views.decorators.vary
Vary headers

View File

@ -24,6 +24,20 @@ class CsrfTestMixin:
class CsrfProtectTests(CsrfTestMixin, SimpleTestCase):
def test_wrapped_sync_function_is_not_coroutine_function(self):
def sync_view(request):
return HttpResponse()
wrapped_view = csrf_protect(sync_view)
self.assertIs(iscoroutinefunction(wrapped_view), False)
def test_wrapped_async_function_is_coroutine_function(self):
async def async_view(request):
return HttpResponse()
wrapped_view = csrf_protect(async_view)
self.assertIs(iscoroutinefunction(wrapped_view), True)
def test_csrf_protect_decorator(self):
@csrf_protect
def sync_view(request):
@ -39,8 +53,37 @@ class CsrfProtectTests(CsrfTestMixin, SimpleTestCase):
response = sync_view(request)
self.assertEqual(response.status_code, 403)
async def test_csrf_protect_decorator_async_view(self):
@csrf_protect
async def async_view(request):
return HttpResponse()
request = self.get_request()
response = await async_view(request)
self.assertEqual(response.status_code, 200)
self.assertIs(request.csrf_processing_done, True)
with self.assertLogs("django.security.csrf", "WARNING"):
request = self.get_request(token=None)
response = await async_view(request)
self.assertEqual(response.status_code, 403)
class RequiresCsrfTokenTests(CsrfTestMixin, SimpleTestCase):
def test_wrapped_sync_function_is_not_coroutine_function(self):
def sync_view(request):
return HttpResponse()
wrapped_view = requires_csrf_token(sync_view)
self.assertIs(iscoroutinefunction(wrapped_view), False)
def test_wrapped_async_function_is_coroutine_function(self):
async def async_view(request):
return HttpResponse()
wrapped_view = requires_csrf_token(async_view)
self.assertIs(iscoroutinefunction(wrapped_view), True)
def test_requires_csrf_token_decorator(self):
@requires_csrf_token
def sync_view(request):
@ -56,8 +99,37 @@ class RequiresCsrfTokenTests(CsrfTestMixin, SimpleTestCase):
response = sync_view(request)
self.assertEqual(response.status_code, 200)
async def test_requires_csrf_token_decorator_async_view(self):
@requires_csrf_token
async def async_view(request):
return HttpResponse()
request = self.get_request()
response = await async_view(request)
self.assertEqual(response.status_code, 200)
self.assertIs(request.csrf_processing_done, True)
with self.assertNoLogs("django.security.csrf", "WARNING"):
request = self.get_request(token=None)
response = await async_view(request)
self.assertEqual(response.status_code, 200)
class EnsureCsrfCookieTests(CsrfTestMixin, SimpleTestCase):
def test_wrapped_sync_function_is_not_coroutine_function(self):
def sync_view(request):
return HttpResponse()
wrapped_view = ensure_csrf_cookie(sync_view)
self.assertIs(iscoroutinefunction(wrapped_view), False)
def test_wrapped_async_function_is_coroutine_function(self):
async def async_view(request):
return HttpResponse()
wrapped_view = ensure_csrf_cookie(async_view)
self.assertIs(iscoroutinefunction(wrapped_view), True)
def test_ensure_csrf_cookie_decorator(self):
@ensure_csrf_cookie
def sync_view(request):
@ -73,6 +145,21 @@ class EnsureCsrfCookieTests(CsrfTestMixin, SimpleTestCase):
response = sync_view(request)
self.assertEqual(response.status_code, 200)
async def test_ensure_csrf_cookie_decorator_async_view(self):
@ensure_csrf_cookie
async def async_view(request):
return HttpResponse()
request = self.get_request()
response = await async_view(request)
self.assertEqual(response.status_code, 200)
self.assertIs(request.csrf_processing_done, True)
with self.assertNoLogs("django.security.csrf", "WARNING"):
request = self.get_request(token=None)
response = await async_view(request)
self.assertEqual(response.status_code, 200)
class CsrfExemptTests(SimpleTestCase):
def test_wrapped_sync_function_is_not_coroutine_function(self):

View File

@ -1,3 +1,5 @@
from asgiref.sync import iscoroutinefunction
from django.http import HttpRequest, HttpResponse
from django.test import SimpleTestCase
from django.views.decorators.gzip import gzip_page
@ -7,6 +9,20 @@ class GzipPageTests(SimpleTestCase):
# Gzip ignores content that is too short.
content = "Content " * 100
def test_wrapped_sync_function_is_not_coroutine_function(self):
def sync_view(request):
return HttpResponse()
wrapped_view = gzip_page(sync_view)
self.assertIs(iscoroutinefunction(wrapped_view), False)
def test_wrapped_async_function_is_coroutine_function(self):
async def async_view(request):
return HttpResponse()
wrapped_view = gzip_page(async_view)
self.assertIs(iscoroutinefunction(wrapped_view), True)
def test_gzip_page_decorator(self):
@gzip_page
def sync_view(request):
@ -17,3 +33,14 @@ class GzipPageTests(SimpleTestCase):
response = sync_view(request)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get("Content-Encoding"), "gzip")
async def test_gzip_page_decorator_async_view(self):
@gzip_page
async def async_view(request):
return HttpResponse(content=self.content)
request = HttpRequest()
request.META["HTTP_ACCEPT_ENCODING"] = "gzip"
response = await async_view(request)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get("Content-Encoding"), "gzip")

View File

@ -163,6 +163,20 @@ class ConditionDecoratorTest(SimpleTestCase):
class ConditionalPageTests(SimpleTestCase):
def test_wrapped_sync_function_is_not_coroutine_function(self):
def sync_view(request):
return HttpResponse()
wrapped_view = conditional_page(sync_view)
self.assertIs(iscoroutinefunction(wrapped_view), False)
def test_wrapped_async_function_is_coroutine_function(self):
async def async_view(request):
return HttpResponse()
wrapped_view = conditional_page(async_view)
self.assertIs(iscoroutinefunction(wrapped_view), True)
def test_conditional_page_decorator_successful(self):
@conditional_page
def sync_view(request):
@ -176,3 +190,17 @@ class ConditionalPageTests(SimpleTestCase):
response = sync_view(request)
self.assertEqual(response.status_code, 200)
self.assertIsNotNone(response.get("Etag"))
async def test_conditional_page_decorator_successful_async_view(self):
@conditional_page
async def async_view(request):
response = HttpResponse()
response.content = b"test"
response["Cache-Control"] = "public"
return response
request = HttpRequest()
request.method = "GET"
response = await async_view(request)
self.assertEqual(response.status_code, 200)
self.assertIsNotNone(response.get("Etag"))