Fixed #34757 -- Added support for following redirects to AsyncClient.

This commit is contained in:
Olivier Tabone 2023-08-04 09:14:19 +02:00 committed by Mariusz Felisiak
parent 1ac397674b
commit 3f8dbe267d
4 changed files with 252 additions and 12 deletions

View File

@ -705,9 +705,6 @@ class AsyncRequestFactory(RequestFactory):
]
)
s["_body_file"] = FakePayload(data)
follow = extra.pop("follow", None)
if follow is not None:
s["follow"] = follow
if query_string := extra.pop("QUERY_STRING", None):
s["query_string"] = query_string
if headers:
@ -1296,10 +1293,6 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
query environment, which can be overridden using the arguments to the
request.
"""
if "follow" in request:
raise NotImplementedError(
"AsyncClient request methods do not accept the follow parameter."
)
scope = self._base_scope(**request)
# Curry a data dictionary into an instance of the template renderer
# callback function.
@ -1338,3 +1331,234 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
if response.cookies:
self.cookies.update(response.cookies)
return response
async def get(
self,
path,
data=None,
follow=False,
secure=False,
*,
headers=None,
**extra,
):
"""Request a response from the server using GET."""
self.extra = extra
self.headers = headers
response = await super().get(
path, data=data, secure=secure, headers=headers, **extra
)
if follow:
response = await self._ahandle_redirects(
response, data=data, headers=headers, **extra
)
return response
async def post(
self,
path,
data=None,
content_type=MULTIPART_CONTENT,
follow=False,
secure=False,
*,
headers=None,
**extra,
):
"""Request a response from the server using POST."""
self.extra = extra
self.headers = headers
response = await super().post(
path,
data=data,
content_type=content_type,
secure=secure,
headers=headers,
**extra,
)
if follow:
response = await self._ahandle_redirects(
response, data=data, content_type=content_type, headers=headers, **extra
)
return response
async def head(
self,
path,
data=None,
follow=False,
secure=False,
*,
headers=None,
**extra,
):
"""Request a response from the server using HEAD."""
self.extra = extra
self.headers = headers
response = await super().head(
path, data=data, secure=secure, headers=headers, **extra
)
if follow:
response = await self._ahandle_redirects(
response, data=data, headers=headers, **extra
)
return response
async def options(
self,
path,
data="",
content_type="application/octet-stream",
follow=False,
secure=False,
*,
headers=None,
**extra,
):
"""Request a response from the server using OPTIONS."""
self.extra = extra
self.headers = headers
response = await super().options(
path,
data=data,
content_type=content_type,
secure=secure,
headers=headers,
**extra,
)
if follow:
response = await self._ahandle_redirects(
response, data=data, content_type=content_type, headers=headers, **extra
)
return response
async def put(
self,
path,
data="",
content_type="application/octet-stream",
follow=False,
secure=False,
*,
headers=None,
**extra,
):
"""Send a resource to the server using PUT."""
self.extra = extra
self.headers = headers
response = await super().put(
path,
data=data,
content_type=content_type,
secure=secure,
headers=headers,
**extra,
)
if follow:
response = await self._ahandle_redirects(
response, data=data, content_type=content_type, headers=headers, **extra
)
return response
async def patch(
self,
path,
data="",
content_type="application/octet-stream",
follow=False,
secure=False,
*,
headers=None,
**extra,
):
"""Send a resource to the server using PATCH."""
self.extra = extra
self.headers = headers
response = await super().patch(
path,
data=data,
content_type=content_type,
secure=secure,
headers=headers,
**extra,
)
if follow:
response = await self._ahandle_redirects(
response, data=data, content_type=content_type, headers=headers, **extra
)
return response
async def delete(
self,
path,
data="",
content_type="application/octet-stream",
follow=False,
secure=False,
*,
headers=None,
**extra,
):
"""Send a DELETE request to the server."""
self.extra = extra
self.headers = headers
response = await super().delete(
path,
data=data,
content_type=content_type,
secure=secure,
headers=headers,
**extra,
)
if follow:
response = await self._ahandle_redirects(
response, data=data, content_type=content_type, headers=headers, **extra
)
return response
async def trace(
self,
path,
data="",
follow=False,
secure=False,
*,
headers=None,
**extra,
):
"""Send a TRACE request to the server."""
self.extra = extra
self.headers = headers
response = await super().trace(
path, data=data, secure=secure, headers=headers, **extra
)
if follow:
response = await self._ahandle_redirects(
response, data=data, headers=headers, **extra
)
return response
async def _ahandle_redirects(
self,
response,
data="",
content_type="",
headers=None,
**extra,
):
"""
Follow any redirects by requesting responses from the server using GET.
"""
response.redirect_chain = []
while response.status_code in REDIRECT_STATUS_CODES:
redirect_chain = response.redirect_chain
response = await self._follow_redirect(
response,
data=data,
content_type=content_type,
headers=headers,
**extra,
)
response.redirect_chain = redirect_chain
self._ensure_redirects_not_cyclic(response)
return response

View File

@ -433,6 +433,8 @@ Tests
:meth:`~django.test.Client.aforce_login`, and
:meth:`~django.test.Client.alogout`.
* :class:`~django.test.AsyncClient` now supports the ``follow`` parameter.
URLs
~~~~

View File

@ -2032,7 +2032,6 @@ test client, with the following exceptions:
* In the initialization, arbitrary keyword arguments in ``defaults`` are added
directly into the ASGI scope.
* The ``follow`` parameter is not supported.
* Headers passed as ``extra`` keyword arguments should not have the ``HTTP_``
prefix required by the synchronous client (see :meth:`Client.get`). For
example, here is how to set an HTTP ``Accept`` header:
@ -2046,6 +2045,10 @@ test client, with the following exceptions:
The ``headers`` parameter was added.
.. versionchanged:: 5.0
Support for the ``follow`` parameter was added to the ``AsyncClient``.
Using ``AsyncClient`` any method that makes a request must be awaited::
async def test_my_thing(self):

View File

@ -1135,8 +1135,11 @@ class AsyncClientTest(TestCase):
response = await self.async_client.get("/middleware_urlconf_view/")
self.assertEqual(response.resolver_match.url_name, "middleware_urlconf_view")
async def test_follow_parameter_not_implemented(self):
msg = "AsyncClient request methods do not accept the follow parameter."
async def test_redirect(self):
response = await self.async_client.get("/redirect_view/")
self.assertEqual(response.status_code, 302)
async def test_follow_redirect(self):
tests = (
"get",
"post",
@ -1150,8 +1153,16 @@ class AsyncClientTest(TestCase):
for method_name in tests:
with self.subTest(method=method_name):
method = getattr(self.async_client, method_name)
with self.assertRaisesMessage(NotImplementedError, msg):
await method("/redirect_view/", follow=True)
response = await method("/redirect_view/", follow=True)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.resolver_match.url_name, "get_view")
async def test_follow_double_redirect(self):
response = await self.async_client.get("/double_redirect_view/", follow=True)
self.assertRedirects(
response, "/get_view/", status_code=302, target_status_code=200
)
self.assertEqual(len(response.redirect_chain), 2)
async def test_get_data(self):
response = await self.async_client.get("/get_view/", {"var": "val"})