Fixed #33738 -- Allowed handling ASGI http.disconnect in long-lived requests.

This commit is contained in:
th3nn3ss 2022-12-21 14:25:24 -05:00 committed by Mariusz Felisiak
parent 4e4eda6d6c
commit 1d1ddffc27
5 changed files with 157 additions and 3 deletions

View File

@ -1,3 +1,4 @@
import asyncio
import logging
import sys
import tempfile
@ -177,15 +178,49 @@ class ASGIHandler(base.BaseHandler):
body_file.close()
await self.send_response(error_response, send)
return
# Get the response, using the async mode of BaseHandler.
# Try to catch a disconnect while getting response.
tasks = [
asyncio.create_task(self.run_get_response(request)),
asyncio.create_task(self.listen_for_disconnect(receive)),
]
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
done, pending = done.pop(), pending.pop()
# Allow views to handle cancellation.
pending.cancel()
try:
await pending
except asyncio.CancelledError:
# Task re-raised the CancelledError as expected.
pass
try:
response = done.result()
except RequestAborted:
body_file.close()
return
except AssertionError:
body_file.close()
raise
# Send the response.
await self.send_response(response, send)
async def listen_for_disconnect(self, receive):
"""Listen for disconnect from the client."""
message = await receive()
if message["type"] == "http.disconnect":
raise RequestAborted()
# This should never happen.
assert False, "Invalid ASGI message after request body: %s" % message["type"]
async def run_get_response(self, request):
"""Get async response."""
# Use the async mode of BaseHandler.
response = await self.get_response_async(request)
response._handler_class = self.__class__
# Increase chunk size on file responses (ASGI servers handles low-level
# chunking).
if isinstance(response, FileResponse):
response.block_size = self.chunk_size
# Send the response.
await self.send_response(response, send)
return response
async def read_body(self, receive):
"""Reads an HTTP body from an ASGI connection."""

View File

@ -192,6 +192,13 @@ Minor features
* ...
Asynchronous views
~~~~~~~~~~~~~~~~~~
* Under ASGI, ``http.disconnect`` events are now handled. This allows views to
perform any necessary cleanup if a client disconnects before the response is
generated. See :ref:`async-handling-disconnect` for more details.
Cache
~~~~~

View File

@ -136,6 +136,26 @@ a purely synchronous codebase under ASGI because the request-handling code is
still all running asynchronously. In general you will only want to enable ASGI
mode if you have asynchronous code in your project.
.. _async-handling-disconnect:
Handling disconnects
--------------------
.. versionadded:: 5.0
For long-lived requests, a client may disconnect before the view returns a
response. In this case, an ``asyncio.CancelledError`` will be raised in the
view. You can catch this error and handle it if you need to perform any
cleanup::
async def my_view(request):
try:
# Do some work
...
except asyncio.CancelledError:
# Handle disconnect
raise
.. _async-safety:
Async safety

View File

@ -7,8 +7,10 @@ from asgiref.testing import ApplicationCommunicator
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
from django.core.asgi import get_asgi_application
from django.core.handlers.asgi import ASGIHandler, ASGIRequest
from django.core.signals import request_finished, request_started
from django.db import close_old_connections
from django.http import HttpResponse
from django.test import (
AsyncRequestFactory,
SimpleTestCase,
@ -16,6 +18,7 @@ from django.test import (
modify_settings,
override_settings,
)
from django.urls import path
from django.utils.http import http_date
from .urls import sync_waiter, test_filename
@ -234,6 +237,34 @@ class ASGITest(SimpleTestCase):
with self.assertRaises(asyncio.TimeoutError):
await communicator.receive_output()
async def test_disconnect_with_body(self):
application = get_asgi_application()
scope = self.async_request_factory._base_scope(path="/")
communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request", "body": b"some body"})
await communicator.send_input({"type": "http.disconnect"})
with self.assertRaises(asyncio.TimeoutError):
await communicator.receive_output()
async def test_assert_in_listen_for_disconnect(self):
application = get_asgi_application()
scope = self.async_request_factory._base_scope(path="/")
communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request"})
await communicator.send_input({"type": "http.not_a_real_message"})
msg = "Invalid ASGI message after request body: http.not_a_real_message"
with self.assertRaisesMessage(AssertionError, msg):
await communicator.receive_output()
async def test_delayed_disconnect_with_body(self):
application = get_asgi_application()
scope = self.async_request_factory._base_scope(path="/delayed_hello/")
communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request", "body": b"some body"})
await communicator.send_input({"type": "http.disconnect"})
with self.assertRaises(asyncio.TimeoutError):
await communicator.receive_output()
async def test_wrong_connection_type(self):
application = get_asgi_application()
scope = self.async_request_factory._base_scope(path="/", type="other")
@ -318,3 +349,56 @@ class ASGITest(SimpleTestCase):
self.assertEqual(len(sync_waiter.active_threads), 2)
sync_waiter.active_threads.clear()
async def test_asyncio_cancel_error(self):
# Flag to check if the view was cancelled.
view_did_cancel = False
# A view that will listen for the cancelled error.
async def view(request):
nonlocal view_did_cancel
try:
await asyncio.sleep(0.2)
return HttpResponse("Hello World!")
except asyncio.CancelledError:
# Set the flag.
view_did_cancel = True
raise
# Request class to use the view.
class TestASGIRequest(ASGIRequest):
urlconf = (path("cancel/", view),)
# Handler to use request class.
class TestASGIHandler(ASGIHandler):
request_class = TestASGIRequest
# Request cycle should complete since no disconnect was sent.
application = TestASGIHandler()
scope = self.async_request_factory._base_scope(path="/cancel/")
communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request"})
response_start = await communicator.receive_output()
self.assertEqual(response_start["type"], "http.response.start")
self.assertEqual(response_start["status"], 200)
response_body = await communicator.receive_output()
self.assertEqual(response_body["type"], "http.response.body")
self.assertEqual(response_body["body"], b"Hello World!")
# Give response.close() time to finish.
await communicator.wait()
self.assertIs(view_did_cancel, False)
# Request cycle with a disconnect before the view can respond.
application = TestASGIHandler()
scope = self.async_request_factory._base_scope(path="/cancel/")
communicator = ApplicationCommunicator(application, scope)
await communicator.send_input({"type": "http.request"})
# Let the view actually start.
await asyncio.sleep(0.1)
# Disconnect the client.
await communicator.send_input({"type": "http.disconnect"})
# The handler should not send a response.
with self.assertRaises(asyncio.TimeoutError):
await communicator.receive_output()
await communicator.wait()
self.assertIs(view_did_cancel, True)

View File

@ -1,4 +1,5 @@
import threading
import time
from django.http import FileResponse, HttpResponse
from django.urls import path
@ -10,6 +11,12 @@ def hello(request):
return HttpResponse("Hello %s!" % name)
def hello_with_delay(request):
name = request.GET.get("name") or "World"
time.sleep(1)
return HttpResponse(f"Hello {name}!")
def hello_meta(request):
return HttpResponse(
"From %s" % request.META.get("HTTP_REFERER") or "",
@ -46,4 +53,5 @@ urlpatterns = [
path("meta/", hello_meta),
path("post/", post_echo),
path("wait/", sync_waiter),
path("delayed_hello/", hello_with_delay),
]