avoid some redundancy by using SysCapture from FDCapture for manipulating sys.std{out,in,err}

This commit is contained in:
holger krekel 2014-04-01 14:19:55 +02:00
parent 69cbac8fb5
commit ca5e6830c6
2 changed files with 78 additions and 31 deletions

View File

@ -333,6 +333,8 @@ class MultiCapture(object):
return (self.out.snap() if self.out is not None else "",
self.err.snap() if self.err is not None else "")
class NoCapture:
__init__ = start = done = lambda *args: None
class FDCapture:
""" Capture IO to/from a given os-level filedescriptor. """
@ -340,36 +342,38 @@ class FDCapture:
def __init__(self, targetfd, tmpfile=None):
self.targetfd = targetfd
try:
self._savefd = os.dup(self.targetfd)
self.targetfd_save = os.dup(self.targetfd)
except OSError:
self.start = lambda: None
self.done = lambda: None
else:
if tmpfile is None:
if targetfd == 0:
tmpfile = open(os.devnull, "r")
else:
if targetfd == 0:
assert not tmpfile, "cannot set tmpfile with stdin"
tmpfile = open(os.devnull, "r")
self.syscapture = SysCapture(targetfd)
else:
if tmpfile is None:
f = TemporaryFile()
with f:
tmpfile = safe_text_dupfile(f, mode="wb+")
if targetfd in patchsysdict:
self.syscapture = SysCapture(targetfd, tmpfile)
else:
self.syscapture = NoCapture()
self.tmpfile = tmpfile
if targetfd in patchsysdict:
self._oldsys = getattr(sys, patchsysdict[targetfd])
self.tmpfile_fd = tmpfile.fileno()
def __repr__(self):
return "<FDCapture %s oldfd=%s>" % (self.targetfd, self._savefd)
return "<FDCapture %s oldfd=%s>" % (self.targetfd, self.targetfd_save)
def start(self):
""" Start capturing on targetfd using memorized tmpfile. """
try:
os.fstat(self._savefd)
except OSError:
os.fstat(self.targetfd_save)
except (AttributeError, OSError):
raise ValueError("saved filedescriptor not valid anymore")
targetfd = self.targetfd
os.dup2(self.tmpfile.fileno(), targetfd)
if hasattr(self, '_oldsys'):
subst = self.tmpfile if targetfd != 0 else DontReadFromInput()
setattr(sys, patchsysdict[targetfd], subst)
os.dup2(self.tmpfile_fd, self.targetfd)
self.syscapture.start()
def snap(self):
f = self.tmpfile
@ -386,28 +390,38 @@ class FDCapture:
def done(self):
""" stop capturing, restore streams, return original capture file,
seeked to position zero. """
os.dup2(self._savefd, self.targetfd)
os.close(self._savefd)
if hasattr(self, '_oldsys'):
setattr(sys, patchsysdict[self.targetfd], self._oldsys)
targetfd_save = self.__dict__.pop("targetfd_save")
os.dup2(targetfd_save, self.targetfd)
os.close(targetfd_save)
self.syscapture.done()
self.tmpfile.close()
def suspend(self):
self.syscapture.suspend()
os.dup2(self.targetfd_save, self.targetfd)
def resume(self):
self.syscapture.resume()
os.dup2(self.tmpfile_fd, self.targetfd)
def writeorg(self, data):
""" write to original file descriptor. """
if py.builtin._istext(data):
data = data.encode("utf8") # XXX use encoding of original stream
os.write(self._savefd, data)
os.write(self.targetfd_save, data)
class SysCapture:
def __init__(self, fd):
def __init__(self, fd, tmpfile=None):
name = patchsysdict[fd]
self._old = getattr(sys, name)
self.name = name
if name == "stdin":
self.tmpfile = DontReadFromInput()
else:
self.tmpfile = TextIO()
if tmpfile is None:
if name == "stdin":
tmpfile = DontReadFromInput()
else:
tmpfile = TextIO()
self.tmpfile = tmpfile
def start(self):
setattr(sys, self.name, self.tmpfile)
@ -421,8 +435,15 @@ class SysCapture:
def done(self):
setattr(sys, self.name, self._old)
del self._old
self.tmpfile.close()
def suspend(self):
setattr(sys, self.name, self._old)
def resume(self):
setattr(sys, self.name, self.tmpfile)
def writeorg(self, data):
self._old.write(data)
self._old.flush()

View File

@ -726,15 +726,11 @@ class TestFDCapture:
assert s == "hello\n"
def test_stdin(self, tmpfile):
tmpfile.write(tobytes("3"))
tmpfile.seek(0)
cap = capture.FDCapture(0, tmpfile)
cap = capture.FDCapture(0)
cap.start()
# check with os.read() directly instead of raw_input(), because
# sys.stdin itself may be redirected (as pytest now does by default)
x = os.read(0, 100).strip()
cap.done()
assert x == tobytes("3")
assert x == tobytes('')
def test_writeorg(self, tmpfile):
data1, data2 = tobytes("foo"), tobytes("bar")
@ -751,7 +747,37 @@ class TestFDCapture:
stmp = open(tmpfile.name, 'rb').read()
assert stmp == data2
def test_simple_resume_suspend(self, tmpfile):
with saved_fd(1):
cap = capture.FDCapture(1)
cap.start()
data = tobytes("hello")
os.write(1, data)
sys.stdout.write("whatever")
s = cap.snap()
assert s == "hellowhatever"
cap.suspend()
os.write(1, tobytes("world"))
sys.stdout.write("qlwkej")
assert not cap.snap()
cap.resume()
os.write(1, tobytes("but now"))
sys.stdout.write(" yes\n")
s = cap.snap()
assert s == "but now yes\n"
cap.suspend()
cap.done()
pytest.raises(AttributeError, cap.suspend)
@contextlib.contextmanager
def saved_fd(fd):
new_fd = os.dup(fd)
try:
yield
finally:
os.dup2(new_fd, fd)
class TestStdCapture:
captureclass = staticmethod(StdCapture)