avoid some redundancy by using SysCapture from FDCapture for manipulating sys.std{out,in,err}
This commit is contained in:
parent
69cbac8fb5
commit
ca5e6830c6
|
@ -333,6 +333,8 @@ class MultiCapture(object):
|
|||
return (self.out.snap() if self.out is not None else "",
|
||||
self.err.snap() if self.err is not None else "")
|
||||
|
||||
class NoCapture:
|
||||
__init__ = start = done = lambda *args: None
|
||||
|
||||
class FDCapture:
|
||||
""" Capture IO to/from a given os-level filedescriptor. """
|
||||
|
@ -340,36 +342,38 @@ class FDCapture:
|
|||
def __init__(self, targetfd, tmpfile=None):
|
||||
self.targetfd = targetfd
|
||||
try:
|
||||
self._savefd = os.dup(self.targetfd)
|
||||
self.targetfd_save = os.dup(self.targetfd)
|
||||
except OSError:
|
||||
self.start = lambda: None
|
||||
self.done = lambda: None
|
||||
else:
|
||||
if tmpfile is None:
|
||||
if targetfd == 0:
|
||||
tmpfile = open(os.devnull, "r")
|
||||
else:
|
||||
if targetfd == 0:
|
||||
assert not tmpfile, "cannot set tmpfile with stdin"
|
||||
tmpfile = open(os.devnull, "r")
|
||||
self.syscapture = SysCapture(targetfd)
|
||||
else:
|
||||
if tmpfile is None:
|
||||
f = TemporaryFile()
|
||||
with f:
|
||||
tmpfile = safe_text_dupfile(f, mode="wb+")
|
||||
if targetfd in patchsysdict:
|
||||
self.syscapture = SysCapture(targetfd, tmpfile)
|
||||
else:
|
||||
self.syscapture = NoCapture()
|
||||
self.tmpfile = tmpfile
|
||||
if targetfd in patchsysdict:
|
||||
self._oldsys = getattr(sys, patchsysdict[targetfd])
|
||||
self.tmpfile_fd = tmpfile.fileno()
|
||||
|
||||
def __repr__(self):
|
||||
return "<FDCapture %s oldfd=%s>" % (self.targetfd, self._savefd)
|
||||
return "<FDCapture %s oldfd=%s>" % (self.targetfd, self.targetfd_save)
|
||||
|
||||
def start(self):
|
||||
""" Start capturing on targetfd using memorized tmpfile. """
|
||||
try:
|
||||
os.fstat(self._savefd)
|
||||
except OSError:
|
||||
os.fstat(self.targetfd_save)
|
||||
except (AttributeError, OSError):
|
||||
raise ValueError("saved filedescriptor not valid anymore")
|
||||
targetfd = self.targetfd
|
||||
os.dup2(self.tmpfile.fileno(), targetfd)
|
||||
if hasattr(self, '_oldsys'):
|
||||
subst = self.tmpfile if targetfd != 0 else DontReadFromInput()
|
||||
setattr(sys, patchsysdict[targetfd], subst)
|
||||
os.dup2(self.tmpfile_fd, self.targetfd)
|
||||
self.syscapture.start()
|
||||
|
||||
def snap(self):
|
||||
f = self.tmpfile
|
||||
|
@ -386,28 +390,38 @@ class FDCapture:
|
|||
def done(self):
|
||||
""" stop capturing, restore streams, return original capture file,
|
||||
seeked to position zero. """
|
||||
os.dup2(self._savefd, self.targetfd)
|
||||
os.close(self._savefd)
|
||||
if hasattr(self, '_oldsys'):
|
||||
setattr(sys, patchsysdict[self.targetfd], self._oldsys)
|
||||
targetfd_save = self.__dict__.pop("targetfd_save")
|
||||
os.dup2(targetfd_save, self.targetfd)
|
||||
os.close(targetfd_save)
|
||||
self.syscapture.done()
|
||||
self.tmpfile.close()
|
||||
|
||||
def suspend(self):
|
||||
self.syscapture.suspend()
|
||||
os.dup2(self.targetfd_save, self.targetfd)
|
||||
|
||||
def resume(self):
|
||||
self.syscapture.resume()
|
||||
os.dup2(self.tmpfile_fd, self.targetfd)
|
||||
|
||||
def writeorg(self, data):
|
||||
""" write to original file descriptor. """
|
||||
if py.builtin._istext(data):
|
||||
data = data.encode("utf8") # XXX use encoding of original stream
|
||||
os.write(self._savefd, data)
|
||||
os.write(self.targetfd_save, data)
|
||||
|
||||
|
||||
class SysCapture:
|
||||
def __init__(self, fd):
|
||||
def __init__(self, fd, tmpfile=None):
|
||||
name = patchsysdict[fd]
|
||||
self._old = getattr(sys, name)
|
||||
self.name = name
|
||||
if name == "stdin":
|
||||
self.tmpfile = DontReadFromInput()
|
||||
else:
|
||||
self.tmpfile = TextIO()
|
||||
if tmpfile is None:
|
||||
if name == "stdin":
|
||||
tmpfile = DontReadFromInput()
|
||||
else:
|
||||
tmpfile = TextIO()
|
||||
self.tmpfile = tmpfile
|
||||
|
||||
def start(self):
|
||||
setattr(sys, self.name, self.tmpfile)
|
||||
|
@ -421,8 +435,15 @@ class SysCapture:
|
|||
|
||||
def done(self):
|
||||
setattr(sys, self.name, self._old)
|
||||
del self._old
|
||||
self.tmpfile.close()
|
||||
|
||||
def suspend(self):
|
||||
setattr(sys, self.name, self._old)
|
||||
|
||||
def resume(self):
|
||||
setattr(sys, self.name, self.tmpfile)
|
||||
|
||||
def writeorg(self, data):
|
||||
self._old.write(data)
|
||||
self._old.flush()
|
||||
|
|
|
@ -726,15 +726,11 @@ class TestFDCapture:
|
|||
assert s == "hello\n"
|
||||
|
||||
def test_stdin(self, tmpfile):
|
||||
tmpfile.write(tobytes("3"))
|
||||
tmpfile.seek(0)
|
||||
cap = capture.FDCapture(0, tmpfile)
|
||||
cap = capture.FDCapture(0)
|
||||
cap.start()
|
||||
# check with os.read() directly instead of raw_input(), because
|
||||
# sys.stdin itself may be redirected (as pytest now does by default)
|
||||
x = os.read(0, 100).strip()
|
||||
cap.done()
|
||||
assert x == tobytes("3")
|
||||
assert x == tobytes('')
|
||||
|
||||
def test_writeorg(self, tmpfile):
|
||||
data1, data2 = tobytes("foo"), tobytes("bar")
|
||||
|
@ -751,6 +747,36 @@ class TestFDCapture:
|
|||
stmp = open(tmpfile.name, 'rb').read()
|
||||
assert stmp == data2
|
||||
|
||||
def test_simple_resume_suspend(self, tmpfile):
|
||||
with saved_fd(1):
|
||||
cap = capture.FDCapture(1)
|
||||
cap.start()
|
||||
data = tobytes("hello")
|
||||
os.write(1, data)
|
||||
sys.stdout.write("whatever")
|
||||
s = cap.snap()
|
||||
assert s == "hellowhatever"
|
||||
cap.suspend()
|
||||
os.write(1, tobytes("world"))
|
||||
sys.stdout.write("qlwkej")
|
||||
assert not cap.snap()
|
||||
cap.resume()
|
||||
os.write(1, tobytes("but now"))
|
||||
sys.stdout.write(" yes\n")
|
||||
s = cap.snap()
|
||||
assert s == "but now yes\n"
|
||||
cap.suspend()
|
||||
cap.done()
|
||||
pytest.raises(AttributeError, cap.suspend)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def saved_fd(fd):
|
||||
new_fd = os.dup(fd)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.dup2(new_fd, fd)
|
||||
|
||||
|
||||
class TestStdCapture:
|
||||
captureclass = staticmethod(StdCapture)
|
||||
|
|
Loading…
Reference in New Issue