* more tests and fixes for cross-python compatibility

* use byte-buffer files if available for io
* shift receivelock to gateway object
* kill dead code

--HG--
branch : trunk
This commit is contained in:
holger krekel 2009-09-02 18:56:43 +02:00
parent 5d2504df0a
commit e30aeed876
4 changed files with 159 additions and 125 deletions

View File

@ -35,7 +35,7 @@ class GatewayCleanup:
if debug: if debug:
debug.writeslines(["="*20, "cleaning up", "=" * 20]) debug.writeslines(["="*20, "cleaning up", "=" * 20])
debug.flush() debug.flush()
for gw in self._activegateways.keys(): for gw in list(self._activegateways):
gw.exit() gw.exit()
#gw.join() # should work as well #gw.join() # should work as well
@ -70,6 +70,12 @@ class InitiatingGateway(BaseGateway):
return "<%s%s %s/%s (%s active channels)>" %( return "<%s%s %s/%s (%s active channels)>" %(
self.__class__.__name__, addr, r, s, i) self.__class__.__name__, addr, r, s, i)
def exit(self):
""" Try to stop all exec and IO activity. """
self._cleanup.unregister(self)
self._stopexec()
self._stopsend()
self.hook.pyexecnet_gateway_exit(gateway=self)
def _remote_bootstrap_gateway(self, io, extra=''): def _remote_bootstrap_gateway(self, io, extra=''):
""" return Gateway with a asynchronously remotely """ return Gateway with a asynchronously remotely
@ -93,16 +99,8 @@ class InitiatingGateway(BaseGateway):
def _rinfo(self, update=False): def _rinfo(self, update=False):
""" return some sys/env information from remote. """ """ return some sys/env information from remote. """
if update or not hasattr(self, '_cache_rinfo'): if update or not hasattr(self, '_cache_rinfo'):
self._cache_rinfo = RInfo(**self.remote_exec(""" ch = self.remote_exec(rinfo_source)
import sys, os self._cache_rinfo = RInfo(**ch.receive())
channel.send(dict(
executable = sys.executable,
version_info = sys.version_info,
platform = sys.platform,
cwd = os.getcwd(),
pid = os.getpid(),
))
""").receive())
return self._cache_rinfo return self._cache_rinfo
def remote_exec(self, source, stdout=None, stderr=None): def remote_exec(self, source, stdout=None, stderr=None):
@ -193,14 +191,24 @@ class RInfo:
for item in self.__dict__.items()]) for item in self.__dict__.items()])
return "<RInfo %r>" % info return "<RInfo %r>" % info
rinfo_source = """
import sys, os
channel.send(dict(
executable = sys.executable,
version_info = tuple([sys.version_info[i] for i in range(5)]),
platform = sys.platform,
cwd = os.getcwd(),
pid = os.getpid(),
))
"""
class PopenCmdGateway(InitiatingGateway): class PopenCmdGateway(InitiatingGateway):
def __init__(self, cmd): def __init__(self, cmd):
# on win close_fds=True does not work, not sure it'd needed # on win close_fds=True does not work, not sure it'd needed
#p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, close_fds=True) #p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, close_fds=True)
self._popen = p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE) self._popen = p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE)
infile, outfile = p.stdin, p.stdout
self._cmd = cmd self._cmd = cmd
io = Popen2IO(infile, outfile) io = Popen2IO(p.stdin, p.stdout)
super(PopenCmdGateway, self).__init__(io=io) super(PopenCmdGateway, self).__init__(io=io)
def exit(self): def exit(self):
@ -217,8 +225,8 @@ class PopenGateway(PopenCmdGateway):
""" """
if not python: if not python:
python = sys.executable python = sys.executable
cmd = '%s -u -c "exec input()"' % python cmd = ('%s -u -c "import sys ; '
cmd = '%s -u -c "import sys ; exec(eval(sys.stdin.readline()))"' % python 'exec(eval(sys.stdin.readline()))"' % python)
super(PopenGateway, self).__init__(cmd) super(PopenGateway, self).__init__(cmd)
def _remote_bootstrap_gateway(self, io, extra=''): def _remote_bootstrap_gateway(self, io, extra=''):

View File

