mirror of https://github.com/django/django.git
Fixed #34394 -- Added FORCE_SCRIPT_NAME handling to ASGIRequest.
Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
parent
c3d7a71f83
commit
041b0a359a
|
@ -26,6 +26,15 @@ from django.utils.functional import cached_property
|
||||||
logger = logging.getLogger("django.request")
|
logger = logging.getLogger("django.request")
|
||||||
|
|
||||||
|
|
||||||
|
def get_script_prefix(scope):
|
||||||
|
"""
|
||||||
|
Return the script prefix to use from either the scope or a setting.
|
||||||
|
"""
|
||||||
|
if settings.FORCE_SCRIPT_NAME:
|
||||||
|
return settings.FORCE_SCRIPT_NAME
|
||||||
|
return scope.get("root_path", "") or ""
|
||||||
|
|
||||||
|
|
||||||
class ASGIRequest(HttpRequest):
|
class ASGIRequest(HttpRequest):
|
||||||
"""
|
"""
|
||||||
Custom request subclass that decodes from an ASGI-standard request dict
|
Custom request subclass that decodes from an ASGI-standard request dict
|
||||||
|
@ -41,7 +50,7 @@ class ASGIRequest(HttpRequest):
|
||||||
self._post_parse_error = False
|
self._post_parse_error = False
|
||||||
self._read_started = False
|
self._read_started = False
|
||||||
self.resolver_match = None
|
self.resolver_match = None
|
||||||
self.script_name = self.scope.get("root_path", "")
|
self.script_name = get_script_prefix(scope)
|
||||||
if self.script_name:
|
if self.script_name:
|
||||||
# TODO: Better is-prefix checking, slash handling?
|
# TODO: Better is-prefix checking, slash handling?
|
||||||
self.path_info = scope["path"].removeprefix(self.script_name)
|
self.path_info = scope["path"].removeprefix(self.script_name)
|
||||||
|
@ -170,7 +179,7 @@ class ASGIHandler(base.BaseHandler):
|
||||||
except RequestAborted:
|
except RequestAborted:
|
||||||
return
|
return
|
||||||
# Request is complete and can be served.
|
# Request is complete and can be served.
|
||||||
set_script_prefix(self.get_script_prefix(scope))
|
set_script_prefix(get_script_prefix(scope))
|
||||||
await signals.request_started.asend(sender=self.__class__, scope=scope)
|
await signals.request_started.asend(sender=self.__class__, scope=scope)
|
||||||
# Get the request and check for basic issues.
|
# Get the request and check for basic issues.
|
||||||
request, error_response = self.create_request(scope, body_file)
|
request, error_response = self.create_request(scope, body_file)
|
||||||
|
@ -344,11 +353,3 @@ class ASGIHandler(base.BaseHandler):
|
||||||
(position + cls.chunk_size) >= len(data),
|
(position + cls.chunk_size) >= len(data),
|
||||||
)
|
)
|
||||||
position += cls.chunk_size
|
position += cls.chunk_size
|
||||||
|
|
||||||
def get_script_prefix(self, scope):
|
|
||||||
"""
|
|
||||||
Return the script prefix to use from either the scope or a setting.
|
|
||||||
"""
|
|
||||||
if settings.FORCE_SCRIPT_NAME:
|
|
||||||
return settings.FORCE_SCRIPT_NAME
|
|
||||||
return scope.get("root_path", "") or ""
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from django.core.handlers.wsgi import WSGIHandler, WSGIRequest, get_script_name
|
||||||
from django.core.signals import request_finished, request_started
|
from django.core.signals import request_finished, request_started
|
||||||
from django.db import close_old_connections, connection
|
from django.db import close_old_connections, connection
|
||||||
from django.test import (
|
from django.test import (
|
||||||
|
AsyncRequestFactory,
|
||||||
RequestFactory,
|
RequestFactory,
|
||||||
SimpleTestCase,
|
SimpleTestCase,
|
||||||
TransactionTestCase,
|
TransactionTestCase,
|
||||||
|
@ -328,6 +329,12 @@ class AsyncHandlerRequestTests(SimpleTestCase):
|
||||||
with self.assertRaisesMessage(ValueError, msg):
|
with self.assertRaisesMessage(ValueError, msg):
|
||||||
await self.async_client.get("/unawaited/")
|
await self.async_client.get("/unawaited/")
|
||||||
|
|
||||||
|
@override_settings(FORCE_SCRIPT_NAME="/FORCED_PREFIX/")
|
||||||
|
def test_force_script_name(self):
|
||||||
|
async_request_factory = AsyncRequestFactory()
|
||||||
|
request = async_request_factory.request(**{"path": "/somepath/"})
|
||||||
|
self.assertEqual(request.path, "/FORCED_PREFIX/somepath/")
|
||||||
|
|
||||||
async def test_sync_streaming(self):
|
async def test_sync_streaming(self):
|
||||||
response = await self.async_client.get("/streaming/")
|
response = await self.async_client.get("/streaming/")
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
Loading…
Reference in New Issue