Fixed #28337 -- Preserved extra headers of requests made with django.test.Client in assertRedirects().

Co-Authored-By: Hasan Ramezani <hasan.r67@gmail.com>
This commit is contained in:
Patrick Jenkins 2017-08-17 17:10:10 -07:00 committed by Mariusz Felisiak
parent 3ca9df51c7
commit 46e74a5256
5 changed files with 51 additions and 2 deletions

View File

@ -444,6 +444,7 @@ class Client(RequestFactory):
self.handler = ClientHandler(enforce_csrf_checks) self.handler = ClientHandler(enforce_csrf_checks)
self.raise_request_exception = raise_request_exception self.raise_request_exception = raise_request_exception
self.exc_info = None self.exc_info = None
self.extra = None
def store_exc_info(self, **kwargs): def store_exc_info(self, **kwargs):
"""Store exceptions when they are generated by a view.""" """Store exceptions when they are generated by a view."""
@ -515,6 +516,7 @@ class Client(RequestFactory):
def get(self, path, data=None, follow=False, secure=False, **extra): def get(self, path, data=None, follow=False, secure=False, **extra):
"""Request a response from the server using GET.""" """Request a response from the server using GET."""
self.extra = extra
response = super().get(path, data=data, secure=secure, **extra) response = super().get(path, data=data, secure=secure, **extra)
if follow: if follow:
response = self._handle_redirects(response, data=data, **extra) response = self._handle_redirects(response, data=data, **extra)
@ -523,6 +525,7 @@ class Client(RequestFactory):
def post(self, path, data=None, content_type=MULTIPART_CONTENT, def post(self, path, data=None, content_type=MULTIPART_CONTENT,
follow=False, secure=False, **extra): follow=False, secure=False, **extra):
"""Request a response from the server using POST.""" """Request a response from the server using POST."""
self.extra = extra
response = super().post(path, data=data, content_type=content_type, secure=secure, **extra) response = super().post(path, data=data, content_type=content_type, secure=secure, **extra)
if follow: if follow:
response = self._handle_redirects(response, data=data, content_type=content_type, **extra) response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
@ -530,6 +533,7 @@ class Client(RequestFactory):
def head(self, path, data=None, follow=False, secure=False, **extra): def head(self, path, data=None, follow=False, secure=False, **extra):
"""Request a response from the server using HEAD.""" """Request a response from the server using HEAD."""
self.extra = extra
response = super().head(path, data=data, secure=secure, **extra) response = super().head(path, data=data, secure=secure, **extra)
if follow: if follow:
response = self._handle_redirects(response, data=data, **extra) response = self._handle_redirects(response, data=data, **extra)
@ -538,6 +542,7 @@ class Client(RequestFactory):
def options(self, path, data='', content_type='application/octet-stream', def options(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra): follow=False, secure=False, **extra):
"""Request a response from the server using OPTIONS.""" """Request a response from the server using OPTIONS."""
self.extra = extra
response = super().options(path, data=data, content_type=content_type, secure=secure, **extra) response = super().options(path, data=data, content_type=content_type, secure=secure, **extra)
if follow: if follow:
response = self._handle_redirects(response, data=data, content_type=content_type, **extra) response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
@ -546,6 +551,7 @@ class Client(RequestFactory):
def put(self, path, data='', content_type='application/octet-stream', def put(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra): follow=False, secure=False, **extra):
"""Send a resource to the server using PUT.""" """Send a resource to the server using PUT."""
self.extra = extra
response = super().put(path, data=data, content_type=content_type, secure=secure, **extra) response = super().put(path, data=data, content_type=content_type, secure=secure, **extra)
if follow: if follow:
response = self._handle_redirects(response, data=data, content_type=content_type, **extra) response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
@ -554,6 +560,7 @@ class Client(RequestFactory):
def patch(self, path, data='', content_type='application/octet-stream', def patch(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra): follow=False, secure=False, **extra):
"""Send a resource to the server using PATCH.""" """Send a resource to the server using PATCH."""
self.extra = extra
response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra) response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra)
if follow: if follow:
response = self._handle_redirects(response, data=data, content_type=content_type, **extra) response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
@ -562,6 +569,7 @@ class Client(RequestFactory):
def delete(self, path, data='', content_type='application/octet-stream', def delete(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra): follow=False, secure=False, **extra):
"""Send a DELETE request to the server.""" """Send a DELETE request to the server."""
self.extra = extra
response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra) response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra)
if follow: if follow:
response = self._handle_redirects(response, data=data, content_type=content_type, **extra) response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
@ -569,6 +577,7 @@ class Client(RequestFactory):
def trace(self, path, data='', follow=False, secure=False, **extra): def trace(self, path, data='', follow=False, secure=False, **extra):
"""Send a TRACE request to the server.""" """Send a TRACE request to the server."""
self.extra = extra
response = super().trace(path, data=data, secure=secure, **extra) response = super().trace(path, data=data, secure=secure, **extra)
if follow: if follow:
response = self._handle_redirects(response, data=data, **extra) response = self._handle_redirects(response, data=data, **extra)

View File

@ -347,10 +347,15 @@ class SimpleTestCase(unittest.TestCase):
"Otherwise, use assertRedirects(..., fetch_redirect_response=False)." "Otherwise, use assertRedirects(..., fetch_redirect_response=False)."
% (url, domain) % (url, domain)
) )
redirect_response = response.client.get(path, QueryDict(query), secure=(scheme == 'https'))
# Get the redirection page, using the same client that was used # Get the redirection page, using the same client that was used
# to obtain the original response. # to obtain the original response.
extra = response.client.extra or {}
redirect_response = response.client.get(
path,
QueryDict(query),
secure=(scheme == 'https'),
**extra,
)
self.assertEqual( self.assertEqual(
redirect_response.status_code, target_status_code, redirect_response.status_code, target_status_code,
msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)" msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)"

View File

@ -508,6 +508,27 @@ class AssertRedirectsTests(SimpleTestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.assertRedirects(response, 'http://testserver/secure_view/', status_code=302) self.assertRedirects(response, 'http://testserver/secure_view/', status_code=302)
def test_redirect_fetch_redirect_response(self):
"""Preserve extra headers of requests made with django.test.Client."""
methods = (
'get', 'post', 'head', 'options', 'put', 'patch', 'delete', 'trace',
)
for method in methods:
with self.subTest(method=method):
req_method = getattr(self.client, method)
response = req_method(
'/redirect_based_on_extra_headers_1/',
follow=False,
HTTP_REDIRECT='val',
)
self.assertRedirects(
response,
'/redirect_based_on_extra_headers_2/',
fetch_redirect_response=True,
status_code=302,
target_status_code=302,
)
@override_settings(ROOT_URLCONF='test_client_regress.urls') @override_settings(ROOT_URLCONF='test_client_regress.urls')
class AssertFormErrorTests(SimpleTestCase): class AssertFormErrorTests(SimpleTestCase):

View File

@ -25,6 +25,8 @@ urlpatterns = [
path('circular_redirect_2/', RedirectView.as_view(url='/circular_redirect_3/')), path('circular_redirect_2/', RedirectView.as_view(url='/circular_redirect_3/')),
path('circular_redirect_3/', RedirectView.as_view(url='/circular_redirect_1/')), path('circular_redirect_3/', RedirectView.as_view(url='/circular_redirect_1/')),
path('redirect_other_host/', RedirectView.as_view(url='https://otherserver:8443/no_template_view/')), path('redirect_other_host/', RedirectView.as_view(url='https://otherserver:8443/no_template_view/')),
path('redirect_based_on_extra_headers_1/', views.redirect_based_on_extra_headers_1_view),
path('redirect_based_on_extra_headers_2/', views.redirect_based_on_extra_headers_2_view),
path('set_session/', views.set_session_view), path('set_session/', views.set_session_view),
path('check_session/', views.check_session_view), path('check_session/', views.check_session_view),
path('request_methods/', views.request_methods_view), path('request_methods/', views.request_methods_view),

View File

@ -154,3 +154,15 @@ def render_template_multiple_times(request):
"""A view that renders a template multiple times.""" """A view that renders a template multiple times."""
return HttpResponse( return HttpResponse(
render_to_string('base.html') + render_to_string('base.html')) render_to_string('base.html') + render_to_string('base.html'))
def redirect_based_on_extra_headers_1_view(request):
if 'HTTP_REDIRECT' in request.META:
return HttpResponseRedirect('/redirect_based_on_extra_headers_2/')
return HttpResponse()
def redirect_based_on_extra_headers_2_view(request):
if 'HTTP_REDIRECT' in request.META:
return HttpResponseRedirect('/redirects/further/more/')
return HttpResponse()