unify and normalize Sys/FD Capturing classes

* * *
more unification
This commit is contained in:
holger krekel 2014-03-28 07:03:37 +01:00
parent 2263fcf6b7
commit e18c3ed494
2 changed files with 134 additions and 235 deletions

View File

@ -100,44 +100,25 @@ def pytest_load_initial_conftests(early_config, parser, args, __multicall__):
raise
class NoCapture:
def start_capturing(self):
pass
def stop_capturing(self):
pass
def pop_outerr_to_orig(self):
pass
def reset(self):
pass
def readouterr(self):
return "", ""
def maketmpfile():
f = py.std.tempfile.TemporaryFile()
newf = dupfile(f, encoding="UTF-8")
f.close()
return newf
class CaptureManager:
def __init__(self, defaultmethod=None):
self._method2capture = {}
self._defaultmethod = defaultmethod
def _maketempfile(self):
f = py.std.tempfile.TemporaryFile()
newf = dupfile(f, encoding="UTF-8")
f.close()
return newf
def _getcapture(self, method):
if method == "fd":
return StdCaptureFD(
out=self._maketempfile(),
err=self._maketempfile(),
)
return StdCaptureBase(out=True, err=True, Capture=FDCapture)
elif method == "sys":
return StdCapture(out=TextIO(), err=TextIO())
return StdCaptureBase(out=True, err=True, Capture=SysCapture)
elif method == "no":
return NoCapture()
return StdCaptureBase(out=False, err=False, in_=False)
else:
raise ValueError("unknown capturing method: %r" % method)
@ -277,8 +258,7 @@ def pytest_funcarg__capsys(request):
"""
if "capfd" in request._funcargs:
raise request.raiseerror(error_capsysfderror)
return CaptureFixture(StdCapture)
return CaptureFixture(SysCapture)
def pytest_funcarg__capfd(request):
"""enables capturing of writes to file descriptors 1 and 2 and makes
@ -289,12 +269,13 @@ def pytest_funcarg__capfd(request):
request.raiseerror(error_capsysfderror)
if not hasattr(os, 'dup'):
pytest.skip("capfd funcarg needs os.dup")
return CaptureFixture(StdCaptureFD)
return CaptureFixture(FDCapture)
class CaptureFixture:
def __init__(self, captureclass):
self._capture = captureclass(in_=False)
self._capture = StdCaptureBase(out=True, err=True, in_=False,
Capture=captureclass)
def _start(self):
self._capture.start_capturing()
@ -315,63 +296,6 @@ class CaptureFixture:
self._finalize()
class FDCapture:
""" Capture IO to/from a given os-level filedescriptor. """
def __init__(self, targetfd, tmpfile=None, patchsys=False):
""" save targetfd descriptor, and open a new
temporary file there. If no tmpfile is
specified a tempfile.Tempfile() will be opened
in text mode.
"""
self.targetfd = targetfd
if tmpfile is None and targetfd != 0:
# this code path is covered in the tests
# but not used by a regular pytest run
f = tempfile.TemporaryFile('wb+')
tmpfile = dupfile(f, encoding="UTF-8")
f.close()
self.tmpfile = tmpfile
self._savefd = os.dup(self.targetfd)
if patchsys:
self._oldsys = getattr(sys, patchsysdict[targetfd])
def start(self):
try:
os.fstat(self._savefd)
except OSError:
raise ValueError(
"saved filedescriptor not valid, "
"did you call start() twice?")
if self.targetfd == 0 and not self.tmpfile:
fd = os.open(os.devnull, os.O_RDONLY)
os.dup2(fd, 0)
os.close(fd)
if hasattr(self, '_oldsys'):
setattr(sys, patchsysdict[self.targetfd], DontReadFromInput())
else:
os.dup2(self.tmpfile.fileno(), self.targetfd)
if hasattr(self, '_oldsys'):
setattr(sys, patchsysdict[self.targetfd], self.tmpfile)
def done(self):
""" unpatch and clean up, returns the self.tmpfile (file object)
"""
os.dup2(self._savefd, self.targetfd)
os.close(self._savefd)
if self.targetfd != 0:
self.tmpfile.seek(0)
if hasattr(self, '_oldsys'):
setattr(sys, patchsysdict[self.targetfd], self._oldsys)
return self.tmpfile
def writeorg(self, data):
""" write a string to the original file descriptor
"""
if py.builtin._istext(data):
data = data.encode("utf8") # XXX use encoding of original stream
os.write(self._savefd, data)
def dupfile(f, mode=None, buffering=0, raising=False, encoding=None):
""" return a new open file object that's a duplicate of f
@ -421,6 +345,16 @@ class EncodedFile(object):
class StdCaptureBase(object):
out = err = in_ = None
def __init__(self, out=True, err=True, in_=True, Capture=None):
if in_:
self.in_ = Capture(0)
if out:
self.out = Capture(1)
if err:
self.err = Capture(2)
def reset(self):
""" reset sys.stdout/stderr and return captured output as strings. """
if hasattr(self, '_reset'):
@ -436,6 +370,25 @@ class StdCaptureBase(object):
errfile.close()
return out, err
def start_capturing(self):
if self.in_:
self.in_.start()
if self.out:
self.out.start()
if self.err:
self.err.start()
def stop_capturing(self):
""" return (outfile, errfile) and stop capturing. """
outfile = errfile = None
if self.out:
outfile = self.out.done()
if self.err:
errfile = self.err.done()
if self.in_:
self.in_.done()
return outfile, errfile
def pop_outerr_to_orig(self):
""" pop current snapshot out/err capture and flush to orig streams. """
out, err = self.readouterr()
@ -444,61 +397,8 @@ class StdCaptureBase(object):
if err:
self.err.writeorg(err)
class StdCaptureFD(StdCaptureBase):
""" This class allows to capture writes to FD1 and FD2
and may connect a NULL file to FD0 (and prevent
reads from sys.stdin). If any of the 0,1,2 file descriptors
is invalid it will not be captured.
"""
def __init__(self, out=True, err=True, in_=True, patchsys=True):
if in_:
try:
self.in_ = FDCapture(0, tmpfile=None, patchsys=patchsys)
except OSError:
pass
if out:
tmpfile = None
if hasattr(out, 'write'):
tmpfile = out
try:
self.out = FDCapture(1, tmpfile=tmpfile, patchsys=patchsys)
except OSError:
pass
if err:
if hasattr(err, 'write'):
tmpfile = err
else:
tmpfile = None
try:
self.err = FDCapture(2, tmpfile=tmpfile, patchsys=patchsys)
except OSError:
pass
def start_capturing(self):
if hasattr(self, 'in_'):
self.in_.start()
if hasattr(self, 'out'):
self.out.start()
if hasattr(self, 'err'):
self.err.start()
#def pytest_sessionfinish(self):
# self.reset_capturings()
def stop_capturing(self):
""" return (outfile, errfile) and stop capturing. """
outfile = errfile = None
if hasattr(self, 'out') and not self.out.tmpfile.closed:
outfile = self.out.done()
if hasattr(self, 'err') and not self.err.tmpfile.closed:
errfile = self.err.done()
if hasattr(self, 'in_'):
self.in_.done()
return outfile, errfile
def readouterr(self):
""" return snapshot value of stdout/stderr capturings. """
""" return snapshot unicode value of stdout/stderr capturings. """
return self._readsnapshot('out'), self._readsnapshot('err')
def _readsnapshot(self, name):
@ -511,77 +411,87 @@ class StdCaptureFD(StdCaptureBase):
f.seek(0)
res = f.read()
enc = getattr(f, "encoding", None)
if enc:
if enc and isinstance(res, bytes):
res = py.builtin._totext(res, enc, "replace")
f.truncate(0)
f.seek(0)
return res
class TextCapture(TextIO):
def __init__(self, oldout):
super(TextCapture, self).__init__()
self._oldout = oldout
class FDCapture:
""" Capture IO to/from a given os-level filedescriptor. """
def __init__(self, targetfd, tmpfile=None):
self.targetfd = targetfd
try:
self._savefd = os.dup(self.targetfd)
except OSError:
self.start = lambda: None
self.done = lambda: None
else:
if tmpfile is None:
if targetfd == 0:
tmpfile = open(os.devnull, "r")
else:
tmpfile = maketmpfile()
self.tmpfile = tmpfile
if targetfd in patchsysdict:
self._oldsys = getattr(sys, patchsysdict[targetfd])
def start(self):
""" Start capturing on targetfd using memorized tmpfile. """
try:
os.fstat(self._savefd)
except OSError:
raise ValueError("saved filedescriptor not valid anymore")
targetfd = self.targetfd
os.dup2(self.tmpfile.fileno(), targetfd)
if hasattr(self, '_oldsys'):
subst = self.tmpfile if targetfd != 0 else DontReadFromInput()
setattr(sys, patchsysdict[targetfd], subst)
def done(self):
""" stop capturing, restore streams, return original capture file,
seeked to position zero. """
os.dup2(self._savefd, self.targetfd)
os.close(self._savefd)
if self.targetfd != 0:
self.tmpfile.seek(0)
if hasattr(self, '_oldsys'):
setattr(sys, patchsysdict[self.targetfd], self._oldsys)
return self.tmpfile
def writeorg(self, data):
self._oldout.write(data)
self._oldout.flush()
class StdCapture(StdCaptureBase):
""" This class allows to capture writes to sys.stdout|stderr "in-memory"
and will raise errors on tries to read from sys.stdin. It only
modifies sys.stdout|stderr|stdin attributes and does not
touch underlying File Descriptors (use StdCaptureFD for that).
""" write a string to the original file descriptor
"""
def __init__(self, out=True, err=True, in_=True):
self._oldout = sys.stdout
self._olderr = sys.stderr
self._oldin = sys.stdin
if out and not hasattr(out, 'file'):
out = TextCapture(self._oldout)
self.out = out
if err:
if not hasattr(err, 'write'):
err = TextCapture(self._olderr)
self.err = err
self.in_ = in_
def start_capturing(self):
if self.out:
sys.stdout = self.out
if self.err:
sys.stderr = self.err
if self.in_:
sys.stdin = self.in_ = DontReadFromInput()
def stop_capturing(self):
""" return (outfile, errfile) and stop capturing. """
outfile = errfile = None
if self.out and not self.out.closed:
sys.stdout = self._oldout
outfile = self.out
outfile.seek(0)
if self.err and not self.err.closed:
sys.stderr = self._olderr
errfile = self.err
errfile.seek(0)
if self.in_:
sys.stdin = self._oldin
return outfile, errfile
if py.builtin._istext(data):
data = data.encode("utf8") # XXX use encoding of original stream
os.write(self._savefd, data)
def readouterr(self):
""" return snapshot value of stdout/stderr capturings. """
out = err = ""
if self.out:
out = self.out.getvalue()
self.out.truncate(0)
self.out.seek(0)
if self.err:
err = self.err.getvalue()
self.err.truncate(0)
self.err.seek(0)
return out, err
class SysCapture:
def __init__(self, fd):
name = patchsysdict[fd]
self._old = getattr(sys, name)
self.name = name
if name == "stdin":
self.tmpfile = DontReadFromInput()
else:
self.tmpfile = TextIO()
def start(self):
setattr(sys, self.name, self.tmpfile)
def done(self):
setattr(sys, self.name, self._old)
if self.name != "stdin":
self.tmpfile.seek(0)
return self.tmpfile
def writeorg(self, data):
self._old.write(data)
self._old.flush()
class DontReadFromInput:

