simplify reset/stop_capturing and fix capturing wrt to capturing simple os.write() calls

This commit is contained in:
holger krekel 2014-03-28 07:11:25 +01:00
parent e18c3ed494
commit a8f4f49a82
2 changed files with 103 additions and 98 deletions

View File

@ -6,7 +6,7 @@ from __future__ import with_statement
import sys import sys
import os import os
import tempfile from tempfile import TemporaryFile
import contextlib import contextlib
import py import py
@ -101,12 +101,6 @@ def pytest_load_initial_conftests(early_config, parser, args, __multicall__):
def maketmpfile():
f = py.std.tempfile.TemporaryFile()
newf = dupfile(f, encoding="UTF-8")
f.close()
return newf
class CaptureManager: class CaptureManager:
def __init__(self, defaultmethod=None): def __init__(self, defaultmethod=None):
self._method2capture = {} self._method2capture = {}
@ -137,7 +131,7 @@ class CaptureManager:
def reset_capturings(self): def reset_capturings(self):
for cap in self._method2capture.values(): for cap in self._method2capture.values():
cap.pop_outerr_to_orig() cap.pop_outerr_to_orig()
cap.reset() cap.stop_capturing()
self._method2capture.clear() self._method2capture.clear()
def resumecapture_item(self, item): def resumecapture_item(self, item):
@ -274,15 +268,16 @@ def pytest_funcarg__capfd(request):
class CaptureFixture: class CaptureFixture:
def __init__(self, captureclass): def __init__(self, captureclass):
self._capture = StdCaptureBase(out=True, err=True, in_=False, self.captureclass = captureclass
Capture=captureclass)
def _start(self): def _start(self):
self._capture = StdCaptureBase(out=True, err=True, in_=False,
Capture=self.captureclass)
self._capture.start_capturing() self._capture.start_capturing()
def _finalize(self): def _finalize(self):
if hasattr(self, '_capture'): if hasattr(self, '_capture'):
outerr = self._outerr = self._capture.reset() outerr = self._outerr = self._capture.stop_capturing()
del self._capture del self._capture
return outerr return outerr
@ -355,21 +350,6 @@ class StdCaptureBase(object):
if err: if err:
self.err = Capture(2) self.err = Capture(2)
def reset(self):
""" reset sys.stdout/stderr and return captured output as strings. """
if hasattr(self, '_reset'):
raise ValueError("was already reset")
self._reset = True
outfile, errfile = self.stop_capturing()
out, err = "", ""
if outfile and not outfile.closed:
out = outfile.read()
outfile.close()
if errfile and errfile != outfile and not errfile.closed:
err = errfile.read()
errfile.close()
return out, err
def start_capturing(self): def start_capturing(self):
if self.in_: if self.in_:
self.in_.start() self.in_.start()
@ -378,17 +358,6 @@ class StdCaptureBase(object):
if self.err: if self.err:
self.err.start() self.err.start()
def stop_capturing(self):
""" return (outfile, errfile) and stop capturing. """
outfile = errfile = None
if self.out:
outfile = self.out.done()
if self.err:
errfile = self.err.done()
if self.in_:
self.in_.done()
return outfile, errfile
def pop_outerr_to_orig(self): def pop_outerr_to_orig(self):
""" pop current snapshot out/err capture and flush to orig streams. """ """ pop current snapshot out/err capture and flush to orig streams. """
out, err = self.readouterr() out, err = self.readouterr()
@ -397,25 +366,27 @@ class StdCaptureBase(object):
if err: if err:
self.err.writeorg(err) self.err.writeorg(err)
def stop_capturing(self):
""" stop capturing and reset capturing streams """
if hasattr(self, '_reset'):
raise ValueError("was already stopped")
self._reset = True
if self.out:
self.out.done()
if self.err:
self.err.done()
if self.in_:
self.in_.done()
def readouterr(self): def readouterr(self):
""" return snapshot unicode value of stdout/stderr capturings. """ """ return snapshot unicode value of stdout/stderr capturings. """
return self._readsnapshot('out'), self._readsnapshot('err') return self._readsnapshot('out'), self._readsnapshot('err')
def _readsnapshot(self, name): def _readsnapshot(self, name):
try: cap = getattr(self, name, None)
f = getattr(self, name).tmpfile if cap is None:
except AttributeError: return ""
return '' return cap.snap()
if f.tell() == 0:
return ''
f.seek(0)
res = f.read()
enc = getattr(f, "encoding", None)
if enc and isinstance(res, bytes):
res = py.builtin._totext(res, enc, "replace")
f.truncate(0)
f.seek(0)
return res
class FDCapture: class FDCapture:
@ -433,11 +404,16 @@ class FDCapture:
if targetfd == 0: if targetfd == 0:
tmpfile = open(os.devnull, "r") tmpfile = open(os.devnull, "r")
else: else:
tmpfile = maketmpfile() f = TemporaryFile()
with f:
tmpfile = dupfile(f, encoding="UTF-8")
self.tmpfile = tmpfile self.tmpfile = tmpfile
if targetfd in patchsysdict: if targetfd in patchsysdict:
self._oldsys = getattr(sys, patchsysdict[targetfd]) self._oldsys = getattr(sys, patchsysdict[targetfd])
def __repr__(self):
return "<FDCapture %s oldfd=%s>" % (self.targetfd, self._savefd)
def start(self): def start(self):
""" Start capturing on targetfd using memorized tmpfile. """ """ Start capturing on targetfd using memorized tmpfile. """
try: try:
@ -450,16 +426,26 @@ class FDCapture:
subst = self.tmpfile if targetfd != 0 else DontReadFromInput() subst = self.tmpfile if targetfd != 0 else DontReadFromInput()
setattr(sys, patchsysdict[targetfd], subst) setattr(sys, patchsysdict[targetfd], subst)
def snap(self):
f = self.tmpfile
f.seek(0)
res = f.read()
if res:
enc = getattr(f, "encoding", None)
if enc and isinstance(res, bytes):
res = py.builtin._totext(res, enc, "replace")
f.truncate(0)
f.seek(0)
return res
def done(self): def done(self):
""" stop capturing, restore streams, return original capture file, """ stop capturing, restore streams, return original capture file,
seeked to position zero. """ seeked to position zero. """
os.dup2(self._savefd, self.targetfd) os.dup2(self._savefd, self.targetfd)
os.close(self._savefd) os.close(self._savefd)
if self.targetfd != 0:
self.tmpfile.seek(0)
if hasattr(self, '_oldsys'): if hasattr(self, '_oldsys'):
setattr(sys, patchsysdict[self.targetfd], self._oldsys) setattr(sys, patchsysdict[self.targetfd], self._oldsys)
return self.tmpfile self.tmpfile.close()
def writeorg(self, data): def writeorg(self, data):
""" write a string to the original file descriptor """ write a string to the original file descriptor
@ -482,18 +468,22 @@ class SysCapture:
def start(self): def start(self):
setattr(sys, self.name, self.tmpfile) setattr(sys, self.name, self.tmpfile)
def snap(self):
f = self.tmpfile
res = f.getvalue()
f.truncate(0)
f.seek(0)
return res
def done(self): def done(self):
setattr(sys, self.name, self._old) setattr(sys, self.name, self._old)
if self.name != "stdin": self.tmpfile.close()
self.tmpfile.seek(0)
return self.tmpfile
def writeorg(self, data): def writeorg(self, data):
self._old.write(data) self._old.write(data)
self._old.flush() self._old.flush()
class DontReadFromInput: class DontReadFromInput:
"""Temporary stub class. Ideally when stdin is accessed, the """Temporary stub class. Ideally when stdin is accessed, the
capturing should be turned off, with possibly all data captured capturing should be turned off, with possibly all data captured

