From a35800c2e1461d32dfd4322dc21d02e46c49c776 Mon Sep 17 00:00:00 2001
From: Ran Benita <ran@unusedvar.com>
Date: Fri, 17 Apr 2020 10:01:51 +0300
Subject: [PATCH] capture: formalize and check allowed state transition in
 capture classes

There are state transitions start/done/suspend/resume and two additional
operations snap/writeorg.

Previously it was not well defined in what order they can be called, and
which operations are idempotent.

Formalize this and enforce using assert checks with informative error
messages if they fail (rather than random AttributeErrors).
---
 src/_pytest/capture.py  | 48 +++++++++++++++++++++++++++++++++++++----
 testing/test_capture.py | 14 ++++++------
 2 files changed, 51 insertions(+), 11 deletions(-)

diff --git a/src/_pytest/capture.py b/src/_pytest/capture.py
index a07892563..32e83dd21 100644
--- a/src/_pytest/capture.py
+++ b/src/_pytest/capture.py
@@ -11,6 +11,7 @@ from io import UnsupportedOperation
 from tempfile import TemporaryFile
 from typing import Optional
 from typing import TextIO
+from typing import Tuple
 
 import pytest
 from _pytest.compat import TYPE_CHECKING
@@ -245,7 +246,6 @@ class NoCapture:
 class SysCaptureBinary:
 
     EMPTY_BUFFER = b""
-    _state = None
 
     def __init__(self, fd, tmpfile=None, *, tee=False):
         name = patchsysdict[fd]
@@ -257,6 +257,7 @@ class SysCaptureBinary:
             else:
                 tmpfile = CaptureIO() if not tee else TeeCaptureIO(self._old)
         self.tmpfile = tmpfile
+        self._state = "initialized"
 
     def repr(self, class_name: str) -> str:
         return "<{} {} _old={} _state={!r} tmpfile={!r}>".format(
@@ -276,11 +277,20 @@ class SysCaptureBinary:
             self.tmpfile,
         )
 
+    def _assert_state(self, op: str, states: Tuple[str, ...]) -> None:
+        assert (
+            self._state in states
+        ), "cannot {} in state {!r}: expected one of {}".format(
+            op, self._state, ", ".join(states)
+        )
+
     def start(self):
+        self._assert_state("start", ("initialized",))
         setattr(sys, self.name, self.tmpfile)
         self._state = "started"
 
     def snap(self):
+        self._assert_state("snap", ("started", "suspended"))
         self.tmpfile.seek(0)
         res = self.tmpfile.buffer.read()
         self.tmpfile.seek(0)
@@ -288,20 +298,28 @@ class SysCaptureBinary:
         return res
 
     def done(self):
+        self._assert_state("done", ("initialized", "started", "suspended", "done"))
+        if self._state == "done":
+            return
         setattr(sys, self.name, self._old)
         del self._old
         self.tmpfile.close()
         self._state = "done"
 
     def suspend(self):
+        self._assert_state("suspend", ("started", "suspended"))
         setattr(sys, self.name, self._old)
         self._state = "suspended"
 
     def resume(self):
+        self._assert_state("resume", ("started", "suspended"))
+        if self._state == "started":
+            return
         setattr(sys, self.name, self.tmpfile)
-        self._state = "resumed"
+        self._state = "started"
 
     def writeorg(self, data):
+        self._assert_state("writeorg", ("started", "suspended"))
         self._old.flush()
         self._old.buffer.write(data)
         self._old.buffer.flush()
@@ -317,6 +335,7 @@ class SysCapture(SysCaptureBinary):
         return res
 
     def writeorg(self, data):
+        self._assert_state("writeorg", ("started", "suspended"))
         self._old.write(data)
         self._old.flush()
 
