diff --git a/django/utils/decorators.py b/django/utils/decorators.py index 735821d4758..d8814b0d4c6 100644 --- a/django/utils/decorators.py +++ b/django/utils/decorators.py @@ -2,6 +2,8 @@ from functools import partial, update_wrapper, wraps +from asgiref.sync import iscoroutinefunction + class classonlymethod(classmethod): def __get__(self, instance, cls=None): @@ -120,8 +122,7 @@ def make_middleware_decorator(middleware_class): def _decorator(view_func): middleware = middleware_class(view_func, *m_args, **m_kwargs) - @wraps(view_func) - def _wrapper_view(request, *args, **kwargs): + def _pre_process_request(request, *args, **kwargs): if hasattr(middleware, "process_request"): result = middleware.process_request(request) if result is not None: @@ -130,14 +131,16 @@ def make_middleware_decorator(middleware_class): result = middleware.process_view(request, view_func, args, kwargs) if result is not None: return result - try: - response = view_func(request, *args, **kwargs) - except Exception as e: - if hasattr(middleware, "process_exception"): - result = middleware.process_exception(request, e) - if result is not None: - return result - raise + return None + + def _process_exception(request, exception): + if hasattr(middleware, "process_exception"): + result = middleware.process_exception(request, exception) + if result is not None: + return result + raise + + def _post_process_request(request, response): if hasattr(response, "render") and callable(response.render): if hasattr(middleware, "process_template_response"): response = middleware.process_template_response( @@ -156,7 +159,39 @@ def make_middleware_decorator(middleware_class): return middleware.process_response(request, response) return response - return _wrapper_view + if iscoroutinefunction(view_func): + + async def _view_wrapper(request, *args, **kwargs): + result = _pre_process_request(request, *args, **kwargs) + if result is not None: + return result + + try: + response = await view_func(request, *args, **kwargs) + except Exception as e: + result = _process_exception(request, e) + if result is not None: + return result + + return _post_process_request(request, response) + + else: + + def _view_wrapper(request, *args, **kwargs): + result = _pre_process_request(request, *args, **kwargs) + if result is not None: + return result + + try: + response = view_func(request, *args, **kwargs) + except Exception as e: + result = _process_exception(request, e) + if result is not None: + return result + + return _post_process_request(request, response) + + return wraps(view_func)(_view_wrapper) return _decorator diff --git a/docs/ref/csrf.txt b/docs/ref/csrf.txt index d52539e6b63..ae94ccdee49 100644 --- a/docs/ref/csrf.txt +++ b/docs/ref/csrf.txt @@ -170,6 +170,10 @@ class-based views`. # ... return render(request, "a_template.html", c) + .. versionchanged:: 5.0 + + Support for wrapping asynchronous view functions was added. + .. function:: requires_csrf_token(view) Normally the :ttag:`csrf_token` template tag will not work if @@ -190,10 +194,18 @@ class-based views`. # ... return render(request, "a_template.html", c) + .. versionchanged:: 5.0 + + Support for wrapping asynchronous view functions was added. + .. function:: ensure_csrf_cookie(view) This decorator forces a view to send the CSRF cookie. + .. versionchanged:: 5.0 + + Support for wrapping asynchronous view functions was added. + Settings ======== diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index 10ab7066451..3bcec805eff 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -322,9 +322,14 @@ Decorators * :func:`~django.views.decorators.cache.never_cache` * :func:`~django.views.decorators.common.no_append_slash` * :func:`~django.views.decorators.csrf.csrf_exempt` + * :func:`~django.views.decorators.csrf.csrf_protect` + * :func:`~django.views.decorators.csrf.ensure_csrf_cookie` + * :func:`~django.views.decorators.csrf.requires_csrf_token` * :func:`~django.views.decorators.debug.sensitive_variables` * :func:`~django.views.decorators.debug.sensitive_post_parameters` + * :func:`~django.views.decorators.gzip.gzip_page` * :func:`~django.views.decorators.http.condition` + * ``conditional_page()`` * :func:`~django.views.decorators.http.etag` * :func:`~django.views.decorators.http.last_modified` * :func:`~django.views.decorators.http.require_http_methods` diff --git a/docs/topics/async.txt b/docs/topics/async.txt index 1faf7873807..0e33753e682 100644 --- a/docs/topics/async.txt +++ b/docs/topics/async.txt @@ -85,9 +85,14 @@ view functions: * :func:`~django.views.decorators.cache.never_cache` * :func:`~django.views.decorators.common.no_append_slash` * :func:`~django.views.decorators.csrf.csrf_exempt` +* :func:`~django.views.decorators.csrf.csrf_protect` +* :func:`~django.views.decorators.csrf.ensure_csrf_cookie` +* :func:`~django.views.decorators.csrf.requires_csrf_token` * :func:`~django.views.decorators.debug.sensitive_variables` * :func:`~django.views.decorators.debug.sensitive_post_parameters` +* :func:`~django.views.decorators.gzip.gzip_page` * :func:`~django.views.decorators.http.condition` +* ``conditional_page()`` * :func:`~django.views.decorators.http.etag` * :func:`~django.views.decorators.http.last_modified` * :func:`~django.views.decorators.http.require_http_methods` diff --git a/docs/topics/http/decorators.txt b/docs/topics/http/decorators.txt index 38e528ecf53..fa84e8d9b76 100644 --- a/docs/topics/http/decorators.txt +++ b/docs/topics/http/decorators.txt @@ -105,6 +105,10 @@ compression on a per-view basis. It sets the ``Vary`` header accordingly, so that caches will base their storage on the ``Accept-Encoding`` header. + .. versionchanged:: 5.0 + + Support for wrapping asynchronous view functions was added. + .. module:: django.views.decorators.vary Vary headers diff --git a/tests/decorators/test_csrf.py b/tests/decorators/test_csrf.py index 213ee7fdd86..c1934448643 100644 --- a/tests/decorators/test_csrf.py +++ b/tests/decorators/test_csrf.py @@ -24,6 +24,20 @@ class CsrfTestMixin: class CsrfProtectTests(CsrfTestMixin, SimpleTestCase): + def test_wrapped_sync_function_is_not_coroutine_function(self): + def sync_view(request): + return HttpResponse() + + wrapped_view = csrf_protect(sync_view) + self.assertIs(iscoroutinefunction(wrapped_view), False) + + def test_wrapped_async_function_is_coroutine_function(self): + async def async_view(request): + return HttpResponse() + + wrapped_view = csrf_protect(async_view) + self.assertIs(iscoroutinefunction(wrapped_view), True) + def test_csrf_protect_decorator(self): @csrf_protect def sync_view(request): @@ -39,8 +53,37 @@ class CsrfProtectTests(CsrfTestMixin, SimpleTestCase): response = sync_view(request) self.assertEqual(response.status_code, 403) + async def test_csrf_protect_decorator_async_view(self): + @csrf_protect + async def async_view(request): + return HttpResponse() + + request = self.get_request() + response = await async_view(request) + self.assertEqual(response.status_code, 200) + self.assertIs(request.csrf_processing_done, True) + + with self.assertLogs("django.security.csrf", "WARNING"): + request = self.get_request(token=None) + response = await async_view(request) + self.assertEqual(response.status_code, 403) + class RequiresCsrfTokenTests(CsrfTestMixin, SimpleTestCase): + def test_wrapped_sync_function_is_not_coroutine_function(self): + def sync_view(request): + return HttpResponse() + + wrapped_view = requires_csrf_token(sync_view) + self.assertIs(iscoroutinefunction(wrapped_view), False) + + def test_wrapped_async_function_is_coroutine_function(self): + async def async_view(request): + return HttpResponse() + + wrapped_view = requires_csrf_token(async_view) + self.assertIs(iscoroutinefunction(wrapped_view), True) + def test_requires_csrf_token_decorator(self): @requires_csrf_token def sync_view(request): @@ -56,8 +99,37 @@ class RequiresCsrfTokenTests(CsrfTestMixin, SimpleTestCase): response = sync_view(request) self.assertEqual(response.status_code, 200) + async def test_requires_csrf_token_decorator_async_view(self): + @requires_csrf_token + async def async_view(request): + return HttpResponse() + + request = self.get_request() + response = await async_view(request) + self.assertEqual(response.status_code, 200) + self.assertIs(request.csrf_processing_done, True) + + with self.assertNoLogs("django.security.csrf", "WARNING"): + request = self.get_request(token=None) + response = await async_view(request) + self.assertEqual(response.status_code, 200) + class EnsureCsrfCookieTests(CsrfTestMixin, SimpleTestCase): + def test_wrapped_sync_function_is_not_coroutine_function(self): + def sync_view(request): + return HttpResponse() + + wrapped_view = ensure_csrf_cookie(sync_view) + self.assertIs(iscoroutinefunction(wrapped_view), False) + + def test_wrapped_async_function_is_coroutine_function(self): + async def async_view(request): + return HttpResponse() + + wrapped_view = ensure_csrf_cookie(async_view) + self.assertIs(iscoroutinefunction(wrapped_view), True) + def test_ensure_csrf_cookie_decorator(self): @ensure_csrf_cookie def sync_view(request): @@ -73,6 +145,21 @@ class EnsureCsrfCookieTests(CsrfTestMixin, SimpleTestCase): response = sync_view(request) self.assertEqual(response.status_code, 200) + async def test_ensure_csrf_cookie_decorator_async_view(self): + @ensure_csrf_cookie + async def async_view(request): + return HttpResponse() + + request = self.get_request() + response = await async_view(request) + self.assertEqual(response.status_code, 200) + self.assertIs(request.csrf_processing_done, True) + + with self.assertNoLogs("django.security.csrf", "WARNING"): + request = self.get_request(token=None) + response = await async_view(request) + self.assertEqual(response.status_code, 200) + class CsrfExemptTests(SimpleTestCase): def test_wrapped_sync_function_is_not_coroutine_function(self): diff --git a/tests/decorators/test_gzip.py b/tests/decorators/test_gzip.py index 129befbd1e9..87eff0fe777 100644 --- a/tests/decorators/test_gzip.py +++ b/tests/decorators/test_gzip.py @@ -1,3 +1,5 @@ +from asgiref.sync import iscoroutinefunction + from django.http import HttpRequest, HttpResponse from django.test import SimpleTestCase from django.views.decorators.gzip import gzip_page @@ -7,6 +9,20 @@ class GzipPageTests(SimpleTestCase): # Gzip ignores content that is too short. content = "Content " * 100 + def test_wrapped_sync_function_is_not_coroutine_function(self): + def sync_view(request): + return HttpResponse() + + wrapped_view = gzip_page(sync_view) + self.assertIs(iscoroutinefunction(wrapped_view), False) + + def test_wrapped_async_function_is_coroutine_function(self): + async def async_view(request): + return HttpResponse() + + wrapped_view = gzip_page(async_view) + self.assertIs(iscoroutinefunction(wrapped_view), True) + def test_gzip_page_decorator(self): @gzip_page def sync_view(request): @@ -17,3 +33,14 @@ class GzipPageTests(SimpleTestCase): response = sync_view(request) self.assertEqual(response.status_code, 200) self.assertEqual(response.get("Content-Encoding"), "gzip") + + async def test_gzip_page_decorator_async_view(self): + @gzip_page + async def async_view(request): + return HttpResponse(content=self.content) + + request = HttpRequest() + request.META["HTTP_ACCEPT_ENCODING"] = "gzip" + response = await async_view(request) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get("Content-Encoding"), "gzip") diff --git a/tests/decorators/test_http.py b/tests/decorators/test_http.py index d36e46562c5..4b1b81588e2 100644 --- a/tests/decorators/test_http.py +++ b/tests/decorators/test_http.py @@ -163,6 +163,20 @@ class ConditionDecoratorTest(SimpleTestCase): class ConditionalPageTests(SimpleTestCase): + def test_wrapped_sync_function_is_not_coroutine_function(self): + def sync_view(request): + return HttpResponse() + + wrapped_view = conditional_page(sync_view) + self.assertIs(iscoroutinefunction(wrapped_view), False) + + def test_wrapped_async_function_is_coroutine_function(self): + async def async_view(request): + return HttpResponse() + + wrapped_view = conditional_page(async_view) + self.assertIs(iscoroutinefunction(wrapped_view), True) + def test_conditional_page_decorator_successful(self): @conditional_page def sync_view(request): @@ -176,3 +190,17 @@ class ConditionalPageTests(SimpleTestCase): response = sync_view(request) self.assertEqual(response.status_code, 200) self.assertIsNotNone(response.get("Etag")) + + async def test_conditional_page_decorator_successful_async_view(self): + @conditional_page + async def async_view(request): + response = HttpResponse() + response.content = b"test" + response["Cache-Control"] = "public" + return response + + request = HttpRequest() + request.method = "GET" + response = await async_view(request) + self.assertEqual(response.status_code, 200) + self.assertIsNotNone(response.get("Etag"))