diff --git a/py/io/stdcapture.py b/py/io/stdcapture.py index deabb7133..40a6e3175 100644 --- a/py/io/stdcapture.py +++ b/py/io/stdcapture.py @@ -4,7 +4,6 @@ import py try: from cStringIO import StringIO except ImportError: from StringIO import StringIO -emptyfile = StringIO() class Capture(object): def call(cls, func, *args, **kwargs): @@ -33,10 +32,17 @@ class Capture(object): class StdCaptureFD(Capture): - """ capture Stdout and Stderr both on filedescriptor - and sys.stdout/stderr level. + """ This class allows to capture writes to FD1 and FD2 + and may connect a NULL file to FD0 (and prevent + reads from sys.stdin) """ - def __init__(self, out=True, err=True, mixed=False, patchsys=True): + def __init__(self, out=True, err=True, mixed=False, in_=True, patchsys=True): + if in_: + self._oldin = (sys.stdin, os.dup(0)) + sys.stdin = DontReadFromInput() + fd = os.open(devnullpath, os.O_RDONLY) + os.dup2(fd, 0) + os.close(fd) if out: self.out = py.io.FDCapture(1) if patchsys: @@ -57,17 +63,23 @@ class StdCaptureFD(Capture): outfile = self.out.done() if hasattr(self, 'err'): errfile = self.err.done() + if hasattr(self, '_oldin'): + oldsys, oldfd = self._oldin + os.dup2(oldfd, 0) + os.close(oldfd) + sys.stdin = oldsys return outfile, errfile class StdCapture(Capture): - """ capture sys.stdout/sys.stderr (but not system level fd 1 and 2). - - This class allows to capture writes to sys.stdout|stderr "in-memory" - and will raise errors on tries to read from sys.stdin. + """ This class allows to capture writes to sys.stdout|stderr "in-memory" + and will raise errors on tries to read from sys.stdin. It only + modifies sys.stdout|stderr|stdin attributes and does not + touch underlying File Descriptors (use StdCaptureFD for that). """ - def __init__(self, out=True, err=True, mixed=False): + def __init__(self, out=True, err=True, in_=True, mixed=False): self._out = out self._err = err + self._in = in_ if out: self.oldout = sys.stdout sys.stdout = self.newout = StringIO() @@ -78,8 +90,9 @@ class StdCapture(Capture): else: newerr = StringIO() sys.stderr = self.newerr = newerr - self.oldin = sys.stdin - sys.stdin = self.newin = DontReadFromInput() + if in_: + self.oldin = sys.stdin + sys.stdin = self.newin = DontReadFromInput() def reset(self): """ return captured output as strings and restore sys.stdout/err.""" @@ -106,7 +119,8 @@ class StdCapture(Capture): del self.olderr errfile = self.newerr errfile.seek(0) - sys.stdin = self.oldin + if self._in: + sys.stdin = self.oldin return outfile, errfile class DontReadFromInput: @@ -121,3 +135,14 @@ class DontReadFromInput: readline = read readlines = read __iter__ = read + +try: + devnullpath = os.devnull +except AttributeError: + if os.name == 'nt': + devnullpath = 'NUL' + else: + devnullpath = '/dev/null' + +emptyfile = StringIO() + diff --git a/py/io/test/test_stdcapture.py b/py/io/test/test_stdcapture.py index c2f815e99..a83d9d6ac 100644 --- a/py/io/test/test_stdcapture.py +++ b/py/io/test/test_stdcapture.py @@ -79,6 +79,22 @@ class TestStdCapture: assert err == "world\n" assert not out + def test_stdin_restored(self): + old = sys.stdin + cap = self.getcapture(in_=True) + newstdin = sys.stdin + out, err = cap.reset() + assert newstdin != sys.stdin + assert sys.stdin is old + + def test_stdin_nulled_by_default(self): + print "XXX this test may well hang instead of crashing" + print "XXX which indicates an error in the underlying capturing" + print "XXX mechanisms" + cap = self.getcapture() + py.test.raises(IOError, "sys.stdin.read()") + out, err = cap.reset() + class TestStdCaptureFD(TestStdCapture): def getcapture(self, **kw): return py.io.StdCaptureFD(**kw)