View File

@ -4,6 +4,7 @@ from __future__ import with_statement
import os
import sys
import py
import tempfile
import pytest
import contextlib
@ -44,6 +45,13 @@ def oswritebytes(fd, obj):
def StdCaptureFD(out=True, err=True, in_=True):
return capture.StdCaptureBase(out, err, in_, Capture=capture.FDCapture)
def StdCapture(out=True, err=True, in_=True):
return capture.StdCaptureBase(out, err, in_, Capture=capture.SysCapture)
class TestCaptureManager:
def test_getmethod_default_no_fd(self, testdir, monkeypatch):
config = testdir.parseconfig(testdir.tmpdir)
@ -75,7 +83,7 @@ class TestCaptureManager:
@needsosdup
@pytest.mark.parametrize("method", ['no', 'fd', 'sys'])
def test_capturing_basic_api(self, method):
capouter = capture.StdCaptureFD()
capouter = StdCaptureFD()
old = sys.stdout, sys.stderr, sys.stdin
try:
capman = CaptureManager()
@ -99,7 +107,7 @@ class TestCaptureManager:
@needsosdup
def test_juggle_capturings(self, testdir):
capouter = capture.StdCaptureFD()
capouter = StdCaptureFD()
try:
#config = testdir.parseconfig(testdir.tmpdir)
capman = CaptureManager()
@ -717,7 +725,7 @@ class TestFDCapture:
f.close()
def test_stderr(self):
cap = capture.FDCapture(2, patchsys=True)
cap = capture.FDCapture(2)
cap.start()
print_("hello", file=sys.stderr)
f = cap.done()
@ -727,7 +735,7 @@ class TestFDCapture:
def test_stdin(self, tmpfile):
tmpfile.write(tobytes("3"))
tmpfile.seek(0)
cap = capture.FDCapture(0, tmpfile=tmpfile)
cap = capture.FDCapture(0, tmpfile)
cap.start()
# check with os.read() directly instead of raw_input(), because
# sys.stdin itself may be redirected (as pytest now does by default)
@ -753,7 +761,7 @@ class TestFDCapture:
class TestStdCapture:
def getcapture(self, **kw):
cap = capture.StdCapture(**kw)
cap = StdCapture(**kw)
cap.start_capturing()
return cap
@ -878,7 +886,7 @@ class TestStdCaptureFD(TestStdCapture):
pytestmark = needsosdup
def getcapture(self, **kw):
cap = capture.StdCaptureFD(**kw)
cap = StdCaptureFD(**kw)
cap.start_capturing()
return cap
@ -899,18 +907,10 @@ class TestStdCaptureFD(TestStdCapture):
def test_many(self, capfd):
with lsof_check():
for i in range(10):
cap = capture.StdCaptureFD()
cap = StdCaptureFD()
cap.reset()
@needsosdup
def test_stdcapture_fd_tmpfile(tmpfile):
capfd = capture.StdCaptureFD(out=tmpfile)
os.write(1, "hello".encode("ascii"))
os.write(2, "world".encode("ascii"))
outf, errf = capfd.stop_capturing()
assert outf == tmpfile
class TestStdCaptureFDinvalidFD:
pytestmark = needsosdup
@ -918,7 +918,10 @@ class TestStdCaptureFDinvalidFD:
def test_stdcapture_fd_invalid_fd(self, testdir):
testdir.makepyfile("""
import os
from _pytest.capture import StdCaptureFD
from _pytest import capture
def StdCaptureFD(out=True, err=True, in_=True):
return capture.StdCaptureBase(out, err, in_,
Capture=capture.FDCapture)
def test_stdout():
os.close(1)
cap = StdCaptureFD(out=True, err=False, in_=False)
@ -938,27 +941,12 @@ class TestStdCaptureFDinvalidFD:
def test_capture_not_started_but_reset():
capsys = capture.StdCapture()
capsys = StdCapture()
capsys.stop_capturing()
capsys.stop_capturing()
capsys.reset()
@needsosdup
def test_capture_no_sys():
capsys = capture.StdCapture()
try:
cap = capture.StdCaptureFD(patchsys=False)
cap.start_capturing()
sys.stdout.write("hello")
sys.stderr.write("world")
oswritebytes(1, "1")
oswritebytes(2, "2")
out, err = cap.reset()
assert out == "1"
assert err == "2"
finally:
capsys.reset()
@needsosdup
@ -966,7 +954,7 @@ def test_capture_no_sys():
def test_fdcapture_tmpfile_remains_the_same(tmpfile, use):
if not use:
tmpfile = True
cap = capture.StdCaptureFD(out=False, err=tmpfile)
cap = StdCaptureFD(out=False, err=tmpfile)
try:
cap.start_capturing()
capfile = cap.err.tmpfile
@ -977,7 +965,7 @@ def test_fdcapture_tmpfile_remains_the_same(tmpfile, use):
assert capfile2 == capfile
@pytest.mark.parametrize('method', ['StdCapture', 'StdCaptureFD'])
@pytest.mark.parametrize('method', ['SysCapture', 'FDCapture'])
def test_capturing_and_logging_fundamentals(testdir, method):
if method == "StdCaptureFD" and not hasattr(os, 'dup'):
pytest.skip("need os.dup")
@ -986,7 +974,8 @@ def test_capturing_and_logging_fundamentals(testdir, method):
import sys, os
import py, logging
from _pytest import capture
cap = capture.%s(out=False, in_=False)
cap = capture.StdCaptureBase(out=False, in_=False,
Capture=capture.%s)
cap.start_capturing()
logging.warn("hello1")