[svn r37872] StdCaptureFD and StdCapture now try

to take care of stdin in a mostly uniform way.

--HG--
branch : trunk
This commit is contained in:
hpk 2007-02-03 14:57:25 +01:00
parent d6d7eb8704
commit 9f8035656e
2 changed files with 53 additions and 12 deletions

View File

@ -4,7 +4,6 @@ import py
try: from cStringIO import StringIO try: from cStringIO import StringIO
except ImportError: from StringIO import StringIO except ImportError: from StringIO import StringIO
emptyfile = StringIO()
class Capture(object): class Capture(object):
def call(cls, func, *args, **kwargs): def call(cls, func, *args, **kwargs):
@ -33,10 +32,17 @@ class Capture(object):
class StdCaptureFD(Capture): class StdCaptureFD(Capture):
""" capture Stdout and Stderr both on filedescriptor """ This class allows to capture writes to FD1 and FD2
and sys.stdout/stderr level. and may connect a NULL file to FD0 (and prevent
reads from sys.stdin)
""" """
def __init__(self, out=True, err=True, mixed=False, patchsys=True): def __init__(self, out=True, err=True, mixed=False, in_=True, patchsys=True):
if in_:
self._oldin = (sys.stdin, os.dup(0))
sys.stdin = DontReadFromInput()
fd = os.open(devnullpath, os.O_RDONLY)
os.dup2(fd, 0)
os.close(fd)
if out: if out:
self.out = py.io.FDCapture(1) self.out = py.io.FDCapture(1)
if patchsys: if patchsys:
@ -57,17 +63,23 @@ class StdCaptureFD(Capture):
outfile = self.out.done() outfile = self.out.done()
if hasattr(self, 'err'): if hasattr(self, 'err'):
errfile = self.err.done() errfile = self.err.done()
if hasattr(self, '_oldin'):
oldsys, oldfd = self._oldin
os.dup2(oldfd, 0)
os.close(oldfd)
sys.stdin = oldsys
return outfile, errfile return outfile, errfile
class StdCapture(Capture): class StdCapture(Capture):
""" capture sys.stdout/sys.stderr (but not system level fd 1 and 2). """ This class allows to capture writes to sys.stdout|stderr "in-memory"
and will raise errors on tries to read from sys.stdin. It only
This class allows to capture writes to sys.stdout|stderr "in-memory" modifies sys.stdout|stderr|stdin attributes and does not
and will raise errors on tries to read from sys.stdin. touch underlying File Descriptors (use StdCaptureFD for that).
""" """
def __init__(self, out=True, err=True, mixed=False): def __init__(self, out=True, err=True, in_=True, mixed=False):
self._out = out self._out = out
self._err = err self._err = err
self._in = in_
if out: if out:
self.oldout = sys.stdout self.oldout = sys.stdout
sys.stdout = self.newout = StringIO() sys.stdout = self.newout = StringIO()
@ -78,8 +90,9 @@ class StdCapture(Capture):
else: else:
newerr = StringIO() newerr = StringIO()
sys.stderr = self.newerr = newerr sys.stderr = self.newerr = newerr
self.oldin = sys.stdin if in_:
sys.stdin = self.newin = DontReadFromInput() self.oldin = sys.stdin
sys.stdin = self.newin = DontReadFromInput()
def reset(self): def reset(self):
""" return captured output as strings and restore sys.stdout/err.""" """ return captured output as strings and restore sys.stdout/err."""
@ -106,7 +119,8 @@ class StdCapture(Capture):
del self.olderr del self.olderr
errfile = self.newerr errfile = self.newerr
errfile.seek(0) errfile.seek(0)
sys.stdin = self.oldin if self._in:
sys.stdin = self.oldin
return outfile, errfile return outfile, errfile
class DontReadFromInput: class DontReadFromInput:
@ -121,3 +135,14 @@ class DontReadFromInput:
readline = read readline = read
readlines = read readlines = read
__iter__ = read __iter__ = read
try:
devnullpath = os.devnull
except AttributeError:
if os.name == 'nt':
devnullpath = 'NUL'
else:
devnullpath = '/dev/null'
emptyfile = StringIO()

View File

@ -79,6 +79,22 @@ class TestStdCapture:
assert err == "world\n" assert err == "world\n"
assert not out assert not out
def test_stdin_restored(self):
old = sys.stdin
cap = self.getcapture(in_=True)
newstdin = sys.stdin
out, err = cap.reset()
assert newstdin != sys.stdin
assert sys.stdin is old
def test_stdin_nulled_by_default(self):
print "XXX this test may well hang instead of crashing"
print "XXX which indicates an error in the underlying capturing"
print "XXX mechanisms"
cap = self.getcapture()
py.test.raises(IOError, "sys.stdin.read()")
out, err = cap.reset()
class TestStdCaptureFD(TestStdCapture): class TestStdCaptureFD(TestStdCapture):
def getcapture(self, **kw): def getcapture(self, **kw):
return py.io.StdCaptureFD(**kw) return py.io.StdCaptureFD(**kw)