capture: improve `captureclass` typing
Previously, the any `captureclass` arguments were Any. We need to introduce another common base class to fix this.
This commit is contained in:
parent
c746d2b016
commit
54911acf8d
|
@ -1,4 +1,5 @@
|
||||||
"""Per-test stdout/stderr capturing mechanism."""
|
"""Per-test stdout/stderr capturing mechanism."""
|
||||||
|
import abc
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import io
|
import io
|
||||||
|
@ -270,6 +271,38 @@ class DontReadFromInput(TextIO):
|
||||||
# Capture classes.
|
# Capture classes.
|
||||||
|
|
||||||
|
|
||||||
|
class CaptureBase(abc.ABC, Generic[AnyStr]):
|
||||||
|
EMPTY_BUFFER: AnyStr
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __init__(self, fd: int) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def start(self) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def done(self) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def suspend(self) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def resume(self) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def writeorg(self, data: AnyStr) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def snap(self) -> AnyStr:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"}
|
patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"}
|
||||||
|
|
||||||
|
|
||||||
|
@ -278,7 +311,7 @@ class NoCapture:
|
||||||
__init__ = start = done = suspend = resume = lambda *args: None
|
__init__ = start = done = suspend = resume = lambda *args: None
|
||||||
|
|
||||||
|
|
||||||
class SysCaptureBase:
|
class SysCaptureBase(CaptureBase[AnyStr]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False
|
self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -345,7 +378,7 @@ class SysCaptureBase:
|
||||||
self._state = "started"
|
self._state = "started"
|
||||||
|
|
||||||
|
|
||||||
class SysCaptureBinary(SysCaptureBase):
|
class SysCaptureBinary(SysCaptureBase[bytes]):
|
||||||
EMPTY_BUFFER = b""
|
EMPTY_BUFFER = b""
|
||||||
|
|
||||||
def snap(self) -> bytes:
|
def snap(self) -> bytes:
|
||||||
|
@ -363,7 +396,7 @@ class SysCaptureBinary(SysCaptureBase):
|
||||||
self._old.buffer.flush()
|
self._old.buffer.flush()
|
||||||
|
|
||||||
|
|
||||||
class SysCapture(SysCaptureBase):
|
class SysCapture(SysCaptureBase[str]):
|
||||||
EMPTY_BUFFER = ""
|
EMPTY_BUFFER = ""
|
||||||
|
|
||||||
def snap(self) -> str:
|
def snap(self) -> str:
|
||||||
|
@ -380,7 +413,7 @@ class SysCapture(SysCaptureBase):
|
||||||
self._old.flush()
|
self._old.flush()
|
||||||
|
|
||||||
|
|
||||||
class FDCaptureBase:
|
class FDCaptureBase(CaptureBase[AnyStr]):
|
||||||
def __init__(self, targetfd: int) -> None:
|
def __init__(self, targetfd: int) -> None:
|
||||||
self.targetfd = targetfd
|
self.targetfd = targetfd
|
||||||
|
|
||||||
|
@ -478,7 +511,7 @@ class FDCaptureBase:
|
||||||
self._state = "started"
|
self._state = "started"
|
||||||
|
|
||||||
|
|
||||||
class FDCaptureBinary(FDCaptureBase):
|
class FDCaptureBinary(FDCaptureBase[bytes]):
|
||||||
"""Capture IO to/from a given OS-level file descriptor.
|
"""Capture IO to/from a given OS-level file descriptor.
|
||||||
|
|
||||||
snap() produces `bytes`.
|
snap() produces `bytes`.
|
||||||
|
@ -500,7 +533,7 @@ class FDCaptureBinary(FDCaptureBase):
|
||||||
os.write(self.targetfd_save, data)
|
os.write(self.targetfd_save, data)
|
||||||
|
|
||||||
|
|
||||||
class FDCapture(FDCaptureBase):
|
class FDCapture(FDCaptureBase[str]):
|
||||||
"""Capture IO to/from a given OS-level file descriptor.
|
"""Capture IO to/from a given OS-level file descriptor.
|
||||||
|
|
||||||
snap() produces text.
|
snap() produces text.
|
||||||
|
@ -550,10 +583,15 @@ class MultiCapture(Generic[AnyStr]):
|
||||||
_state = None
|
_state = None
|
||||||
_in_suspended = False
|
_in_suspended = False
|
||||||
|
|
||||||
def __init__(self, in_, out, err) -> None:
|
def __init__(
|
||||||
self.in_ = in_
|
self,
|
||||||
self.out = out
|
in_: Optional[CaptureBase[AnyStr]],
|
||||||
self.err = err
|
out: Optional[CaptureBase[AnyStr]],
|
||||||
|
err: Optional[CaptureBase[AnyStr]],
|
||||||
|
) -> None:
|
||||||
|
self.in_: Optional[CaptureBase[AnyStr]] = in_
|
||||||
|
self.out: Optional[CaptureBase[AnyStr]] = out
|
||||||
|
self.err: Optional[CaptureBase[AnyStr]] = err
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format(
|
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format(
|
||||||
|
@ -577,8 +615,10 @@ class MultiCapture(Generic[AnyStr]):
|
||||||
"""Pop current snapshot out/err capture and flush to orig streams."""
|
"""Pop current snapshot out/err capture and flush to orig streams."""
|
||||||
out, err = self.readouterr()
|
out, err = self.readouterr()
|
||||||
if out:
|
if out:
|
||||||
|
assert self.out is not None
|
||||||
self.out.writeorg(out)
|
self.out.writeorg(out)
|
||||||
if err:
|
if err:
|
||||||
|
assert self.err is not None
|
||||||
self.err.writeorg(err)
|
self.err.writeorg(err)
|
||||||
return out, err
|
return out, err
|
||||||
|
|
||||||
|
@ -599,6 +639,7 @@ class MultiCapture(Generic[AnyStr]):
|
||||||
if self.err:
|
if self.err:
|
||||||
self.err.resume()
|
self.err.resume()
|
||||||
if self._in_suspended:
|
if self._in_suspended:
|
||||||
|
assert self.in_ is not None
|
||||||
self.in_.resume()
|
self.in_.resume()
|
||||||
self._in_suspended = False
|
self._in_suspended = False
|
||||||
|
|
||||||
|
@ -621,7 +662,8 @@ class MultiCapture(Generic[AnyStr]):
|
||||||
def readouterr(self) -> CaptureResult[AnyStr]:
|
def readouterr(self) -> CaptureResult[AnyStr]:
|
||||||
out = self.out.snap() if self.out else ""
|
out = self.out.snap() if self.out else ""
|
||||||
err = self.err.snap() if self.err else ""
|
err = self.err.snap() if self.err else ""
|
||||||
return CaptureResult(out, err)
|
# TODO: This type error is real, need to fix.
|
||||||
|
return CaptureResult(out, err) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
def _get_multicapture(method: "_CaptureMethod") -> MultiCapture[str]:
|
def _get_multicapture(method: "_CaptureMethod") -> MultiCapture[str]:
|
||||||
|
@ -830,14 +872,18 @@ class CaptureFixture(Generic[AnyStr]):
|
||||||
:fixture:`capfd` and :fixture:`capfdbinary` fixtures."""
|
:fixture:`capfd` and :fixture:`capfdbinary` fixtures."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, captureclass, request: SubRequest, *, _ispytest: bool = False
|
self,
|
||||||
|
captureclass: Type[CaptureBase[AnyStr]],
|
||||||
|
request: SubRequest,
|
||||||
|
*,
|
||||||
|
_ispytest: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
check_ispytest(_ispytest)
|
check_ispytest(_ispytest)
|
||||||
self.captureclass = captureclass
|
self.captureclass: Type[CaptureBase[AnyStr]] = captureclass
|
||||||
self.request = request
|
self.request = request
|
||||||
self._capture: Optional[MultiCapture[AnyStr]] = None
|
self._capture: Optional[MultiCapture[AnyStr]] = None
|
||||||
self._captured_out = self.captureclass.EMPTY_BUFFER
|
self._captured_out: AnyStr = self.captureclass.EMPTY_BUFFER
|
||||||
self._captured_err = self.captureclass.EMPTY_BUFFER
|
self._captured_err: AnyStr = self.captureclass.EMPTY_BUFFER
|
||||||
|
|
||||||
def _start(self) -> None:
|
def _start(self) -> None:
|
||||||
if self._capture is None:
|
if self._capture is None:
|
||||||
|
@ -922,7 +968,7 @@ def capsys(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
|
||||||
assert captured.out == "hello\n"
|
assert captured.out == "hello\n"
|
||||||
"""
|
"""
|
||||||
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
|
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
|
||||||
capture_fixture = CaptureFixture[str](SysCapture, request, _ispytest=True)
|
capture_fixture = CaptureFixture(SysCapture, request, _ispytest=True)
|
||||||
capman.set_fixture(capture_fixture)
|
capman.set_fixture(capture_fixture)
|
||||||
capture_fixture._start()
|
capture_fixture._start()
|
||||||
yield capture_fixture
|
yield capture_fixture
|
||||||
|
@ -950,7 +996,7 @@ def capsysbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None,
|
||||||
assert captured.out == b"hello\n"
|
assert captured.out == b"hello\n"
|
||||||
"""
|
"""
|
||||||
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
|
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
|
||||||
capture_fixture = CaptureFixture[bytes](SysCaptureBinary, request, _ispytest=True)
|
capture_fixture = CaptureFixture(SysCaptureBinary, request, _ispytest=True)
|
||||||
capman.set_fixture(capture_fixture)
|
capman.set_fixture(capture_fixture)
|
||||||
capture_fixture._start()
|
capture_fixture._start()
|
||||||
yield capture_fixture
|
yield capture_fixture
|
||||||
|
@ -978,7 +1024,7 @@ def capfd(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
|
||||||
assert captured.out == "hello\n"
|
assert captured.out == "hello\n"
|
||||||
"""
|
"""
|
||||||
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
|
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
|
||||||
capture_fixture = CaptureFixture[str](FDCapture, request, _ispytest=True)
|
capture_fixture = CaptureFixture(FDCapture, request, _ispytest=True)
|
||||||
capman.set_fixture(capture_fixture)
|
capman.set_fixture(capture_fixture)
|
||||||
capture_fixture._start()
|
capture_fixture._start()
|
||||||
yield capture_fixture
|
yield capture_fixture
|
||||||
|
@ -1007,7 +1053,7 @@ def capfdbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, N
|
||||||
|
|
||||||
"""
|
"""
|
||||||
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
|
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
|
||||||
capture_fixture = CaptureFixture[bytes](FDCaptureBinary, request, _ispytest=True)
|
capture_fixture = CaptureFixture(FDCaptureBinary, request, _ispytest=True)
|
||||||
capman.set_fixture(capture_fixture)
|
capman.set_fixture(capture_fixture)
|
||||||
capture_fixture._start()
|
capture_fixture._start()
|
||||||
yield capture_fixture
|
yield capture_fixture
|
||||||
|
|
|
@ -1352,6 +1352,7 @@ def test_capsys_results_accessible_by_attribute(capsys: CaptureFixture[str]) ->
|
||||||
|
|
||||||
def test_fdcapture_tmpfile_remains_the_same() -> None:
|
def test_fdcapture_tmpfile_remains_the_same() -> None:
|
||||||
cap = StdCaptureFD(out=False, err=True)
|
cap = StdCaptureFD(out=False, err=True)
|
||||||
|
assert isinstance(cap.err, capture.FDCapture)
|
||||||
try:
|
try:
|
||||||
cap.start_capturing()
|
cap.start_capturing()
|
||||||
capfile = cap.err.tmpfile
|
capfile = cap.err.tmpfile
|
||||||
|
|
Loading…
Reference in New Issue