Refs #33865 -- Improved implementation of FakePayload.

FakePayload is a wrapper around io.BytesIO and is expected to
masquerade as though it is a file-like object. For that reason it makes
sense that it should inherit the correct signatures from io.BytesIO
methods.

Crucially an implementation of .readline() is added which will be
necessary for this to behave more like the expected file-like objects as
LimitedStream will be changed to defer to the wrapped stream object
rather than rolling its own implementation for improved performance.

It should be safe to adjust these signatures because FakePayload is
only used internally within test client helpers, is undocumented, and
thus private.
This commit is contained in:
Nick Pope 2022-07-22 20:58:38 +01:00 committed by Mariusz Felisiak
parent 95182a8593
commit 57f5669d23
3 changed files with 46 additions and 21 deletions

View File

@ -6,7 +6,7 @@ from copy import copy
from functools import partial
from http import HTTPStatus
from importlib import import_module
from io import BytesIO
from io import BytesIO, IOBase
from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
from asgiref.sync import sync_to_async
@ -55,7 +55,7 @@ class RedirectCycleError(Exception):
self.redirect_chain = last_response.redirect_chain
class FakePayload:
class FakePayload(IOBase):
"""
A wrapper around BytesIO that restricts what can be read since data from
the network can't be sought and cannot be read outside of its content
@ -63,39 +63,49 @@ class FakePayload:
that wouldn't work in real life.
"""
def __init__(self, content=None):
def __init__(self, initial_bytes=None):
self.__content = BytesIO()
self.__len = 0
self.read_started = False
if content is not None:
self.write(content)
if initial_bytes is not None:
self.write(initial_bytes)
def __len__(self):
return self.__len
def read(self, num_bytes=None):
def read(self, size=-1, /):
if not self.read_started:
self.__content.seek(0)
self.read_started = True
if num_bytes is None:
num_bytes = self.__len or 0
if size == -1 or size is None:
size = self.__len
assert (
self.__len >= num_bytes
self.__len >= size
), "Cannot read more than the available bytes from the HTTP incoming data."
content = self.__content.read(num_bytes)
self.__len -= num_bytes
content = self.__content.read(size)
self.__len -= len(content)
return content
def write(self, content):
def readline(self, size=-1, /):
if not self.read_started:
self.__content.seek(0)
self.read_started = True
if size == -1 or size is None:
size = self.__len
assert (
self.__len >= size
), "Cannot read more than the available bytes from the HTTP incoming data."
content = self.__content.readline(size)
self.__len -= len(content)
return content
def write(self, b, /):
if self.read_started:
raise ValueError("Unable to write a payload after it's been read")
content = force_bytes(content)
content = force_bytes(b)
self.__content.write(content)
self.__len += len(content)
def close(self):
pass
def closing_iterator_wrapper(iterable, close):
try:

View File

@ -290,7 +290,7 @@ class RequestsTests(SimpleTestCase):
self.assertEqual(stream.read(2), b"")
self.assertEqual(stream.read(), b"")
def test_stream(self):
def test_stream_read(self):
payload = FakePayload("name=value")
request = WSGIRequest(
{
@ -302,6 +302,19 @@ class RequestsTests(SimpleTestCase):
)
self.assertEqual(request.read(), b"name=value")
def test_stream_readline(self):
payload = FakePayload("name=value\nother=string")
request = WSGIRequest(
{
"REQUEST_METHOD": "POST",
"CONTENT_TYPE": "application/x-www-form-urlencoded",
"CONTENT_LENGTH": len(payload),
"wsgi.input": payload,
},
)
self.assertEqual(request.readline(), b"name=value\n")
self.assertEqual(request.readline(), b"other=string")
def test_read_after_value(self):
"""
Reading from request is allowed after accessing request contents as

View File

@ -5,7 +5,9 @@ from django.test.client import FakePayload
class FakePayloadTests(SimpleTestCase):
def test_write_after_read(self):
payload = FakePayload()
payload.read()
msg = "Unable to write a payload after it's been read"
with self.assertRaisesMessage(ValueError, msg):
payload.write(b"abc")
for operation in [payload.read, payload.readline]:
with self.subTest(operation=operation.__name__):
operation()
msg = "Unable to write a payload after it's been read"
with self.assertRaisesMessage(ValueError, msg):
payload.write(b"abc")