* simplify lock acquiration for received messages, review code

* try to fix seldomly occuring race condition with setcallback/receive and closing of channel

--HG--
branch : trunk
This commit is contained in:
holger krekel 2009-09-02 15:45:59 +02:00
parent f636ed8ced
commit 5d2504df0a
1 changed files with 38 additions and 47 deletions

View File

@ -223,7 +223,7 @@ def _setupmessages():
class CHANNEL_LAST_MESSAGE(Message): class CHANNEL_LAST_MESSAGE(Message):
def received(self, gateway): def received(self, gateway):
gateway._channelfactory._local_last_message(self.channelid) gateway._channelfactory._local_close(self.channelid, sendonly=True)
classes = [CHANNEL_OPEN, CHANNEL_NEW, CHANNEL_DATA, classes = [CHANNEL_OPEN, CHANNEL_NEW, CHANNEL_DATA,
CHANNEL_CLOSE, CHANNEL_CLOSE_ERROR, CHANNEL_LAST_MESSAGE] CHANNEL_CLOSE, CHANNEL_CLOSE_ERROR, CHANNEL_LAST_MESSAGE]
@ -269,31 +269,36 @@ class Channel(object):
self._remoteerrors = [] self._remoteerrors = []
def setcallback(self, callback, endmarker=NO_ENDMARKER_WANTED): def setcallback(self, callback, endmarker=NO_ENDMARKER_WANTED):
items = self._items # we first execute the callback on all already received
lock = self.gateway._channelfactory._receivelock # items. We need to hold the receivelock to prevent
lock.acquire() # race conditions with newly arriving items.
try: # after having cleared the queue we register
# the callback only if the channel is not closed already.
_callbacks = self.gateway._channelfactory._callbacks _callbacks = self.gateway._channelfactory._callbacks
dictvalue = (callback, endmarker) _receivelock = self.gateway._channelfactory._receivelock
if _callbacks.setdefault(self.id, dictvalue) != dictvalue: _receivelock.acquire()
try:
if self._items is None:
raise IOError("%r has callback already registered" %(self,)) raise IOError("%r has callback already registered" %(self,))
items = self._items
self._items = None self._items = None
while 1: while 1:
try: try:
olditem = items.get(block=False) olditem = items.get(block=False)
except queue.Empty: except queue.Empty:
if not (self._closed or self._receiveclosed.isSet()):
_callbacks[self.id] = (callback, endmarker)
break break
else: else:
if olditem is ENDMARKER: if olditem is ENDMARKER:
items.put(olditem) items.put(olditem) # for other receivers
if endmarker is not NO_ENDMARKER_WANTED:
callback(endmarker)
break break
else: else:
callback(olditem) callback(olditem)
if self._closed or self._receiveclosed.isSet():
# no need to keep a callback
self.gateway._channelfactory._close_callback(self.id)
finally: finally:
lock.release() _receivelock.release()
def __repr__(self): def __repr__(self):
flag = self.isclosed() and "closed" or "open" flag = self.isclosed() and "closed" or "open"
@ -462,9 +467,6 @@ class ChannelFactory(object):
del self._channels[id] del self._channels[id]
except KeyError: except KeyError:
pass pass
self._close_callback(id)
def _close_callback(self, id):
try: try:
callback, endmarker = self._callbacks.pop(id) callback, endmarker = self._callbacks.pop(id)
except KeyError: except KeyError:
@ -473,7 +475,7 @@ class ChannelFactory(object):
if endmarker is not NO_ENDMARKER_WANTED: if endmarker is not NO_ENDMARKER_WANTED:
callback(endmarker) callback(endmarker)
def _local_close(self, id, remoteerror=None): def _local_close(self, id, remoteerror=None, sendonly=False):
channel = self._channels.get(id) channel = self._channels.get(id)
if channel is None: if channel is None:
# channel already in "deleted" state # channel already in "deleted" state
@ -483,6 +485,7 @@ class ChannelFactory(object):
# state transition to "closed" state # state transition to "closed" state
if remoteerror: if remoteerror:
channel._remoteerrors.append(remoteerror) channel._remoteerrors.append(remoteerror)
if not sendonly: # otherwise #--> "sendonly"
channel._closed = True # --> "closed" channel._closed = True # --> "closed"
channel._receiveclosed.set() channel._receiveclosed.set()
queue = channel._items queue = channel._items
@ -490,23 +493,8 @@ class ChannelFactory(object):
queue.put(ENDMARKER) queue.put(ENDMARKER)
self._no_longer_opened(id) self._no_longer_opened(id)
def _local_last_message(self, id):
channel = self._channels.get(id)
if channel is None:
# channel already in "deleted" state
pass
else:
# state transition: if "opened", change to "sendonly"
channel._receiveclosed.set()
queue = channel._items
if queue is not None:
queue.put(ENDMARKER)
self._no_longer_opened(id)
def _local_receive(self, id, data): def _local_receive(self, id, data):
# executes in receiver thread # executes in receiver thread
self._receivelock.acquire()
try:
try: try:
callback, endmarker = self._callbacks[id] callback, endmarker = self._callbacks[id]
except KeyError: except KeyError:
@ -518,8 +506,6 @@ class ChannelFactory(object):
queue.put(data) queue.put(data)
else: else:
callback(data) # even if channel may be already closed callback(data) # even if channel may be already closed
finally:
self._receivelock.release()
def _finished_receiving(self): def _finished_receiving(self):
self._writelock.acquire() self._writelock.acquire()
@ -528,9 +514,9 @@ class ChannelFactory(object):
finally: finally:
self._writelock.release() self._writelock.release()
for id in list(self._channels): for id in list(self._channels):
self._local_last_message(id) self._local_close(id, sendonly=True)
for id in list(self._callbacks): for id in list(self._callbacks):
self._close_callback(id) self._no_longer_opened(id)
class ChannelFile(object): class ChannelFile(object):
def __init__(self, channel, proxyclose=True): def __init__(self, channel, proxyclose=True):
@ -648,7 +634,12 @@ class BaseGateway(object):
try: try:
msg = Message.readfrom(self._io) msg = Message.readfrom(self._io)
self._trace("received <- %r" % msg) self._trace("received <- %r" % msg)
_receivelock = self._channelfactory._receivelock
_receivelock.acquire()
try:
msg.received(self) msg.received(self)
finally:
_receivelock.release()
except sysex: except sysex:
break break
except EOFError: except EOFError:
@ -736,7 +727,7 @@ class BaseGateway(object):
finally: finally:
close() close()
self._trace("execution finished:", repr(source)[:50]) self._trace("execution finished:", repr(source)[:50])
except (KeyboardInterrupt, SystemExit): except sysex:
pass pass
except self._StopExecLoop: except self._StopExecLoop:
channel.close() channel.close()