@ -27,18 +27,18 @@ if sys.version_info > (3, 0):
exec("""def do_exec(co, loc): exec("""def do_exec(co, loc):
exec(co, loc)""") exec(co, loc)""")
unicode = str unicode = str
sysex = Exception
else: else:
exec("""def do_exec(co, loc): exec("""def do_exec(co, loc):
exec co in loc""") exec co in loc""")
bytes = str bytes = str
sysex = (KeyboardInterrupt, SystemExit)
def str(*args): def str(*args):
raise EnvironmentError( raise EnvironmentError(
"use unicode or bytes, not cross-python ambigous 'str'") "use unicode or bytes, not cross-python ambigous 'str'")
default_encoding = "UTF-8" default_encoding = "UTF-8"
sysex = (KeyboardInterrupt, SystemExit)
debug = 0 # open('/tmp/execnet-debug-%d' % os.getpid() , 'w') debug = 0 # open('/tmp/execnet-debug-%d' % os.getpid() , 'w')
@ -94,13 +94,15 @@ class SocketIO:
class Popen2IO: class Popen2IO:
server_stmt = """ server_stmt = """
import os, sys, tempfile import os, sys, tempfile
#io = Popen2IO(os.fdopen(1, 'wb'), os.fdopen(0, 'rb'))
io = Popen2IO(sys.stdout, sys.stdin) io = Popen2IO(sys.stdout, sys.stdin)
sys.stdout = sys.stderr = tempfile.TemporaryFile() sys.stdout = tempfile.TemporaryFile()
sys.stdin = tempfile.TemporaryFile()
""" """
error = (IOError, OSError, EOFError) error = (IOError, OSError, EOFError)
def __init__(self, infile, outfile): def __init__(self, outfile, infile):
self.outfile, self.infile = infile, outfile self.outfile, self.infile = outfile, infile
if sys.platform == "win32": if sys.platform == "win32":
import msvcrt import msvcrt
msvcrt.setmode(infile.fileno(), os.O_BINARY) msvcrt.setmode(infile.fileno(), os.O_BINARY)
@ -110,17 +112,22 @@ sys.stdout = sys.stderr = tempfile.TemporaryFile()
def read(self, numbytes): def read(self, numbytes):
"""Read exactly 'bytes' bytes from the pipe. """ """Read exactly 'bytes' bytes from the pipe. """
data = self.infile.read(numbytes) infile = self.infile
if hasattr(infile, 'buffer'):
infile = infile.buffer
data = infile.read(numbytes)
if len(data) < numbytes: if len(data) < numbytes:
raise EOFError raise EOFError
assert isinstance(data, bytes)
return data return data
def write(self, data): def write(self, data):
"""write out all bytes to the pipe. """ """write out all bytes to the pipe. """
assert isinstance(data, bytes) assert isinstance(data, bytes)
self.outfile.write(data) outfile = self.outfile
self.outfile.flush() if hasattr(outfile, 'buffer'):
outfile = outfile.buffer
outfile.write(data)
outfile.flush()
def close_read(self): def close_read(self):
if self.readable: if self.readable:
@ -179,9 +186,6 @@ class Message:
return msg return msg
readfrom = classmethod(readfrom) readfrom = classmethod(readfrom)
def post_sent(self, gateway, excinfo=None):
pass
def __repr__(self): def __repr__(self):
r = repr(self.data) r = repr(self.data)
if len(r) > 50: if len(r) > 50:
@ -193,8 +197,6 @@ class Message:
def _setupmessages(): def _setupmessages():
# XXX use metaclass for registering
class CHANNEL_OPEN(Message): class CHANNEL_OPEN(Message):
def received(self, gateway): def received(self, gateway):
channel = gateway._channelfactory.new(self.channelid) channel = gateway._channelfactory.new(self.channelid)
@ -275,7 +277,7 @@ class Channel(object):
# after having cleared the queue we register # after having cleared the queue we register
# the callback only if the channel is not closed already. # the callback only if the channel is not closed already.
_callbacks = self.gateway._channelfactory._callbacks _callbacks = self.gateway._channelfactory._callbacks
_receivelock = self.gateway._channelfactory._receivelock _receivelock = self.gateway._receivelock
_receivelock.acquire() _receivelock.acquire()
try: try:
if self._items is None: if self._items is None:
@ -426,6 +428,7 @@ class Channel(object):
return self.receive() return self.receive()
except EOFError: except EOFError:
raise StopIteration raise StopIteration
__next__ = next
ENDMARKER = object() ENDMARKER = object()
@ -436,7 +439,6 @@ class ChannelFactory(object):
self._channels = weakref.WeakValueDictionary() self._channels = weakref.WeakValueDictionary()
self._callbacks = {} self._callbacks = {}
self._writelock = threading.Lock() self._writelock = threading.Lock()
self._receivelock = threading.RLock()
self.gateway = gateway self.gateway = gateway
self.count = startcount self.count = startcount
self.finished = False self.finished = False
@ -596,6 +598,7 @@ class BaseGateway(object):
""" """
self._io = io self._io = io
self._channelfactory = ChannelFactory(self, _startcount) self._channelfactory = ChannelFactory(self, _startcount)
self._receivelock = threading.RLock()
def _initreceive(self, requestqueue=False): def _initreceive(self, requestqueue=False):
if requestqueue: if requestqueue:
@ -605,18 +608,15 @@ class BaseGateway(object):
self._receiverthread.setDaemon(1) self._receiverthread.setDaemon(1)
self._receiverthread.start() self._receiverthread.start()
def _trace(self, *args): def _trace(self, msg):
if debug: if debug:
try: try:
l = "\n".join(args).split(os.linesep) debug.write(unicode(msg) + "\n")
id = getid(self)
for x in l:
debug.write(x+"\n")
debug.flush() debug.flush()
except sysex: except sysex:
raise raise
except: except:
traceback.print_exc() sys.stderr.write("exception during tracing\n")
def _traceex(self, excinfo): def _traceex(self, excinfo):
try: try:
@ -629,12 +629,13 @@ class BaseGateway(object):
def _thread_receiver(self): def _thread_receiver(self):
""" thread to read and handle Messages half-sync-half-async. """ """ thread to read and handle Messages half-sync-half-async. """
self._trace("starting to receive")
try: try:
while 1: while 1:
try: try:
msg = Message.readfrom(self._io) msg = Message.readfrom(self._io)
self._trace("received <- %r" % msg) self._trace("received <- %r" % msg)
_receivelock = self._channelfactory._receivelock _receivelock = self._receivelock
_receivelock.acquire() _receivelock.acquire()
try: try:
msg.received(self) msg.received(self)
@ -669,11 +670,16 @@ class BaseGateway(object):
except: except:
excinfo = self.exc_info() excinfo = self.exc_info()
self._traceex(excinfo) self._traceex(excinfo)
msg.post_sent(self, excinfo)
else: else:
msg.post_sent(self)
self._trace('sent -> %r' % msg) self._trace('sent -> %r' % msg)
def _stopsend(self):
self._send(None)
def _stopexec(self):
if self._requestqueue is not None:
self._requestqueue.put(None)
def _local_redirect_thread_output(self, outid, errid): def _local_redirect_thread_output(self, outid, errid):
l = [] l = []
for name, id in ('stdout', outid), ('stderr', errid): for name, id in ('stdout', outid), ('stderr', errid):
@ -719,14 +725,14 @@ class BaseGateway(object):
try: try:
loc = { 'channel' : channel, '__name__': '__channelexec__'} loc = { 'channel' : channel, '__name__': '__channelexec__'}
#open("task.py", 'w').write(source) #open("task.py", 'w').write(source)
self._trace("execution starts:", repr(source)[:50]) self._trace("execution starts: %s" % repr(source)[:50])
close = self._local_redirect_thread_output(outid, errid) close = self._local_redirect_thread_output(outid, errid)
try: try:
co = compile(source+'\n', '', 'exec') co = compile(source+'\n', '', 'exec')
do_exec(co, loc) do_exec(co, loc)
finally: finally:
close() close()
self._trace("execution finished:", repr(source)[:50]) self._trace("execution finished")
except sysex: except sysex:
pass pass
except self._StopExecLoop: except self._StopExecLoop:
@ -734,10 +740,10 @@ class BaseGateway(object):
raise raise
except: except:
excinfo = self.exc_info() excinfo = self.exc_info()
self._trace("got exception %s" % excinfo[1])
l = traceback.format_exception(*excinfo) l = traceback.format_exception(*excinfo)
errortext = "".join(l) errortext = "".join(l)
channel.close(errortext) channel.close(errortext)
self._trace(errortext)
else: else:
channel.close() channel.close()
@ -760,25 +766,3 @@ class BaseGateway(object):
self._trace("joining receiver thread") self._trace("joining receiver thread")
self._receiverthread.join() self._receiverthread.join()
def exit(self):
""" Try to stop all exec and IO activity. """
self._cleanup.unregister(self)
self._stopexec()
self._stopsend()
self.hook.pyexecnet_gateway_exit(gateway=self)
def _stopsend(self):
self._send(None)
def _stopexec(self):
if self._requestqueue is not None:
self._requestqueue.put(None)
def getid(gw, cache={}):
name = gw.__class__.__name__
try:
return cache.setdefault(name, {})[id(gw)]
except KeyError:
cache[name][id(gw)] = x = "%s:%s.%d" %(os.getpid(), gw.__class__.__name__, len(cache[name]))
return x

