import os import threading import Queue import traceback import atexit import weakref import __future__ # note that the whole code of this module (as well as some # other modules) execute not only on the local side but # also on any gateway's remote side. On such remote sides # we cannot assume the py library to be there and # InstallableGateway._remote_bootstrap_gateway() (located # in register.py) will take care to send source fragments # to the other side. Yes, it is fragile but we have a # few tests that try to catch when we mess up. # XXX the following lines should not be here if 'ThreadOut' not in globals(): import py from py.code import Source from py.__.execnet.channel import ChannelFactory, Channel from py.__.execnet.message import Message ThreadOut = py._thread.ThreadOut WorkerPool = py._thread.WorkerPool NamedThreadPool = py._thread.NamedThreadPool import os debug = 0 # open('/tmp/execnet-debug-%d' % os.getpid() , 'wa') sysex = (KeyboardInterrupt, SystemExit) class Gateway(object): _ThreadOut = ThreadOut remoteaddress = "" def __init__(self, io, startcount=2, maxthreads=None): global registered_cleanup self._execpool = WorkerPool(maxthreads=maxthreads) ## self.running = True self._io = io self._outgoing = Queue.Queue() self._channelfactory = ChannelFactory(self, startcount) ## self._exitlock = threading.Lock() if not registered_cleanup: atexit.register(cleanup_atexit) registered_cleanup = True _active_sendqueues[self._outgoing] = True self._pool = NamedThreadPool(receiver = self._thread_receiver, sender = self._thread_sender) def __repr__(self): addr = self.remoteaddress if addr: addr = '[%s]' % (addr,) else: addr = '' try: r = (len(self._pool.getstarted('receiver')) and "receiving" or "not receiving") s = (len(self._pool.getstarted('sender')) and "sending" or "not sending") i = len(self._channelfactory.channels()) except AttributeError: r = s = "uninitialized" i = "no" return "<%s%s %s/%s (%s active channels)>" %( self.__class__.__name__, addr, r, s, i) ## def _local_trystopexec(self): ## self._execpool.shutdown() def _trace(self, *args): if debug: try: l = "\n".join(args).split(os.linesep) id = getid(self) for x in l: print >>debug, x debug.flush() except sysex: raise except: traceback.print_exc() def _traceex(self, excinfo): try: l = traceback.format_exception(*excinfo) errortext = "".join(l) except: errortext = '%s: %s' % (excinfo[0].__name__, excinfo[1]) self._trace(errortext) def _thread_receiver(self): """ thread to read and handle Messages half-sync-half-async. """ try: from sys import exc_info while 1: try: msg = Message.readfrom(self._io) self._trace("received <- %r" % msg) msg.received(self) except sysex: raise except EOFError: break except: self._traceex(exc_info()) break finally: self._outgoing.put(None) self._channelfactory._finished_receiving() self._trace('leaving %r' % threading.currentThread()) def _thread_sender(self): """ thread to send Messages over the wire. """ try: from sys import exc_info while 1: msg = self._outgoing.get() try: if msg is None: self._io.close_write() break msg.writeto(self._io) except: excinfo = exc_info() self._traceex(excinfo) if msg is not None: msg.post_sent(self, excinfo) raise else: self._trace('sent -> %r' % msg) msg.post_sent(self) finally: self._trace('leaving %r' % threading.currentThread()) def _local_redirect_thread_output(self, outid, errid): l = [] for name, id in ('stdout', outid), ('stderr', errid): if id: channel = self._channelfactory.new(outid) out = self._ThreadOut(sys, name) out.setwritefunc(channel.send) l.append((out, channel)) def close(): for out, channel in l: out.delwritefunc() channel.close() return close def _thread_executor(self, channel, (source, outid, errid)): """ worker thread to execute source objects from the execution queue. """ from sys import exc_info try: loc = { 'channel' : channel } self._trace("execution starts:", repr(source)[:50]) close = self._local_redirect_thread_output(outid, errid) try: co = compile(source+'\n', '', 'exec', __future__.CO_GENERATOR_ALLOWED) exec co in loc finally: close() self._trace("execution finished:", repr(source)[:50]) except (KeyboardInterrupt, SystemExit): raise except: excinfo = exc_info() l = traceback.format_exception(*excinfo) errortext = "".join(l) channel.close(errortext) self._trace(errortext) else: channel.close() def _local_schedulexec(self, channel, sourcetask): self._trace("dispatching exec") self._execpool.dispatch(self._thread_executor, channel, sourcetask) def _newredirectchannelid(self, callback): if callback is None: return if hasattr(callback, 'write'): callback = callback.write assert callable(callback) chan = self.newchannel() chan.setcallback(callback) return chan.id # _____________________________________________________________________ # # High Level Interface # _____________________________________________________________________ # def newchannel(self): """ return new channel object. """ return self._channelfactory.new() def remote_exec(self, source, stdout=None, stderr=None): """ return channel object for communicating with the asynchronously executing 'source' code which will have a corresponding 'channel' object in its executing namespace. You may provide callback functions 'stdout' and 'stderr' which will get called with the remote stdout/stderr output piece by piece. """ try: source = str(Source(source)) except NameError: try: import py source = str(py.code.Source(source)) except ImportError: pass channel = self.newchannel() outid = self._newredirectchannelid(stdout) errid = self._newredirectchannelid(stderr) self._outgoing.put(Message.CHANNEL_OPEN(channel.id, (source, outid, errid))) return channel def _remote_redirect(self, stdout=None, stderr=None): """ return a handle representing a redirection of a remote end's stdout to a local file object. with handle.close() the redirection will be reverted. """ clist = [] for name, out in ('stdout', stdout), ('stderr', stderr): if out: outchannel = self.newchannel() outchannel.setcallback(getattr(out, 'write', out)) channel = self.remote_exec(""" import sys outchannel = channel.receive() outchannel.gateway._ThreadOut(sys, %r).setdefaultwriter(outchannel.send) """ % name) channel.send(outchannel) clist.append(channel) for c in clist: c.waitclose(1.0) class Handle: def close(_): for name, out in ('stdout', stdout), ('stderr', stderr): if out: c = self.remote_exec(""" import sys channel.gateway._ThreadOut(sys, %r).resetdefault() """ % name) c.waitclose(1.0) return Handle() ## def exit(self): ## """ initiate full gateway teardown. ## Note that the teardown of sender/receiver threads happens ## asynchronously and timeouts on stopping worker execution ## threads are ignored. You can issue join() or join(joinexec=False) ## if you want to wait for a full teardown (possibly excluding ## execution threads). ## """ ## # note that threads may still be scheduled to start ## # during our execution! ## self._exitlock.acquire() ## try: ## if self.running: ## self.running = False ## if not self._pool.getstarted('sender'): ## raise IOError("sender thread not alive anymore!") ## self._outgoing.put(None) ## self._trace("exit procedure triggered, pid %d " % (os.getpid(),)) ## _gateways.remove(self) ## finally: ## self._exitlock.release() def exit(self): try: del _active_sendqueues[self._outgoing] except KeyError: pass else: self._outgoing.put(None) def join(self, joinexec=True): current = threading.currentThread() for x in self._pool.getstarted(): if x != current: self._trace("joining %s" % x) x.join() self._trace("joining sender/reciver threads finished, current %r" % current) if joinexec: self._execpool.join() self._trace("joining execution threads finished, current %r" % current) def getid(gw, cache={}): name = gw.__class__.__name__ try: return cache.setdefault(name, {})[id(gw)] except KeyError: cache[name][id(gw)] = x = "%s:%s.%d" %(os.getpid(), gw.__class__.__name__, len(cache[name])) return x registered_cleanup = False _active_sendqueues = weakref.WeakKeyDictionary() def cleanup_atexit(): if debug: print >>debug, "="*20 + "cleaning up" + "=" * 20 debug.flush() while True: try: queue, ignored = _active_sendqueues.popitem() except KeyError: break queue.put(None)