rewrite all _pytest.capture uses of py.io to _pytest.capture

This commit is contained in:
Ronny Pfannschmidt 2014-01-22 19:44:20 +01:00
parent 3cc58c2f78
commit ea18e9656b
1 changed files with 59 additions and 30 deletions

View File

@ -29,13 +29,16 @@ def pytest_load_initial_conftests(early_config, parser, args, __multicall__):
method = "sys" method = "sys"
capman = CaptureManager(method) capman = CaptureManager(method)
early_config.pluginmanager.register(capman, "capturemanager") early_config.pluginmanager.register(capman, "capturemanager")
# make sure that capturemanager is properly reset at final shutdown # make sure that capturemanager is properly reset at final shutdown
def teardown(): def teardown():
try: try:
capman.reset_capturings() capman.reset_capturings()
except ValueError: except ValueError:
pass pass
early_config.pluginmanager.add_shutdown(teardown) early_config.pluginmanager.add_shutdown(teardown)
# make sure logging does not raise exceptions at the end # make sure logging does not raise exceptions at the end
def silence_logging_at_shutdown(): def silence_logging_at_shutdown():
if "logging" in sys.modules: if "logging" in sys.modules:
@ -54,21 +57,27 @@ def pytest_load_initial_conftests(early_config, parser, args, __multicall__):
sys.stderr.write(err) sys.stderr.write(err)
raise raise
def addouterr(rep, outerr): def addouterr(rep, outerr):
for secname, content in zip(["out", "err"], outerr): for secname, content in zip(["out", "err"], outerr):
if content: if content:
rep.sections.append(("Captured std%s" % secname, content)) rep.sections.append(("Captured std%s" % secname, content))
class NoCapture: class NoCapture:
def startall(self): def startall(self):
pass pass
def resume(self): def resume(self):
pass pass
def reset(self): def reset(self):
pass pass
def suspend(self): def suspend(self):
return "", "" return "", ""
class CaptureManager: class CaptureManager:
def __init__(self, defaultmethod=None): def __init__(self, defaultmethod=None):
self._method2capture = {} self._method2capture = {}
@ -76,21 +85,25 @@ class CaptureManager:
def _maketempfile(self): def _maketempfile(self):
f = py.std.tempfile.TemporaryFile() f = py.std.tempfile.TemporaryFile()
newf = py.io.dupfile(f, encoding="UTF-8") newf = dupfile(f, encoding="UTF-8")
f.close() f.close()
return newf return newf
def _makestringio(self): def _makestringio(self):
return py.io.TextIO() return TextIO()
def _getcapture(self, method): def _getcapture(self, method):
if method == "fd": if method == "fd":
return py.io.StdCaptureFD(now=False, return StdCaptureFD(
out=self._maketempfile(), err=self._maketempfile() now=False,
out=self._maketempfile(),
err=self._maketempfile(),
) )
elif method == "sys": elif method == "sys":
return py.io.StdCapture(now=False, return StdCapture(
out=self._makestringio(), err=self._makestringio() now=False,
out=self._makestringio(),
err=self._makestringio(),
) )
elif method == "no": elif method == "no":
return NoCapture() return NoCapture()
@ -105,7 +118,7 @@ class CaptureManager:
method = config._conftest.rget("option_capture", path=fspath) method = config._conftest.rget("option_capture", path=fspath)
except KeyError: except KeyError:
method = "fd" method = "fd"
if method == "fd" and not hasattr(os, 'dup'): # e.g. jython if method == "fd" and not hasattr(os, 'dup'): # e.g. jython
method = "sys" method = "sys"
return method return method
@ -116,12 +129,13 @@ class CaptureManager:
def resumecapture_item(self, item): def resumecapture_item(self, item):
method = self._getmethod(item.config, item.fspath) method = self._getmethod(item.config, item.fspath)
if not hasattr(item, 'outerr'): if not hasattr(item, 'outerr'):
item.outerr = ('', '') # we accumulate outerr on the item item.outerr = ('', '') # we accumulate outerr on the item
return self.resumecapture(method) return self.resumecapture(method)
def resumecapture(self, method=None): def resumecapture(self, method=None):
if hasattr(self, '_capturing'): if hasattr(self, '_capturing'):
raise ValueError("cannot resume, already capturing with %r" % raise ValueError(
"cannot resume, already capturing with %r" %
(self._capturing,)) (self._capturing,))
if method is None: if method is None:
method = self._defaultmethod method = self._defaultmethod
@ -170,8 +184,9 @@ class CaptureManager:
try: try:
self.resumecapture(method) self.resumecapture(method)
except ValueError: except ValueError:
return # recursive collect, XXX refactor capturing # recursive collect, XXX refactor capturing
# to allow for more lightweight recursive capturing # to allow for more lightweight recursive capturing
return
try: try:
rep = __multicall__.execute() rep = __multicall__.execute()
finally: finally:
@ -212,6 +227,7 @@ class CaptureManager:
error_capsysfderror = "cannot use capsys and capfd at the same time" error_capsysfderror = "cannot use capsys and capfd at the same time"
def pytest_funcarg__capsys(request): def pytest_funcarg__capsys(request):
"""enables capturing of writes to sys.stdout/sys.stderr and makes """enables capturing of writes to sys.stdout/sys.stderr and makes
captured output available via ``capsys.readouterr()`` method calls captured output available via ``capsys.readouterr()`` method calls
@ -219,7 +235,8 @@ def pytest_funcarg__capsys(request):
""" """
if "capfd" in request._funcargs: if "capfd" in request._funcargs:
raise request.raiseerror(error_capsysfderror) raise request.raiseerror(error_capsysfderror)
return CaptureFixture(py.io.StdCapture) return CaptureFixture(StdCapture)
def pytest_funcarg__capfd(request): def pytest_funcarg__capfd(request):
"""enables capturing of writes to file descriptors 1 and 2 and makes """enables capturing of writes to file descriptors 1 and 2 and makes
@ -230,7 +247,8 @@ def pytest_funcarg__capfd(request):
request.raiseerror(error_capsysfderror) request.raiseerror(error_capsysfderror)
if not hasattr(os, 'dup'): if not hasattr(os, 'dup'):
pytest.skip("capfd funcarg needs os.dup") pytest.skip("capfd funcarg needs os.dup")
return CaptureFixture(py.io.StdCaptureFD) return CaptureFixture(StdCaptureFD)
class CaptureFixture: class CaptureFixture:
def __init__(self, captureclass): def __init__(self, captureclass):
@ -253,9 +271,7 @@ class CaptureFixture:
def close(self): def close(self):
self._finalize() self._finalize()
import os
import sys
import py
import tempfile import tempfile
try: try:
@ -263,11 +279,13 @@ try:
except ImportError: except ImportError:
from StringIO import StringIO from StringIO import StringIO
if sys.version_info < (3,0):
if sys.version_info < (3, 0):
class TextIO(StringIO): class TextIO(StringIO):
def write(self, data): def write(self, data):
if not isinstance(data, unicode): if not isinstance(data, unicode):
data = unicode(data, getattr(self, '_encoding', 'UTF-8'), 'replace') enc = getattr(self, '_encoding', 'UTF-8')
data = unicode(data, enc, 'replace')
StringIO.write(self, data) StringIO.write(self, data)
else: else:
TextIO = StringIO TextIO = StringIO
@ -278,11 +296,12 @@ except ImportError:
class BytesIO(StringIO): class BytesIO(StringIO):
def write(self, data): def write(self, data):
if isinstance(data, unicode): if isinstance(data, unicode):
raise TypeError("not a byte value: %r" %(data,)) raise TypeError("not a byte value: %r" % (data,))
StringIO.write(self, data) StringIO.write(self, data)
patchsysdict = {0: 'stdin', 1: 'stdout', 2: 'stderr'} patchsysdict = {0: 'stdin', 1: 'stdout', 2: 'stderr'}
class FDCapture: class FDCapture:
""" Capture IO to/from a given os-level filedescriptor. """ """ Capture IO to/from a given os-level filedescriptor. """
@ -308,7 +327,8 @@ class FDCapture:
try: try:
os.fstat(self._savefd) os.fstat(self._savefd)
except OSError: except OSError:
raise ValueError("saved filedescriptor not valid, " raise ValueError(
"saved filedescriptor not valid, "
"did you call start() twice?") "did you call start() twice?")
if self.targetfd == 0 and not self.tmpfile: if self.targetfd == 0 and not self.tmpfile:
fd = os.open(devnullpath, os.O_RDONLY) fd = os.open(devnullpath, os.O_RDONLY)
@ -360,7 +380,7 @@ def dupfile(f, mode=None, buffering=0, raising=False, encoding=None):
raise raise
return f return f
newfd = os.dup(fd) newfd = os.dup(fd)
if sys.version_info >= (3,0): if sys.version_info >= (3, 0):
if encoding is not None: if encoding is not None:
mode = mode.replace("b", "") mode = mode.replace("b", "")
buffering = True buffering = True
@ -371,6 +391,7 @@ def dupfile(f, mode=None, buffering=0, raising=False, encoding=None):
return EncodedFile(f, encoding) return EncodedFile(f, encoding)
return f return f
class EncodedFile(object): class EncodedFile(object):
def __init__(self, _stream, encoding): def __init__(self, _stream, encoding):
self._stream = _stream self._stream = _stream
@ -392,6 +413,7 @@ class EncodedFile(object):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self._stream, name) return getattr(self._stream, name)
class Capture(object): class Capture(object):
def call(cls, func, *args, **kwargs): def call(cls, func, *args, **kwargs):
""" return a (res, out, err) tuple where """ return a (res, out, err) tuple where
@ -437,7 +459,7 @@ class StdCaptureFD(Capture):
is invalid it will not be captured. is invalid it will not be captured.
""" """
def __init__(self, out=True, err=True, mixed=False, def __init__(self, out=True, err=True, mixed=False,
in_=True, patchsys=True, now=True): in_=True, patchsys=True, now=True):
self._options = { self._options = {
"out": out, "out": out,
"err": err, "err": err,
@ -458,7 +480,8 @@ class StdCaptureFD(Capture):
patchsys = self._options['patchsys'] patchsys = self._options['patchsys']
if in_: if in_:
try: try:
self.in_ = FDCapture(0, tmpfile=None, now=False, self.in_ = FDCapture(
0, tmpfile=None, now=False,
patchsys=patchsys) patchsys=patchsys)
except OSError: except OSError:
pass pass
@ -467,8 +490,9 @@ class StdCaptureFD(Capture):
if hasattr(out, 'write'): if hasattr(out, 'write'):
tmpfile = out tmpfile = out
try: try:
self.out = FDCapture(1, tmpfile=tmpfile, self.out = FDCapture(
now=False, patchsys=patchsys) 1, tmpfile=tmpfile,
now=False, patchsys=patchsys)
self._options['out'] = self.out.tmpfile self._options['out'] = self.out.tmpfile
except OSError: except OSError:
pass pass
@ -480,8 +504,9 @@ class StdCaptureFD(Capture):
else: else:
tmpfile = None tmpfile = None
try: try:
self.err = FDCapture(2, tmpfile=tmpfile, self.err = FDCapture(
now=False, patchsys=patchsys) 2, tmpfile=tmpfile,
now=False, patchsys=patchsys)
self._options['err'] = self.err.tmpfile self._options['err'] = self.err.tmpfile
except OSError: except OSError:
pass pass
@ -506,7 +531,7 @@ class StdCaptureFD(Capture):
if hasattr(self, 'err') and not self.err.tmpfile.closed: if hasattr(self, 'err') and not self.err.tmpfile.closed:
errfile = self.err.done() errfile = self.err.done()
if hasattr(self, 'in_'): if hasattr(self, 'in_'):
tmpfile = self.in_.done() self.in_.done()
if save: if save:
self._save() self._save()
return outfile, errfile return outfile, errfile
@ -543,7 +568,7 @@ class StdCapture(Capture):
def __init__(self, out=True, err=True, in_=True, mixed=False, now=True): def __init__(self, out=True, err=True, in_=True, mixed=False, now=True):
self._oldout = sys.stdout self._oldout = sys.stdout
self._olderr = sys.stderr self._olderr = sys.stderr
self._oldin = sys.stdin self._oldin = sys.stdin
if out and not hasattr(out, 'file'): if out and not hasattr(out, 'file'):
out = TextIO() out = TextIO()
self.out = out self.out = out
@ -563,7 +588,7 @@ class StdCapture(Capture):
if self.err: if self.err:
sys.stderr = self.err sys.stderr = self.err
if self.in_: if self.in_:
sys.stdin = self.in_ = DontReadFromInput() sys.stdin = self.in_ = DontReadFromInput()
def done(self, save=True): def done(self, save=True):
""" return (outfile, errfile) and stop capturing. """ """ return (outfile, errfile) and stop capturing. """
@ -597,6 +622,7 @@ class StdCapture(Capture):
self.err.seek(0) self.err.seek(0)
return out, err return out, err
class DontReadFromInput: class DontReadFromInput:
"""Temporary stub class. Ideally when stdin is accessed, the """Temporary stub class. Ideally when stdin is accessed, the
capturing should be turned off, with possibly all data captured capturing should be turned off, with possibly all data captured
@ -612,11 +638,14 @@ class DontReadFromInput:
def fileno(self): def fileno(self):
raise ValueError("redirected Stdin is pseudofile, has no fileno()") raise ValueError("redirected Stdin is pseudofile, has no fileno()")
def isatty(self): def isatty(self):
return False return False
def close(self): def close(self):
pass pass
try: try:
devnullpath = os.devnull devnullpath = os.devnull
except AttributeError: except AttributeError: