diff --git a/src/_pytest/capture.py b/src/_pytest/capture.py index 4b66bb451..275322cc3 100644 --- a/src/_pytest/capture.py +++ b/src/_pytest/capture.py @@ -1,4 +1,5 @@ """Per-test stdout/stderr capturing mechanism.""" +import abc import collections import contextlib import io @@ -6,14 +7,20 @@ import os import sys from io import UnsupportedOperation from tempfile import TemporaryFile +from types import TracebackType from typing import Any from typing import AnyStr +from typing import BinaryIO from typing import Generator from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List from typing import NamedTuple from typing import Optional from typing import TextIO from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import Union @@ -29,6 +36,7 @@ from _pytest.nodes import File from _pytest.nodes import Item if TYPE_CHECKING: + from typing_extensions import Final from typing_extensions import Literal _CaptureMethod = Literal["fd", "sys", "no", "tee-sys"] @@ -185,19 +193,27 @@ class TeeCaptureIO(CaptureIO): return self._other.write(s) -class DontReadFromInput: - encoding = None +class DontReadFromInput(TextIO): + @property + def encoding(self) -> str: + return sys.__stdin__.encoding - def read(self, *args): + def read(self, size: int = -1) -> str: raise OSError( "pytest: reading from stdin while output is captured! Consider using `-s`." ) readline = read - readlines = read - __next__ = read - def __iter__(self): + def __next__(self) -> str: + return self.readline() + + def readlines(self, hint: Optional[int] = -1) -> List[str]: + raise OSError( + "pytest: reading from stdin while output is captured! Consider using `-s`." + ) + + def __iter__(self) -> Iterator[str]: return self def fileno(self) -> int: @@ -215,7 +231,7 @@ class DontReadFromInput: def readable(self) -> bool: return False - def seek(self, offset: int) -> int: + def seek(self, offset: int, whence: int = 0) -> int: raise UnsupportedOperation("redirected stdin is pseudofile, has no seek(int)") def seekable(self) -> bool: @@ -224,41 +240,104 @@ class DontReadFromInput: def tell(self) -> int: raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()") - def truncate(self, size: int) -> None: + def truncate(self, size: Optional[int] = None) -> int: raise UnsupportedOperation("cannont truncate stdin") - def write(self, *args) -> None: + def write(self, data: str) -> int: raise UnsupportedOperation("cannot write to stdin") - def writelines(self, *args) -> None: + def writelines(self, lines: Iterable[str]) -> None: raise UnsupportedOperation("Cannot write to stdin") def writable(self) -> bool: return False - @property - def buffer(self): + def __enter__(self) -> "DontReadFromInput": return self + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + pass + + @property + def buffer(self) -> BinaryIO: + # The str/bytes doesn't actually matter in this type, so OK to fake. + return self # type: ignore[return-value] + # Capture classes. +class CaptureBase(abc.ABC, Generic[AnyStr]): + EMPTY_BUFFER: AnyStr + + @abc.abstractmethod + def __init__(self, fd: int) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def start(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def done(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def suspend(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def resume(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def writeorg(self, data: AnyStr) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def snap(self) -> AnyStr: + raise NotImplementedError() + + patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"} -class NoCapture: - EMPTY_BUFFER = None - __init__ = start = done = suspend = resume = lambda *args: None +class NoCapture(CaptureBase[str]): + EMPTY_BUFFER = "" + + def __init__(self, fd: int) -> None: + pass + + def start(self) -> None: + pass + + def done(self) -> None: + pass + + def suspend(self) -> None: + pass + + def resume(self) -> None: + pass + + def snap(self) -> str: + return "" + + def writeorg(self, data: str) -> None: + pass -class SysCaptureBinary: - - EMPTY_BUFFER = b"" - - def __init__(self, fd: int, tmpfile=None, *, tee: bool = False) -> None: +class SysCaptureBase(CaptureBase[AnyStr]): + def __init__( + self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False + ) -> None: name = patchsysdict[fd] - self._old = getattr(sys, name) + self._old: TextIO = getattr(sys, name) self.name = name if tmpfile is None: if name == "stdin": @@ -298,14 +377,6 @@ class SysCaptureBinary: setattr(sys, self.name, self.tmpfile) self._state = "started" - def snap(self): - self._assert_state("snap", ("started", "suspended")) - self.tmpfile.seek(0) - res = self.tmpfile.buffer.read() - self.tmpfile.seek(0) - self.tmpfile.truncate() - return res - def done(self) -> None: self._assert_state("done", ("initialized", "started", "suspended", "done")) if self._state == "done": @@ -327,36 +398,43 @@ class SysCaptureBinary: setattr(sys, self.name, self.tmpfile) self._state = "started" - def writeorg(self, data) -> None: + +class SysCaptureBinary(SysCaptureBase[bytes]): + EMPTY_BUFFER = b"" + + def snap(self) -> bytes: + self._assert_state("snap", ("started", "suspended")) + self.tmpfile.seek(0) + res = self.tmpfile.buffer.read() + self.tmpfile.seek(0) + self.tmpfile.truncate() + return res + + def writeorg(self, data: bytes) -> None: self._assert_state("writeorg", ("started", "suspended")) self._old.flush() self._old.buffer.write(data) self._old.buffer.flush() -class SysCapture(SysCaptureBinary): - EMPTY_BUFFER = "" # type: ignore[assignment] +class SysCapture(SysCaptureBase[str]): + EMPTY_BUFFER = "" - def snap(self): + def snap(self) -> str: + self._assert_state("snap", ("started", "suspended")) + assert isinstance(self.tmpfile, CaptureIO) res = self.tmpfile.getvalue() self.tmpfile.seek(0) self.tmpfile.truncate() return res - def writeorg(self, data): + def writeorg(self, data: str) -> None: self._assert_state("writeorg", ("started", "suspended")) self._old.write(data) self._old.flush() -class FDCaptureBinary: - """Capture IO to/from a given OS-level file descriptor. - - snap() produces `bytes`. - """ - - EMPTY_BUFFER = b"" - +class FDCaptureBase(CaptureBase[AnyStr]): def __init__(self, targetfd: int) -> None: self.targetfd = targetfd @@ -382,7 +460,7 @@ class FDCaptureBinary: if targetfd == 0: self.tmpfile = open(os.devnull, encoding="utf-8") - self.syscapture = SysCapture(targetfd) + self.syscapture: CaptureBase[str] = SysCapture(targetfd) else: self.tmpfile = EncodedFile( TemporaryFile(buffering=0), @@ -394,7 +472,7 @@ class FDCaptureBinary: if targetfd in patchsysdict: self.syscapture = SysCapture(targetfd, self.tmpfile) else: - self.syscapture = NoCapture() + self.syscapture = NoCapture(targetfd) self._state = "initialized" @@ -421,14 +499,6 @@ class FDCaptureBinary: self.syscapture.start() self._state = "started" - def snap(self): - self._assert_state("snap", ("started", "suspended")) - self.tmpfile.seek(0) - res = self.tmpfile.buffer.read() - self.tmpfile.seek(0) - self.tmpfile.truncate() - return res - def done(self) -> None: """Stop capturing, restore streams, return original capture file, seeked to position zero.""" @@ -461,22 +531,38 @@ class FDCaptureBinary: os.dup2(self.tmpfile.fileno(), self.targetfd) self._state = "started" - def writeorg(self, data): + +class FDCaptureBinary(FDCaptureBase[bytes]): + """Capture IO to/from a given OS-level file descriptor. + + snap() produces `bytes`. + """ + + EMPTY_BUFFER = b"" + + def snap(self) -> bytes: + self._assert_state("snap", ("started", "suspended")) + self.tmpfile.seek(0) + res = self.tmpfile.buffer.read() + self.tmpfile.seek(0) + self.tmpfile.truncate() + return res + + def writeorg(self, data: bytes) -> None: """Write to original file descriptor.""" self._assert_state("writeorg", ("started", "suspended")) os.write(self.targetfd_save, data) -class FDCapture(FDCaptureBinary): +class FDCapture(FDCaptureBase[str]): """Capture IO to/from a given OS-level file descriptor. snap() produces text. """ - # Ignore type because it doesn't match the type in the superclass (bytes). - EMPTY_BUFFER = "" # type: ignore + EMPTY_BUFFER = "" - def snap(self): + def snap(self) -> str: self._assert_state("snap", ("started", "suspended")) self.tmpfile.seek(0) res = self.tmpfile.read() @@ -484,9 +570,11 @@ class FDCapture(FDCaptureBinary): self.tmpfile.truncate() return res - def writeorg(self, data): + def writeorg(self, data: str) -> None: """Write to original file descriptor.""" - super().writeorg(data.encode("utf-8")) # XXX use encoding of original stream + self._assert_state("writeorg", ("started", "suspended")) + # XXX use encoding of original stream + os.write(self.targetfd_save, data.encode("utf-8")) # MultiCapture @@ -516,10 +604,15 @@ class MultiCapture(Generic[AnyStr]): _state = None _in_suspended = False - def __init__(self, in_, out, err) -> None: - self.in_ = in_ - self.out = out - self.err = err + def __init__( + self, + in_: Optional[CaptureBase[AnyStr]], + out: Optional[CaptureBase[AnyStr]], + err: Optional[CaptureBase[AnyStr]], + ) -> None: + self.in_: Optional[CaptureBase[AnyStr]] = in_ + self.out: Optional[CaptureBase[AnyStr]] = out + self.err: Optional[CaptureBase[AnyStr]] = err def __repr__(self) -> str: return "".format( @@ -543,8 +636,10 @@ class MultiCapture(Generic[AnyStr]): """Pop current snapshot out/err capture and flush to orig streams.""" out, err = self.readouterr() if out: + assert self.out is not None self.out.writeorg(out) if err: + assert self.err is not None self.err.writeorg(err) return out, err @@ -565,6 +660,7 @@ class MultiCapture(Generic[AnyStr]): if self.err: self.err.resume() if self._in_suspended: + assert self.in_ is not None self.in_.resume() self._in_suspended = False @@ -587,7 +683,8 @@ class MultiCapture(Generic[AnyStr]): def readouterr(self) -> CaptureResult[AnyStr]: out = self.out.snap() if self.out else "" err = self.err.snap() if self.err else "" - return CaptureResult(out, err) + # TODO: This type error is real, need to fix. + return CaptureResult(out, err) # type: ignore[arg-type] def _get_multicapture(method: "_CaptureMethod") -> MultiCapture[str]: @@ -627,7 +724,7 @@ class CaptureManager: """ def __init__(self, method: "_CaptureMethod") -> None: - self._method = method + self._method: Final = method self._global_capturing: Optional[MultiCapture[str]] = None self._capture_fixture: Optional[CaptureFixture[Any]] = None @@ -796,14 +893,18 @@ class CaptureFixture(Generic[AnyStr]): :fixture:`capfd` and :fixture:`capfdbinary` fixtures.""" def __init__( - self, captureclass, request: SubRequest, *, _ispytest: bool = False + self, + captureclass: Type[CaptureBase[AnyStr]], + request: SubRequest, + *, + _ispytest: bool = False, ) -> None: check_ispytest(_ispytest) - self.captureclass = captureclass + self.captureclass: Type[CaptureBase[AnyStr]] = captureclass self.request = request self._capture: Optional[MultiCapture[AnyStr]] = None - self._captured_out = self.captureclass.EMPTY_BUFFER - self._captured_err = self.captureclass.EMPTY_BUFFER + self._captured_out: AnyStr = self.captureclass.EMPTY_BUFFER + self._captured_err: AnyStr = self.captureclass.EMPTY_BUFFER def _start(self) -> None: if self._capture is None: @@ -858,7 +959,9 @@ class CaptureFixture(Generic[AnyStr]): @contextlib.contextmanager def disabled(self) -> Generator[None, None, None]: """Temporarily disable capturing while inside the ``with`` block.""" - capmanager = self.request.config.pluginmanager.getplugin("capturemanager") + capmanager: CaptureManager = self.request.config.pluginmanager.getplugin( + "capturemanager" + ) with capmanager.global_and_fixture_disabled(): yield @@ -885,8 +988,8 @@ def capsys(request: SubRequest) -> Generator[CaptureFixture[str], None, None]: captured = capsys.readouterr() assert captured.out == "hello\n" """ - capman = request.config.pluginmanager.getplugin("capturemanager") - capture_fixture = CaptureFixture[str](SysCapture, request, _ispytest=True) + capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") + capture_fixture = CaptureFixture(SysCapture, request, _ispytest=True) capman.set_fixture(capture_fixture) capture_fixture._start() yield capture_fixture @@ -913,8 +1016,8 @@ def capsysbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, captured = capsysbinary.readouterr() assert captured.out == b"hello\n" """ - capman = request.config.pluginmanager.getplugin("capturemanager") - capture_fixture = CaptureFixture[bytes](SysCaptureBinary, request, _ispytest=True) + capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") + capture_fixture = CaptureFixture(SysCaptureBinary, request, _ispytest=True) capman.set_fixture(capture_fixture) capture_fixture._start() yield capture_fixture @@ -941,8 +1044,8 @@ def capfd(request: SubRequest) -> Generator[CaptureFixture[str], None, None]: captured = capfd.readouterr() assert captured.out == "hello\n" """ - capman = request.config.pluginmanager.getplugin("capturemanager") - capture_fixture = CaptureFixture[str](FDCapture, request, _ispytest=True) + capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") + capture_fixture = CaptureFixture(FDCapture, request, _ispytest=True) capman.set_fixture(capture_fixture) capture_fixture._start() yield capture_fixture @@ -970,8 +1073,8 @@ def capfdbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, N assert captured.out == b"hello\n" """ - capman = request.config.pluginmanager.getplugin("capturemanager") - capture_fixture = CaptureFixture[bytes](FDCaptureBinary, request, _ispytest=True) + capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") + capture_fixture = CaptureFixture(FDCaptureBinary, request, _ispytest=True) capman.set_fixture(capture_fixture) capture_fixture._start() yield capture_fixture diff --git a/testing/test_capture.py b/testing/test_capture.py index 00cab1933..5d6ef64ef 100644 --- a/testing/test_capture.py +++ b/testing/test_capture.py @@ -890,7 +890,7 @@ def test_dontreadfrominput() -> None: from _pytest.capture import DontReadFromInput f = DontReadFromInput() - assert f.buffer is f + assert f.buffer is f # type: ignore[comparison-overlap] assert not f.isatty() pytest.raises(OSError, f.read) pytest.raises(OSError, f.readlines) @@ -906,7 +906,10 @@ def test_dontreadfrominput() -> None: pytest.raises(UnsupportedOperation, f.write, b"") pytest.raises(UnsupportedOperation, f.writelines, []) assert not f.writable() + assert isinstance(f.encoding, str) f.close() # just for completeness + with f: + pass def test_captureresult() -> None: @@ -1049,6 +1052,7 @@ class TestFDCapture: ) ) # Should not crash with missing "_old". + assert isinstance(cap.syscapture, capture.SysCapture) assert repr(cap.syscapture) == ( " _state='done' tmpfile={!r}>".format( cap.syscapture.tmpfile @@ -1349,6 +1353,7 @@ def test_capsys_results_accessible_by_attribute(capsys: CaptureFixture[str]) -> def test_fdcapture_tmpfile_remains_the_same() -> None: cap = StdCaptureFD(out=False, err=True) + assert isinstance(cap.err, capture.FDCapture) try: cap.start_capturing() capfile = cap.err.tmpfile