capture: replace TeeSysCapture with SysCapture(tee=True)

This is more straightforward and does not require duplicating the
initialization logic.
This commit is contained in:
Ran Benita 2020-04-16 09:49:17 +03:00
parent 02c95ea624
commit ea3f44894f
2 changed files with 13 additions and 22 deletions

View File

@ -75,7 +75,9 @@ def _get_multicapture(method: "_CaptureMethod") -> "MultiCapture":
elif method == "no":
return MultiCapture(in_=None, out=None, err=None)
elif method == "tee-sys":
return MultiCapture(in_=None, out=TeeSysCapture(1), err=TeeSysCapture(2))
return MultiCapture(
in_=None, out=SysCapture(1, tee=True), err=SysCapture(2, tee=True)
)
raise ValueError("unknown capturing method: {!r}".format(method))
@ -620,7 +622,7 @@ class SysCaptureBinary:
EMPTY_BUFFER = b""
_state = None
def __init__(self, fd, tmpfile=None):
def __init__(self, fd, tmpfile=None, *, tee=False):
name = patchsysdict[fd]
self._old = getattr(sys, name)
self.name = name
@ -628,7 +630,7 @@ class SysCaptureBinary:
if name == "stdin":
tmpfile = DontReadFromInput()
else:
tmpfile = CaptureIO()
tmpfile = CaptureIO() if not tee else TeeCaptureIO(self._old)
self.tmpfile = tmpfile
def __repr__(self):
@ -684,19 +686,6 @@ class SysCapture(SysCaptureBinary):
self._old.flush()
class TeeSysCapture(SysCapture):
def __init__(self, fd, tmpfile=None):
name = patchsysdict[fd]
self._old = getattr(sys, name)
self.name = name
if tmpfile is None:
if name == "stdin":
tmpfile = DontReadFromInput()
else:
tmpfile = TeeCaptureIO(self._old)
self.tmpfile = tmpfile
class DontReadFromInput:
encoding = None

View File

@ -37,9 +37,9 @@ def StdCapture(out: bool = True, err: bool = True, in_: bool = True) -> MultiCap
def TeeStdCapture(out: bool = True, err: bool = True, in_: bool = True) -> MultiCapture:
return capture.MultiCapture(
in_=capture.TeeSysCapture(0) if in_ else None,
out=capture.TeeSysCapture(1) if out else None,
err=capture.TeeSysCapture(2) if err else None,
in_=capture.SysCapture(0, tee=True) if in_ else None,
out=capture.SysCapture(1, tee=True) if out else None,
err=capture.SysCapture(2, tee=True) if err else None,
)
@ -1292,8 +1292,10 @@ def test_close_and_capture_again(testdir):
)
@pytest.mark.parametrize("method", ["SysCapture", "FDCapture", "TeeSysCapture"])
def test_capturing_and_logging_fundamentals(testdir, method):
@pytest.mark.parametrize(
"method", ["SysCapture(2)", "SysCapture(2, tee=True)", "FDCapture(2)"]
)
def test_capturing_and_logging_fundamentals(testdir, method: str) -> None:
# here we check a fundamental feature
p = testdir.makepyfile(
"""
@ -1303,7 +1305,7 @@ def test_capturing_and_logging_fundamentals(testdir, method):
cap = capture.MultiCapture(
in_=None,
out=None,
err=capture.%s(2),
err=capture.%s,
)
cap.start_capturing()