diff --git a/django/contrib/auth/decorators.py b/django/contrib/auth/decorators.py index b6481ea52c..2f216ab758 100644 --- a/django/contrib/auth/decorators.py +++ b/django/contrib/auth/decorators.py @@ -8,19 +8,9 @@ def user_passes_test(test_func, login_url=None, redirect_field_name=REDIRECT_FIE redirecting to the log-in page if necessary. The test should be a callable that takes the user object and returns True if the user passes. """ - if not login_url: - from django.conf import settings - login_url = settings.LOGIN_URL - def _dec(view_func): - def _checklogin(request, *args, **kwargs): - if test_func(request.user): - return view_func(request, *args, **kwargs) - return HttpResponseRedirect('%s?%s=%s' % (login_url, redirect_field_name, urlquote(request.get_full_path()))) - _checklogin.__doc__ = view_func.__doc__ - _checklogin.__dict__ = view_func.__dict__ - - return _checklogin - return _dec + def decorate(view_func): + return _CheckLogin(view_func, test_func, login_url, redirect_field_name) + return decorate def login_required(function=None, redirect_field_name=REDIRECT_FIELD_NAME): """ @@ -42,3 +32,33 @@ def permission_required(perm, login_url=None): """ return user_passes_test(lambda u: u.has_perm(perm), login_url=login_url) +class _CheckLogin(object): + """ + Class that checks that the user passes the given test, redirecting to + the log-in page if necessary. If the test is passed, the view function + is invoked. The test should be a callable that takes the user object + and returns True if the user passes. + + We use a class here so that we can define __get__. This way, when a + _CheckLogin object is used as a method decorator, the view function + is properly bound to its instance. + """ + def __init__(self, view_func, test_func, login_url=None, redirect_field_name=REDIRECT_FIELD_NAME): + if not login_url: + from django.conf import settings + login_url = settings.LOGIN_URL + self.view_func = view_func + self.test_func = test_func + self.login_url = login_url + self.redirect_field_name = redirect_field_name + + def __get__(self, obj, cls=None): + view_func = self.view_func.__get__(obj, cls) + return _CheckLogin(view_func, self.test_func, self.login_url, self.redirect_field_name) + + def __call__(self, request, *args, **kwargs): + if self.test_func(request.user): + return self.view_func(request, *args, **kwargs) + path = urlquote(request.get_full_path()) + tup = self.login_url, self.redirect_field_name, path + return HttpResponseRedirect('%s?%s=%s' % tup) diff --git a/tests/modeltests/test_client/models.py b/tests/modeltests/test_client/models.py index 2df5d3cf77..95f68c6744 100644 --- a/tests/modeltests/test_client/models.py +++ b/tests/modeltests/test_client/models.py @@ -250,6 +250,22 @@ class ClientTest(TestCase): self.assertEqual(response.status_code, 200) self.assertEqual(response.context['user'].username, 'testclient') + def test_view_with_method_login(self): + "Request a page that is protected with a @login_required method" + + # Get the page without logging in. Should result in 302. + response = self.client.get('/test_client/login_protected_method_view/') + self.assertRedirects(response, 'http://testserver/accounts/login/?next=/test_client/login_protected_method_view/') + + # Log in + login = self.client.login(username='testclient', password='password') + self.failUnless(login, 'Could not log in') + + # Request a page that requires a login + response = self.client.get('/test_client/login_protected_method_view/') + self.assertEqual(response.status_code, 200) + self.assertEqual(response.context['user'].username, 'testclient') + def test_view_with_login_and_custom_redirect(self): "Request a page that is protected with @login_required(redirect_field_name='redirect_to')" @@ -295,6 +311,40 @@ class ClientTest(TestCase): response = self.client.get('/test_client/login_protected_view/') self.assertRedirects(response, 'http://testserver/accounts/login/?next=/test_client/login_protected_view/') + def test_view_with_permissions(self): + "Request a page that is protected with @permission_required" + + # Get the page without logging in. Should result in 302. + response = self.client.get('/test_client/permission_protected_view/') + self.assertRedirects(response, 'http://testserver/accounts/login/?next=/test_client/permission_protected_view/') + + # Log in + login = self.client.login(username='testclient', password='password') + self.failUnless(login, 'Could not log in') + + # Log in with wrong permissions. Should result in 302. + response = self.client.get('/test_client/permission_protected_view/') + self.assertRedirects(response, 'http://testserver/accounts/login/?next=/test_client/permission_protected_view/') + + # TODO: Log in with right permissions and request the page again + + def test_view_with_method_permissions(self): + "Request a page that is protected with a @permission_required method" + + # Get the page without logging in. Should result in 302. + response = self.client.get('/test_client/permission_protected_method_view/') + self.assertRedirects(response, 'http://testserver/accounts/login/?next=/test_client/permission_protected_method_view/') + + # Log in + login = self.client.login(username='testclient', password='password') + self.failUnless(login, 'Could not log in') + + # Log in with wrong permissions. Should result in 302. + response = self.client.get('/test_client/permission_protected_method_view/') + self.assertRedirects(response, 'http://testserver/accounts/login/?next=/test_client/permission_protected_method_view/') + + # TODO: Log in with right permissions and request the page again + def test_session_modifying_view(self): "Request a page that modifies the session" # Session value isn't set initially diff --git a/tests/modeltests/test_client/urls.py b/tests/modeltests/test_client/urls.py index 3779a0ecd1..09ee7eaf34 100644 --- a/tests/modeltests/test_client/urls.py +++ b/tests/modeltests/test_client/urls.py @@ -13,7 +13,10 @@ urlpatterns = patterns('', (r'^form_view/$', views.form_view), (r'^form_view_with_template/$', views.form_view_with_template), (r'^login_protected_view/$', views.login_protected_view), + (r'^login_protected_method_view/$', views.login_protected_method_view), (r'^login_protected_view_custom_redirect/$', views.login_protected_view_changed_redirect), + (r'^permission_protected_view/$', views.permission_protected_view), + (r'^permission_protected_method_view/$', views.permission_protected_method_view), (r'^session_view/$', views.session_view), (r'^broken_view/$', views.broken_view), (r'^mail_sending_view/$', views.mail_sending_view), diff --git a/tests/modeltests/test_client/views.py b/tests/modeltests/test_client/views.py index c406e17d30..3f4a54c5bd 100644 --- a/tests/modeltests/test_client/views.py +++ b/tests/modeltests/test_client/views.py @@ -3,7 +3,7 @@ from xml.dom.minidom import parseString from django.core.mail import EmailMessage, SMTPConnection from django.template import Context, Template from django.http import HttpResponse, HttpResponseRedirect, HttpResponseNotFound -from django.contrib.auth.decorators import login_required +from django.contrib.auth.decorators import login_required, permission_required from django.newforms.forms import Form from django.newforms import fields from django.shortcuts import render_to_response @@ -130,6 +130,38 @@ def login_protected_view_changed_redirect(request): return HttpResponse(t.render(c)) login_protected_view_changed_redirect = login_required(redirect_field_name="redirect_to")(login_protected_view_changed_redirect) +def permission_protected_view(request): + "A simple view that is permission protected." + t = Template('This is a permission protected test. ' + 'Username is {{ user.username }}. ' + 'Permissions are {{ user.get_all_permissions }}.' , + name='Permissions Template') + c = Context({'user': request.user}) + return HttpResponse(t.render(c)) +permission_protected_view = permission_required('modeltests.test_perm')(permission_protected_view) + +class _ViewManager(object): + def login_protected_view(self, request): + t = Template('This is a login protected test using a method. ' + 'Username is {{ user.username }}.', + name='Login Method Template') + c = Context({'user': request.user}) + return HttpResponse(t.render(c)) + login_protected_view = login_required(login_protected_view) + + def permission_protected_view(self, request): + t = Template('This is a permission protected test using a method. ' + 'Username is {{ user.username }}. ' + 'Permissions are {{ user.get_all_permissions }}.' , + name='Permissions Template') + c = Context({'user': request.user}) + return HttpResponse(t.render(c)) + permission_protected_view = permission_required('modeltests.test_perm')(permission_protected_view) + +_view_manager = _ViewManager() +login_protected_method_view = _view_manager.login_protected_view +permission_protected_method_view = _view_manager.permission_protected_view + def session_view(request): "A view that modifies the session" request.session['tobacconist'] = 'hovercraft'