@@ -328,7 +347,6 @@ class FDCaptureBinary:
     """
 
     EMPTY_BUFFER = b""
-    _state = None
 
     def __init__(self, targetfd):
         self.targetfd = targetfd
@@ -368,6 +386,8 @@ class FDCaptureBinary:
             else:
                 self.syscapture = NoCapture()
 
+        self._state = "initialized"
+
     def __repr__(self):
         return "<{} {} oldfd={} _state={!r} tmpfile={!r}>".format(
             self.__class__.__name__,
@@ -377,13 +397,22 @@ class FDCaptureBinary:
             self.tmpfile,
         )
 
+    def _assert_state(self, op: str, states: Tuple[str, ...]) -> None:
+        assert (
+            self._state in states
+        ), "cannot {} in state {!r}: expected one of {}".format(
+            op, self._state, ", ".join(states)
+        )
+
     def start(self):
         """ Start capturing on targetfd using memorized tmpfile. """
+        self._assert_state("start", ("initialized",))
         os.dup2(self.tmpfile.fileno(), self.targetfd)
         self.syscapture.start()
         self._state = "started"
 
     def snap(self):
+        self._assert_state("snap", ("started", "suspended"))
         self.tmpfile.seek(0)
         res = self.tmpfile.buffer.read()
         self.tmpfile.seek(0)
@@ -393,6 +422,9 @@ class FDCaptureBinary:
     def done(self):
         """ stop capturing, restore streams, return original capture file,
         seeked to position zero. """
+        self._assert_state("done", ("initialized", "started", "suspended", "done"))
+        if self._state == "done":
+            return
         os.dup2(self.targetfd_save, self.targetfd)
         os.close(self.targetfd_save)
         if self.targetfd_invalid is not None:
@@ -404,17 +436,24 @@ class FDCaptureBinary:
         self._state = "done"
 
     def suspend(self):
+        self._assert_state("suspend", ("started", "suspended"))
+        if self._state == "suspended":
+            return
         self.syscapture.suspend()
         os.dup2(self.targetfd_save, self.targetfd)
         self._state = "suspended"
 
     def resume(self):
+        self._assert_state("resume", ("started", "suspended"))
+        if self._state == "started":
+            return
         self.syscapture.resume()
         os.dup2(self.tmpfile.fileno(), self.targetfd)
-        self._state = "resumed"
+        self._state = "started"
 
     def writeorg(self, data):
         """ write to original file descriptor. """
+        self._assert_state("writeorg", ("started", "suspended"))
         os.write(self.targetfd_save, data)
 
 
@@ -428,6 +467,7 @@ class FDCapture(FDCaptureBinary):
     EMPTY_BUFFER = ""  # type: ignore
 
     def snap(self):
+        self._assert_state("snap", ("started", "suspended"))
         self.tmpfile.seek(0)
         res = self.tmpfile.read()
         self.tmpfile.seek(0)
diff --git a/testing/test_capture.py b/testing/test_capture.py
index 5a0998da7..95f2d748a 100644
--- a/testing/test_capture.py
+++ b/testing/test_capture.py
@@ -878,9 +878,8 @@ class TestFDCapture:
         cap = capture.FDCapture(fd)
         data = b"hello"
         os.write(fd, data)
-        s = cap.snap()
+        pytest.raises(AssertionError, cap.snap)
         cap.done()
-        assert not s
         cap = capture.FDCapture(fd)
         cap.start()
         os.write(fd, data)
@@ -901,7 +900,7 @@ class TestFDCapture:
         fd = tmpfile.fileno()
         cap = capture.FDCapture(fd)
         cap.done()
-        pytest.raises(ValueError, cap.start)
+        pytest.raises(AssertionError, cap.start)
 
     def test_stderr(self):
         cap = capture.FDCapture(2)
@@ -952,7 +951,7 @@ class TestFDCapture:
             assert s == "but now yes\n"
             cap.suspend()
             cap.done()
-            pytest.raises(AttributeError, cap.suspend)
+            pytest.raises(AssertionError, cap.suspend)
 
             assert repr(cap) == (
                 "<FDCapture 1 oldfd={} _state='done' tmpfile={!r}>".format(
@@ -1154,6 +1153,7 @@ class TestStdCaptureFD(TestStdCapture):
         with lsof_check():
             for i in range(10):
                 cap = StdCaptureFD()
+                cap.start_capturing()
                 cap.stop_capturing()
 
 
@@ -1175,7 +1175,7 @@ class TestStdCaptureFDinvalidFD:
             def test_stdout():
                 os.close(1)
                 cap = StdCaptureFD(out=True, err=False, in_=False)
-                assert fnmatch(repr(cap.out), "<FDCapture 1 oldfd=* _state=None tmpfile=*>")
+                assert fnmatch(repr(cap.out), "<FDCapture 1 oldfd=* _state='initialized' tmpfile=*>")
                 cap.start_capturing()
                 os.write(1, b"stdout")
                 assert cap.readouterr() == ("stdout", "")
@@ -1184,7 +1184,7 @@ class TestStdCaptureFDinvalidFD:
             def test_stderr():
                 os.close(2)
                 cap = StdCaptureFD(out=False, err=True, in_=False)
-                assert fnmatch(repr(cap.err), "<FDCapture 2 oldfd=* _state=None tmpfile=*>")
+                assert fnmatch(repr(cap.err), "<FDCapture 2 oldfd=* _state='initialized' tmpfile=*>")
                 cap.start_capturing()
                 os.write(2, b"stderr")
                 assert cap.readouterr() == ("", "stderr")
@@ -1193,7 +1193,7 @@ class TestStdCaptureFDinvalidFD:
             def test_stdin():
                 os.close(0)
                 cap = StdCaptureFD(out=False, err=False, in_=True)
-                assert fnmatch(repr(cap.in_), "<FDCapture 0 oldfd=* _state=None tmpfile=*>")
+                assert fnmatch(repr(cap.in_), "<FDCapture 0 oldfd=* _state='initialized' tmpfile=*>")
                 cap.stop_capturing()
         """
         )