avoid some redundancy by using SysCapture from FDCapture for manipulating sys.std{out,in,err}
This commit is contained in:
parent
69cbac8fb5
commit
ca5e6830c6
|
@ -333,6 +333,8 @@ class MultiCapture(object):
|
||||||
return (self.out.snap() if self.out is not None else "",
|
return (self.out.snap() if self.out is not None else "",
|
||||||
self.err.snap() if self.err is not None else "")
|
self.err.snap() if self.err is not None else "")
|
||||||
|
|
||||||
|
class NoCapture:
|
||||||
|
__init__ = start = done = lambda *args: None
|
||||||
|
|
||||||
class FDCapture:
|
class FDCapture:
|
||||||
""" Capture IO to/from a given os-level filedescriptor. """
|
""" Capture IO to/from a given os-level filedescriptor. """
|
||||||
|
@ -340,36 +342,38 @@ class FDCapture:
|
||||||
def __init__(self, targetfd, tmpfile=None):
|
def __init__(self, targetfd, tmpfile=None):
|
||||||
self.targetfd = targetfd
|
self.targetfd = targetfd
|
||||||
try:
|
try:
|
||||||
self._savefd = os.dup(self.targetfd)
|
self.targetfd_save = os.dup(self.targetfd)
|
||||||
except OSError:
|
except OSError:
|
||||||
self.start = lambda: None
|
self.start = lambda: None
|
||||||
self.done = lambda: None
|
self.done = lambda: None
|
||||||
else:
|
else:
|
||||||
if tmpfile is None:
|
if targetfd == 0:
|
||||||
if targetfd == 0:
|
assert not tmpfile, "cannot set tmpfile with stdin"
|
||||||
tmpfile = open(os.devnull, "r")
|
tmpfile = open(os.devnull, "r")
|
||||||
else:
|
self.syscapture = SysCapture(targetfd)
|
||||||
|
else:
|
||||||
|
if tmpfile is None:
|
||||||
f = TemporaryFile()
|
f = TemporaryFile()
|
||||||
with f:
|
with f:
|
||||||
tmpfile = safe_text_dupfile(f, mode="wb+")
|
tmpfile = safe_text_dupfile(f, mode="wb+")
|
||||||
|
if targetfd in patchsysdict:
|
||||||
|
self.syscapture = SysCapture(targetfd, tmpfile)
|
||||||
|
else:
|
||||||
|
self.syscapture = NoCapture()
|
||||||
self.tmpfile = tmpfile
|
self.tmpfile = tmpfile
|
||||||
if targetfd in patchsysdict:
|
self.tmpfile_fd = tmpfile.fileno()
|
||||||
self._oldsys = getattr(sys, patchsysdict[targetfd])
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<FDCapture %s oldfd=%s>" % (self.targetfd, self._savefd)
|
return "<FDCapture %s oldfd=%s>" % (self.targetfd, self.targetfd_save)
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
""" Start capturing on targetfd using memorized tmpfile. """
|
""" Start capturing on targetfd using memorized tmpfile. """
|
||||||
try:
|
try:
|
||||||
os.fstat(self._savefd)
|
os.fstat(self.targetfd_save)
|
||||||
except OSError:
|
except (AttributeError, OSError):
|
||||||
raise ValueError("saved filedescriptor not valid anymore")
|
raise ValueError("saved filedescriptor not valid anymore")
|
||||||
targetfd = self.targetfd
|
os.dup2(self.tmpfile_fd, self.targetfd)
|
||||||
os.dup2(self.tmpfile.fileno(), targetfd)
|
self.syscapture.start()
|
||||||
if hasattr(self, '_oldsys'):
|
|
||||||
subst = self.tmpfile if targetfd != 0 else DontReadFromInput()
|
|
||||||
setattr(sys, patchsysdict[targetfd], subst)
|
|
||||||
|
|
||||||
def snap(self):
|
def snap(self):
|
||||||
f = self.tmpfile
|
f = self.tmpfile
|
||||||
|
@ -386,28 +390,38 @@ class FDCapture:
|
||||||
def done(self):
|
def done(self):
|
||||||
""" stop capturing, restore streams, return original capture file,
|
""" stop capturing, restore streams, return original capture file,
|
||||||
seeked to position zero. """
|
seeked to position zero. """
|
||||||
os.dup2(self._savefd, self.targetfd)
|
targetfd_save = self.__dict__.pop("targetfd_save")
|
||||||
os.close(self._savefd)
|
os.dup2(targetfd_save, self.targetfd)
|
||||||
if hasattr(self, '_oldsys'):
|
os.close(targetfd_save)
|
||||||
setattr(sys, patchsysdict[self.targetfd], self._oldsys)
|
self.syscapture.done()
|
||||||
self.tmpfile.close()
|
self.tmpfile.close()
|
||||||
|
|
||||||
|
def suspend(self):
|
||||||
|
self.syscapture.suspend()
|
||||||
|
os.dup2(self.targetfd_save, self.targetfd)
|
||||||
|
|
||||||
|
def resume(self):
|
||||||
|
self.syscapture.resume()
|
||||||
|
os.dup2(self.tmpfile_fd, self.targetfd)
|
||||||
|
|
||||||
def writeorg(self, data):
|
def writeorg(self, data):
|
||||||
""" write to original file descriptor. """
|
""" write to original file descriptor. """
|
||||||
if py.builtin._istext(data):
|
if py.builtin._istext(data):
|
||||||
data = data.encode("utf8") # XXX use encoding of original stream
|
data = data.encode("utf8") # XXX use encoding of original stream
|
||||||
os.write(self._savefd, data)
|
os.write(self.targetfd_save, data)
|
||||||
|
|
||||||
|
|
||||||
class SysCapture:
|
class SysCapture:
|
||||||
def __init__(self, fd):
|
def __init__(self, fd, tmpfile=None):
|
||||||
name = patchsysdict[fd]
|
name = patchsysdict[fd]
|
||||||
self._old = getattr(sys, name)
|
self._old = getattr(sys, name)
|
||||||
self.name = name
|
self.name = name
|
||||||
if name == "stdin":
|
if tmpfile is None:
|
||||||
self.tmpfile = DontReadFromInput()
|
if name == "stdin":
|
||||||
else:
|
tmpfile = DontReadFromInput()
|
||||||
self.tmpfile = TextIO()
|
else:
|
||||||
|
tmpfile = TextIO()
|
||||||
|
self.tmpfile = tmpfile
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
setattr(sys, self.name, self.tmpfile)
|
setattr(sys, self.name, self.tmpfile)
|
||||||
|
@ -421,8 +435,15 @@ class SysCapture:
|
||||||
|
|
||||||
def done(self):
|
def done(self):
|
||||||
setattr(sys, self.name, self._old)
|
setattr(sys, self.name, self._old)
|
||||||
|
del self._old
|
||||||
self.tmpfile.close()
|
self.tmpfile.close()
|
||||||
|
|
||||||
|
def suspend(self):
|
||||||
|
setattr(sys, self.name, self._old)
|
||||||
|
|
||||||
|
def resume(self):
|
||||||
|
setattr(sys, self.name, self.tmpfile)
|
||||||
|
|
||||||
def writeorg(self, data):
|
def writeorg(self, data):
|
||||||
self._old.write(data)
|
self._old.write(data)
|
||||||
self._old.flush()
|
self._old.flush()
|
||||||
|
|
|
@ -726,15 +726,11 @@ class TestFDCapture:
|
||||||
assert s == "hello\n"
|
assert s == "hello\n"
|
||||||
|
|
||||||
def test_stdin(self, tmpfile):
|
def test_stdin(self, tmpfile):
|
||||||
tmpfile.write(tobytes("3"))
|
cap = capture.FDCapture(0)
|
||||||
tmpfile.seek(0)
|
|
||||||
cap = capture.FDCapture(0, tmpfile)
|
|
||||||
cap.start()
|
cap.start()
|
||||||
# check with os.read() directly instead of raw_input(), because
|
|
||||||
# sys.stdin itself may be redirected (as pytest now does by default)
|
|
||||||
x = os.read(0, 100).strip()
|
x = os.read(0, 100).strip()
|
||||||
cap.done()
|
cap.done()
|
||||||
assert x == tobytes("3")
|
assert x == tobytes('')
|
||||||
|
|
||||||
def test_writeorg(self, tmpfile):
|
def test_writeorg(self, tmpfile):
|
||||||
data1, data2 = tobytes("foo"), tobytes("bar")
|
data1, data2 = tobytes("foo"), tobytes("bar")
|
||||||
|
@ -751,7 +747,37 @@ class TestFDCapture:
|
||||||
stmp = open(tmpfile.name, 'rb').read()
|
stmp = open(tmpfile.name, 'rb').read()
|
||||||
assert stmp == data2
|
assert stmp == data2
|
||||||
|
|
||||||
|
def test_simple_resume_suspend(self, tmpfile):
|
||||||
|
with saved_fd(1):
|
||||||
|
cap = capture.FDCapture(1)
|
||||||
|
cap.start()
|
||||||
|
data = tobytes("hello")
|
||||||
|
os.write(1, data)
|
||||||
|
sys.stdout.write("whatever")
|
||||||
|
s = cap.snap()
|
||||||
|
assert s == "hellowhatever"
|
||||||
|
cap.suspend()
|
||||||
|
os.write(1, tobytes("world"))
|
||||||
|
sys.stdout.write("qlwkej")
|
||||||
|
assert not cap.snap()
|
||||||
|
cap.resume()
|
||||||
|
os.write(1, tobytes("but now"))
|
||||||
|
sys.stdout.write(" yes\n")
|
||||||
|
s = cap.snap()
|
||||||
|
assert s == "but now yes\n"
|
||||||
|
cap.suspend()
|
||||||
|
cap.done()
|
||||||
|
pytest.raises(AttributeError, cap.suspend)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def saved_fd(fd):
|
||||||
|
new_fd = os.dup(fd)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
os.dup2(new_fd, fd)
|
||||||
|
|
||||||
|
|
||||||
class TestStdCapture:
|
class TestStdCapture:
|
||||||
captureclass = staticmethod(StdCapture)
|
captureclass = staticmethod(StdCapture)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue