diff --git a/py/_io/capture.py b/py/_io/capture.py index d715d85bb..09157fbca 100644 --- a/py/_io/capture.py +++ b/py/_io/capture.py @@ -26,23 +26,26 @@ except ImportError: raise TypeError("not a byte value: %r" %(data,)) StringIO.write(self, data) +patchsysdict = {0: 'stdin', 1: 'stdout', 2: 'stderr'} + class FDCapture: """ Capture IO to/from a given os-level filedescriptor. """ - def __init__(self, targetfd, tmpfile=None, now=True): + def __init__(self, targetfd, tmpfile=None, now=True, patchsys=False): """ save targetfd descriptor, and open a new temporary file there. If no tmpfile is specified a tempfile.Tempfile() will be opened in text mode. """ self.targetfd = targetfd - self._patched = [] if tmpfile is None: f = tempfile.TemporaryFile('wb+') tmpfile = dupfile(f, encoding="UTF-8") f.close() self.tmpfile = tmpfile self._savefd = os.dup(self.targetfd) + if patchsys: + self._oldsys = getattr(sys, patchsysdict[targetfd]) if now: self.start() @@ -53,20 +56,8 @@ class FDCapture: raise ValueError("saved filedescriptor not valid, " "did you call start() twice?") os.dup2(self.tmpfile.fileno(), self.targetfd) - - def setasfile(self, name, module=sys): - """ patch . to self.tmpfile - """ - key = (module, name) - self._patched.append((key, getattr(module, name))) - setattr(module, name, self.tmpfile) - - def unsetfiles(self): - """ unpatch all patched items - """ - while self._patched: - (module, name), value = self._patched.pop() - setattr(module, name, value) + if hasattr(self, '_oldsys'): + setattr(sys, patchsysdict[self.targetfd], self.tmpfile) def done(self): """ unpatch and clean up, returns the self.tmpfile (file object) @@ -74,7 +65,8 @@ class FDCapture: os.dup2(self._savefd, self.targetfd) os.close(self._savefd) self.tmpfile.seek(0) - self.unsetfiles() + if hasattr(self, '_oldsys'): + setattr(sys, patchsysdict[self.targetfd], self._oldsys) return self.tmpfile def writeorg(self, data): @@ -182,7 +174,6 @@ class StdCaptureFD(Capture): in_=True, patchsys=True, now=True): self._options = locals() self._save() - self.patchsys = patchsys if now: self.startall() @@ -191,10 +182,17 @@ class StdCaptureFD(Capture): out = self._options['out'] err = self._options['err'] mixed = self._options['mixed'] - self.in_ = in_ + patchsys = self._options['patchsys'] if in_: + if hasattr(in_, 'read'): + tmpfile = in_ + else: + fd = os.open(devnullpath, os.O_RDONLY) + tmpfile = os.fdopen(fd) try: - self._oldin = (sys.stdin, os.dup(0)) + self.in_ = FDCapture(0, tmpfile=tmpfile, now=False, + patchsys=patchsys) + self._options['in_'] = self.in_.tmpfile except OSError: pass if out: @@ -202,7 +200,8 @@ class StdCaptureFD(Capture): if hasattr(out, 'write'): tmpfile = out try: - self.out = FDCapture(1, tmpfile=tmpfile, now=False) + self.out = FDCapture(1, tmpfile=tmpfile, + now=False, patchsys=patchsys) self._options['out'] = self.out.tmpfile except OSError: pass @@ -214,28 +213,23 @@ class StdCaptureFD(Capture): else: tmpfile = None try: - self.err = FDCapture(2, tmpfile=tmpfile, now=False) + self.err = FDCapture(2, tmpfile=tmpfile, + now=False, patchsys=patchsys) self._options['err'] = self.err.tmpfile except OSError: pass def startall(self): - if self.in_: + if hasattr(self, 'in_'): + self.in_.start() sys.stdin = DontReadFromInput() - fd = os.open(devnullpath, os.O_RDONLY) - os.dup2(fd, 0) - os.close(fd) out = getattr(self, 'out', None) if out: out.start() - if self.patchsys: - out.setasfile('stdout') err = getattr(self, 'err', None) if err: err.start() - if self.patchsys: - err.setasfile('stderr') def resume(self): """ resume capturing with original temp files. """ @@ -248,11 +242,8 @@ class StdCaptureFD(Capture): outfile = self.out.done() if hasattr(self, 'err') and not self.err.tmpfile.closed: errfile = self.err.done() - if hasattr(self, '_oldin'): - oldsys, oldfd = self._oldin - os.dup2(oldfd, 0) - os.close(oldfd) - sys.stdin = oldsys + if hasattr(self, 'in_'): + tmpfile = self.in_.done() self._save() return outfile, errfile diff --git a/testing/io_/test_capture.py b/testing/io_/test_capture.py index a48701d5d..817103d42 100644 --- a/testing/io_/test_capture.py +++ b/testing/io_/test_capture.py @@ -144,8 +144,7 @@ class TestFDCapture: f.close() def test_stderr(self): - cap = py.io.FDCapture(2) - cap.setasfile('stderr') + cap = py.io.FDCapture(2, patchsys=True) print_("hello", file=sys.stderr) f = cap.done() s = f.read()