Merge pull request #10680 from bluetech/capture-typing

capture: improve typing
This commit is contained in:
Ran Benita 2023-01-23 14:38:28 +02:00 committed by GitHub
commit 02893139f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 187 additions and 79 deletions

View File

@ -1,4 +1,5 @@
"""Per-test stdout/stderr capturing mechanism.""" """Per-test stdout/stderr capturing mechanism."""
import abc
import collections import collections
import contextlib import contextlib
import io import io
@ -6,14 +7,20 @@ import os
import sys import sys
from io import UnsupportedOperation from io import UnsupportedOperation
from tempfile import TemporaryFile from tempfile import TemporaryFile
from types import TracebackType
from typing import Any from typing import Any
from typing import AnyStr from typing import AnyStr
from typing import BinaryIO
from typing import Generator from typing import Generator
from typing import Generic from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
from typing import NamedTuple from typing import NamedTuple
from typing import Optional from typing import Optional
from typing import TextIO from typing import TextIO
from typing import Tuple from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union from typing import Union
@ -29,6 +36,7 @@ from _pytest.nodes import File
from _pytest.nodes import Item from _pytest.nodes import Item
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Final
from typing_extensions import Literal from typing_extensions import Literal
_CaptureMethod = Literal["fd", "sys", "no", "tee-sys"] _CaptureMethod = Literal["fd", "sys", "no", "tee-sys"]
@ -185,19 +193,27 @@ class TeeCaptureIO(CaptureIO):
return self._other.write(s) return self._other.write(s)
class DontReadFromInput: class DontReadFromInput(TextIO):
encoding = None @property
def encoding(self) -> str:
return sys.__stdin__.encoding
def read(self, *args): def read(self, size: int = -1) -> str:
raise OSError( raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`." "pytest: reading from stdin while output is captured! Consider using `-s`."
) )
readline = read readline = read
readlines = read
__next__ = read
def __iter__(self): def __next__(self) -> str:
return self.readline()
def readlines(self, hint: Optional[int] = -1) -> List[str]:
raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`."
)
def __iter__(self) -> Iterator[str]:
return self return self
def fileno(self) -> int: def fileno(self) -> int:
@ -215,7 +231,7 @@ class DontReadFromInput:
def readable(self) -> bool: def readable(self) -> bool:
return False return False
def seek(self, offset: int) -> int: def seek(self, offset: int, whence: int = 0) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no seek(int)") raise UnsupportedOperation("redirected stdin is pseudofile, has no seek(int)")
def seekable(self) -> bool: def seekable(self) -> bool:
@ -224,41 +240,104 @@ class DontReadFromInput:
def tell(self) -> int: def tell(self) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()") raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()")
def truncate(self, size: int) -> None: def truncate(self, size: Optional[int] = None) -> int:
raise UnsupportedOperation("cannont truncate stdin") raise UnsupportedOperation("cannont truncate stdin")
def write(self, *args) -> None: def write(self, data: str) -> int:
raise UnsupportedOperation("cannot write to stdin") raise UnsupportedOperation("cannot write to stdin")
def writelines(self, *args) -> None: def writelines(self, lines: Iterable[str]) -> None:
raise UnsupportedOperation("Cannot write to stdin") raise UnsupportedOperation("Cannot write to stdin")
def writable(self) -> bool: def writable(self) -> bool:
return False return False
@property def __enter__(self) -> "DontReadFromInput":
def buffer(self):
return self return self
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
pass
@property
def buffer(self) -> BinaryIO:
# The str/bytes doesn't actually matter in this type, so OK to fake.
return self # type: ignore[return-value]
# Capture classes. # Capture classes.
class CaptureBase(abc.ABC, Generic[AnyStr]):
EMPTY_BUFFER: AnyStr
@abc.abstractmethod
def __init__(self, fd: int) -> None:
raise NotImplementedError()
@abc.abstractmethod
def start(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def done(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def suspend(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def resume(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def writeorg(self, data: AnyStr) -> None:
raise NotImplementedError()
@abc.abstractmethod
def snap(self) -> AnyStr:
raise NotImplementedError()
patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"} patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"}
class NoCapture: class NoCapture(CaptureBase[str]):
EMPTY_BUFFER = None EMPTY_BUFFER = ""
__init__ = start = done = suspend = resume = lambda *args: None
def __init__(self, fd: int) -> None:
pass
def start(self) -> None:
pass
def done(self) -> None:
pass
def suspend(self) -> None:
pass
def resume(self) -> None:
pass
def snap(self) -> str:
return ""
def writeorg(self, data: str) -> None:
pass
class SysCaptureBinary: class SysCaptureBase(CaptureBase[AnyStr]):
def __init__(
EMPTY_BUFFER = b"" self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False
) -> None:
def __init__(self, fd: int, tmpfile=None, *, tee: bool = False) -> None:
name = patchsysdict[fd] name = patchsysdict[fd]
self._old = getattr(sys, name) self._old: TextIO = getattr(sys, name)
self.name = name self.name = name
if tmpfile is None: if tmpfile is None:
if name == "stdin": if name == "stdin":
@ -298,14 +377,6 @@ class SysCaptureBinary:
setattr(sys, self.name, self.tmpfile) setattr(sys, self.name, self.tmpfile)
self._state = "started" self._state = "started"
def snap(self):
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def done(self) -> None: def done(self) -> None:
self._assert_state("done", ("initialized", "started", "suspended", "done")) self._assert_state("done", ("initialized", "started", "suspended", "done"))
if self._state == "done": if self._state == "done":
@ -327,36 +398,43 @@ class SysCaptureBinary:
setattr(sys, self.name, self.tmpfile) setattr(sys, self.name, self.tmpfile)
self._state = "started" self._state = "started"
def writeorg(self, data) -> None:
class SysCaptureBinary(SysCaptureBase[bytes]):
EMPTY_BUFFER = b""
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data: bytes) -> None:
self._assert_state("writeorg", ("started", "suspended")) self._assert_state("writeorg", ("started", "suspended"))
self._old.flush() self._old.flush()
self._old.buffer.write(data) self._old.buffer.write(data)
self._old.buffer.flush() self._old.buffer.flush()
class SysCapture(SysCaptureBinary): class SysCapture(SysCaptureBase[str]):
EMPTY_BUFFER = "" # type: ignore[assignment] EMPTY_BUFFER = ""
def snap(self): def snap(self) -> str:
self._assert_state("snap", ("started", "suspended"))
assert isinstance(self.tmpfile, CaptureIO)
res = self.tmpfile.getvalue() res = self.tmpfile.getvalue()
self.tmpfile.seek(0) self.tmpfile.seek(0)
self.tmpfile.truncate() self.tmpfile.truncate()
return res return res
def writeorg(self, data): def writeorg(self, data: str) -> None:
self._assert_state("writeorg", ("started", "suspended")) self._assert_state("writeorg", ("started", "suspended"))
self._old.write(data) self._old.write(data)
self._old.flush() self._old.flush()
class FDCaptureBinary: class FDCaptureBase(CaptureBase[AnyStr]):
"""Capture IO to/from a given OS-level file descriptor.
snap() produces `bytes`.
"""
EMPTY_BUFFER = b""
def __init__(self, targetfd: int) -> None: def __init__(self, targetfd: int) -> None:
self.targetfd = targetfd self.targetfd = targetfd
@ -382,7 +460,7 @@ class FDCaptureBinary:
if targetfd == 0: if targetfd == 0:
self.tmpfile = open(os.devnull, encoding="utf-8") self.tmpfile = open(os.devnull, encoding="utf-8")
self.syscapture = SysCapture(targetfd) self.syscapture: CaptureBase[str] = SysCapture(targetfd)
else: else:
self.tmpfile = EncodedFile( self.tmpfile = EncodedFile(
TemporaryFile(buffering=0), TemporaryFile(buffering=0),
@ -394,7 +472,7 @@ class FDCaptureBinary:
if targetfd in patchsysdict: if targetfd in patchsysdict:
self.syscapture = SysCapture(targetfd, self.tmpfile) self.syscapture = SysCapture(targetfd, self.tmpfile)
else: else:
self.syscapture = NoCapture() self.syscapture = NoCapture(targetfd)
self._state = "initialized" self._state = "initialized"
@ -421,14 +499,6 @@ class FDCaptureBinary:
self.syscapture.start() self.syscapture.start()
self._state = "started" self._state = "started"
def snap(self):
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def done(self) -> None: def done(self) -> None:
"""Stop capturing, restore streams, return original capture file, """Stop capturing, restore streams, return original capture file,
seeked to position zero.""" seeked to position zero."""
@ -461,22 +531,38 @@ class FDCaptureBinary:
os.dup2(self.tmpfile.fileno(), self.targetfd) os.dup2(self.tmpfile.fileno(), self.targetfd)
self._state = "started" self._state = "started"
def writeorg(self, data):
class FDCaptureBinary(FDCaptureBase[bytes]):
"""Capture IO to/from a given OS-level file descriptor.
snap() produces `bytes`.
"""
EMPTY_BUFFER = b""
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data: bytes) -> None:
"""Write to original file descriptor.""" """Write to original file descriptor."""
self._assert_state("writeorg", ("started", "suspended")) self._assert_state("writeorg", ("started", "suspended"))
os.write(self.targetfd_save, data) os.write(self.targetfd_save, data)
class FDCapture(FDCaptureBinary): class FDCapture(FDCaptureBase[str]):
"""Capture IO to/from a given OS-level file descriptor. """Capture IO to/from a given OS-level file descriptor.
snap() produces text. snap() produces text.
""" """
# Ignore type because it doesn't match the type in the superclass (bytes). EMPTY_BUFFER = ""
EMPTY_BUFFER = "" # type: ignore
def snap(self): def snap(self) -> str:
self._assert_state("snap", ("started", "suspended")) self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0) self.tmpfile.seek(0)
res = self.tmpfile.read() res = self.tmpfile.read()
@ -484,9 +570,11 @@ class FDCapture(FDCaptureBinary):
self.tmpfile.truncate() self.tmpfile.truncate()
return res return res
def writeorg(self, data): def writeorg(self, data: str) -> None:
"""Write to original file descriptor.""" """Write to original file descriptor."""
super().writeorg(data.encode("utf-8")) # XXX use encoding of original stream self._assert_state("writeorg", ("started", "suspended"))
# XXX use encoding of original stream
os.write(self.targetfd_save, data.encode("utf-8"))
# MultiCapture # MultiCapture
@ -516,10 +604,15 @@ class MultiCapture(Generic[AnyStr]):
_state = None _state = None
_in_suspended = False _in_suspended = False
def __init__(self, in_, out, err) -> None: def __init__(
self.in_ = in_ self,
self.out = out in_: Optional[CaptureBase[AnyStr]],
self.err = err out: Optional[CaptureBase[AnyStr]],
err: Optional[CaptureBase[AnyStr]],
) -> None:
self.in_: Optional[CaptureBase[AnyStr]] = in_
self.out: Optional[CaptureBase[AnyStr]] = out
self.err: Optional[CaptureBase[AnyStr]] = err
def __repr__(self) -> str: def __repr__(self) -> str:
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format( return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format(
@ -543,8 +636,10 @@ class MultiCapture(Generic[AnyStr]):
"""Pop current snapshot out/err capture and flush to orig streams.""" """Pop current snapshot out/err capture and flush to orig streams."""
out, err = self.readouterr() out, err = self.readouterr()
if out: if out:
assert self.out is not None
self.out.writeorg(out) self.out.writeorg(out)
if err: if err:
assert self.err is not None
self.err.writeorg(err) self.err.writeorg(err)
return out, err return out, err
@ -565,6 +660,7 @@ class MultiCapture(Generic[AnyStr]):
if self.err: if self.err:
self.err.resume() self.err.resume()
if self._in_suspended: if self._in_suspended:
assert self.in_ is not None
self.in_.resume() self.in_.resume()
self._in_suspended = False self._in_suspended = False
@ -587,7 +683,8 @@ class MultiCapture(Generic[AnyStr]):
def readouterr(self) -> CaptureResult[AnyStr]: def readouterr(self) -> CaptureResult[AnyStr]:
out = self.out.snap() if self.out else "" out = self.out.snap() if self.out else ""
err = self.err.snap() if self.err else "" err = self.err.snap() if self.err else ""
return CaptureResult(out, err) # TODO: This type error is real, need to fix.
return CaptureResult(out, err) # type: ignore[arg-type]
def _get_multicapture(method: "_CaptureMethod") -> MultiCapture[str]: def _get_multicapture(method: "_CaptureMethod") -> MultiCapture[str]:
@ -627,7 +724,7 @@ class CaptureManager:
""" """
def __init__(self, method: "_CaptureMethod") -> None: def __init__(self, method: "_CaptureMethod") -> None:
self._method = method self._method: Final = method
self._global_capturing: Optional[MultiCapture[str]] = None self._global_capturing: Optional[MultiCapture[str]] = None
self._capture_fixture: Optional[CaptureFixture[Any]] = None self._capture_fixture: Optional[CaptureFixture[Any]] = None
@ -796,14 +893,18 @@ class CaptureFixture(Generic[AnyStr]):
:fixture:`capfd` and :fixture:`capfdbinary` fixtures.""" :fixture:`capfd` and :fixture:`capfdbinary` fixtures."""
def __init__( def __init__(
self, captureclass, request: SubRequest, *, _ispytest: bool = False self,
captureclass: Type[CaptureBase[AnyStr]],
request: SubRequest,
*,
_ispytest: bool = False,
) -> None: ) -> None:
check_ispytest(_ispytest) check_ispytest(_ispytest)
self.captureclass = captureclass self.captureclass: Type[CaptureBase[AnyStr]] = captureclass
self.request = request self.request = request
self._capture: Optional[MultiCapture[AnyStr]] = None self._capture: Optional[MultiCapture[AnyStr]] = None
self._captured_out = self.captureclass.EMPTY_BUFFER self._captured_out: AnyStr = self.captureclass.EMPTY_BUFFER
self._captured_err = self.captureclass.EMPTY_BUFFER self._captured_err: AnyStr = self.captureclass.EMPTY_BUFFER
def _start(self) -> None: def _start(self) -> None:
if self._capture is None: if self._capture is None:
@ -858,7 +959,9 @@ class CaptureFixture(Generic[AnyStr]):
@contextlib.contextmanager @contextlib.contextmanager
def disabled(self) -> Generator[None, None, None]: def disabled(self) -> Generator[None, None, None]:
"""Temporarily disable capturing while inside the ``with`` block.""" """Temporarily disable capturing while inside the ``with`` block."""
capmanager = self.request.config.pluginmanager.getplugin("capturemanager") capmanager: CaptureManager = self.request.config.pluginmanager.getplugin(
"capturemanager"
)
with capmanager.global_and_fixture_disabled(): with capmanager.global_and_fixture_disabled():
yield yield
@ -885,8 +988,8 @@ def capsys(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == "hello\n" assert captured.out == "hello\n"
""" """
capman = request.config.pluginmanager.getplugin("capturemanager") capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[str](SysCapture, request, _ispytest=True) capture_fixture = CaptureFixture(SysCapture, request, _ispytest=True)
capman.set_fixture(capture_fixture) capman.set_fixture(capture_fixture)
capture_fixture._start() capture_fixture._start()
yield capture_fixture yield capture_fixture
@ -913,8 +1016,8 @@ def capsysbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None,
captured = capsysbinary.readouterr() captured = capsysbinary.readouterr()
assert captured.out == b"hello\n" assert captured.out == b"hello\n"
""" """
capman = request.config.pluginmanager.getplugin("capturemanager") capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[bytes](SysCaptureBinary, request, _ispytest=True) capture_fixture = CaptureFixture(SysCaptureBinary, request, _ispytest=True)
capman.set_fixture(capture_fixture) capman.set_fixture(capture_fixture)
capture_fixture._start() capture_fixture._start()
yield capture_fixture yield capture_fixture
@ -941,8 +1044,8 @@ def capfd(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
captured = capfd.readouterr() captured = capfd.readouterr()
assert captured.out == "hello\n" assert captured.out == "hello\n"
""" """
capman = request.config.pluginmanager.getplugin("capturemanager") capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[str](FDCapture, request, _ispytest=True) capture_fixture = CaptureFixture(FDCapture, request, _ispytest=True)
capman.set_fixture(capture_fixture) capman.set_fixture(capture_fixture)
capture_fixture._start() capture_fixture._start()
yield capture_fixture yield capture_fixture
@ -970,8 +1073,8 @@ def capfdbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, N
assert captured.out == b"hello\n" assert captured.out == b"hello\n"
""" """
capman = request.config.pluginmanager.getplugin("capturemanager") capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[bytes](FDCaptureBinary, request, _ispytest=True) capture_fixture = CaptureFixture(FDCaptureBinary, request, _ispytest=True)
capman.set_fixture(capture_fixture) capman.set_fixture(capture_fixture)
capture_fixture._start() capture_fixture._start()
yield capture_fixture yield capture_fixture

View File

@ -890,7 +890,7 @@ def test_dontreadfrominput() -> None:
from _pytest.capture import DontReadFromInput from _pytest.capture import DontReadFromInput
f = DontReadFromInput() f = DontReadFromInput()
assert f.buffer is f assert f.buffer is f # type: ignore[comparison-overlap]
assert not f.isatty() assert not f.isatty()
pytest.raises(OSError, f.read) pytest.raises(OSError, f.read)
pytest.raises(OSError, f.readlines) pytest.raises(OSError, f.readlines)
@ -906,7 +906,10 @@ def test_dontreadfrominput() -> None:
pytest.raises(UnsupportedOperation, f.write, b"") pytest.raises(UnsupportedOperation, f.write, b"")
pytest.raises(UnsupportedOperation, f.writelines, []) pytest.raises(UnsupportedOperation, f.writelines, [])
assert not f.writable() assert not f.writable()
assert isinstance(f.encoding, str)
f.close() # just for completeness f.close() # just for completeness
with f:
pass
def test_captureresult() -> None: def test_captureresult() -> None:
@ -1049,6 +1052,7 @@ class TestFDCapture:
) )
) )
# Should not crash with missing "_old". # Should not crash with missing "_old".
assert isinstance(cap.syscapture, capture.SysCapture)
assert repr(cap.syscapture) == ( assert repr(cap.syscapture) == (
"<SysCapture stdout _old=<UNSET> _state='done' tmpfile={!r}>".format( "<SysCapture stdout _old=<UNSET> _state='done' tmpfile={!r}>".format(
cap.syscapture.tmpfile cap.syscapture.tmpfile
@ -1349,6 +1353,7 @@ def test_capsys_results_accessible_by_attribute(capsys: CaptureFixture[str]) ->
def test_fdcapture_tmpfile_remains_the_same() -> None: def test_fdcapture_tmpfile_remains_the_same() -> None:
cap = StdCaptureFD(out=False, err=True) cap = StdCaptureFD(out=False, err=True)
assert isinstance(cap.err, capture.FDCapture)
try: try:
cap.start_capturing() cap.start_capturing()
capfile = cap.err.tmpfile capfile = cap.err.tmpfile