capture: improve `DontReadFromInput` typing

Have `DontReadFromInput` inherit from `TextIO`, ensuring it's fully
compatible with `sys.stdin` (which has type `TextIO`).
This commit is contained in:
Ran Benita 2023-01-20 14:14:44 +02:00
parent 7d4b40337b
commit a3693ce503
2 changed files with 42 additions and 13 deletions

View File

@ -6,14 +6,20 @@ import os
import sys import sys
from io import UnsupportedOperation from io import UnsupportedOperation
from tempfile import TemporaryFile from tempfile import TemporaryFile
from types import TracebackType
from typing import Any from typing import Any
from typing import AnyStr from typing import AnyStr
from typing import BinaryIO
from typing import Generator from typing import Generator
from typing import Generic from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
from typing import NamedTuple from typing import NamedTuple
from typing import Optional from typing import Optional
from typing import TextIO from typing import TextIO
from typing import Tuple from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union from typing import Union
@ -185,19 +191,27 @@ class TeeCaptureIO(CaptureIO):
return self._other.write(s) return self._other.write(s)
class DontReadFromInput: class DontReadFromInput(TextIO):
encoding = None @property
def encoding(self) -> str:
return sys.__stdin__.encoding
def read(self, *args): def read(self, size: int = -1) -> str:
raise OSError( raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`." "pytest: reading from stdin while output is captured! Consider using `-s`."
) )
readline = read readline = read
readlines = read
__next__ = read
def __iter__(self): def __next__(self) -> str:
return self.readline()
def readlines(self, hint: Optional[int] = -1) -> List[str]:
raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`."
)
def __iter__(self) -> Iterator[str]:
return self return self
def fileno(self) -> int: def fileno(self) -> int:
@ -215,7 +229,7 @@ class DontReadFromInput:
def readable(self) -> bool: def readable(self) -> bool:
return False return False
def seek(self, offset: int) -> int: def seek(self, offset: int, whence: int = 0) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no seek(int)") raise UnsupportedOperation("redirected stdin is pseudofile, has no seek(int)")
def seekable(self) -> bool: def seekable(self) -> bool:
@ -224,22 +238,34 @@ class DontReadFromInput:
def tell(self) -> int: def tell(self) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()") raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()")
def truncate(self, size: int) -> None: def truncate(self, size: Optional[int] = None) -> int:
raise UnsupportedOperation("cannont truncate stdin") raise UnsupportedOperation("cannont truncate stdin")
def write(self, *args) -> None: def write(self, data: str) -> int:
raise UnsupportedOperation("cannot write to stdin") raise UnsupportedOperation("cannot write to stdin")
def writelines(self, *args) -> None: def writelines(self, lines: Iterable[str]) -> None:
raise UnsupportedOperation("Cannot write to stdin") raise UnsupportedOperation("Cannot write to stdin")
def writable(self) -> bool: def writable(self) -> bool:
return False return False
@property def __enter__(self) -> "DontReadFromInput":
def buffer(self):
return self return self
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
pass
@property
def buffer(self) -> BinaryIO:
# The str/bytes doesn't actually matter in this type, so OK to fake.
return self # type: ignore[return-value]
# Capture classes. # Capture classes.

View File

@ -890,7 +890,7 @@ def test_dontreadfrominput() -> None:
from _pytest.capture import DontReadFromInput from _pytest.capture import DontReadFromInput
f = DontReadFromInput() f = DontReadFromInput()
assert f.buffer is f assert f.buffer is f # type: ignore[comparison-overlap]
assert not f.isatty() assert not f.isatty()
pytest.raises(OSError, f.read) pytest.raises(OSError, f.read)
pytest.raises(OSError, f.readlines) pytest.raises(OSError, f.readlines)
@ -906,7 +906,10 @@ def test_dontreadfrominput() -> None:
pytest.raises(UnsupportedOperation, f.write, b"") pytest.raises(UnsupportedOperation, f.write, b"")
pytest.raises(UnsupportedOperation, f.writelines, []) pytest.raises(UnsupportedOperation, f.writelines, [])
assert not f.writable() assert not f.writable()
assert isinstance(f.encoding, str)
f.close() # just for completeness f.close() # just for completeness
with f:
pass
def test_captureresult() -> None: def test_captureresult() -> None: