From c746d2b016a149338c68511d2c5bee8cc5a31325 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Fri, 20 Jan 2023 14:08:08 +0200 Subject: [PATCH] capture: improve `SysCapture`/`FDCapture` typing Instead of `SysCapture`/`FDCapture` inheriting from `SysCaptureBinary`/`FDCaptureBinary`, have both inherit from a common `SysCaptureBase`/`FDCaptureBase`. This fixes a Liskov substitution violation. --- src/_pytest/capture.py | 92 +++++++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/src/_pytest/capture.py b/src/_pytest/capture.py index 29d4f1524..cb5b966c9 100644 --- a/src/_pytest/capture.py +++ b/src/_pytest/capture.py @@ -278,13 +278,12 @@ class NoCapture: __init__ = start = done = suspend = resume = lambda *args: None -class SysCaptureBinary: - - EMPTY_BUFFER = b"" - - def __init__(self, fd: int, tmpfile=None, *, tee: bool = False) -> None: +class SysCaptureBase: + def __init__( + self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False + ) -> None: name = patchsysdict[fd] - self._old = getattr(sys, name) + self._old: TextIO = getattr(sys, name) self.name = name if tmpfile is None: if name == "stdin": @@ -324,14 +323,6 @@ class SysCaptureBinary: setattr(sys, self.name, self.tmpfile) self._state = "started" - def snap(self): - self._assert_state("snap", ("started", "suspended")) - self.tmpfile.seek(0) - res = self.tmpfile.buffer.read() - self.tmpfile.seek(0) - self.tmpfile.truncate() - return res - def done(self) -> None: self._assert_state("done", ("initialized", "started", "suspended", "done")) if self._state == "done": @@ -353,36 +344,43 @@ class SysCaptureBinary: setattr(sys, self.name, self.tmpfile) self._state = "started" - def writeorg(self, data) -> None: + +class SysCaptureBinary(SysCaptureBase): + EMPTY_BUFFER = b"" + + def snap(self) -> bytes: + self._assert_state("snap", ("started", "suspended")) + self.tmpfile.seek(0) + res = self.tmpfile.buffer.read() + self.tmpfile.seek(0) + self.tmpfile.truncate() + return res + + def writeorg(self, data: bytes) -> None: self._assert_state("writeorg", ("started", "suspended")) self._old.flush() self._old.buffer.write(data) self._old.buffer.flush() -class SysCapture(SysCaptureBinary): - EMPTY_BUFFER = "" # type: ignore[assignment] +class SysCapture(SysCaptureBase): + EMPTY_BUFFER = "" - def snap(self): + def snap(self) -> str: + self._assert_state("snap", ("started", "suspended")) + assert isinstance(self.tmpfile, CaptureIO) res = self.tmpfile.getvalue() self.tmpfile.seek(0) self.tmpfile.truncate() return res - def writeorg(self, data): + def writeorg(self, data: str) -> None: self._assert_state("writeorg", ("started", "suspended")) self._old.write(data) self._old.flush() -class FDCaptureBinary: - """Capture IO to/from a given OS-level file descriptor. - - snap() produces `bytes`. - """ - - EMPTY_BUFFER = b"" - +class FDCaptureBase: def __init__(self, targetfd: int) -> None: self.targetfd = targetfd @@ -447,14 +445,6 @@ class FDCaptureBinary: self.syscapture.start() self._state = "started" - def snap(self): - self._assert_state("snap", ("started", "suspended")) - self.tmpfile.seek(0) - res = self.tmpfile.buffer.read() - self.tmpfile.seek(0) - self.tmpfile.truncate() - return res - def done(self) -> None: """Stop capturing, restore streams, return original capture file, seeked to position zero.""" @@ -487,22 +477,38 @@ class FDCaptureBinary: os.dup2(self.tmpfile.fileno(), self.targetfd) self._state = "started" - def writeorg(self, data): + +class FDCaptureBinary(FDCaptureBase): + """Capture IO to/from a given OS-level file descriptor. + + snap() produces `bytes`. + """ + + EMPTY_BUFFER = b"" + + def snap(self) -> bytes: + self._assert_state("snap", ("started", "suspended")) + self.tmpfile.seek(0) + res = self.tmpfile.buffer.read() + self.tmpfile.seek(0) + self.tmpfile.truncate() + return res + + def writeorg(self, data: bytes) -> None: """Write to original file descriptor.""" self._assert_state("writeorg", ("started", "suspended")) os.write(self.targetfd_save, data) -class FDCapture(FDCaptureBinary): +class FDCapture(FDCaptureBase): """Capture IO to/from a given OS-level file descriptor. snap() produces text. """ - # Ignore type because it doesn't match the type in the superclass (bytes). - EMPTY_BUFFER = "" # type: ignore + EMPTY_BUFFER = "" - def snap(self): + def snap(self) -> str: self._assert_state("snap", ("started", "suspended")) self.tmpfile.seek(0) res = self.tmpfile.read() @@ -510,9 +516,11 @@ class FDCapture(FDCaptureBinary): self.tmpfile.truncate() return res - def writeorg(self, data): + def writeorg(self, data: str) -> None: """Write to original file descriptor.""" - super().writeorg(data.encode("utf-8")) # XXX use encoding of original stream + self._assert_state("writeorg", ("started", "suspended")) + # XXX use encoding of original stream + os.write(self.targetfd_save, data.encode("utf-8")) # MultiCapture