mirror of https://github.com/django/django.git
Fixed #35030 -- Made django.contrib.auth decorators to work with async functions.
This commit is contained in:
parent
1fffa4af12
commit
549320946d
|
@ -1,6 +1,9 @@
|
||||||
|
import asyncio
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from asgiref.sync import async_to_sync, sync_to_async
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth import REDIRECT_FIELD_NAME
|
from django.contrib.auth import REDIRECT_FIELD_NAME
|
||||||
from django.core.exceptions import PermissionDenied
|
from django.core.exceptions import PermissionDenied
|
||||||
|
@ -17,10 +20,7 @@ def user_passes_test(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(view_func):
|
def decorator(view_func):
|
||||||
@wraps(view_func)
|
def _redirect_to_login(request):
|
||||||
def _wrapper_view(request, *args, **kwargs):
|
|
||||||
if test_func(request.user):
|
|
||||||
return view_func(request, *args, **kwargs)
|
|
||||||
path = request.build_absolute_uri()
|
path = request.build_absolute_uri()
|
||||||
resolved_login_url = resolve_url(login_url or settings.LOGIN_URL)
|
resolved_login_url = resolve_url(login_url or settings.LOGIN_URL)
|
||||||
# If the login url is the same scheme and net location then just
|
# If the login url is the same scheme and net location then just
|
||||||
|
@ -35,7 +35,32 @@ def user_passes_test(
|
||||||
|
|
||||||
return redirect_to_login(path, resolved_login_url, redirect_field_name)
|
return redirect_to_login(path, resolved_login_url, redirect_field_name)
|
||||||
|
|
||||||
return _wrapper_view
|
if asyncio.iscoroutinefunction(view_func):
|
||||||
|
|
||||||
|
async def _view_wrapper(request, *args, **kwargs):
|
||||||
|
auser = await request.auser()
|
||||||
|
if asyncio.iscoroutinefunction(test_func):
|
||||||
|
test_pass = await test_func(auser)
|
||||||
|
else:
|
||||||
|
test_pass = await sync_to_async(test_func)(auser)
|
||||||
|
|
||||||
|
if test_pass:
|
||||||
|
return await view_func(request, *args, **kwargs)
|
||||||
|
return _redirect_to_login(request)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _view_wrapper(request, *args, **kwargs):
|
||||||
|
if asyncio.iscoroutinefunction(test_func):
|
||||||
|
test_pass = async_to_sync(test_func)(request.user)
|
||||||
|
else:
|
||||||
|
test_pass = test_func(request.user)
|
||||||
|
|
||||||
|
if test_pass:
|
||||||
|
return view_func(request, *args, **kwargs)
|
||||||
|
return _redirect_to_login(request)
|
||||||
|
|
||||||
|
return wraps(view_func)(_view_wrapper)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
@ -64,19 +89,36 @@ def permission_required(perm, login_url=None, raise_exception=False):
|
||||||
If the raise_exception parameter is given the PermissionDenied exception
|
If the raise_exception parameter is given the PermissionDenied exception
|
||||||
is raised.
|
is raised.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(perm, str):
|
||||||
|
perms = (perm,)
|
||||||
|
else:
|
||||||
|
perms = perm
|
||||||
|
|
||||||
|
def decorator(view_func):
|
||||||
|
if asyncio.iscoroutinefunction(view_func):
|
||||||
|
|
||||||
|
async def check_perms(user):
|
||||||
|
# First check if the user has the permission (even anon users).
|
||||||
|
if await sync_to_async(user.has_perms)(perms):
|
||||||
|
return True
|
||||||
|
# In case the 403 handler should be called raise the exception.
|
||||||
|
if raise_exception:
|
||||||
|
raise PermissionDenied
|
||||||
|
# As the last resort, show the login form.
|
||||||
|
return False
|
||||||
|
|
||||||
def check_perms(user):
|
|
||||||
if isinstance(perm, str):
|
|
||||||
perms = (perm,)
|
|
||||||
else:
|
else:
|
||||||
perms = perm
|
|
||||||
# First check if the user has the permission (even anon users)
|
|
||||||
if user.has_perms(perms):
|
|
||||||
return True
|
|
||||||
# In case the 403 handler should be called raise the exception
|
|
||||||
if raise_exception:
|
|
||||||
raise PermissionDenied
|
|
||||||
# As the last resort, show the login form
|
|
||||||
return False
|
|
||||||
|
|
||||||
return user_passes_test(check_perms, login_url=login_url)
|
def check_perms(user):
|
||||||
|
# First check if the user has the permission (even anon users).
|
||||||
|
if user.has_perms(perms):
|
||||||
|
return True
|
||||||
|
# In case the 403 handler should be called raise the exception.
|
||||||
|
if raise_exception:
|
||||||
|
raise PermissionDenied
|
||||||
|
# As the last resort, show the login form.
|
||||||
|
return False
|
||||||
|
|
||||||
|
return user_passes_test(check_perms, login_url=login_url)(view_func)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
|
@ -52,6 +52,11 @@ Minor features
|
||||||
form save. This is now available in the admin when visiting the user creation
|
form save. This is now available in the admin when visiting the user creation
|
||||||
and password change pages.
|
and password change pages.
|
||||||
|
|
||||||
|
* :func:`~.django.contrib.auth.decorators.login_required`,
|
||||||
|
:func:`~.django.contrib.auth.decorators.permission_required`, and
|
||||||
|
:func:`~.django.contrib.auth.decorators.user_passes_test` decorators now
|
||||||
|
support wrapping asynchronous view functions.
|
||||||
|
|
||||||
:mod:`django.contrib.contenttypes`
|
:mod:`django.contrib.contenttypes`
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -617,6 +617,10 @@ The ``login_required`` decorator
|
||||||
:func:`django.contrib.admin.views.decorators.staff_member_required`
|
:func:`django.contrib.admin.views.decorators.staff_member_required`
|
||||||
decorator a useful alternative to ``login_required()``.
|
decorator a useful alternative to ``login_required()``.
|
||||||
|
|
||||||
|
.. versionchanged:: 5.1
|
||||||
|
|
||||||
|
Support for wrapping asynchronous view functions was added.
|
||||||
|
|
||||||
.. currentmodule:: django.contrib.auth.mixins
|
.. currentmodule:: django.contrib.auth.mixins
|
||||||
|
|
||||||
The ``LoginRequiredMixin`` mixin
|
The ``LoginRequiredMixin`` mixin
|
||||||
|
@ -714,6 +718,11 @@ email in the desired domain and if not, redirects to the login page::
|
||||||
@user_passes_test(email_check, login_url="/login/")
|
@user_passes_test(email_check, login_url="/login/")
|
||||||
def my_view(request): ...
|
def my_view(request): ...
|
||||||
|
|
||||||
|
.. versionchanged:: 5.1
|
||||||
|
|
||||||
|
Support for wrapping asynchronous view functions and using asynchronous
|
||||||
|
test callables was added.
|
||||||
|
|
||||||
.. currentmodule:: django.contrib.auth.mixins
|
.. currentmodule:: django.contrib.auth.mixins
|
||||||
|
|
||||||
.. class:: UserPassesTestMixin
|
.. class:: UserPassesTestMixin
|
||||||
|
@ -818,6 +827,10 @@ The ``permission_required`` decorator
|
||||||
``redirect_authenticated_user=True`` and the logged-in user doesn't have
|
``redirect_authenticated_user=True`` and the logged-in user doesn't have
|
||||||
all of the required permissions.
|
all of the required permissions.
|
||||||
|
|
||||||
|
.. versionchanged:: 5.1
|
||||||
|
|
||||||
|
Support for wrapping asynchronous view functions was added.
|
||||||
|
|
||||||
.. currentmodule:: django.contrib.auth.mixins
|
.. currentmodule:: django.contrib.auth.mixins
|
||||||
|
|
||||||
The ``PermissionRequiredMixin`` mixin
|
The ``PermissionRequiredMixin`` mixin
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
from asyncio import iscoroutinefunction
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth import models
|
from django.contrib.auth import models
|
||||||
from django.contrib.auth.decorators import (
|
from django.contrib.auth.decorators import (
|
||||||
|
@ -19,6 +23,22 @@ class LoginRequiredTestCase(AuthViewsTestCase):
|
||||||
Tests the login_required decorators
|
Tests the login_required decorators
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
factory = RequestFactory()
|
||||||
|
|
||||||
|
def test_wrapped_sync_function_is_not_coroutine_function(self):
|
||||||
|
def sync_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
wrapped_view = login_required(sync_view)
|
||||||
|
self.assertIs(iscoroutinefunction(wrapped_view), False)
|
||||||
|
|
||||||
|
def test_wrapped_async_function_is_coroutine_function(self):
|
||||||
|
async def async_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
wrapped_view = login_required(async_view)
|
||||||
|
self.assertIs(iscoroutinefunction(wrapped_view), True)
|
||||||
|
|
||||||
def test_callable(self):
|
def test_callable(self):
|
||||||
"""
|
"""
|
||||||
login_required is assignable to callable objects.
|
login_required is assignable to callable objects.
|
||||||
|
@ -63,6 +83,35 @@ class LoginRequiredTestCase(AuthViewsTestCase):
|
||||||
view_url="/login_required_login_url/", login_url="/somewhere/"
|
view_url="/login_required_login_url/", login_url="/somewhere/"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def test_login_required_async_view(self, login_url=None):
|
||||||
|
async def async_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
async def auser_anonymous():
|
||||||
|
return models.AnonymousUser()
|
||||||
|
|
||||||
|
async def auser():
|
||||||
|
return self.u1
|
||||||
|
|
||||||
|
if login_url is None:
|
||||||
|
async_view = login_required(async_view)
|
||||||
|
login_url = settings.LOGIN_URL
|
||||||
|
else:
|
||||||
|
async_view = login_required(async_view, login_url=login_url)
|
||||||
|
|
||||||
|
request = self.factory.get("/rand")
|
||||||
|
request.auser = auser_anonymous
|
||||||
|
response = await async_view(request)
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertIn(login_url, response.url)
|
||||||
|
|
||||||
|
request.auser = auser
|
||||||
|
response = await async_view(request)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
async def test_login_required_next_url_async_view(self):
|
||||||
|
await self.test_login_required_async_view(login_url="/somewhere/")
|
||||||
|
|
||||||
|
|
||||||
class PermissionsRequiredDecoratorTest(TestCase):
|
class PermissionsRequiredDecoratorTest(TestCase):
|
||||||
"""
|
"""
|
||||||
|
@ -80,6 +129,24 @@ class PermissionsRequiredDecoratorTest(TestCase):
|
||||||
)
|
)
|
||||||
cls.user.user_permissions.add(*perms)
|
cls.user.user_permissions.add(*perms)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def auser(cls):
|
||||||
|
return cls.user
|
||||||
|
|
||||||
|
def test_wrapped_sync_function_is_not_coroutine_function(self):
|
||||||
|
def sync_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
wrapped_view = permission_required([])(sync_view)
|
||||||
|
self.assertIs(iscoroutinefunction(wrapped_view), False)
|
||||||
|
|
||||||
|
def test_wrapped_async_function_is_coroutine_function(self):
|
||||||
|
async def async_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
wrapped_view = permission_required([])(async_view)
|
||||||
|
self.assertIs(iscoroutinefunction(wrapped_view), True)
|
||||||
|
|
||||||
def test_many_permissions_pass(self):
|
def test_many_permissions_pass(self):
|
||||||
@permission_required(
|
@permission_required(
|
||||||
["auth_tests.add_customuser", "auth_tests.change_customuser"]
|
["auth_tests.add_customuser", "auth_tests.change_customuser"]
|
||||||
|
@ -147,6 +214,73 @@ class PermissionsRequiredDecoratorTest(TestCase):
|
||||||
with self.assertRaises(PermissionDenied):
|
with self.assertRaises(PermissionDenied):
|
||||||
a_view(request)
|
a_view(request)
|
||||||
|
|
||||||
|
async def test_many_permissions_pass_async_view(self):
|
||||||
|
@permission_required(
|
||||||
|
["auth_tests.add_customuser", "auth_tests.change_customuser"]
|
||||||
|
)
|
||||||
|
async def async_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
request = self.factory.get("/rand")
|
||||||
|
request.auser = self.auser
|
||||||
|
response = await async_view(request)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
async def test_many_permissions_in_set_pass_async_view(self):
|
||||||
|
@permission_required(
|
||||||
|
{"auth_tests.add_customuser", "auth_tests.change_customuser"}
|
||||||
|
)
|
||||||
|
async def async_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
request = self.factory.get("/rand")
|
||||||
|
request.auser = self.auser
|
||||||
|
response = await async_view(request)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
async def test_single_permission_pass_async_view(self):
|
||||||
|
@permission_required("auth_tests.add_customuser")
|
||||||
|
async def async_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
request = self.factory.get("/rand")
|
||||||
|
request.auser = self.auser
|
||||||
|
response = await async_view(request)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
async def test_permissioned_denied_redirect_async_view(self):
|
||||||
|
@permission_required(
|
||||||
|
[
|
||||||
|
"auth_tests.add_customuser",
|
||||||
|
"auth_tests.change_customuser",
|
||||||
|
"nonexistent-permission",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
async def async_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
request = self.factory.get("/rand")
|
||||||
|
request.auser = self.auser
|
||||||
|
response = await async_view(request)
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
|
||||||
|
async def test_permissioned_denied_exception_raised_async_view(self):
|
||||||
|
@permission_required(
|
||||||
|
[
|
||||||
|
"auth_tests.add_customuser",
|
||||||
|
"auth_tests.change_customuser",
|
||||||
|
"nonexistent-permission",
|
||||||
|
],
|
||||||
|
raise_exception=True,
|
||||||
|
)
|
||||||
|
async def async_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
request = self.factory.get("/rand")
|
||||||
|
request.auser = self.auser
|
||||||
|
with self.assertRaises(PermissionDenied):
|
||||||
|
await async_view(request)
|
||||||
|
|
||||||
|
|
||||||
class UserPassesTestDecoratorTest(TestCase):
|
class UserPassesTestDecoratorTest(TestCase):
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
|
@ -162,6 +296,28 @@ class UserPassesTestDecoratorTest(TestCase):
|
||||||
)
|
)
|
||||||
cls.user_pass.user_permissions.add(*perms)
|
cls.user_pass.user_permissions.add(*perms)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def auser_pass(cls):
|
||||||
|
return cls.user_pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def auser_deny(cls):
|
||||||
|
return cls.user_deny
|
||||||
|
|
||||||
|
def test_wrapped_sync_function_is_not_coroutine_function(self):
|
||||||
|
def sync_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
wrapped_view = user_passes_test(lambda user: True)(sync_view)
|
||||||
|
self.assertIs(iscoroutinefunction(wrapped_view), False)
|
||||||
|
|
||||||
|
def test_wrapped_async_function_is_coroutine_function(self):
|
||||||
|
async def async_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
wrapped_view = user_passes_test(lambda user: True)(async_view)
|
||||||
|
self.assertIs(iscoroutinefunction(wrapped_view), True)
|
||||||
|
|
||||||
def test_decorator(self):
|
def test_decorator(self):
|
||||||
def sync_test_func(user):
|
def sync_test_func(user):
|
||||||
return bool(
|
return bool(
|
||||||
|
@ -180,3 +336,56 @@ class UserPassesTestDecoratorTest(TestCase):
|
||||||
request.user = self.user_deny
|
request.user = self.user_deny
|
||||||
response = sync_view(request)
|
response = sync_view(request)
|
||||||
self.assertEqual(response.status_code, 302)
|
self.assertEqual(response.status_code, 302)
|
||||||
|
|
||||||
|
def test_decorator_async_test_func(self):
|
||||||
|
async def async_test_func(user):
|
||||||
|
return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
|
||||||
|
|
||||||
|
@user_passes_test(async_test_func)
|
||||||
|
def sync_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
request = self.factory.get("/rand")
|
||||||
|
request.user = self.user_pass
|
||||||
|
response = sync_view(request)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
request.user = self.user_deny
|
||||||
|
response = sync_view(request)
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
|
||||||
|
async def test_decorator_async_view(self):
|
||||||
|
def sync_test_func(user):
|
||||||
|
return bool(
|
||||||
|
models.Group.objects.filter(name__istartswith=user.username).exists()
|
||||||
|
)
|
||||||
|
|
||||||
|
@user_passes_test(sync_test_func)
|
||||||
|
async def async_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
request = self.factory.get("/rand")
|
||||||
|
request.auser = self.auser_pass
|
||||||
|
response = await async_view(request)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
request.auser = self.auser_deny
|
||||||
|
response = await async_view(request)
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
|
||||||
|
async def test_decorator_async_view_async_test_func(self):
|
||||||
|
async def async_test_func(user):
|
||||||
|
return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"])
|
||||||
|
|
||||||
|
@user_passes_test(async_test_func)
|
||||||
|
async def async_view(request):
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
request = self.factory.get("/rand")
|
||||||
|
request.auser = self.auser_pass
|
||||||
|
response = await async_view(request)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
request.auser = self.auser_deny
|
||||||
|
response = await async_view(request)
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
|
Loading…
Reference in New Issue