mirror of https://github.com/django/django.git
Fixed #28337 -- Preserved extra headers of requests made with django.test.Client in assertRedirects().
Co-Authored-By: Hasan Ramezani <hasan.r67@gmail.com>
This commit is contained in:
parent
3ca9df51c7
commit
46e74a5256
|
@ -444,6 +444,7 @@ class Client(RequestFactory):
|
|||
self.handler = ClientHandler(enforce_csrf_checks)
|
||||
self.raise_request_exception = raise_request_exception
|
||||
self.exc_info = None
|
||||
self.extra = None
|
||||
|
||||
def store_exc_info(self, **kwargs):
|
||||
"""Store exceptions when they are generated by a view."""
|
||||
|
@ -515,6 +516,7 @@ class Client(RequestFactory):
|
|||
|
||||
def get(self, path, data=None, follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using GET."""
|
||||
self.extra = extra
|
||||
response = super().get(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
|
@ -523,6 +525,7 @@ class Client(RequestFactory):
|
|||
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
|
||||
follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using POST."""
|
||||
self.extra = extra
|
||||
response = super().post(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
|
@ -530,6 +533,7 @@ class Client(RequestFactory):
|
|||
|
||||
def head(self, path, data=None, follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using HEAD."""
|
||||
self.extra = extra
|
||||
response = super().head(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
|
@ -538,6 +542,7 @@ class Client(RequestFactory):
|
|||
def options(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using OPTIONS."""
|
||||
self.extra = extra
|
||||
response = super().options(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
|
@ -546,6 +551,7 @@ class Client(RequestFactory):
|
|||
def put(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a resource to the server using PUT."""
|
||||
self.extra = extra
|
||||
response = super().put(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
|
@ -554,6 +560,7 @@ class Client(RequestFactory):
|
|||
def patch(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a resource to the server using PATCH."""
|
||||
self.extra = extra
|
||||
response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
|
@ -562,6 +569,7 @@ class Client(RequestFactory):
|
|||
def delete(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a DELETE request to the server."""
|
||||
self.extra = extra
|
||||
response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
|
@ -569,6 +577,7 @@ class Client(RequestFactory):
|
|||
|
||||
def trace(self, path, data='', follow=False, secure=False, **extra):
|
||||
"""Send a TRACE request to the server."""
|
||||
self.extra = extra
|
||||
response = super().trace(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
|
|
|
@ -347,10 +347,15 @@ class SimpleTestCase(unittest.TestCase):
|
|||
"Otherwise, use assertRedirects(..., fetch_redirect_response=False)."
|
||||
% (url, domain)
|
||||
)
|
||||
redirect_response = response.client.get(path, QueryDict(query), secure=(scheme == 'https'))
|
||||
|
||||
# Get the redirection page, using the same client that was used
|
||||
# to obtain the original response.
|
||||
extra = response.client.extra or {}
|
||||
redirect_response = response.client.get(
|
||||
path,
|
||||
QueryDict(query),
|
||||
secure=(scheme == 'https'),
|
||||
**extra,
|
||||
)
|
||||
self.assertEqual(
|
||||
redirect_response.status_code, target_status_code,
|
||||
msg_prefix + "Couldn't retrieve redirection page '%s': response code was %d (expected %d)"
|
||||
|
|
|
@ -508,6 +508,27 @@ class AssertRedirectsTests(SimpleTestCase):
|
|||
with self.assertRaises(AssertionError):
|
||||
self.assertRedirects(response, 'http://testserver/secure_view/', status_code=302)
|
||||
|
||||
def test_redirect_fetch_redirect_response(self):
|
||||
"""Preserve extra headers of requests made with django.test.Client."""
|
||||
methods = (
|
||||
'get', 'post', 'head', 'options', 'put', 'patch', 'delete', 'trace',
|
||||
)
|
||||
for method in methods:
|
||||
with self.subTest(method=method):
|
||||
req_method = getattr(self.client, method)
|
||||
response = req_method(
|
||||
'/redirect_based_on_extra_headers_1/',
|
||||
follow=False,
|
||||
HTTP_REDIRECT='val',
|
||||
)
|
||||
self.assertRedirects(
|
||||
response,
|
||||
'/redirect_based_on_extra_headers_2/',
|
||||
fetch_redirect_response=True,
|
||||
status_code=302,
|
||||
target_status_code=302,
|
||||
)
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF='test_client_regress.urls')
|
||||
class AssertFormErrorTests(SimpleTestCase):
|
||||
|
|
|
@ -25,6 +25,8 @@ urlpatterns = [
|
|||
path('circular_redirect_2/', RedirectView.as_view(url='/circular_redirect_3/')),
|
||||
path('circular_redirect_3/', RedirectView.as_view(url='/circular_redirect_1/')),
|
||||
path('redirect_other_host/', RedirectView.as_view(url='https://otherserver:8443/no_template_view/')),
|
||||
path('redirect_based_on_extra_headers_1/', views.redirect_based_on_extra_headers_1_view),
|
||||
path('redirect_based_on_extra_headers_2/', views.redirect_based_on_extra_headers_2_view),
|
||||
path('set_session/', views.set_session_view),
|
||||
path('check_session/', views.check_session_view),
|
||||
path('request_methods/', views.request_methods_view),
|
||||
|
|
|
@ -154,3 +154,15 @@ def render_template_multiple_times(request):
|
|||
"""A view that renders a template multiple times."""
|
||||
return HttpResponse(
|
||||
render_to_string('base.html') + render_to_string('base.html'))
|
||||
|
||||
|
||||
def redirect_based_on_extra_headers_1_view(request):
|
||||
if 'HTTP_REDIRECT' in request.META:
|
||||
return HttpResponseRedirect('/redirect_based_on_extra_headers_2/')
|
||||
return HttpResponse()
|
||||
|
||||
|
||||
def redirect_based_on_extra_headers_2_view(request):
|
||||
if 'HTTP_REDIRECT' in request.META:
|
||||
return HttpResponseRedirect('/redirects/further/more/')
|
||||
return HttpResponse()
|
||||
|
|
Loading…
Reference in New Issue