Merge pull request #10680 from bluetech/capture-typing

capture: improve typing
This commit is contained in:
Ran Benita 2023-01-23 14:38:28 +02:00 committed by GitHub
commit 02893139f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 187 additions and 79 deletions

View File

@ -1,4 +1,5 @@
"""Per-test stdout/stderr capturing mechanism."""
import abc
import collections
import contextlib
import io
@ -6,14 +7,20 @@ import os
import sys
from io import UnsupportedOperation
from tempfile import TemporaryFile
from types import TracebackType
from typing import Any
from typing import AnyStr
from typing import BinaryIO
from typing import Generator
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import TextIO
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
@ -29,6 +36,7 @@ from _pytest.nodes import File
from _pytest.nodes import Item
if TYPE_CHECKING:
from typing_extensions import Final
from typing_extensions import Literal
_CaptureMethod = Literal["fd", "sys", "no", "tee-sys"]
@ -185,19 +193,27 @@ class TeeCaptureIO(CaptureIO):
return self._other.write(s)
class DontReadFromInput:
encoding = None
class DontReadFromInput(TextIO):
@property
def encoding(self) -> str:
return sys.__stdin__.encoding
def read(self, *args):
def read(self, size: int = -1) -> str:
raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`."
)
readline = read
readlines = read
__next__ = read
def __iter__(self):
def __next__(self) -> str:
return self.readline()
def readlines(self, hint: Optional[int] = -1) -> List[str]:
raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`."
)
def __iter__(self) -> Iterator[str]:
return self
def fileno(self) -> int:
@ -215,7 +231,7 @@ class DontReadFromInput:
def readable(self) -> bool:
return False
def seek(self, offset: int) -> int:
def seek(self, offset: int, whence: int = 0) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no seek(int)")
def seekable(self) -> bool:
@ -224,41 +240,104 @@ class DontReadFromInput:
def tell(self) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()")
def truncate(self, size: int) -> None:
def truncate(self, size: Optional[int] = None) -> int:
raise UnsupportedOperation("cannont truncate stdin")
def write(self, *args) -> None:
def write(self, data: str) -> int:
raise UnsupportedOperation("cannot write to stdin")
def writelines(self, *args) -> None:
def writelines(self, lines: Iterable[str]) -> None:
raise UnsupportedOperation("Cannot write to stdin")
def writable(self) -> bool:
return False
@property
def buffer(self):
def __enter__(self) -> "DontReadFromInput":
return self
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
pass
@property
def buffer(self) -> BinaryIO:
# The str/bytes doesn't actually matter in this type, so OK to fake.
return self # type: ignore[return-value]
# Capture classes.
class CaptureBase(abc.ABC, Generic[AnyStr]):
EMPTY_BUFFER: AnyStr
@abc.abstractmethod
def __init__(self, fd: int) -> None:
raise NotImplementedError()
@abc.abstractmethod
def start(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def done(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def suspend(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def resume(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def writeorg(self, data: AnyStr) -> None:
raise NotImplementedError()
@abc.abstractmethod
def snap(self) -> AnyStr:
raise NotImplementedError()
patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"}
class NoCapture:
EMPTY_BUFFER = None
__init__ = start = done = suspend = resume = lambda *args: None
class NoCapture(CaptureBase[str]):
EMPTY_BUFFER = ""
def __init__(self, fd: int) -> None:
pass
def start(self) -> None:
pass
def done(self) -> None:
pass
def suspend(self) -> None:
pass
def resume(self) -> None:
pass
def snap(self) -> str:
return ""
def writeorg(self, data: str) -> None:
pass
class SysCaptureBinary:
EMPTY_BUFFER = b""
def __init__(self, fd: int, tmpfile=None, *, tee: bool = False) -> None:
class SysCaptureBase(CaptureBase[AnyStr]):
def __init__(
self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False
) -> None:
name = patchsysdict[fd]
self._old = getattr(sys, name)
self._old: TextIO = getattr(sys, name)
self.name = name
if tmpfile is None:
if name == "stdin":
@ -298,14 +377,6 @@ class SysCaptureBinary:
setattr(sys, self.name, self.tmpfile)
self._state = "started"
def snap(self):
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def done(self) -> None:
self._assert_state("done", ("initialized", "started", "suspended", "done"))
if self._state == "done":
@ -327,36 +398,43 @@ class SysCaptureBinary:
setattr(sys, self.name, self.tmpfile)
self._state = "started"
def writeorg(self, data) -> None:
class SysCaptureBinary(SysCaptureBase[bytes]):
EMPTY_BUFFER = b""
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data: bytes) -> None:
self._assert_state("writeorg", ("started", "suspended"))
self._old.flush()
self._old.buffer.write(data)
self._old.buffer.flush()
class SysCapture(SysCaptureBinary):
EMPTY_BUFFER = "" # type: ignore[assignment]
class SysCapture(SysCaptureBase[str]):
EMPTY_BUFFER = ""
def snap(self):
def snap(self) -> str:
self._assert_state("snap", ("started", "suspended"))
assert isinstance(self.tmpfile, CaptureIO)
res = self.tmpfile.getvalue()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data):
def writeorg(self, data: str) -> None:
self._assert_state("writeorg", ("started", "suspended"))
self._old.write(data)
self._old.flush()
class FDCaptureBinary:
"""Capture IO to/from a given OS-level file descriptor.
snap() produces `bytes`.
"""
EMPTY_BUFFER = b""
class FDCaptureBase(CaptureBase[AnyStr]):
def __init__(self, targetfd: int) -> None:
self.targetfd = targetfd
@ -382,7 +460,7 @@ class FDCaptureBinary:
if targetfd == 0:
self.tmpfile = open(os.devnull, encoding="utf-8")
self.syscapture = SysCapture(targetfd)
self.syscapture: CaptureBase[str] = SysCapture(targetfd)
else:
self.tmpfile = EncodedFile(
TemporaryFile(buffering=0),
@ -394,7 +472,7 @@ class FDCaptureBinary:
if targetfd in patchsysdict:
self.syscapture = SysCapture(targetfd, self.tmpfile)
else:
self.syscapture = NoCapture()
self.syscapture = NoCapture(targetfd)
self._state = "initialized"
@ -421,14 +499,6 @@ class FDCaptureBinary:
self.syscapture.start()
self._state = "started"
def snap(self):
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def done(self) -> None:
"""Stop capturing, restore streams, return original capture file,
seeked to position zero."""
@ -461,22 +531,38 @@ class FDCaptureBinary:
os.dup2(self.tmpfile.fileno(), self.targetfd)
self._state = "started"
def writeorg(self, data):
class FDCaptureBinary(FDCaptureBase[bytes]):
"""Capture IO to/from a given OS-level file descriptor.
snap() produces `bytes`.
"""
EMPTY_BUFFER = b""
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data: bytes) -> None:
"""Write to original file descriptor."""
self._assert_state("writeorg", ("started", "suspended"))
os.write(self.targetfd_save, data)
class FDCapture(FDCaptureBinary):
class FDCapture(FDCaptureBase[str]):
"""Capture IO to/from a given OS-level file descriptor.
snap() produces text.
"""
# Ignore type because it doesn't match the type in the superclass (bytes).
EMPTY_BUFFER = "" # type: ignore
EMPTY_BUFFER = ""
def snap(self):
def snap(self) -> str:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.read()
@ -484,9 +570,11 @@ class FDCapture(FDCaptureBinary):
self.tmpfile.truncate()
return res
def writeorg(self, data):
def writeorg(self, data: str) -> None:
"""Write to original file descriptor."""
super().writeorg(data.encode("utf-8")) # XXX use encoding of original stream
self._assert_state("writeorg", ("started", "suspended"))
# XXX use encoding of original stream
os.write(self.targetfd_save, data.encode("utf-8"))
# MultiCapture
@ -516,10 +604,15 @@ class MultiCapture(Generic[AnyStr]):
_state = None
_in_suspended = False
def __init__(self, in_, out, err) -> None:
self.in_ = in_
self.out = out
self.err = err
def __init__(
self,
in_: Optional[CaptureBase[AnyStr]],
out: Optional[CaptureBase[AnyStr]],
err: Optional[CaptureBase[AnyStr]],
) -> None:
self.in_: Optional[CaptureBase[AnyStr]] = in_
self.out: Optional[CaptureBase[AnyStr]] = out
self.err: Optional[CaptureBase[AnyStr]] = err
def __repr__(self) -> str:
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format(
@ -543,8 +636,10 @@ class MultiCapture(Generic[AnyStr]):
"""Pop current snapshot out/err capture and flush to orig streams."""
out, err = self.readouterr()
if out:
assert self.out is not None
self.out.writeorg(out)
if err:
assert self.err is not None
self.err.writeorg(err)
return out, err
@ -565,6 +660,7 @@ class MultiCapture(Generic[AnyStr]):
if self.err:
self.err.resume()
if self._in_suspended:
assert self.in_ is not None
self.in_.resume()
self._in_suspended = False
@ -587,7 +683,8 @@ class MultiCapture(Generic[AnyStr]):
def readouterr(self) -> CaptureResult[AnyStr]:
out = self.out.snap() if self.out else ""
err = self.err.snap() if self.err else ""
return CaptureResult(out, err)
# TODO: This type error is real, need to fix.
return CaptureResult(out, err) # type: ignore[arg-type]
def _get_multicapture(method: "_CaptureMethod") -> MultiCapture[str]:
@ -627,7 +724,7 @@ class CaptureManager:
"""
def __init__(self, method: "_CaptureMethod") -> None:
self._method = method
self._method: Final = method
self._global_capturing: Optional[MultiCapture[str]] = None
self._capture_fixture: Optional[CaptureFixture[Any]] = None
@ -796,14 +893,18 @@ class CaptureFixture(Generic[AnyStr]):
:fixture:`capfd` and :fixture:`capfdbinary` fixtures."""
def __init__(
self, captureclass, request: SubRequest, *, _ispytest: bool = False
self,
captureclass: Type[CaptureBase[AnyStr]],
request: SubRequest,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self.captureclass = captureclass
self.captureclass: Type[CaptureBase[AnyStr]] = captureclass
self.request = request
self._capture: Optional[MultiCapture[AnyStr]] = None
self._captured_out = self.captureclass.EMPTY_BUFFER
self._captured_err = self.captureclass.EMPTY_BUFFER
self._captured_out: AnyStr = self.captureclass.EMPTY_BUFFER
self._captured_err: AnyStr = self.captureclass.EMPTY_BUFFER
def _start(self) -> None:
if self._capture is None:
@ -858,7 +959,9 @@ class CaptureFixture(Generic[AnyStr]):
@contextlib.contextmanager
def disabled(self) -> Generator[None, None, None]:
"""Temporarily disable capturing while inside the ``with`` block."""
capmanager = self.request.config.pluginmanager.getplugin("capturemanager")
capmanager: CaptureManager = self.request.config.pluginmanager.getplugin(
"capturemanager"
)
with capmanager.global_and_fixture_disabled():
yield
@ -885,8 +988,8 @@ def capsys(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
captured = capsys.readouterr()
assert captured.out == "hello\n"
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[str](SysCapture, request, _ispytest=True)
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(SysCapture, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture
@ -913,8 +1016,8 @@ def capsysbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None,
captured = capsysbinary.readouterr()
assert captured.out == b"hello\n"
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[bytes](SysCaptureBinary, request, _ispytest=True)
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(SysCaptureBinary, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture
@ -941,8 +1044,8 @@ def capfd(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
captured = capfd.readouterr()
assert captured.out == "hello\n"
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[str](FDCapture, request, _ispytest=True)
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(FDCapture, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture
@ -970,8 +1073,8 @@ def capfdbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, N
assert captured.out == b"hello\n"
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[bytes](FDCaptureBinary, request, _ispytest=True)
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(FDCaptureBinary, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture

View File

@ -890,7 +890,7 @@ def test_dontreadfrominput() -> None:
from _pytest.capture import DontReadFromInput
f = DontReadFromInput()
assert f.buffer is f
assert f.buffer is f # type: ignore[comparison-overlap]
assert not f.isatty()
pytest.raises(OSError, f.read)
pytest.raises(OSError, f.readlines)
@ -906,7 +906,10 @@ def test_dontreadfrominput() -> None:
pytest.raises(UnsupportedOperation, f.write, b"")
pytest.raises(UnsupportedOperation, f.writelines, [])
assert not f.writable()
assert isinstance(f.encoding, str)
f.close() # just for completeness
with f:
pass
def test_captureresult() -> None:
@ -1049,6 +1052,7 @@ class TestFDCapture:
)
)
# Should not crash with missing "_old".
assert isinstance(cap.syscapture, capture.SysCapture)
assert repr(cap.syscapture) == (
"<SysCapture stdout _old=<UNSET> _state='done' tmpfile={!r}>".format(
cap.syscapture.tmpfile
@ -1349,6 +1353,7 @@ def test_capsys_results_accessible_by_attribute(capsys: CaptureFixture[str]) ->
def test_fdcapture_tmpfile_remains_the_same() -> None:
cap = StdCaptureFD(out=False, err=True)
assert isinstance(cap.err, capture.FDCapture)
try:
cap.start_capturing()
capfile = cap.err.tmpfile