Fixed #35030 -- Made django.contrib.auth decorators to work with async functions.

This commit is contained in:
Dingning 2023-12-12 22:09:18 +08:00 committed by Mariusz Felisiak
parent 1fffa4af12
commit 549320946d
4 changed files with 287 additions and 18 deletions

View File

@ -1,6 +1,9 @@
import asyncio
from functools import wraps from functools import wraps
from urllib.parse import urlparse from urllib.parse import urlparse
from asgiref.sync import async_to_sync, sync_to_async
from django.conf import settings from django.conf import settings
from django.contrib.auth import REDIRECT_FIELD_NAME from django.contrib.auth import REDIRECT_FIELD_NAME
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
@ -17,10 +20,7 @@ def user_passes_test(
""" """
def decorator(view_func): def decorator(view_func):
@wraps(view_func) def _redirect_to_login(request):
def _wrapper_view(request, *args, **kwargs):
if test_func(request.user):
return view_func(request, *args, **kwargs)
path = request.build_absolute_uri() path = request.build_absolute_uri()
resolved_login_url = resolve_url(login_url or settings.LOGIN_URL) resolved_login_url = resolve_url(login_url or settings.LOGIN_URL)
# If the login url is the same scheme and net location then just # If the login url is the same scheme and net location then just
@ -35,7 +35,32 @@ def user_passes_test(
return redirect_to_login(path, resolved_login_url, redirect_field_name) return redirect_to_login(path, resolved_login_url, redirect_field_name)
return _wrapper_view if asyncio.iscoroutinefunction(view_func):
async def _view_wrapper(request, *args, **kwargs):
auser = await request.auser()
if asyncio.iscoroutinefunction(test_func):
test_pass = await test_func(auser)
else:
test_pass = await sync_to_async(test_func)(auser)
if test_pass:
return await view_func(request, *args, **kwargs)
return _redirect_to_login(request)
else:
def _view_wrapper(request, *args, **kwargs):
if asyncio.iscoroutinefunction(test_func):
test_pass = async_to_sync(test_func)(request.user)
else:
test_pass = test_func(request.user)
if test_pass:
return view_func(request, *args, **kwargs)
return _redirect_to_login(request)
return wraps(view_func)(_view_wrapper)
return decorator return decorator
@ -64,19 +89,36 @@ def permission_required(perm, login_url=None, raise_exception=False):
If the raise_exception parameter is given the PermissionDenied exception If the raise_exception parameter is given the PermissionDenied exception
is raised. is raised.
""" """
if isinstance(perm, str):
perms = (perm,)
else:
perms = perm
def decorator(view_func):
if asyncio.iscoroutinefunction(view_func):
async def check_perms(user):
# First check if the user has the permission (even anon users).
if await sync_to_async(user.has_perms)(perms):
return True
# In case the 403 handler should be called raise the exception.
if raise_exception:
raise PermissionDenied
# As the last resort, show the login form.
return False
def check_perms(user):
if isinstance(perm, str):
perms = (perm,)
else: else:
perms = perm
# First check if the user has the permission (even anon users)
if user.has_perms(perms):
return True
# In case the 403 handler should be called raise the exception
if raise_exception:
raise PermissionDenied
# As the last resort, show the login form
return False
return user_passes_test(check_perms, login_url=login_url) def check_perms(user):
# First check if the user has the permission (even anon users).
if user.has_perms(perms):
return True
# In case the 403 handler should be called raise the exception.
if raise_exception:
raise PermissionDenied
# As the last resort, show the login form.
return False
return user_passes_test(check_perms, login_url=login_url)(view_func)
return decorator

View File

@ -52,6 +52,11 @@ Minor features
form save. This is now available in the admin when visiting the user creation form save. This is now available in the admin when visiting the user creation
and password change pages. and password change pages.
* :func:`~.django.contrib.auth.decorators.login_required`,
:func:`~.django.contrib.auth.decorators.permission_required`, and
:func:`~.django.contrib.auth.decorators.user_passes_test` decorators now
support wrapping asynchronous view functions.
:mod:`django.contrib.contenttypes` :mod:`django.contrib.contenttypes`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -617,6 +617,10 @@ The ``login_required`` decorator
:func:`django.contrib.admin.views.decorators.staff_member_required` :func:`django.contrib.admin.views.decorators.staff_member_required`
decorator a useful alternative to ``login_required()``. decorator a useful alternative to ``login_required()``.
.. versionchanged:: 5.1
Support for wrapping asynchronous view functions was added.
.. currentmodule:: django.contrib.auth.mixins .. currentmodule:: django.contrib.auth.mixins
The ``LoginRequiredMixin`` mixin The ``LoginRequiredMixin`` mixin
@ -714,6 +718,11 @@ email in the desired domain and if not, redirects to the login page::
@user_passes_test(email_check, login_url="/login/") @user_passes_test(email_check, login_url="/login/")
def my_view(request): ... def my_view(request): ...
.. versionchanged:: 5.1
Support for wrapping asynchronous view functions and using asynchronous
test callables was added.
.. currentmodule:: django.contrib.auth.mixins .. currentmodule:: django.contrib.auth.mixins
.. class:: UserPassesTestMixin .. class:: UserPassesTestMixin
@ -818,6 +827,10 @@ The ``permission_required`` decorator
``redirect_authenticated_user=True`` and the logged-in user doesn't have ``redirect_authenticated_user=True`` and the logged-in user doesn't have
all of the required permissions. all of the required permissions.
.. versionchanged:: 5.1
Support for wrapping asynchronous view functions was added.
.. currentmodule:: django.contrib.auth.mixins .. currentmodule:: django.contrib.auth.mixins
The ``PermissionRequiredMixin`` mixin The ``PermissionRequiredMixin`` mixin

View File

@ -1,3 +1,7 @@
from asyncio import iscoroutinefunction
from asgiref.sync import sync_to_async
from django.conf import settings from django.conf import settings
from django.contrib.auth import models from django.contrib.auth import models
from django.contrib.auth.decorators import ( from django.contrib.auth.decorators import (
@ -19,6 +23,22 @@ class LoginRequiredTestCase(AuthViewsTestCase):
Tests the login_required decorators Tests the login_required decorators
""" """
factory = RequestFactory()
def test_wrapped_sync_function_is_not_coroutine_function(self):
def sync_view(request):
return HttpResponse()
wrapped_view = login_required(sync_view)
self.assertIs(iscoroutinefunction(wrapped_view), False)
def test_wrapped_async_function_is_coroutine_function(self):
async def async_view(request):
return HttpResponse()
wrapped_view = login_required(async_view)
self.assertIs(iscoroutinefunction(wrapped_view), True)
def test_callable(self): def test_callable(self):
""" """
login_required is assignable to callable objects. login_required is assignable to callable objects.
@ -63,6 +83,35 @@ class LoginRequiredTestCase(AuthViewsTestCase):
view_url="/login_required_login_url/", login_url="/somewhere/" view_url="/login_required_login_url/", login_url="/somewhere/"
) )
async def test_login_required_async_view(self, login_url=None):
async def async_view(request):
return HttpResponse()
async def auser_anonymous():
return models.AnonymousUser()
async def auser():
return self.u1
if login_url is None:
async_view = login_required(async_view)
login_url = settings.LOGIN_URL
else:
async_view = login_required(async_view, login_url=login_url)
request = self.factory.get("/rand")
request.auser = auser_anonymous
response = await async_view(request)
self.assertEqual(response.status_code, 302)
self.assertIn(login_url, response.url)
request.auser = auser
response = await async_view(request)
self.assertEqual(response.status_code, 200)
async def test_login_required_next_url_async_view(self):
await self.test_login_required_async_view(login_url="/somewhere/")
class PermissionsRequiredDecoratorTest(TestCase): class PermissionsRequiredDecoratorTest(TestCase):
""" """
@ -80,6 +129,24 @@ class PermissionsRequiredDecoratorTest(TestCase):
) )
cls.user.user_permissions.add(*perms) cls.user.user_permissions.add(*perms)
@classmethod
async def auser(cls):
return cls.user
def test_wrapped_sync_function_is_not_coroutine_function(self):
def sync_view(request):
return HttpResponse()
wrapped_view = permission_required([])(sync_view)
self.assertIs(iscoroutinefunction(wrapped_view), False)
def test_wrapped_async_function_is_coroutine_function(self):
async def async_view(request):
return HttpResponse()
wrapped_view = permission_required([])(async_view)
self.assertIs(iscoroutinefunction(wrapped_view), True)
def test_many_permissions_pass(self): def test_many_permissions_pass(self):
@permission_required( @permission_required(
["auth_tests.add_customuser", "auth_tests.change_customuser"] ["auth_tests.add_customuser", "auth_tests.change_customuser"]
@ -147,6 +214,73 @@ class PermissionsRequiredDecoratorTest(TestCase):
with self.assertRaises(PermissionDenied): with self.assertRaises(PermissionDenied):
a_view(request) a_view(request)
async def test_many_permissions_pass_async_view(self):
@permission_required(
["auth_tests.add_customuser", "auth_tests.change_customuser"]
)
async def async_view(request):
return HttpResponse()
request = self.factory.get("/rand")
request.auser = self.auser
response = await async_view(request)
self.assertEqual(response.status_code, 200)
async def test_many_permissions_in_set_pass_async_view(self):
@permission_required(
{"auth_tests.add_customuser", "auth_tests.change_customuser"}
)
async def async_view(request):
return HttpResponse()
request = self.factory.get("/rand")
request.auser = self.auser
response = await async_view(request)
self.assertEqual(response.status_code, 200)
async def test_single_permission_pass_async_view(self):
@permission_required("auth_tests.add_customuser")
async def async_view(request):
return HttpResponse()
request = self.factory.get("/rand")
request.auser = self.auser
response = await async_view(request)
self.assertEqual(response.status_code, 200)
async def test_permissioned_denied_redirect_async_view(self):
@permission_required(
[
"auth_tests.add_customuser",
"auth_tests.change_customuser",
"nonexistent-permission",
]
)
async def async_view(request):
return HttpResponse()
request = self.factory.get("/rand")
request.auser = self.auser
response = await async_view(request)
self.assertEqual(response.status_code, 302)
async def test_permissioned_denied_exception_raised_async_view(self):
@permission_required(
[
"auth_tests.add_customuser",
"auth_tests.change_customuser",
"nonexistent-permission",
],
raise_exception=True,
)
async def async_view(request):
return HttpResponse()
request = self.factory.get("/rand")
request.auser = self.auser
with self.assertRaises(PermissionDenied):
await async_view(request)
class UserPassesTestDecoratorTest(TestCase): class UserPassesTestDecoratorTest(TestCase):
factory = RequestFactory() factory = RequestFactory()
@ -162,6 +296,28 @@ class UserPassesTestDecoratorTest(TestCase):
) )
cls.user_pass.user_permissions.add(*perms) cls.user_pass.user_permissions.add(*perms)
@classmethod
async def auser_pass(cls):
return cls.user_pass
@classmethod
async def auser_deny(cls):
return cls.user_deny
def test_wrapped_sync_function_is_not_coroutine_function(self):
def sync_view(request):
return HttpResponse()
wrapped_view = user_passes_test(lambda user: True)(sync_view)
self.assertIs(iscoroutinefunction(wrapped_view), False)
def test_wrapped_async_function_is_coroutine_function(self):
async def async_view(request):
return HttpResponse()
wrapped_view = user_passes_test(lambda user: True)(async_view)
self.assertIs(iscoroutinefunction(wrapped_view), True)
def test_decorator(self): def test_decorator(self):
def sync_test_func(user): def sync_test_func(user):
return bool( return bool(
@ -180,3 +336,56 @@ class UserPassesTestDecoratorTest(TestCase):
request.user = self.user_deny request.user = self.user_deny
response = sync_view(request) response = sync_view(request)
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
def test_decorator_async_test_func(self):
async def async_test_func(user):
return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
@user_passes_test(async_test_func)
def sync_view(request):
return HttpResponse()
request = self.factory.get("/rand")
request.user = self.user_pass
response = sync_view(request)
self.assertEqual(response.status_code, 200)
request.user = self.user_deny
response = sync_view(request)
self.assertEqual(response.status_code, 302)
async def test_decorator_async_view(self):
def sync_test_func(user):
return bool(
models.Group.objects.filter(name__istartswith=user.username).exists()
)
@user_passes_test(sync_test_func)
async def async_view(request):
return HttpResponse()
request = self.factory.get("/rand")
request.auser = self.auser_pass
response = await async_view(request)
self.assertEqual(response.status_code, 200)
request.auser = self.auser_deny
response = await async_view(request)
self.assertEqual(response.status_code, 302)
async def test_decorator_async_view_async_test_func(self):
async def async_test_func(user):
return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
@user_passes_test(async_test_func)
async def async_view(request):
return HttpResponse()
request = self.factory.get("/rand")
request.auser = self.auser_pass
response = await async_view(request)
self.assertEqual(response.status_code, 200)
request.auser = self.auser_deny
response = await async_view(request)
self.assertEqual(response.status_code, 302)