capture: improve `SysCapture`/`FDCapture` typing
Instead of `SysCapture`/`FDCapture` inheriting from `SysCaptureBinary`/`FDCaptureBinary`, have both inherit from a common `SysCaptureBase`/`FDCaptureBase`. This fixes a Liskov substitution violation.
This commit is contained in:
parent
a3693ce503
commit
c746d2b016
|
@ -278,13 +278,12 @@ class NoCapture:
|
|||
__init__ = start = done = suspend = resume = lambda *args: None
|
||||
|
||||
|
||||
class SysCaptureBinary:
|
||||
|
||||
EMPTY_BUFFER = b""
|
||||
|
||||
def __init__(self, fd: int, tmpfile=None, *, tee: bool = False) -> None:
|
||||
class SysCaptureBase:
|
||||
def __init__(
|
||||
self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False
|
||||
) -> None:
|
||||
name = patchsysdict[fd]
|
||||
self._old = getattr(sys, name)
|
||||
self._old: TextIO = getattr(sys, name)
|
||||
self.name = name
|
||||
if tmpfile is None:
|
||||
if name == "stdin":
|
||||
|
@ -324,14 +323,6 @@ class SysCaptureBinary:
|
|||
setattr(sys, self.name, self.tmpfile)
|
||||
self._state = "started"
|
||||
|
||||
def snap(self):
|
||||
self._assert_state("snap", ("started", "suspended"))
|
||||
self.tmpfile.seek(0)
|
||||
res = self.tmpfile.buffer.read()
|
||||
self.tmpfile.seek(0)
|
||||
self.tmpfile.truncate()
|
||||
return res
|
||||
|
||||
def done(self) -> None:
|
||||
self._assert_state("done", ("initialized", "started", "suspended", "done"))
|
||||
if self._state == "done":
|
||||
|
@ -353,36 +344,43 @@ class SysCaptureBinary:
|
|||
setattr(sys, self.name, self.tmpfile)
|
||||
self._state = "started"
|
||||
|
||||
def writeorg(self, data) -> None:
|
||||
|
||||
class SysCaptureBinary(SysCaptureBase):
|
||||
EMPTY_BUFFER = b""
|
||||
|
||||
def snap(self) -> bytes:
|
||||
self._assert_state("snap", ("started", "suspended"))
|
||||
self.tmpfile.seek(0)
|
||||
res = self.tmpfile.buffer.read()
|
||||
self.tmpfile.seek(0)
|
||||
self.tmpfile.truncate()
|
||||
return res
|
||||
|
||||
def writeorg(self, data: bytes) -> None:
|
||||
self._assert_state("writeorg", ("started", "suspended"))
|
||||
self._old.flush()
|
||||
self._old.buffer.write(data)
|
||||
self._old.buffer.flush()
|
||||
|
||||
|
||||
class SysCapture(SysCaptureBinary):
|
||||
EMPTY_BUFFER = "" # type: ignore[assignment]
|
||||
class SysCapture(SysCaptureBase):
|
||||
EMPTY_BUFFER = ""
|
||||
|
||||
def snap(self):
|
||||
def snap(self) -> str:
|
||||
self._assert_state("snap", ("started", "suspended"))
|
||||
assert isinstance(self.tmpfile, CaptureIO)
|
||||
res = self.tmpfile.getvalue()
|
||||
self.tmpfile.seek(0)
|
||||
self.tmpfile.truncate()
|
||||
return res
|
||||
|
||||
def writeorg(self, data):
|
||||
def writeorg(self, data: str) -> None:
|
||||
self._assert_state("writeorg", ("started", "suspended"))
|
||||
self._old.write(data)
|
||||
self._old.flush()
|
||||
|
||||
|
||||
class FDCaptureBinary:
|
||||
"""Capture IO to/from a given OS-level file descriptor.
|
||||
|
||||
snap() produces `bytes`.
|
||||
"""
|
||||
|
||||
EMPTY_BUFFER = b""
|
||||
|
||||
class FDCaptureBase:
|
||||
def __init__(self, targetfd: int) -> None:
|
||||
self.targetfd = targetfd
|
||||
|
||||
|
@ -447,14 +445,6 @@ class FDCaptureBinary:
|
|||
self.syscapture.start()
|
||||
self._state = "started"
|
||||
|
||||
def snap(self):
|
||||
self._assert_state("snap", ("started", "suspended"))
|
||||
self.tmpfile.seek(0)
|
||||
res = self.tmpfile.buffer.read()
|
||||
self.tmpfile.seek(0)
|
||||
self.tmpfile.truncate()
|
||||
return res
|
||||
|
||||
def done(self) -> None:
|
||||
"""Stop capturing, restore streams, return original capture file,
|
||||
seeked to position zero."""
|
||||
|
@ -487,22 +477,38 @@ class FDCaptureBinary:
|
|||
os.dup2(self.tmpfile.fileno(), self.targetfd)
|
||||
self._state = "started"
|
||||
|
||||
def writeorg(self, data):
|
||||
|
||||
class FDCaptureBinary(FDCaptureBase):
|
||||
"""Capture IO to/from a given OS-level file descriptor.
|
||||
|
||||
snap() produces `bytes`.
|
||||
"""
|
||||
|
||||
EMPTY_BUFFER = b""
|
||||
|
||||
def snap(self) -> bytes:
|
||||
self._assert_state("snap", ("started", "suspended"))
|
||||
self.tmpfile.seek(0)
|
||||
res = self.tmpfile.buffer.read()
|
||||
self.tmpfile.seek(0)
|
||||
self.tmpfile.truncate()
|
||||
return res
|
||||
|
||||
def writeorg(self, data: bytes) -> None:
|
||||
"""Write to original file descriptor."""
|
||||
self._assert_state("writeorg", ("started", "suspended"))
|
||||
os.write(self.targetfd_save, data)
|
||||
|
||||
|
||||
class FDCapture(FDCaptureBinary):
|
||||
class FDCapture(FDCaptureBase):
|
||||
"""Capture IO to/from a given OS-level file descriptor.
|
||||
|
||||
snap() produces text.
|
||||
"""
|
||||
|
||||
# Ignore type because it doesn't match the type in the superclass (bytes).
|
||||
EMPTY_BUFFER = "" # type: ignore
|
||||
EMPTY_BUFFER = ""
|
||||
|
||||
def snap(self):
|
||||
def snap(self) -> str:
|
||||
self._assert_state("snap", ("started", "suspended"))
|
||||
self.tmpfile.seek(0)
|
||||
res = self.tmpfile.read()
|
||||
|
@ -510,9 +516,11 @@ class FDCapture(FDCaptureBinary):
|
|||
self.tmpfile.truncate()
|
||||
return res
|
||||
|
||||
def writeorg(self, data):
|
||||
def writeorg(self, data: str) -> None:
|
||||
"""Write to original file descriptor."""
|
||||
super().writeorg(data.encode("utf-8")) # XXX use encoding of original stream
|
||||
self._assert_state("writeorg", ("started", "suspended"))
|
||||
# XXX use encoding of original stream
|
||||
os.write(self.targetfd_save, data.encode("utf-8"))
|
||||
|
||||
|
||||
# MultiCapture
|
||||
|
|
Loading…
Reference in New Issue