diff --git a/django/dispatch/dispatcher.py b/django/dispatch/dispatcher.py index 1a617b5946..029c59fd3f 100644 --- a/django/dispatch/dispatcher.py +++ b/django/dispatch/dispatcher.py @@ -25,7 +25,6 @@ Internal attributes: deletion, (considerably speeds up the cleanup process vs. the original code.) """ -from __future__ import generators import types, weakref from django.dispatch import saferef, robustapply, errors @@ -33,11 +32,6 @@ __author__ = "Patrick K. O'Brien " __cvsid__ = "$Id: dispatcher.py,v 1.9 2005/09/17 04:55:57 mcfletch Exp $" __version__ = "$Revision: 1.9 $"[11:-2] -try: - True -except NameError: - True = 1==1 - False = 1==0 class _Parameter: """Used to represent default parameter values.""" @@ -140,10 +134,9 @@ def connect(receiver, signal=Any, sender=Any, weak=True): if weak: receiver = saferef.safeRef(receiver, onDelete=_removeReceiver) senderkey = id(sender) - if connections.has_key(senderkey): - signals = connections[senderkey] - else: - connections[senderkey] = signals = {} + + signals = connections.setdefault(senderkey, {}) + # Keep track of senders for cleanup. # Is Anonymous something we want to clean up? if sender not in (None, Anonymous, Any): @@ -251,10 +244,10 @@ def getReceivers( sender = Any, signal = Any ): to retrieve the actual receiver objects as an iterable object. """ - try: - return connections[id(sender)][signal] - except KeyError: - return [] + existing = connections.get(id(sender)) + if existing is not None: + return existing.get(signal, []) + return [] def liveReceivers(receivers): """Filter sequence of receivers to get resolved, live receivers @@ -278,30 +271,48 @@ def liveReceivers(receivers): def getAllReceivers( sender = Any, signal = Any ): """Get list of all receivers from global tables - This gets all receivers which should receive + This gets all dereferenced receivers which should receive the given signal from sender, each receiver should be produced only once by the resulting generator """ receivers = {} - for set in ( - # Get receivers that receive *this* signal from *this* sender. - getReceivers( sender, signal ), - # Add receivers that receive *any* signal from *this* sender. - getReceivers( sender, Any ), - # Add receivers that receive *this* signal from *any* sender. - getReceivers( Any, signal ), - # Add receivers that receive *any* signal from *any* sender. - getReceivers( Any, Any ), - ): - for receiver in set: - if receiver: # filter out dead instance-method weakrefs - try: - if not receivers.has_key( receiver ): - receivers[receiver] = 1 - yield receiver - except TypeError: - # dead weakrefs raise TypeError on hash... - pass + # Get receivers that receive *this* signal from *this* sender. + # Add receivers that receive *any* signal from *this* sender. + # Add receivers that receive *this* signal from *any* sender. + # Add receivers that receive *any* signal from *any* sender. + l = [] + i = id(sender) + if i in connections: + sender_receivers = connections[i] + if signal in sender_receivers: + l.extend(sender_receivers[signal]) + if signal is not Any and Any in sender_receivers: + l.extend(sender_receivers[Any]) + + if sender is not Any: + i = id(Any) + if i in connections: + sender_receivers = connections[i] + if sender_receivers is not None: + if signal in sender_receivers: + l.extend(sender_receivers[signal]) + if signal is not Any and Any in sender_receivers: + l.extend(sender_receivers[Any]) + + for receiver in l: + try: + if not receiver in receivers: + if isinstance(receiver, WEAKREF_TYPES): + receiver = receiver() + # this should only (rough guess) be possible if somehow, deref'ing + # triggered a wipe. + if receiver is None: + continue + receivers[receiver] = 1 + yield receiver + except TypeError: + # dead weakrefs raise TypeError on hash... + pass def send(signal=Any, sender=Anonymous, *arguments, **named): """Send signal from sender to all connected receivers. @@ -340,7 +351,7 @@ def send(signal=Any, sender=Anonymous, *arguments, **named): # Call each receiver with whatever arguments it can accept. # Return a list of tuple pairs [(receiver, response), ... ]. responses = [] - for receiver in liveReceivers(getAllReceivers(sender, signal)): + for receiver in getAllReceivers(sender, signal): response = robustapply.robustApply( receiver, signal=signal, @@ -350,6 +361,8 @@ def send(signal=Any, sender=Anonymous, *arguments, **named): ) responses.append((receiver, response)) return responses + + def sendExact( signal=Any, sender=Anonymous, *arguments, **named ): """Send signal only to those receivers registered for exact message @@ -421,33 +434,18 @@ def _cleanupConnections(senderkey, signal): def _removeSender(senderkey): """Remove senderkey from connections.""" _removeBackrefs(senderkey) - try: - del connections[senderkey] - except KeyError: - pass - # Senderkey will only be in senders dictionary if sender - # could be weakly referenced. - try: - del senders[senderkey] - except: - pass + + connections.pop(senderkey, None) + senders.pop(senderkey, None) def _removeBackrefs( senderkey): """Remove all back-references to this senderkey""" - try: - signals = connections[senderkey] - except KeyError: - signals = None - else: - items = signals.items() - def allReceivers( ): - for signal,set in items: - for item in set: - yield item - for receiver in allReceivers(): + for receiver_list in connections.pop(senderkey, {}).values(): + for receiver in receiver_list: _killBackref( receiver, senderkey ) + def _removeOldBackRefs(senderkey, signal, receiver, receivers): """Kill old sendersBack references from receiver @@ -483,13 +481,13 @@ def _removeOldBackRefs(senderkey, signal, receiver, receivers): def _killBackref( receiver, senderkey ): """Do the actual removal of back reference from receiver to senderkey""" receiverkey = id(receiver) - set = sendersBack.get( receiverkey, () ) - while senderkey in set: + receivers_list = sendersBack.get( receiverkey, () ) + while senderkey in receivers_list: try: - set.remove( senderkey ) + receivers_list.remove( senderkey ) except: break - if not set: + if not receivers_list: try: del sendersBack[ receiverkey ] except KeyError: diff --git a/tests/regressiontests/dispatch/__init__.py b/tests/regressiontests/dispatch/__init__.py new file mode 100644 index 0000000000..679895bb5c --- /dev/null +++ b/tests/regressiontests/dispatch/__init__.py @@ -0,0 +1,2 @@ +"""Unit-tests for the dispatch project +""" diff --git a/tests/regressiontests/dispatch/models.py b/tests/regressiontests/dispatch/models.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/regressiontests/dispatch/tests/__init__.py b/tests/regressiontests/dispatch/tests/__init__.py new file mode 100644 index 0000000000..0fdefe48a7 --- /dev/null +++ b/tests/regressiontests/dispatch/tests/__init__.py @@ -0,0 +1,7 @@ +""" +Unit-tests for the dispatch project +""" + +from test_dispatcher import * +from test_robustapply import * +from test_saferef import * diff --git a/tests/regressiontests/dispatch/tests/test_dispatcher.py b/tests/regressiontests/dispatch/tests/test_dispatcher.py new file mode 100644 index 0000000000..bfa3ee9a97 --- /dev/null +++ b/tests/regressiontests/dispatch/tests/test_dispatcher.py @@ -0,0 +1,147 @@ +from django.dispatch.dispatcher import * +from django.dispatch import dispatcher, robust + +import unittest, pprint, copy + +def x(a): + return a + +class Dummy( object ): + pass + + +class Callable(object): + + def __call__( self, a ): + return a + + def a( self, a ): + return a + + +class DispatcherTests(unittest.TestCase): + """Test suite for dispatcher (barely started)""" + + def setUp(self): + # track the initial state, since it's possible that others have bleed receivers in + self.sendersBack = copy.copy(dispatcher.sendersBack) + self.connections = copy.copy(dispatcher.connections) + self.senders = copy.copy(dispatcher.senders) + + def _isclean( self ): + """Assert that everything has been cleaned up automatically""" + self.assertEqual(dispatcher.sendersBack, self.sendersBack) + self.assertEqual(dispatcher.connections, self.connections) + self.assertEqual(dispatcher.senders, self.senders) + + def testExact (self): + a = Dummy() + signal = 'this' + connect( x, signal, a ) + expected = [(x,a)] + result = send('this',a, a=a) + self.assertEqual(result, expected) + disconnect( x, signal, a ) + self.assertEqual(len(list(getAllReceivers(a,signal))), 0) + self._isclean() + + def testAnonymousSend(self): + a = Dummy() + signal = 'this' + connect( x, signal ) + expected = [(x,a)] + result = send(signal,None, a=a) + assert result == expected,"""Send didn't return expected result:\n\texpected:%s\n\tgot:%s"""% (expected, result) + disconnect( x, signal ) + assert len(list(getAllReceivers(None,signal))) == 0 + self._isclean() + + def testAnyRegistration(self): + a = Dummy() + signal = 'this' + connect( x, signal, Any ) + expected = [(x,a)] + result = send('this',object(), a=a) + assert result == expected,"""Send didn't return expected result:\n\texpected:%s\n\tgot:%s"""% (expected, result) + disconnect( x, signal, Any ) + expected = [] + result = send('this',object(), a=a) + assert result == expected,"""Send didn't return expected result:\n\texpected:%s\n\tgot:%s"""% (expected, result) + assert len(list(getAllReceivers(Any,signal))) == 0 + + self._isclean() + + def testAnyRegistration2(self): + a = Dummy() + signal = 'this' + connect( x, Any, a ) + expected = [(x,a)] + result = send('this',a, a=a) + assert result == expected,"""Send didn't return expected result:\n\texpected:%s\n\tgot:%s"""% (expected, result) + disconnect( x, Any, a ) + assert len(list(getAllReceivers(a,Any))) == 0 + self._isclean() + + def testGarbageCollected(self): + a = Callable() + b = Dummy() + signal = 'this' + connect( a.a, signal, b ) + expected = [] + del a + result = send('this',b, a=b) + assert result == expected,"""Send didn't return expected result:\n\texpected:%s\n\tgot:%s"""% (expected, result) + assert len(list(getAllReceivers(b,signal))) == 0, """Remaining handlers: %s"""%(getAllReceivers(b,signal),) + self._isclean() + + def testGarbageCollectedObj(self): + class x: + def __call__( self, a ): + return a + a = Callable() + b = Dummy() + signal = 'this' + connect( a, signal, b ) + expected = [] + del a + result = send('this',b, a=b) + self.assertEqual(result, expected) + self.assertEqual(len(list(getAllReceivers(b,signal))), 0) + self._isclean() + + + def testMultipleRegistration(self): + a = Callable() + b = Dummy() + signal = 'this' + connect( a, signal, b ) + connect( a, signal, b ) + connect( a, signal, b ) + connect( a, signal, b ) + connect( a, signal, b ) + connect( a, signal, b ) + result = send('this',b, a=b) + assert len( result ) == 1, result + assert len(list(getAllReceivers(b,signal))) == 1, """Remaining handlers: %s"""%(getAllReceivers(b,signal),) + del a + del b + del result + self._isclean() + + def testRobust( self ): + """Test the sendRobust function""" + def fails( ): + raise ValueError( 'this' ) + a = object() + signal = 'this' + connect( fails, Any, a ) + result = robust.sendRobust('this',a, a=a) + err = result[0][1] + assert isinstance( err, ValueError ) + assert err.args == ('this',) + +def getSuite(): + return unittest.makeSuite(DispatcherTests,'test') + +if __name__ == "__main__": + unittest.main () diff --git a/tests/regressiontests/dispatch/tests/test_robustapply.py b/tests/regressiontests/dispatch/tests/test_robustapply.py new file mode 100644 index 0000000000..a70d968c6b --- /dev/null +++ b/tests/regressiontests/dispatch/tests/test_robustapply.py @@ -0,0 +1,27 @@ +from django.dispatch.robustapply import * + +import unittest +def noArgument(): + pass +def oneArgument (blah): + pass +def twoArgument(blah, other): + pass +class TestCases( unittest.TestCase ): + def test01( self ): + robustApply(noArgument ) + def test02( self ): + self.assertRaises( TypeError, robustApply, noArgument, "this" ) + def test03( self ): + self.assertRaises( TypeError, robustApply, oneArgument ) + def test04( self ): + """Raise error on duplication of a particular argument""" + self.assertRaises( TypeError, robustApply, oneArgument, "this", blah = "that" ) + +def getSuite(): + return unittest.makeSuite(TestCases,'test') + + +if __name__ == "__main__": + unittest.main() + diff --git a/tests/regressiontests/dispatch/tests/test_saferef.py b/tests/regressiontests/dispatch/tests/test_saferef.py new file mode 100644 index 0000000000..eac7d1e5a1 --- /dev/null +++ b/tests/regressiontests/dispatch/tests/test_saferef.py @@ -0,0 +1,76 @@ +from django.dispatch.saferef import * + +import unittest +class Test1( object): + def x( self ): + pass +def test2(obj): + pass +class Test2( object ): + def __call__( self, obj ): + pass +class Tester (unittest.TestCase): + def setUp (self): + ts = [] + ss = [] + for x in xrange( 5000 ): + t = Test1() + ts.append( t) + s = safeRef(t.x, self._closure ) + ss.append( s) + ts.append( test2 ) + ss.append( safeRef(test2, self._closure) ) + for x in xrange( 30 ): + t = Test2() + ts.append( t) + s = safeRef(t, self._closure ) + ss.append( s) + self.ts = ts + self.ss = ss + self.closureCount = 0 + def tearDown( self ): + del self.ts + del self.ss + def testIn(self): + """Test the "in" operator for safe references (cmp)""" + for t in self.ts[:50]: + assert safeRef(t.x) in self.ss + def testValid(self): + """Test that the references are valid (return instance methods)""" + for s in self.ss: + assert s() + def testShortCircuit (self): + """Test that creation short-circuits to reuse existing references""" + sd = {} + for s in self.ss: + sd[s] = 1 + for t in self.ts: + if hasattr( t, 'x'): + assert sd.has_key( safeRef(t.x)) + else: + assert sd.has_key( safeRef(t)) + def testRepresentation (self): + """Test that the reference object's representation works + + XXX Doesn't currently check the results, just that no error + is raised + """ + repr( self.ss[-1] ) + + def test(self): + self.closureCount = 0 + wholeI = len(self.ts) + for i in xrange( len(self.ts)-1, -1, -1): + del self.ts[i] + if wholeI-i != self.closureCount: + """Unexpected number of items closed, expected %s, got %s closed"""%( wholeI-i,self.closureCount) + + def _closure(self, ref): + """Dumb utility mechanism to increment deletion counter""" + self.closureCount +=1 + +def getSuite(): + return unittest.makeSuite(Tester,'test') + +if __name__ == "__main__": + unittest.main ()