capture: improve `SysCapture`/`FDCapture` typing

Instead of `SysCapture`/`FDCapture` inheriting from
`SysCaptureBinary`/`FDCaptureBinary`, have both inherit from a common
`SysCaptureBase`/`FDCaptureBase`. This fixes a Liskov substitution
violation.
This commit is contained in:
Ran Benita 2023-01-20 14:08:08 +02:00
parent a3693ce503
commit c746d2b016
1 changed files with 50 additions and 42 deletions

View File

@ -278,13 +278,12 @@ class NoCapture:
__init__ = start = done = suspend = resume = lambda *args: None
class SysCaptureBinary:
EMPTY_BUFFER = b""
def __init__(self, fd: int, tmpfile=None, *, tee: bool = False) -> None:
class SysCaptureBase:
def __init__(
self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False
) -> None:
name = patchsysdict[fd]
self._old = getattr(sys, name)
self._old: TextIO = getattr(sys, name)
self.name = name
if tmpfile is None:
if name == "stdin":
@ -324,14 +323,6 @@ class SysCaptureBinary:
setattr(sys, self.name, self.tmpfile)
self._state = "started"
def snap(self):
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def done(self) -> None:
self._assert_state("done", ("initialized", "started", "suspended", "done"))
if self._state == "done":
@ -353,36 +344,43 @@ class SysCaptureBinary:
setattr(sys, self.name, self.tmpfile)
self._state = "started"
def writeorg(self, data) -> None:
class SysCaptureBinary(SysCaptureBase):
EMPTY_BUFFER = b""
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data: bytes) -> None:
self._assert_state("writeorg", ("started", "suspended"))
self._old.flush()
self._old.buffer.write(data)
self._old.buffer.flush()
class SysCapture(SysCaptureBinary):
EMPTY_BUFFER = "" # type: ignore[assignment]
class SysCapture(SysCaptureBase):
EMPTY_BUFFER = ""
def snap(self):
def snap(self) -> str:
self._assert_state("snap", ("started", "suspended"))
assert isinstance(self.tmpfile, CaptureIO)
res = self.tmpfile.getvalue()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data):
def writeorg(self, data: str) -> None:
self._assert_state("writeorg", ("started", "suspended"))
self._old.write(data)
self._old.flush()
class FDCaptureBinary:
"""Capture IO to/from a given OS-level file descriptor.
snap() produces `bytes`.
"""
EMPTY_BUFFER = b""
class FDCaptureBase:
def __init__(self, targetfd: int) -> None:
self.targetfd = targetfd
@ -447,14 +445,6 @@ class FDCaptureBinary:
self.syscapture.start()
self._state = "started"
def snap(self):
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def done(self) -> None:
"""Stop capturing, restore streams, return original capture file,
seeked to position zero."""
@ -487,22 +477,38 @@ class FDCaptureBinary:
os.dup2(self.tmpfile.fileno(), self.targetfd)
self._state = "started"
def writeorg(self, data):
class FDCaptureBinary(FDCaptureBase):
"""Capture IO to/from a given OS-level file descriptor.
snap() produces `bytes`.
"""
EMPTY_BUFFER = b""
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data: bytes) -> None:
"""Write to original file descriptor."""
self._assert_state("writeorg", ("started", "suspended"))
os.write(self.targetfd_save, data)
class FDCapture(FDCaptureBinary):
class FDCapture(FDCaptureBase):
"""Capture IO to/from a given OS-level file descriptor.
snap() produces text.
"""
# Ignore type because it doesn't match the type in the superclass (bytes).
EMPTY_BUFFER = "" # type: ignore
EMPTY_BUFFER = ""
def snap(self):
def snap(self) -> str:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.read()
@ -510,9 +516,11 @@ class FDCapture(FDCaptureBinary):
self.tmpfile.truncate()
return res
def writeorg(self, data):
def writeorg(self, data: str) -> None:
"""Write to original file descriptor."""
super().writeorg(data.encode("utf-8")) # XXX use encoding of original stream
self._assert_state("writeorg", ("started", "suspended"))
# XXX use encoding of original stream
os.write(self.targetfd_save, data.encode("utf-8"))
# MultiCapture