Fixed #33755 -- Moved ASGI body-file cleanup into request class.

This commit is contained in:
Jonas Lundberg 2022-06-02 23:19:16 +02:00 committed by Carlton Gibson
parent c32858a8ce
commit e96320c917
7 changed files with 50 additions and 16 deletions

View File

@ -99,3 +99,8 @@ class ASGIStaticFilesHandler(StaticFilesHandlerMixin, ASGIHandler):
return await super().__call__(scope, receive, send) return await super().__call__(scope, receive, send)
# Hand off to the main app # Hand off to the main app
return await self.application(scope, receive, send) return await self.application(scope, receive, send)
async def get_response_async(self, request):
response = await super().get_response_async(request)
response._resource_closers.append(request.close)
return response

View File

@ -128,6 +128,10 @@ class ASGIRequest(HttpRequest):
def COOKIES(self): def COOKIES(self):
return parse_cookie(self.META.get("HTTP_COOKIE", "")) return parse_cookie(self.META.get("HTTP_COOKIE", ""))
def close(self):
super().close()
self._stream.close()
class ASGIHandler(base.BaseHandler): class ASGIHandler(base.BaseHandler):
"""Handler for ASGI requests.""" """Handler for ASGI requests."""
@ -164,21 +168,19 @@ class ASGIHandler(base.BaseHandler):
except RequestAborted: except RequestAborted:
return return
# Request is complete and can be served. # Request is complete and can be served.
try: set_script_prefix(self.get_script_prefix(scope))
set_script_prefix(self.get_script_prefix(scope)) await sync_to_async(signals.request_started.send, thread_sensitive=True)(
await sync_to_async(signals.request_started.send, thread_sensitive=True)( sender=self.__class__, scope=scope
sender=self.__class__, scope=scope )
) # Get the request and check for basic issues.
# Get the request and check for basic issues. request, error_response = self.create_request(scope, body_file)
request, error_response = self.create_request(scope, body_file) if request is None:
if request is None:
await self.send_response(error_response, send)
return
# Get the response, using the async mode of BaseHandler.
response = await self.get_response_async(request)
response._handler_class = self.__class__
finally:
body_file.close() body_file.close()
await self.send_response(error_response, send)
return
# Get the response, using the async mode of BaseHandler.
response = await self.get_response_async(request)
response._handler_class = self.__class__
# Increase chunk size on file responses (ASGI servers handles low-level # Increase chunk size on file responses (ASGI servers handles low-level
# chunking). # chunking).
if isinstance(response, FileResponse): if isinstance(response, FileResponse):

View File

@ -59,6 +59,9 @@ class LimitedStream:
self.buffer = sio.read() self.buffer = sio.read()
return line return line
def close(self):
pass
class WSGIRequest(HttpRequest): class WSGIRequest(HttpRequest):
def __init__(self, environ): def __init__(self, environ):

View File

@ -340,6 +340,8 @@ class HttpRequest:
self._body = self.read() self._body = self.read()
except OSError as e: except OSError as e:
raise UnreadablePostError(*e.args) from e raise UnreadablePostError(*e.args) from e
finally:
self._stream.close()
self._stream = BytesIO(self._body) self._stream = BytesIO(self._body)
return self._body return self._body

View File

@ -93,6 +93,9 @@ class FakePayload:
self.__content.write(content) self.__content.write(content)
self.__len += len(content) self.__len += len(content)
def close(self):
pass
def closing_iterator_wrapper(iterable, close): def closing_iterator_wrapper(iterable, close):
try: try:

View File

@ -165,7 +165,11 @@ class ASGITest(SimpleTestCase):
async def test_post_body(self): async def test_post_body(self):
application = get_asgi_application() application = get_asgi_application()
scope = self.async_request_factory._base_scope(method="POST", path="/post/") scope = self.async_request_factory._base_scope(
method="POST",
path="/post/",
query_string="echo=1",
)
communicator = ApplicationCommunicator(application, scope) communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request", "body": b"Echo!"}) await communicator.send_input({"type": "http.request", "body": b"Echo!"})
response_start = await communicator.receive_output() response_start = await communicator.receive_output()
@ -175,6 +179,18 @@ class ASGITest(SimpleTestCase):
self.assertEqual(response_body["type"], "http.response.body") self.assertEqual(response_body["type"], "http.response.body")
self.assertEqual(response_body["body"], b"Echo!") self.assertEqual(response_body["body"], b"Echo!")
async def test_untouched_request_body_gets_closed(self):
application = get_asgi_application()
scope = self.async_request_factory._base_scope(method="POST", path="/post/")
communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request"})
response_start = await communicator.receive_output()
self.assertEqual(response_start["type"], "http.response.start")
self.assertEqual(response_start["status"], 204)
response_body = await communicator.receive_output()
self.assertEqual(response_body["type"], "http.response.body")
self.assertEqual(response_body["body"], b"")
async def test_get_query_string(self): async def test_get_query_string(self):
application = get_asgi_application() application = get_asgi_application()
for query_string in (b"name=Andrew", "name=Andrew"): for query_string in (b"name=Andrew", "name=Andrew"):

View File

@ -26,7 +26,10 @@ def sync_waiter(request):
@csrf_exempt @csrf_exempt
def post_echo(request): def post_echo(request):
return HttpResponse(request.body) if request.GET.get("echo"):
return HttpResponse(request.body)
else:
return HttpResponse(status=204)
sync_waiter.active_threads = set() sync_waiter.active_threads = set()