From a5e69d2035dc6381b9453e8b8e5761925e1a40de Mon Sep 17 00:00:00 2001 From: hpk Date: Tue, 7 Aug 2007 19:34:59 +0200 Subject: [PATCH] [svn r45539] merge the execnet lessthreads branch (using the branch'es history): * now by default Gateways DO NOT SPAWN execution threads you can call "remote_init_threads(NUM)" on an already instantiated gateway, which will install a loop on the other side which will dispatch each execution task to its own thread. * execution is dissallowed on the side which initiates a gateway (rarely used, anyway) * some cleanups (hopefully) --HG-- branch : trunk --- py/execnet/gateway.py | 181 +++++++++++++++++------------ py/execnet/inputoutput.py | 10 +- py/execnet/register.py | 5 +- py/execnet/testing/test_gateway.py | 35 +++++- 4 files changed, 152 insertions(+), 79 deletions(-) diff --git a/py/execnet/gateway.py b/py/execnet/gateway.py index a85b3df3d..0c1c757af 100644 --- a/py/execnet/gateway.py +++ b/py/execnet/gateway.py @@ -22,33 +22,38 @@ if 'ThreadOut' not in globals(): from py.__.execnet.channel import ChannelFactory, Channel from py.__.execnet.message import Message ThreadOut = py._thread.ThreadOut - WorkerPool = py._thread.WorkerPool - NamedThreadPool = py._thread.NamedThreadPool import os -debug = 0 # open('/tmp/execnet-debug-%d' % os.getpid() , 'wa') +debug = open('/tmp/execnet-debug-%d' % os.getpid() , 'wa') sysex = (KeyboardInterrupt, SystemExit) +class StopExecLoop(Exception): + pass class Gateway(object): _ThreadOut = ThreadOut remoteaddress = "" - def __init__(self, io, execthreads=None, _startcount=2): + _requestqueue = None + + def __init__(self, io, _startcount=2): """ initialize core gateway, using the given - inputoutput object and 'execthreads' execution - threads. + inputoutput object. """ - global registered_cleanup - self._execpool = WorkerPool(maxthreads=execthreads) + global registered_cleanup, _activegateways self._io = io - self._outgoing = Queue.Queue() self._channelfactory = ChannelFactory(self, _startcount) if not registered_cleanup: atexit.register(cleanup_atexit) registered_cleanup = True - _active_sendqueues[self._outgoing] = True - self._pool = NamedThreadPool(receiver = self._thread_receiver, - sender = self._thread_sender) + _activegateways[self] = True + + def _initreceive(self, requestqueue=False): + if requestqueue: + self._requestqueue = Queue.Queue() + self._receiverthread = threading.Thread(name="receiver", + target=self._thread_receiver) + self._receiverthread.setDaemon(0) + self._receiverthread.start() def __repr__(self): """ return string representing gateway type and status. """ @@ -58,10 +63,9 @@ class Gateway(object): else: addr = '' try: - r = (len(self._pool.getstarted('receiver')) - and "receiving" or "not receiving") - s = (len(self._pool.getstarted('sender')) - and "sending" or "not sending") + r = (self._receiverthread.isAlive() and "receiving" or + "not receiving") + s = "sending" # XXX i = len(self._channelfactory.channels()) except AttributeError: r = s = "uninitialized" @@ -69,9 +73,6 @@ class Gateway(object): return "<%s%s %s/%s (%s active channels)>" %( self.__class__.__name__, addr, r, s, i) -## def _local_trystopexec(self): -## self._execpool.shutdown() - def _trace(self, *args): if debug: try: @@ -111,35 +112,25 @@ class Gateway(object): self._traceex(exc_info()) break finally: - self._send(None) + self._stopexec() + self._stopsend() self._channelfactory._finished_receiving() self._trace('leaving %r' % threading.currentThread()) def _send(self, msg): - self._outgoing.put(msg) - - def _thread_sender(self): - """ thread to send Messages over the wire. """ - try: - from sys import exc_info - while 1: - msg = self._outgoing.get() - try: - if msg is None: - self._io.close_write() - break - msg.writeto(self._io) - except: - excinfo = exc_info() - self._traceex(excinfo) - if msg is not None: - msg.post_sent(self, excinfo) - break - else: - self._trace('sent -> %r' % msg) - msg.post_sent(self) - finally: - self._trace('leaving %r' % threading.currentThread()) + from sys import exc_info + if msg is None: + self._io.close_write() + else: + try: + msg.writeto(self._io) + except: + excinfo = exc_info() + self._traceex(excinfo) + msg.post_sent(self, excinfo) + else: + msg.post_sent(self) + self._trace('sent -> %r' % msg) def _local_redirect_thread_output(self, outid, errid): l = [] @@ -155,9 +146,58 @@ class Gateway(object): channel.close() return close - def _thread_executor(self, channel, (source, outid, errid)): - """ worker thread to execute source objects from the execution queue. """ + def _local_schedulexec(self, channel, sourcetask): + if self._requestqueue is not None: + self._requestqueue.put((channel, sourcetask)) + else: + # we will not execute, let's send back an error + # to inform the other side + channel.close("execution disallowed") + + def _servemain(self, joining=True): from sys import exc_info + self._initreceive(requestqueue=True) + try: + while 1: + item = self._requestqueue.get() + if item is None: + self._stopsend() + break + try: + self._executetask(item) + except StopExecLoop: + break + finally: + self._trace("_servemain finished") + if self.joining: + self.join() + + def remote_init_threads(self, num=None): + """ start up to 'num' threads for subsequent + remote_exec() invocations to allow concurrent + execution. + """ + if hasattr(self, '_remotechannelthread'): + raise IOError("remote threads already running") + from py.__.thread import pool + source = py.code.Source(pool, """ + execpool = WorkerPool(maxthreads=%r) + gw = channel.gateway + while 1: + task = gw._requestqueue.get() + if task is None: + gw._stopsend() + execpool.shutdown() + execpool.join() + raise StopExecLoop + execpool.dispatch(gw._executetask, task) + """ % num) + self._remotechannelthread = self.remote_exec(source) + + def _executetask(self, item): + """ execute channel/source items. """ + from sys import exc_info + channel, (source, outid, errid) = item try: loc = { 'channel' : channel } self._trace("execution starts:", repr(source)[:50]) @@ -171,6 +211,9 @@ class Gateway(object): self._trace("execution finished:", repr(source)[:50]) except (KeyboardInterrupt, SystemExit): pass + except StopExecLoop: + channel.close() + raise except: excinfo = exc_info() l = traceback.format_exception(*excinfo) @@ -180,10 +223,6 @@ class Gateway(object): else: channel.close() - def _local_schedulexec(self, channel, sourcetask): - self._trace("dispatching exec") - self._execpool.dispatch(self._thread_executor, channel, sourcetask) - def _newredirectchannelid(self, callback): if callback is None: return @@ -257,27 +296,25 @@ class Gateway(object): return Handle() def exit(self): - """ Try to stop all IO activity. """ - try: - del _active_sendqueues[self._outgoing] - except KeyError: - pass - else: - self._send(None) + """ Try to stop all exec and IO activity. """ + self._stopexec() + self._stopsend() + + def _stopsend(self): + self._send(None) + + def _stopexec(self): + if self._requestqueue is not None: + self._requestqueue.put(None) def join(self, joinexec=True): """ Wait for all IO (and by default all execution activity) - to stop. + to stop. the joinexec parameter is obsolete. """ current = threading.currentThread() - for x in self._pool.getstarted(): - if x != current: - self._trace("joining %s" % x) - x.join() - self._trace("joining sender/reciver threads finished, current %r" % current) - if joinexec: - self._execpool.join() - self._trace("joining execution threads finished, current %r" % current) + if self._receiverthread.isAlive(): + self._trace("joining receiver thread") + self._receiverthread.join() def getid(gw, cache={}): name = gw.__class__.__name__ @@ -288,14 +325,12 @@ def getid(gw, cache={}): return x registered_cleanup = False -_active_sendqueues = weakref.WeakKeyDictionary() +_activegateways = weakref.WeakKeyDictionary() def cleanup_atexit(): if debug: print >>debug, "="*20 + "cleaning up" + "=" * 20 debug.flush() - while True: - try: - queue, ignored = _active_sendqueues.popitem() - except KeyError: - break - queue.put(None) + while _activegateways: + gw, ignored = _activegateways.popitem() + gw.exit() + #gw.join() should work as well diff --git a/py/execnet/inputoutput.py b/py/execnet/inputoutput.py index 4d188fa68..23facc6b0 100644 --- a/py/execnet/inputoutput.py +++ b/py/execnet/inputoutput.py @@ -43,11 +43,17 @@ import sys def close_read(self): if self.readable: - self.sock.shutdown(0) + try: + self.sock.shutdown(0) + except socket.error: + pass self.readable = None def close_write(self): if self.writeable: - self.sock.shutdown(1) + try: + self.sock.shutdown(1) + except socket.error: + pass self.writeable = None class Popen2IO: diff --git a/py/execnet/register.py b/py/execnet/register.py index 4bc382688..7422366d4 100644 --- a/py/execnet/register.py +++ b/py/execnet/register.py @@ -11,7 +11,6 @@ import py startup_modules = [ 'py.__.thread.io', - 'py.__.thread.pool', 'py.__.execnet.inputoutput', 'py.__.execnet.gateway', 'py.__.execnet.message', @@ -29,6 +28,8 @@ class InstallableGateway(gateway.Gateway): def __init__(self, io): self._remote_bootstrap_gateway(io) super(InstallableGateway, self).__init__(io=io, _startcount=1) + # XXX we dissallow execution form the other side + self._initreceive(requestqueue=False) def _remote_bootstrap_gateway(self, io, extra=''): """ return Gateway with a asynchronously remotely @@ -41,7 +42,7 @@ class InstallableGateway(gateway.Gateway): bootstrap = [extra] bootstrap += [getsource(x) for x in startup_modules] bootstrap += [io.server_stmt, - "Gateway(io=io, _startcount=2).join(joinexec=False)", + "Gateway(io=io, _startcount=2)._servemain()", ] source = "\n".join(bootstrap) self._trace("sending gateway bootstrap code") diff --git a/py/execnet/testing/test_gateway.py b/py/execnet/testing/test_gateway.py index 682ca65cc..bec53771e 100644 --- a/py/execnet/testing/test_gateway.py +++ b/py/execnet/testing/test_gateway.py @@ -83,8 +83,7 @@ class PopenGatewayTestSetup: class BasicRemoteExecution: def test_correct_setup(self): - for x in 'sender', 'receiver': - assert self.gw._pool.getstarted(x) + assert self.gw._receiverthread.isAlive() def test_repr_doesnt_crash(self): assert isinstance(repr(self), str) @@ -373,6 +372,18 @@ class BasicRemoteExecution: res = channel.receive() assert res == 42 + def test_non_reverse_execution(self): + gw = self.gw + c1 = gw.remote_exec(""" + c = channel.gateway.remote_exec("pass") + try: + c.waitclose() + except c.RemoteError, e: + channel.send(str(e)) + """) + text = c1.receive() + assert text.find("execution disallowed") != -1 + #class TestBlockingIssues: # def test_join_blocked_execution_gateway(self): # gateway = py.execnet.PopenGateway() @@ -486,3 +497,23 @@ class TestSshGateway(BasicRemoteExecution): # now it did py.test.raises(IOError, gw.remote_exec, "...") +def test_threads(): + gw = py.execnet.PopenGateway() + gw.remote_init_threads(3) + c1 = gw.remote_exec("channel.send(channel.receive())") + c2 = gw.remote_exec("channel.send(channel.receive())") + c2.send(1) + res = c2.receive() + assert res == 1 + c1.send(42) + res = c1.receive() + assert res == 42 + gw.exit() + +def test_threads_twice(): + gw = py.execnet.PopenGateway() + gw.remote_init_threads(3) + py.test.raises(IOError, gw.remote_init_threads, 3) + gw.exit() + +