avoid some redundancy by using SysCapture from FDCapture for manipulating sys.std{out,in,err}

This commit is contained in:
holger krekel 2014-04-01 14:19:55 +02:00
parent 69cbac8fb5
commit ca5e6830c6
2 changed files with 78 additions and 31 deletions

View File

@ -333,6 +333,8 @@ class MultiCapture(object):
return (self.out.snap() if self.out is not None else "", return (self.out.snap() if self.out is not None else "",
self.err.snap() if self.err is not None else "") self.err.snap() if self.err is not None else "")
class NoCapture:
__init__ = start = done = lambda *args: None
class FDCapture: class FDCapture:
""" Capture IO to/from a given os-level filedescriptor. """ """ Capture IO to/from a given os-level filedescriptor. """
@ -340,36 +342,38 @@ class FDCapture:
def __init__(self, targetfd, tmpfile=None): def __init__(self, targetfd, tmpfile=None):
self.targetfd = targetfd self.targetfd = targetfd
try: try:
self._savefd = os.dup(self.targetfd) self.targetfd_save = os.dup(self.targetfd)
except OSError: except OSError:
self.start = lambda: None self.start = lambda: None
self.done = lambda: None self.done = lambda: None
else: else:
if tmpfile is None:
if targetfd == 0: if targetfd == 0:
assert not tmpfile, "cannot set tmpfile with stdin"
tmpfile = open(os.devnull, "r") tmpfile = open(os.devnull, "r")
self.syscapture = SysCapture(targetfd)
else: else:
if tmpfile is None:
f = TemporaryFile() f = TemporaryFile()
with f: with f:
tmpfile = safe_text_dupfile(f, mode="wb+") tmpfile = safe_text_dupfile(f, mode="wb+")
self.tmpfile = tmpfile
if targetfd in patchsysdict: if targetfd in patchsysdict:
self._oldsys = getattr(sys, patchsysdict[targetfd]) self.syscapture = SysCapture(targetfd, tmpfile)
else:
self.syscapture = NoCapture()
self.tmpfile = tmpfile
self.tmpfile_fd = tmpfile.fileno()
def __repr__(self): def __repr__(self):
return "<FDCapture %s oldfd=%s>" % (self.targetfd, self._savefd) return "<FDCapture %s oldfd=%s>" % (self.targetfd, self.targetfd_save)
def start(self): def start(self):
""" Start capturing on targetfd using memorized tmpfile. """ """ Start capturing on targetfd using memorized tmpfile. """
try: try:
os.fstat(self._savefd) os.fstat(self.targetfd_save)
except OSError: except (AttributeError, OSError):
raise ValueError("saved filedescriptor not valid anymore") raise ValueError("saved filedescriptor not valid anymore")
targetfd = self.targetfd os.dup2(self.tmpfile_fd, self.targetfd)
os.dup2(self.tmpfile.fileno(), targetfd) self.syscapture.start()
if hasattr(self, '_oldsys'):
subst = self.tmpfile if targetfd != 0 else DontReadFromInput()
setattr(sys, patchsysdict[targetfd], subst)
def snap(self): def snap(self):
f = self.tmpfile f = self.tmpfile
@ -386,28 +390,38 @@ class FDCapture:
def done(self): def done(self):
""" stop capturing, restore streams, return original capture file, """ stop capturing, restore streams, return original capture file,
seeked to position zero. """ seeked to position zero. """
os.dup2(self._savefd, self.targetfd) targetfd_save = self.__dict__.pop("targetfd_save")
os.close(self._savefd) os.dup2(targetfd_save, self.targetfd)
if hasattr(self, '_oldsys'): os.close(targetfd_save)
setattr(sys, patchsysdict[self.targetfd], self._oldsys) self.syscapture.done()
self.tmpfile.close() self.tmpfile.close()
def suspend(self):
self.syscapture.suspend()
os.dup2(self.targetfd_save, self.targetfd)
def resume(self):
self.syscapture.resume()
os.dup2(self.tmpfile_fd, self.targetfd)
def writeorg(self, data): def writeorg(self, data):
""" write to original file descriptor. """ """ write to original file descriptor. """
if py.builtin._istext(data): if py.builtin._istext(data):
data = data.encode("utf8") # XXX use encoding of original stream data = data.encode("utf8") # XXX use encoding of original stream
os.write(self._savefd, data) os.write(self.targetfd_save, data)
class SysCapture: class SysCapture:
def __init__(self, fd): def __init__(self, fd, tmpfile=None):
name = patchsysdict[fd] name = patchsysdict[fd]
self._old = getattr(sys, name) self._old = getattr(sys, name)
self.name = name self.name = name
if tmpfile is None:
if name == "stdin": if name == "stdin":
self.tmpfile = DontReadFromInput() tmpfile = DontReadFromInput()
else: else:
self.tmpfile = TextIO() tmpfile = TextIO()
self.tmpfile = tmpfile
def start(self): def start(self):
setattr(sys, self.name, self.tmpfile) setattr(sys, self.name, self.tmpfile)
@ -421,8 +435,15 @@ class SysCapture:
def done(self): def done(self):
setattr(sys, self.name, self._old) setattr(sys, self.name, self._old)
del self._old
self.tmpfile.close() self.tmpfile.close()
def suspend(self):
setattr(sys, self.name, self._old)
def resume(self):
setattr(sys, self.name, self.tmpfile)
def writeorg(self, data): def writeorg(self, data):
self._old.write(data) self._old.write(data)
self._old.flush() self._old.flush()

View File

@ -726,15 +726,11 @@ class TestFDCapture:
assert s == "hello\n" assert s == "hello\n"
def test_stdin(self, tmpfile): def test_stdin(self, tmpfile):
tmpfile.write(tobytes("3")) cap = capture.FDCapture(0)
tmpfile.seek(0)
cap = capture.FDCapture(0, tmpfile)
cap.start() cap.start()
# check with os.read() directly instead of raw_input(), because
# sys.stdin itself may be redirected (as pytest now does by default)
x = os.read(0, 100).strip() x = os.read(0, 100).strip()
cap.done() cap.done()
assert x == tobytes("3") assert x == tobytes('')
def test_writeorg(self, tmpfile): def test_writeorg(self, tmpfile):
data1, data2 = tobytes("foo"), tobytes("bar") data1, data2 = tobytes("foo"), tobytes("bar")
@ -751,6 +747,36 @@ class TestFDCapture:
stmp = open(tmpfile.name, 'rb').read() stmp = open(tmpfile.name, 'rb').read()
assert stmp == data2 assert stmp == data2
def test_simple_resume_suspend(self, tmpfile):
with saved_fd(1):
cap = capture.FDCapture(1)
cap.start()
data = tobytes("hello")
os.write(1, data)
sys.stdout.write("whatever")
s = cap.snap()
assert s == "hellowhatever"
cap.suspend()
os.write(1, tobytes("world"))
sys.stdout.write("qlwkej")
assert not cap.snap()
cap.resume()
os.write(1, tobytes("but now"))
sys.stdout.write(" yes\n")
s = cap.snap()
assert s == "but now yes\n"
cap.suspend()
cap.done()
pytest.raises(AttributeError, cap.suspend)
@contextlib.contextmanager
def saved_fd(fd):
new_fd = os.dup(fd)
try:
yield
finally:
os.dup2(new_fd, fd)
class TestStdCapture: class TestStdCapture:
captureclass = staticmethod(StdCapture) captureclass = staticmethod(StdCapture)