View File

@ -3,12 +3,88 @@ import os, sys, time, signal
import py import py
from py.__.execnet.gateway_base import Message, Channel, ChannelFactory from py.__.execnet.gateway_base import Message, Channel, ChannelFactory
from py.__.execnet.gateway_base import ExecnetAPI, queue, Popen2IO from py.__.execnet.gateway_base import ExecnetAPI, queue, Popen2IO
from py.__.execnet import gateway_base, gateway
from py.__.execnet.gateway import startup_modules, getsource from py.__.execnet.gateway import startup_modules, getsource
pytest_plugins = "pytester" pytest_plugins = "pytester"
TESTTIMEOUT = 10.0 # seconds TESTTIMEOUT = 10.0 # seconds
def pytest_generate_tests(metafunc):
if 'pythonpath' in metafunc.funcargnames:
for name in 'python2.4', 'python2.5', 'python2.6', 'python3.1':
metafunc.addcall(id=name, param=name)
def pytest_funcarg__pythonpath(request):
name = request.param
executable = py.path.local.sysfind(name)
if executable is None:
py.test.skip("no %s found" % (name,))
return executable
def test_io_message(pythonpath, tmpdir):
check = tmpdir.join("check.py")
check.write(py.code.Source(gateway_base, """
try:
from io import BytesIO
except ImportError:
from StringIO import StringIO as BytesIO
import tempfile
temp_out = BytesIO()
temp_in = BytesIO()
io = Popen2IO(temp_out, temp_in)
for i, msg_cls in Message._types.items():
print ("checking %s %s" %(i, msg_cls))
for data in "hello", "hello".encode('ascii'):
msg1 = msg_cls(i, data)
msg1.writeto(io)
x = io.outfile.getvalue()
io.outfile.truncate(0)
io.outfile.seek(0)
io.infile.seek(0)
io.infile.write(x)
io.infile.seek(0)
msg2 = Message.readfrom(io)
assert msg1.channelid == msg2.channelid, (msg1, msg2)
assert msg1.data == msg2.data
print ("all passed")
"""))
#out = py.process.cmdexec("%s %s" %(executable,check))
out = pythonpath.sysexec(check)
print (out)
assert "all passed" in out
def test_popen_io(pythonpath, tmpdir):
check = tmpdir.join("check.py")
check.write(py.code.Source(gateway_base, """
do_exec(Popen2IO.server_stmt, globals())
io.write("hello".encode('ascii'))
s = io.read(1)
assert s == "x".encode('ascii')
"""))
from subprocess import Popen, PIPE
args = [str(pythonpath), str(check)]
proc = Popen(args, stdin=PIPE, stdout=PIPE, stderr=PIPE)
proc.stdin.write("x".encode('ascii'))
stdout, stderr = proc.communicate()
print (stderr)
ret = proc.wait()
assert "hello".encode('ascii') in stdout
def test_rinfo_source(pythonpath, tmpdir):
check = tmpdir.join("check.py")
check.write(py.code.Source("""
class Channel:
def send(self, data):
assert eval(repr(data), {}) == data
channel = Channel()
""", gateway.rinfo_source, """
print ('all passed')
"""))
out = pythonpath.sysexec(check)
print (out)
assert "all passed" in out
class TestExecnetEvents: class TestExecnetEvents:
def test_popengateway(self, _pytest): def test_popengateway(self, _pytest):
rec = _pytest.gethookrecorder(ExecnetAPI) rec = _pytest.gethookrecorder(ExecnetAPI)
@ -112,7 +188,7 @@ class BasicRemoteExecution:
def test_correct_setup_no_py(self): def test_correct_setup_no_py(self):
channel = self.gw.remote_exec(""" channel = self.gw.remote_exec("""
import sys import sys
channel.send(sys.modules.keys()) channel.send(list(sys.modules))
""") """)
remotemodules = channel.receive() remotemodules = channel.receive()
assert 'py' not in remotemodules, ( assert 'py' not in remotemodules, (
@ -201,7 +277,7 @@ class BasicRemoteExecution:
channel.send(x) channel.send(x)
""") """)
l = list(channel) l = list(channel)
assert l == range(3) assert l == [0, 1, 2]
def test_channel_passing_over_channel(self): def test_channel_passing_over_channel(self):
channel = self.gw.remote_exec(''' channel = self.gw.remote_exec('''
@ -272,7 +348,11 @@ class BasicRemoteExecution:
# with 'earlyfree==True', this tests the "sendonly" channel state. # with 'earlyfree==True', this tests the "sendonly" channel state.
l = [] l = []
channel = self.gw.remote_exec(source=''' channel = self.gw.remote_exec(source='''
import thread, time try:
import thread
except ImportError:
import _thread as thread
import time
def producer(subchannel): def producer(subchannel):
for i in range(5): for i in range(5):
time.sleep(0.15) time.sleep(0.15)
@ -472,23 +552,6 @@ class BasicCmdbasedRemoteExecution(BasicRemoteExecution):
def test_cmdattr(self): def test_cmdattr(self):
assert hasattr(self.gw, '_cmd') assert hasattr(self.gw, '_cmd')
def test_channel_endmarker_remote_killterm():
gw = py.execnet.PopenGateway()
try:
q = queue.Queue()
channel = gw.remote_exec('''
import os
os.kill(os.getpid(), 15)
''')
channel.setcallback(q.put, endmarker=999)
val = q.get(TESTTIMEOUT)
assert val == 999
err = channel._getremoteerror()
finally:
gw.exit()
py.test.skip("provide information on causes/signals "
"of dying remote gateways")
#class TestBlockingIssues: #class TestBlockingIssues:
# def test_join_blocked_execution_gateway(self): # def test_join_blocked_execution_gateway(self):
# gateway = py.execnet.PopenGateway() # gateway = py.execnet.PopenGateway()
@ -656,3 +719,21 @@ def test_threads_twice():
def test_nodebug(): def test_nodebug():
from py.__.execnet import gateway_base from py.__.execnet import gateway_base
assert not gateway_base.debug assert not gateway_base.debug
def test_channel_endmarker_remote_killterm():
gw = py.execnet.PopenGateway()
try:
q = queue.Queue()
channel = gw.remote_exec('''
import os
os.kill(os.getpid(), 15)
''')
channel.setcallback(q.put, endmarker=999)
val = q.get(TESTTIMEOUT)
assert val == 999
err = channel._getremoteerror()
finally:
gw.exit()
py.test.skip("provide information on causes/signals "
"of dying remote gateways")

View File

@ -1,39 +0,0 @@
import py
from py.__.execnet import gateway_base
@py.test.mark.multi(ver=["2.4", "2.5", "2.6", "3.1"])
def test_io_message(ver, tmpdir):
executable = py.path.local.sysfind("python" + ver)
if executable is None:
py.test.skip("no python%s found" % (ver,))
check = tmpdir.join("check.py")
check.write(py.code.Source(gateway_base, """
try:
from io import BytesIO
except ImportError:
from StringIO import StringIO as BytesIO
import tempfile
temp_out = BytesIO()
temp_in = BytesIO()
io = Popen2IO(temp_out, temp_in)
for i, msg_cls in Message._types.items():
print ("checking %s %s" %(i, msg_cls))
for data in "hello", "hello".encode('ascii'):
msg1 = msg_cls(i, data)
msg1.writeto(io)
x = io.outfile.getvalue()
io.outfile.truncate(0)
io.outfile.seek(0)
io.infile.seek(0)
io.infile.write(x)
io.infile.seek(0)
msg2 = Message.readfrom(io)
assert msg1.channelid == msg2.channelid, (msg1, msg2)
assert msg1.data == msg2.data
print ("all passed")
"""))
#out = py.process.cmdexec("%s %s" %(executable,check))
out = executable.sysexec(check)
print (out)
assert "all passed" in out