diff --git a/django/contrib/auth/views.py b/django/contrib/auth/views.py index 48b27d5a61..2d19d6f884 100644 --- a/django/contrib/auth/views.py +++ b/django/contrib/auth/views.py @@ -35,9 +35,13 @@ UserModel = get_user_model() class RedirectURLMixin: + next_page = None redirect_field_name = REDIRECT_FIELD_NAME success_url_allowed_hosts = set() + def get_success_url(self): + return self.get_redirect_url() or self.get_default_redirect_url() + def get_redirect_url(self): """Return the user-originating redirect URL if it's safe.""" redirect_to = self.request.POST.get( @@ -53,6 +57,12 @@ class RedirectURLMixin: def get_success_url_allowed_hosts(self): return {self.request.get_host(), *self.success_url_allowed_hosts} + def get_default_redirect_url(self): + """Return the default redirect URL.""" + if self.next_page: + return resolve_url(self.next_page) + raise ImproperlyConfigured("No URL to redirect to. Provide a next_page.") + class LoginView(RedirectURLMixin, FormView): """ @@ -61,7 +71,6 @@ class LoginView(RedirectURLMixin, FormView): form_class = AuthenticationForm authentication_form = None - next_page = None template_name = "registration/login.html" redirect_authenticated_user = False extra_context = None @@ -80,9 +89,6 @@ class LoginView(RedirectURLMixin, FormView): return HttpResponseRedirect(redirect_to) return super().dispatch(request, *args, **kwargs) - def get_success_url(self): - return self.get_redirect_url() or self.get_default_redirect_url() - def get_default_redirect_url(self): """Return the default redirect URL.""" if self.next_page: @@ -125,7 +131,6 @@ class LogoutView(RedirectURLMixin, TemplateView): # RemovedInDjango50Warning: when the deprecation ends, remove "get" and # "head" from http_method_names. http_method_names = ["get", "head", "post", "options"] - next_page = None template_name = "registration/logged_out.html" extra_context = None @@ -163,9 +168,6 @@ class LogoutView(RedirectURLMixin, TemplateView): else: return self.request.path - def get_success_url(self): - return self.get_redirect_url() or self.get_default_redirect_url() - def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) current_site = get_current_site(self.request) diff --git a/tests/auth_tests/test_views.py b/tests/auth_tests/test_views.py index 86424636b7..2941716422 100644 --- a/tests/auth_tests/test_views.py +++ b/tests/auth_tests/test_views.py @@ -18,6 +18,7 @@ from django.contrib.auth.models import Permission, User from django.contrib.auth.views import ( INTERNAL_RESET_SESSION_TOKEN, LoginView, + RedirectURLMixin, logout_then_login, redirect_to_login, ) @@ -40,6 +41,20 @@ from .models import CustomUser, UUIDUser from .settings import AUTH_TEMPLATES +class RedirectURLMixinTests(TestCase): + @override_settings(ROOT_URLCONF="auth_tests.urls") + def test_get_default_redirect_url_next_page(self): + class RedirectURLView(RedirectURLMixin): + next_page = "/custom/" + + self.assertEqual(RedirectURLView().get_default_redirect_url(), "/custom/") + + def test_get_default_redirect_url_no_next_page(self): + msg = "No URL to redirect to. Provide a next_page." + with self.assertRaisesMessage(ImproperlyConfigured, msg): + RedirectURLMixin().get_default_redirect_url() + + @override_settings( LANGUAGES=[("en", "English")], LANGUAGE_CODE="en",