diff --git a/src/_pytest/capture.py b/src/_pytest/capture.py index a07892563..32e83dd21 100644 --- a/src/_pytest/capture.py +++ b/src/_pytest/capture.py @@ -11,6 +11,7 @@ from io import UnsupportedOperation from tempfile import TemporaryFile from typing import Optional from typing import TextIO +from typing import Tuple import pytest from _pytest.compat import TYPE_CHECKING @@ -245,7 +246,6 @@ class NoCapture: class SysCaptureBinary: EMPTY_BUFFER = b"" - _state = None def __init__(self, fd, tmpfile=None, *, tee=False): name = patchsysdict[fd] @@ -257,6 +257,7 @@ class SysCaptureBinary: else: tmpfile = CaptureIO() if not tee else TeeCaptureIO(self._old) self.tmpfile = tmpfile + self._state = "initialized" def repr(self, class_name: str) -> str: return "<{} {} _old={} _state={!r} tmpfile={!r}>".format( @@ -276,11 +277,20 @@ class SysCaptureBinary: self.tmpfile, ) + def _assert_state(self, op: str, states: Tuple[str, ...]) -> None: + assert ( + self._state in states + ), "cannot {} in state {!r}: expected one of {}".format( + op, self._state, ", ".join(states) + ) + def start(self): + self._assert_state("start", ("initialized",)) setattr(sys, self.name, self.tmpfile) self._state = "started" def snap(self): + self._assert_state("snap", ("started", "suspended")) self.tmpfile.seek(0) res = self.tmpfile.buffer.read() self.tmpfile.seek(0) @@ -288,20 +298,28 @@ class SysCaptureBinary: return res def done(self): + self._assert_state("done", ("initialized", "started", "suspended", "done")) + if self._state == "done": + return setattr(sys, self.name, self._old) del self._old self.tmpfile.close() self._state = "done" def suspend(self): + self._assert_state("suspend", ("started", "suspended")) setattr(sys, self.name, self._old) self._state = "suspended" def resume(self): + self._assert_state("resume", ("started", "suspended")) + if self._state == "started": + return setattr(sys, self.name, self.tmpfile) - self._state = "resumed" + self._state = "started" def writeorg(self, data): + self._assert_state("writeorg", ("started", "suspended")) self._old.flush() self._old.buffer.write(data) self._old.buffer.flush() @@ -317,6 +335,7 @@ class SysCapture(SysCaptureBinary): return res def writeorg(self, data): + self._assert_state("writeorg", ("started", "suspended")) self._old.write(data) self._old.flush() @@ -328,7 +347,6 @@ class FDCaptureBinary: """ EMPTY_BUFFER = b"" - _state = None def __init__(self, targetfd): self.targetfd = targetfd @@ -368,6 +386,8 @@ class FDCaptureBinary: else: self.syscapture = NoCapture() + self._state = "initialized" + def __repr__(self): return "<{} {} oldfd={} _state={!r} tmpfile={!r}>".format( self.__class__.__name__, @@ -377,13 +397,22 @@ class FDCaptureBinary: self.tmpfile, ) + def _assert_state(self, op: str, states: Tuple[str, ...]) -> None: + assert ( + self._state in states + ), "cannot {} in state {!r}: expected one of {}".format( + op, self._state, ", ".join(states) + ) + def start(self): """ Start capturing on targetfd using memorized tmpfile. """ + self._assert_state("start", ("initialized",)) os.dup2(self.tmpfile.fileno(), self.targetfd) self.syscapture.start() self._state = "started" def snap(self): + self._assert_state("snap", ("started", "suspended")) self.tmpfile.seek(0) res = self.tmpfile.buffer.read() self.tmpfile.seek(0) @@ -393,6 +422,9 @@ class FDCaptureBinary: def done(self): """ stop capturing, restore streams, return original capture file, seeked to position zero. """ + self._assert_state("done", ("initialized", "started", "suspended", "done")) + if self._state == "done": + return os.dup2(self.targetfd_save, self.targetfd) os.close(self.targetfd_save) if self.targetfd_invalid is not None: @@ -404,17 +436,24 @@ class FDCaptureBinary: self._state = "done" def suspend(self): + self._assert_state("suspend", ("started", "suspended")) + if self._state == "suspended": + return self.syscapture.suspend() os.dup2(self.targetfd_save, self.targetfd) self._state = "suspended" def resume(self): + self._assert_state("resume", ("started", "suspended")) + if self._state == "started": + return self.syscapture.resume() os.dup2(self.tmpfile.fileno(), self.targetfd) - self._state = "resumed" + self._state = "started" def writeorg(self, data): """ write to original file descriptor. """ + self._assert_state("writeorg", ("started", "suspended")) os.write(self.targetfd_save, data) @@ -428,6 +467,7 @@ class FDCapture(FDCaptureBinary): EMPTY_BUFFER = "" # type: ignore def snap(self): + self._assert_state("snap", ("started", "suspended")) self.tmpfile.seek(0) res = self.tmpfile.read() self.tmpfile.seek(0) diff --git a/testing/test_capture.py b/testing/test_capture.py index 5a0998da7..95f2d748a 100644 --- a/testing/test_capture.py +++ b/testing/test_capture.py @@ -878,9 +878,8 @@ class TestFDCapture: cap = capture.FDCapture(fd) data = b"hello" os.write(fd, data) - s = cap.snap() + pytest.raises(AssertionError, cap.snap) cap.done() - assert not s cap = capture.FDCapture(fd) cap.start() os.write(fd, data) @@ -901,7 +900,7 @@ class TestFDCapture: fd = tmpfile.fileno() cap = capture.FDCapture(fd) cap.done() - pytest.raises(ValueError, cap.start) + pytest.raises(AssertionError, cap.start) def test_stderr(self): cap = capture.FDCapture(2) @@ -952,7 +951,7 @@ class TestFDCapture: assert s == "but now yes\n" cap.suspend() cap.done() - pytest.raises(AttributeError, cap.suspend) + pytest.raises(AssertionError, cap.suspend) assert repr(cap) == ( "".format( @@ -1154,6 +1153,7 @@ class TestStdCaptureFD(TestStdCapture): with lsof_check(): for i in range(10): cap = StdCaptureFD() + cap.start_capturing() cap.stop_capturing() @@ -1175,7 +1175,7 @@ class TestStdCaptureFDinvalidFD: def test_stdout(): os.close(1) cap = StdCaptureFD(out=True, err=False, in_=False) - assert fnmatch(repr(cap.out), "") + assert fnmatch(repr(cap.out), "") cap.start_capturing() os.write(1, b"stdout") assert cap.readouterr() == ("stdout", "") @@ -1184,7 +1184,7 @@ class TestStdCaptureFDinvalidFD: def test_stderr(): os.close(2) cap = StdCaptureFD(out=False, err=True, in_=False) - assert fnmatch(repr(cap.err), "") + assert fnmatch(repr(cap.err), "") cap.start_capturing() os.write(2, b"stderr") assert cap.readouterr() == ("", "stderr") @@ -1193,7 +1193,7 @@ class TestStdCaptureFDinvalidFD: def test_stdin(): os.close(0) cap = StdCaptureFD(out=False, err=False, in_=True) - assert fnmatch(repr(cap.in_), "") + assert fnmatch(repr(cap.in_), "") cap.stop_capturing() """ )