View File

@ -103,7 +103,7 @@ class TestCaptureManager:
assert not out and not err assert not out and not err
capman.reset_capturings() capman.reset_capturings()
finally: finally:
capouter.reset() capouter.stop_capturing()
@needsosdup @needsosdup
def test_juggle_capturings(self, testdir): def test_juggle_capturings(self, testdir):
@ -127,7 +127,7 @@ class TestCaptureManager:
finally: finally:
capman.reset_capturings() capman.reset_capturings()
finally: finally:
capouter.reset() capouter.stop_capturing()
@pytest.mark.parametrize("method", ['fd', 'sys']) @pytest.mark.parametrize("method", ['fd', 'sys'])
@ -696,17 +696,15 @@ class TestFDCapture:
cap = capture.FDCapture(fd) cap = capture.FDCapture(fd)
data = tobytes("hello") data = tobytes("hello")
os.write(fd, data) os.write(fd, data)
f = cap.done() s = cap.snap()
s = f.read() cap.done()
f.close()
assert not s assert not s
cap = capture.FDCapture(fd) cap = capture.FDCapture(fd)
cap.start() cap.start()
os.write(fd, data) os.write(fd, data)
f = cap.done() s = cap.snap()
s = f.read() cap.done()
assert s == "hello" assert s == "hello"
f.close()
def test_simple_many(self, tmpfile): def test_simple_many(self, tmpfile):
for i in range(10): for i in range(10):
@ -720,16 +718,15 @@ class TestFDCapture:
def test_simple_fail_second_start(self, tmpfile): def test_simple_fail_second_start(self, tmpfile):
fd = tmpfile.fileno() fd = tmpfile.fileno()
cap = capture.FDCapture(fd) cap = capture.FDCapture(fd)
f = cap.done() cap.done()
pytest.raises(ValueError, cap.start) pytest.raises(ValueError, cap.start)
f.close()
def test_stderr(self): def test_stderr(self):
cap = capture.FDCapture(2) cap = capture.FDCapture(2)
cap.start() cap.start()
print_("hello", file=sys.stderr) print_("hello", file=sys.stderr)
f = cap.done() s = cap.snap()
s = f.read() cap.done()
assert s == "hello\n" assert s == "hello\n"
def test_stdin(self, tmpfile): def test_stdin(self, tmpfile):
@ -752,8 +749,8 @@ class TestFDCapture:
cap.writeorg(data2) cap.writeorg(data2)
finally: finally:
tmpfile.close() tmpfile.close()
f = cap.done() scap = cap.snap()
scap = f.read() cap.done()
assert scap == totext(data1) assert scap == totext(data1)
stmp = open(tmpfile.name, 'rb').read() stmp = open(tmpfile.name, 'rb').read()
assert stmp == data2 assert stmp == data2
@ -769,17 +766,17 @@ class TestStdCapture:
cap = self.getcapture() cap = self.getcapture()
sys.stdout.write("hello") sys.stdout.write("hello")
sys.stderr.write("world") sys.stderr.write("world")
outfile, errfile = cap.stop_capturing() out, err = cap.readouterr()
s = outfile.read() cap.stop_capturing()
assert s == "hello" assert out == "hello"
s = errfile.read() assert err == "world"
assert s == "world"
def test_capturing_reset_simple(self): def test_capturing_reset_simple(self):
cap = self.getcapture() cap = self.getcapture()
print("hello world") print("hello world")
sys.stderr.write("hello error\n") sys.stderr.write("hello error\n")
out, err = cap.reset() out, err = cap.readouterr()
cap.stop_capturing()
assert out == "hello world\n" assert out == "hello world\n"
assert err == "hello error\n" assert err == "hello error\n"
@ -792,8 +789,9 @@ class TestStdCapture:
assert out == "hello world\n" assert out == "hello world\n"
assert err == "hello error\n" assert err == "hello error\n"
sys.stderr.write("error2") sys.stderr.write("error2")
out, err = cap.readouterr()
finally: finally:
out, err = cap.reset() cap.stop_capturing()
assert err == "error2" assert err == "error2"
def test_capturing_readouterr_unicode(self): def test_capturing_readouterr_unicode(self):
@ -802,7 +800,7 @@ class TestStdCapture:
print ("hx\xc4\x85\xc4\x87") print ("hx\xc4\x85\xc4\x87")
out, err = cap.readouterr() out, err = cap.readouterr()
finally: finally:
cap.reset() cap.stop_capturing()
assert out == py.builtin._totext("hx\xc4\x85\xc4\x87\n", "utf8") assert out == py.builtin._totext("hx\xc4\x85\xc4\x87\n", "utf8")
@pytest.mark.skipif('sys.version_info >= (3,)', @pytest.mark.skipif('sys.version_info >= (3,)',
@ -813,13 +811,14 @@ class TestStdCapture:
print('\xa6') print('\xa6')
out, err = cap.readouterr() out, err = cap.readouterr()
assert out == py.builtin._totext('\ufffd\n', 'unicode-escape') assert out == py.builtin._totext('\ufffd\n', 'unicode-escape')
cap.reset() cap.stop_capturing()
def test_reset_twice_error(self): def test_reset_twice_error(self):
cap = self.getcapture() cap = self.getcapture()
print ("hello") print ("hello")
out, err = cap.reset() out, err = cap.readouterr()
pytest.raises(ValueError, cap.reset) cap.stop_capturing()
pytest.raises(ValueError, cap.stop_capturing)
assert out == "hello\n" assert out == "hello\n"
assert not err assert not err
@ -833,7 +832,8 @@ class TestStdCapture:
sys.stderr = capture.TextIO() sys.stderr = capture.TextIO()
print ("not seen") print ("not seen")
sys.stderr.write("not seen\n") sys.stderr.write("not seen\n")
out, err = cap.reset() out, err = cap.readouterr()
cap.stop_capturing()
assert out == "hello" assert out == "hello"
assert err == "world" assert err == "world"
assert sys.stdout == oldout assert sys.stdout == oldout
@ -844,8 +844,10 @@ class TestStdCapture:
print ("cap1") print ("cap1")
cap2 = self.getcapture() cap2 = self.getcapture()
print ("cap2") print ("cap2")
out2, err2 = cap2.reset() out2, err2 = cap2.readouterr()
out1, err1 = cap1.reset() out1, err1 = cap1.readouterr()
cap2.stop_capturing()
cap1.stop_capturing()
assert out1 == "cap1\n" assert out1 == "cap1\n"
assert out2 == "cap2\n" assert out2 == "cap2\n"
@ -853,7 +855,8 @@ class TestStdCapture:
cap = self.getcapture(out=True, err=False) cap = self.getcapture(out=True, err=False)
sys.stdout.write("hello") sys.stdout.write("hello")
sys.stderr.write("world") sys.stderr.write("world")
out, err = cap.reset() out, err = cap.readouterr()
cap.stop_capturing()
assert out == "hello" assert out == "hello"
assert not err assert not err
@ -861,7 +864,8 @@ class TestStdCapture:
cap = self.getcapture(out=False, err=True) cap = self.getcapture(out=False, err=True)
sys.stdout.write("hello") sys.stdout.write("hello")
sys.stderr.write("world") sys.stderr.write("world")
out, err = cap.reset() out, err = cap.readouterr()
cap.stop_capturing()
assert err == "world" assert err == "world"
assert not out assert not out
@ -869,7 +873,7 @@ class TestStdCapture:
old = sys.stdin old = sys.stdin
cap = self.getcapture(in_=True) cap = self.getcapture(in_=True)
newstdin = sys.stdin newstdin = sys.stdin
out, err = cap.reset() cap.stop_capturing()
assert newstdin != sys.stdin assert newstdin != sys.stdin
assert sys.stdin is old assert sys.stdin is old
@ -879,7 +883,7 @@ class TestStdCapture:
print ("XXX mechanisms") print ("XXX mechanisms")
cap = self.getcapture() cap = self.getcapture()
pytest.raises(IOError, "sys.stdin.read()") pytest.raises(IOError, "sys.stdin.read()")
out, err = cap.reset() cap.stop_capturing()
class TestStdCaptureFD(TestStdCapture): class TestStdCaptureFD(TestStdCapture):
@ -890,6 +894,20 @@ class TestStdCaptureFD(TestStdCapture):
cap.start_capturing() cap.start_capturing()
return cap return cap
def test_simple_only_fd(self, testdir):
testdir.makepyfile("""
import os
def test_x():
os.write(1, "hello\\n".encode("ascii"))
assert 0
""")
result = testdir.runpytest()
result.stdout.fnmatch_lines("""
*test_x*
*assert 0*
*Captured stdout*
""")
def test_intermingling(self): def test_intermingling(self):
cap = self.getcapture() cap = self.getcapture()
oswritebytes(1, "1") oswritebytes(1, "1")
@ -900,7 +918,8 @@ class TestStdCaptureFD(TestStdCapture):
sys.stderr.write("b") sys.stderr.write("b")
sys.stderr.flush() sys.stderr.flush()
oswritebytes(2, "c") oswritebytes(2, "c")
out, err = cap.reset() out, err = cap.readouterr()
cap.stop_capturing()
assert out == "123" assert out == "123"
assert err == "abc" assert err == "abc"
@ -908,7 +927,7 @@ class TestStdCaptureFD(TestStdCapture):
with lsof_check(): with lsof_check():
for i in range(10): for i in range(10):
cap = StdCaptureFD() cap = StdCaptureFD()
cap.reset() cap.stop_capturing()
@ -943,10 +962,6 @@ class TestStdCaptureFDinvalidFD:
def test_capture_not_started_but_reset(): def test_capture_not_started_but_reset():
capsys = StdCapture() capsys = StdCapture()
capsys.stop_capturing() capsys.stop_capturing()
capsys.stop_capturing()
capsys.reset()
@needsosdup @needsosdup
@ -960,7 +975,7 @@ def test_fdcapture_tmpfile_remains_the_same(tmpfile, use):
capfile = cap.err.tmpfile capfile = cap.err.tmpfile
cap.readouterr() cap.readouterr()
finally: finally:
cap.reset() cap.stop_capturing()
capfile2 = cap.err.tmpfile capfile2 = cap.err.tmpfile
assert capfile2 == capfile assert capfile2 == capfile