diff --git a/py/io/stdcapture.py b/py/io/stdcapture.py index be14cce99..4812cdad1 100644 --- a/py/io/stdcapture.py +++ b/py/io/stdcapture.py @@ -4,6 +4,8 @@ import py try: from cStringIO import StringIO except ImportError: from StringIO import StringIO +emptyfile = StringIO() + class Capture(object): def call(cls, func, *args, **kwargs): """ return a (res, out, err) tuple where @@ -24,13 +26,17 @@ class StdCaptureFD(Capture): """ capture Stdout and Stderr both on filedescriptor and sys.stdout/stderr level. """ - def __init__(self, out=True, err=True, patchsys=True): + def __init__(self, out=True, err=True, mixed=False, patchsys=True): if out: self.out = py.io.FDCapture(1) if patchsys: self.out.setasfile('stdout') if err: - self.err = py.io.FDCapture(2) + if mixed and out: + tmpfile = self.out.tmpfile + else: + tmpfile = None + self.err = py.io.FDCapture(2, tmpfile=tmpfile) if patchsys: self.err.setasfile('stderr') @@ -40,14 +46,12 @@ class StdCaptureFD(Capture): returns a tuple of file objects (out, err) for the captured data """ - out = err = "" + outfile = errfile = emptyfile if hasattr(self, 'out'): outfile = self.out.done() - out = outfile.read() if hasattr(self, 'err'): errfile = self.err.done() - err = errfile.read() - return out, err + return outfile.read(), errfile.read() class StdCapture(Capture): """ capture sys.stdout/sys.stderr (but not system level fd 1 and 2). @@ -55,13 +59,21 @@ class StdCapture(Capture): This class allows to capture writes to sys.stdout|stderr "in-memory" and will raise errors on tries to read from sys.stdin. """ - def __init__(self): + def __init__(self, out=True, err=True, mixed=False): + self._out = out + self._err = err + if out: + self.oldout = sys.stdout + sys.stdout = self.newout = StringIO() + if err: + self.olderr = sys.stderr + if out and mixed: + newerr = self.newout + else: + newerr = StringIO() + sys.stderr = self.newerr = newerr self.oldin = sys.stdin - self.oldout = sys.stdout - self.olderr = sys.stderr sys.stdin = self.newin = DontReadFromInput() - sys.stdout = self.newout = StringIO() - sys.stderr = self.newerr = StringIO() def reset(self): """ return captured output and restore sys.stdout/err.""" @@ -70,13 +82,25 @@ class StdCapture(Capture): def done(self): o,e = sys.stdout, sys.stderr - sys.stdin, sys.stdout, sys.stderr = ( - self.oldin, self.oldout, self.olderr) - del self.oldin, self.oldout, self.olderr - o, e = self.newout, self.newerr - o.seek(0) - e.seek(0) - return o,e + outfile = errfile = emptyfile + if self._out: + try: + sys.stdout = self.oldout + except AttributeError: + raise IOError("stdout capturing already reset") + del self.oldout + outfile = self.newout + outfile.seek(0) + if self._err: + try: + sys.stderr = self.olderr + except AttributeError: + raise IOError("stderr capturing already reset") + del self.olderr + errfile = self.newerr + errfile.seek(0) + sys.stdin = self.oldin + return outfile, errfile class DontReadFromInput: """Temporary stub class. Ideally when stdin is accessed, the diff --git a/py/io/test/test_capture.py b/py/io/test/test_capture.py index 09357f56c..d2bd70f31 100644 --- a/py/io/test/test_capture.py +++ b/py/io/test/test_capture.py @@ -57,9 +57,9 @@ class TestFDCapture: tmpfp.close() f = cap.done() -class TestCapturing: - def getcapture(self): - return py.io.StdCaptureFD() +class TestStdCapture: + def getcapture(self, **kw): + return py.io.StdCapture(**kw) def test_capturing_simple(self): cap = self.getcapture() @@ -69,6 +69,15 @@ class TestCapturing: assert out == "hello world\n" assert err == "hello error\n" + def test_capturing_mixed(self): + cap = self.getcapture(mixed=True) + print "hello", + print >>sys.stderr, "world", + print >>sys.stdout, ".", + out, err = cap.reset() + assert out.strip() == "hello world ." + assert not err + def test_capturing_twice_error(self): cap = self.getcapture() print "hello" @@ -101,6 +110,26 @@ class TestCapturing: out1, err1 = cap1.reset() assert out1 == "cap1\n" assert out2 == "cap2\n" + + def test_just_out_capture(self): + cap = self.getcapture(out=True, err=False) + print >>sys.stdout, "hello" + print >>sys.stderr, "world" + out, err = cap.reset() + assert out == "hello\n" + assert not err + + def test_just_err_capture(self): + cap = self.getcapture(out=False, err=True) + print >>sys.stdout, "hello" + print >>sys.stderr, "world" + out, err = cap.reset() + assert err == "world\n" + assert not out + +class TestStdCaptureFD(TestStdCapture): + def getcapture(self, **kw): + return py.io.StdCaptureFD(**kw) def test_intermingling(self): cap = self.getcapture() @@ -114,32 +143,16 @@ class TestCapturing: assert out == "123" assert err == "abc" -def test_callcapture(): - def func(x, y): - print x - print >>py.std.sys.stderr, y - return 42 - - res, out, err = py.io.StdCaptureFD.call(func, 3, y=4) - assert res == 42 - assert out.startswith("3") - assert err.startswith("4") - -def test_just_out_capture(): - cap = py.io.StdCaptureFD(out=True, err=False) - print >>sys.stdout, "hello" - print >>sys.stderr, "world" - out, err = cap.reset() - assert out == "hello\n" - assert not err - -def test_just_err_capture(): - cap = py.io.StdCaptureFD(out=False, err=True) - print >>sys.stdout, "hello" - print >>sys.stderr, "world" - out, err = cap.reset() - assert err == "world\n" - assert not out + def test_callcapture(self): + def func(x, y): + print x + print >>py.std.sys.stderr, y + return 42 + + res, out, err = py.io.StdCaptureFD.call(func, 3, y=4) + assert res == 42 + assert out.startswith("3") + assert err.startswith("4") def test_capture_no_sys(): cap = py.io.StdCaptureFD(patchsys=False) @@ -151,6 +164,15 @@ def test_capture_no_sys(): assert out == "1" assert err == "2" -#class TestCapturingOnFDs(TestCapturingOnSys): -# def getcapture(self): -# return Capture() +def test_callcapture_nofd(): + def func(x, y): + os.write(1, "hello") + os.write(2, "hello") + print x + print >>py.std.sys.stderr, y + return 42 + + res, out, err = py.io.StdCapture.call(func, 3, y=4) + assert res == 42 + assert out.startswith("3") + assert err.startswith("4") diff --git a/py/io/test/test_simplecapture.py b/py/io/test/test_simplecapture.py deleted file mode 100644 index ff6547e77..000000000 --- a/py/io/test/test_simplecapture.py +++ /dev/null @@ -1,67 +0,0 @@ -import os, sys -import py - -class TestCapturingOnSys: - - def getcapture(self): - return py.io.StdCapture() - - def test_capturing_simple(self): - cap = self.getcapture() - print "hello world" - print >>sys.stderr, "hello error" - out, err = cap.reset() - assert out == "hello world\n" - assert err == "hello error\n" - - def test_capturing_twice_error(self): - cap = self.getcapture() - print "hello" - cap.reset() - py.test.raises(AttributeError, "cap.reset()") - - def test_capturing_modify_sysouterr_in_between(self): - oldout = sys.stdout - olderr = sys.stderr - cap = self.getcapture() - print "hello", - print >>sys.stderr, "world", - sys.stdout = py.std.StringIO.StringIO() - sys.stderr = py.std.StringIO.StringIO() - print "not seen" - print >>sys.stderr, "not seen" - out, err = cap.reset() - assert out == "hello" - assert err == "world" - assert sys.stdout == oldout - assert sys.stderr == olderr - - def test_capturing_error_recursive(self): - cap1 = self.getcapture() - print "cap1" - cap2 = self.getcapture() - print "cap2" - out2, err2 = cap2.reset() - py.test.raises(AttributeError, "cap2.reset()") - out1, err1 = cap1.reset() - assert out1 == "cap1\n" - assert out2 == "cap2\n" - - def test_reading_stdin_while_captured_doesnt_hang(self): - cap = self.getcapture() - try: - py.test.raises(IOError, raw_input) - finally: - cap.reset() - -def test_callcapture_nofd(): - def func(x, y): - print x - print >>py.std.sys.stderr, y - return 42 - - res, out, err = py.io.StdCapture.call(func, 3, y=4) - assert res == 42 - assert out.startswith("3") - assert err.startswith("4") -