From 54911acf8dd5af091e91f215b6fca78e06ef5a9d Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Sat, 21 Jan 2023 10:09:28 +0200 Subject: [PATCH] capture: improve `captureclass` typing Previously, the any `captureclass` arguments were Any. We need to introduce another common base class to fix this. --- src/_pytest/capture.py | 84 +++++++++++++++++++++++++++++++---------- testing/test_capture.py | 1 + 2 files changed, 66 insertions(+), 19 deletions(-) diff --git a/src/_pytest/capture.py b/src/_pytest/capture.py index cb5b966c9..0ff583e04 100644 --- a/src/_pytest/capture.py +++ b/src/_pytest/capture.py @@ -1,4 +1,5 @@ """Per-test stdout/stderr capturing mechanism.""" +import abc import collections import contextlib import io @@ -270,6 +271,38 @@ class DontReadFromInput(TextIO): # Capture classes. +class CaptureBase(abc.ABC, Generic[AnyStr]): + EMPTY_BUFFER: AnyStr + + @abc.abstractmethod + def __init__(self, fd: int) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def start(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def done(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def suspend(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def resume(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def writeorg(self, data: AnyStr) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def snap(self) -> AnyStr: + raise NotImplementedError() + + patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"} @@ -278,7 +311,7 @@ class NoCapture: __init__ = start = done = suspend = resume = lambda *args: None -class SysCaptureBase: +class SysCaptureBase(CaptureBase[AnyStr]): def __init__( self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False ) -> None: @@ -345,7 +378,7 @@ class SysCaptureBase: self._state = "started" -class SysCaptureBinary(SysCaptureBase): +class SysCaptureBinary(SysCaptureBase[bytes]): EMPTY_BUFFER = b"" def snap(self) -> bytes: @@ -363,7 +396,7 @@ class SysCaptureBinary(SysCaptureBase): self._old.buffer.flush() -class SysCapture(SysCaptureBase): +class SysCapture(SysCaptureBase[str]): EMPTY_BUFFER = "" def snap(self) -> str: @@ -380,7 +413,7 @@ class SysCapture(SysCaptureBase): self._old.flush() -class FDCaptureBase: +class FDCaptureBase(CaptureBase[AnyStr]): def __init__(self, targetfd: int) -> None: self.targetfd = targetfd @@ -478,7 +511,7 @@ class FDCaptureBase: self._state = "started" -class FDCaptureBinary(FDCaptureBase): +class FDCaptureBinary(FDCaptureBase[bytes]): """Capture IO to/from a given OS-level file descriptor. snap() produces `bytes`. @@ -500,7 +533,7 @@ class FDCaptureBinary(FDCaptureBase): os.write(self.targetfd_save, data) -class FDCapture(FDCaptureBase): +class FDCapture(FDCaptureBase[str]): """Capture IO to/from a given OS-level file descriptor. snap() produces text. @@ -550,10 +583,15 @@ class MultiCapture(Generic[AnyStr]): _state = None _in_suspended = False - def __init__(self, in_, out, err) -> None: - self.in_ = in_ - self.out = out - self.err = err + def __init__( + self, + in_: Optional[CaptureBase[AnyStr]], + out: Optional[CaptureBase[AnyStr]], + err: Optional[CaptureBase[AnyStr]], + ) -> None: + self.in_: Optional[CaptureBase[AnyStr]] = in_ + self.out: Optional[CaptureBase[AnyStr]] = out + self.err: Optional[CaptureBase[AnyStr]] = err def __repr__(self) -> str: return "".format( @@ -577,8 +615,10 @@ class MultiCapture(Generic[AnyStr]): """Pop current snapshot out/err capture and flush to orig streams.""" out, err = self.readouterr() if out: + assert self.out is not None self.out.writeorg(out) if err: + assert self.err is not None self.err.writeorg(err) return out, err @@ -599,6 +639,7 @@ class MultiCapture(Generic[AnyStr]): if self.err: self.err.resume() if self._in_suspended: + assert self.in_ is not None self.in_.resume() self._in_suspended = False @@ -621,7 +662,8 @@ class MultiCapture(Generic[AnyStr]): def readouterr(self) -> CaptureResult[AnyStr]: out = self.out.snap() if self.out else "" err = self.err.snap() if self.err else "" - return CaptureResult(out, err) + # TODO: This type error is real, need to fix. + return CaptureResult(out, err) # type: ignore[arg-type] def _get_multicapture(method: "_CaptureMethod") -> MultiCapture[str]: @@ -830,14 +872,18 @@ class CaptureFixture(Generic[AnyStr]): :fixture:`capfd` and :fixture:`capfdbinary` fixtures.""" def __init__( - self, captureclass, request: SubRequest, *, _ispytest: bool = False + self, + captureclass: Type[CaptureBase[AnyStr]], + request: SubRequest, + *, + _ispytest: bool = False, ) -> None: check_ispytest(_ispytest) - self.captureclass = captureclass + self.captureclass: Type[CaptureBase[AnyStr]] = captureclass self.request = request self._capture: Optional[MultiCapture[AnyStr]] = None - self._captured_out = self.captureclass.EMPTY_BUFFER - self._captured_err = self.captureclass.EMPTY_BUFFER + self._captured_out: AnyStr = self.captureclass.EMPTY_BUFFER + self._captured_err: AnyStr = self.captureclass.EMPTY_BUFFER def _start(self) -> None: if self._capture is None: @@ -922,7 +968,7 @@ def capsys(request: SubRequest) -> Generator[CaptureFixture[str], None, None]: assert captured.out == "hello\n" """ capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") - capture_fixture = CaptureFixture[str](SysCapture, request, _ispytest=True) + capture_fixture = CaptureFixture(SysCapture, request, _ispytest=True) capman.set_fixture(capture_fixture) capture_fixture._start() yield capture_fixture @@ -950,7 +996,7 @@ def capsysbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, assert captured.out == b"hello\n" """ capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") - capture_fixture = CaptureFixture[bytes](SysCaptureBinary, request, _ispytest=True) + capture_fixture = CaptureFixture(SysCaptureBinary, request, _ispytest=True) capman.set_fixture(capture_fixture) capture_fixture._start() yield capture_fixture @@ -978,7 +1024,7 @@ def capfd(request: SubRequest) -> Generator[CaptureFixture[str], None, None]: assert captured.out == "hello\n" """ capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") - capture_fixture = CaptureFixture[str](FDCapture, request, _ispytest=True) + capture_fixture = CaptureFixture(FDCapture, request, _ispytest=True) capman.set_fixture(capture_fixture) capture_fixture._start() yield capture_fixture @@ -1007,7 +1053,7 @@ def capfdbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, N """ capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager") - capture_fixture = CaptureFixture[bytes](FDCaptureBinary, request, _ispytest=True) + capture_fixture = CaptureFixture(FDCaptureBinary, request, _ispytest=True) capman.set_fixture(capture_fixture) capture_fixture._start() yield capture_fixture diff --git a/testing/test_capture.py b/testing/test_capture.py index fcd318347..26c1a5f74 100644 --- a/testing/test_capture.py +++ b/testing/test_capture.py @@ -1352,6 +1352,7 @@ def test_capsys_results_accessible_by_attribute(capsys: CaptureFixture[str]) -> def test_fdcapture_tmpfile_remains_the_same() -> None: cap = StdCaptureFD(out=False, err=True) + assert isinstance(cap.err, capture.FDCapture) try: cap.start_capturing() capfile = cap.err.tmpfile