capture: improve `SysCapture`/`FDCapture` typing

Instead of `SysCapture`/`FDCapture` inheriting from
`SysCaptureBinary`/`FDCaptureBinary`, have both inherit from a common
`SysCaptureBase`/`FDCaptureBase`. This fixes a Liskov substitution
violation.
This commit is contained in:
Ran Benita 2023-01-20 14:08:08 +02:00
parent a3693ce503
commit c746d2b016
1 changed files with 50 additions and 42 deletions

View File

@ -278,13 +278,12 @@ class NoCapture:
__init__ = start = done = suspend = resume = lambda *args: None __init__ = start = done = suspend = resume = lambda *args: None
class SysCaptureBinary: class SysCaptureBase:
def __init__(
EMPTY_BUFFER = b"" self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False
) -> None:
def __init__(self, fd: int, tmpfile=None, *, tee: bool = False) -> None:
name = patchsysdict[fd] name = patchsysdict[fd]
self._old = getattr(sys, name) self._old: TextIO = getattr(sys, name)
self.name = name self.name = name
if tmpfile is None: if tmpfile is None:
if name == "stdin": if name == "stdin":
@ -324,14 +323,6 @@ class SysCaptureBinary:
setattr(sys, self.name, self.tmpfile) setattr(sys, self.name, self.tmpfile)
self._state = "started" self._state = "started"
def snap(self):
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def done(self) -> None: def done(self) -> None:
self._assert_state("done", ("initialized", "started", "suspended", "done")) self._assert_state("done", ("initialized", "started", "suspended", "done"))
if self._state == "done": if self._state == "done":
@ -353,36 +344,43 @@ class SysCaptureBinary:
setattr(sys, self.name, self.tmpfile) setattr(sys, self.name, self.tmpfile)
self._state = "started" self._state = "started"
def writeorg(self, data) -> None:
class SysCaptureBinary(SysCaptureBase):
EMPTY_BUFFER = b""
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data: bytes) -> None:
self._assert_state("writeorg", ("started", "suspended")) self._assert_state("writeorg", ("started", "suspended"))
self._old.flush() self._old.flush()
self._old.buffer.write(data) self._old.buffer.write(data)
self._old.buffer.flush() self._old.buffer.flush()
class SysCapture(SysCaptureBinary): class SysCapture(SysCaptureBase):
EMPTY_BUFFER = "" # type: ignore[assignment] EMPTY_BUFFER = ""
def snap(self): def snap(self) -> str:
self._assert_state("snap", ("started", "suspended"))
assert isinstance(self.tmpfile, CaptureIO)
res = self.tmpfile.getvalue() res = self.tmpfile.getvalue()
self.tmpfile.seek(0) self.tmpfile.seek(0)
self.tmpfile.truncate() self.tmpfile.truncate()
return res return res
def writeorg(self, data): def writeorg(self, data: str) -> None:
self._assert_state("writeorg", ("started", "suspended")) self._assert_state("writeorg", ("started", "suspended"))
self._old.write(data) self._old.write(data)
self._old.flush() self._old.flush()
class FDCaptureBinary: class FDCaptureBase:
"""Capture IO to/from a given OS-level file descriptor.
snap() produces `bytes`.
"""
EMPTY_BUFFER = b""
def __init__(self, targetfd: int) -> None: def __init__(self, targetfd: int) -> None:
self.targetfd = targetfd self.targetfd = targetfd
@ -447,14 +445,6 @@ class FDCaptureBinary:
self.syscapture.start() self.syscapture.start()
self._state = "started" self._state = "started"
def snap(self):
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def done(self) -> None: def done(self) -> None:
"""Stop capturing, restore streams, return original capture file, """Stop capturing, restore streams, return original capture file,
seeked to position zero.""" seeked to position zero."""
@ -487,22 +477,38 @@ class FDCaptureBinary:
os.dup2(self.tmpfile.fileno(), self.targetfd) os.dup2(self.tmpfile.fileno(), self.targetfd)
self._state = "started" self._state = "started"
def writeorg(self, data):
class FDCaptureBinary(FDCaptureBase):
"""Capture IO to/from a given OS-level file descriptor.
snap() produces `bytes`.
"""
EMPTY_BUFFER = b""
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data: bytes) -> None:
"""Write to original file descriptor.""" """Write to original file descriptor."""
self._assert_state("writeorg", ("started", "suspended")) self._assert_state("writeorg", ("started", "suspended"))
os.write(self.targetfd_save, data) os.write(self.targetfd_save, data)
class FDCapture(FDCaptureBinary): class FDCapture(FDCaptureBase):
"""Capture IO to/from a given OS-level file descriptor. """Capture IO to/from a given OS-level file descriptor.
snap() produces text. snap() produces text.
""" """
# Ignore type because it doesn't match the type in the superclass (bytes). EMPTY_BUFFER = ""
EMPTY_BUFFER = "" # type: ignore
def snap(self): def snap(self) -> str:
self._assert_state("snap", ("started", "suspended")) self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0) self.tmpfile.seek(0)
res = self.tmpfile.read() res = self.tmpfile.read()
@ -510,9 +516,11 @@ class FDCapture(FDCaptureBinary):
self.tmpfile.truncate() self.tmpfile.truncate()
return res return res
def writeorg(self, data): def writeorg(self, data: str) -> None:
"""Write to original file descriptor.""" """Write to original file descriptor."""
super().writeorg(data.encode("utf-8")) # XXX use encoding of original stream self._assert_state("writeorg", ("started", "suspended"))
# XXX use encoding of original stream
os.write(self.targetfd_save, data.encode("utf-8"))
# MultiCapture # MultiCapture