From 46e74a525671f2eef09021c06933c45b49f9d421 Mon Sep 17 00:00:00 2001 From: Patrick Jenkins Date: Thu, 17 Aug 2017 17:10:10 -0700 Subject: [PATCH] Fixed #28337 -- Preserved extra headers of requests made with django.test.Client in assertRedirects(). Co-Authored-By: Hasan Ramezani --- django/test/client.py | 9 +++++++++ django/test/testcases.py | 9 +++++++-- tests/test_client_regress/tests.py | 21 +++++++++++++++++++++ tests/test_client_regress/urls.py | 2 ++ tests/test_client_regress/views.py | 12 ++++++++++++ 5 files changed, 51 insertions(+), 2 deletions(-) diff --git a/django/test/client.py b/django/test/client.py index 9b3fdf59364..98ede36499c 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -444,6 +444,7 @@ class Client(RequestFactory): self.handler = ClientHandler(enforce_csrf_checks) self.raise_request_exception = raise_request_exception self.exc_info = None + self.extra = None def store_exc_info(self, **kwargs): """Store exceptions when they are generated by a view.""" @@ -515,6 +516,7 @@ class Client(RequestFactory): def get(self, path, data=None, follow=False, secure=False, **extra): """Request a response from the server using GET.""" + self.extra = extra response = super().get(path, data=data, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, **extra) @@ -523,6 +525,7 @@ class Client(RequestFactory): def post(self, path, data=None, content_type=MULTIPART_CONTENT, follow=False, secure=False, **extra): """Request a response from the server using POST.""" + self.extra = extra response = super().post(path, data=data, content_type=content_type, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, content_type=content_type, **extra) @@ -530,6 +533,7 @@ class Client(RequestFactory): def head(self, path, data=None, follow=False, secure=False, **extra): """Request a response from the server using HEAD.""" + self.extra = extra response = super().head(path, data=data, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, **extra) @@ -538,6 +542,7 @@ class Client(RequestFactory): def options(self, path, data='', content_type='application/octet-stream', follow=False, secure=False, **extra): """Request a response from the server using OPTIONS.""" + self.extra = extra response = super().options(path, data=data, content_type=content_type, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, content_type=content_type, **extra) @@ -546,6 +551,7 @@ class Client(RequestFactory): def put(self, path, data='', content_type='application/octet-stream', follow=False, secure=False, **extra): """Send a resource to the server using PUT.""" + self.extra = extra response = super().put(path, data=data, content_type=content_type, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, content_type=content_type, **extra) @@ -554,6 +560,7 @@ class Client(RequestFactory): def patch(self, path, data='', content_type='application/octet-stream', follow=False, secure=False, **extra): """Send a resource to the server using PATCH.""" + self.extra = extra response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, content_type=content_type, **extra) @@ -562,6 +569,7 @@ class Client(RequestFactory): def delete(self, path, data='', content_type='application/octet-stream', follow=False, secure=False, **extra): """Send a DELETE request to the server.""" + self.extra = extra response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, content_type=content_type, **extra) @@ -569,6 +577,7 @@ class Client(RequestFactory): def trace(self, path, data='', follow=False, secure=False, **extra): """Send a TRACE request to the server.""" + self.extra = extra response = super().trace(path, data=data, secure=secure, **extra) if follow: response = self._handle_redirects(response, data=data, **extra) diff --git a/django/test/testcases.py b/django/test/testcases.py index 1a7414b91d2..468c0c4fbca 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -347,10 +347,15 @@ class SimpleTestCase(unittest.TestCase): "Otherwise, use assertRedirects(..., fetch_redirect_response=False)." % (url, domain) ) - redirect_response = response.client.get(path, QueryDict(query), secure=(scheme == 'https')) - # Get the redirection page, using the same client that was used # to obtain the original response. + extra = response.client.extra or {} + redirect_response = response.client.get( + path, + QueryDict(query), + secure=(scheme == 'https'), + **extra, + ) self.assertEqual( redirect_response.status_code, target_status_code, msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)" diff --git a/tests/test_client_regress/tests.py b/tests/test_client_regress/tests.py index 0763ae67b17..1a8d069ea1f 100644 --- a/tests/test_client_regress/tests.py +++ b/tests/test_client_regress/tests.py @@ -508,6 +508,27 @@ class AssertRedirectsTests(SimpleTestCase): with self.assertRaises(AssertionError): self.assertRedirects(response, 'http://testserver/secure_view/', status_code=302) + def test_redirect_fetch_redirect_response(self): + """Preserve extra headers of requests made with django.test.Client.""" + methods = ( + 'get', 'post', 'head', 'options', 'put', 'patch', 'delete', 'trace', + ) + for method in methods: + with self.subTest(method=method): + req_method = getattr(self.client, method) + response = req_method( + '/redirect_based_on_extra_headers_1/', + follow=False, + HTTP_REDIRECT='val', + ) + self.assertRedirects( + response, + '/redirect_based_on_extra_headers_2/', + fetch_redirect_response=True, + status_code=302, + target_status_code=302, + ) + @override_settings(ROOT_URLCONF='test_client_regress.urls') class AssertFormErrorTests(SimpleTestCase): diff --git a/tests/test_client_regress/urls.py b/tests/test_client_regress/urls.py index 8590e5953a5..19c9551cc27 100644 --- a/tests/test_client_regress/urls.py +++ b/tests/test_client_regress/urls.py @@ -25,6 +25,8 @@ urlpatterns = [ path('circular_redirect_2/', RedirectView.as_view(url='/circular_redirect_3/')), path('circular_redirect_3/', RedirectView.as_view(url='/circular_redirect_1/')), path('redirect_other_host/', RedirectView.as_view(url='https://otherserver:8443/no_template_view/')), + path('redirect_based_on_extra_headers_1/', views.redirect_based_on_extra_headers_1_view), + path('redirect_based_on_extra_headers_2/', views.redirect_based_on_extra_headers_2_view), path('set_session/', views.set_session_view), path('check_session/', views.check_session_view), path('request_methods/', views.request_methods_view), diff --git a/tests/test_client_regress/views.py b/tests/test_client_regress/views.py index 94b0ab29a8c..db4206ce75f 100644 --- a/tests/test_client_regress/views.py +++ b/tests/test_client_regress/views.py @@ -154,3 +154,15 @@ def render_template_multiple_times(request): """A view that renders a template multiple times.""" return HttpResponse( render_to_string('base.html') + render_to_string('base.html')) + + +def redirect_based_on_extra_headers_1_view(request): + if 'HTTP_REDIRECT' in request.META: + return HttpResponseRedirect('/redirect_based_on_extra_headers_2/') + return HttpResponse() + + +def redirect_based_on_extra_headers_2_view(request): + if 'HTTP_REDIRECT' in request.META: + return HttpResponseRedirect('/redirects/further/more/') + return HttpResponse()