Fixed #3439: vastly improved the performance of django.dispatch (and added tests!). Thanks to Brian Harring for the patch. Note that one of the new tests fails under sqlite currently; it's not clear if this is a sqlite problem or a problem with the tests, but it appears not to be a problem with the dispatcher itself.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@4588 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Jacob Kaplan-Moss 2007-02-26 03:17:04 +00:00
parent 6871a455e2
commit 357e26baf6
7 changed files with 317 additions and 60 deletions

View File

@ -25,7 +25,6 @@ Internal attributes:
deletion, (considerably speeds up the cleanup process deletion, (considerably speeds up the cleanup process
vs. the original code.) vs. the original code.)
""" """
from __future__ import generators
import types, weakref import types, weakref
from django.dispatch import saferef, robustapply, errors from django.dispatch import saferef, robustapply, errors
@ -33,11 +32,6 @@ __author__ = "Patrick K. O'Brien <pobrien@orbtech.com>"
__cvsid__ = "$Id: dispatcher.py,v 1.9 2005/09/17 04:55:57 mcfletch Exp $" __cvsid__ = "$Id: dispatcher.py,v 1.9 2005/09/17 04:55:57 mcfletch Exp $"
__version__ = "$Revision: 1.9 $"[11:-2] __version__ = "$Revision: 1.9 $"[11:-2]
try:
True
except NameError:
True = 1==1
False = 1==0
class _Parameter: class _Parameter:
"""Used to represent default parameter values.""" """Used to represent default parameter values."""
@ -140,10 +134,9 @@ def connect(receiver, signal=Any, sender=Any, weak=True):
if weak: if weak:
receiver = saferef.safeRef(receiver, onDelete=_removeReceiver) receiver = saferef.safeRef(receiver, onDelete=_removeReceiver)
senderkey = id(sender) senderkey = id(sender)
if connections.has_key(senderkey):
signals = connections[senderkey] signals = connections.setdefault(senderkey, {})
else:
connections[senderkey] = signals = {}
# Keep track of senders for cleanup. # Keep track of senders for cleanup.
# Is Anonymous something we want to clean up? # Is Anonymous something we want to clean up?
if sender not in (None, Anonymous, Any): if sender not in (None, Anonymous, Any):
@ -251,10 +244,10 @@ def getReceivers( sender = Any, signal = Any ):
to retrieve the actual receiver objects as an iterable to retrieve the actual receiver objects as an iterable
object. object.
""" """
try: existing = connections.get(id(sender))
return connections[id(sender)][signal] if existing is not None:
except KeyError: return existing.get(signal, [])
return [] return []
def liveReceivers(receivers): def liveReceivers(receivers):
"""Filter sequence of receivers to get resolved, live receivers """Filter sequence of receivers to get resolved, live receivers
@ -278,30 +271,48 @@ def liveReceivers(receivers):
def getAllReceivers( sender = Any, signal = Any ): def getAllReceivers( sender = Any, signal = Any ):
"""Get list of all receivers from global tables """Get list of all receivers from global tables
This gets all receivers which should receive This gets all dereferenced receivers which should receive
the given signal from sender, each receiver should the given signal from sender, each receiver should
be produced only once by the resulting generator be produced only once by the resulting generator
""" """
receivers = {} receivers = {}
for set in ( # Get receivers that receive *this* signal from *this* sender.
# Get receivers that receive *this* signal from *this* sender. # Add receivers that receive *any* signal from *this* sender.
getReceivers( sender, signal ), # Add receivers that receive *this* signal from *any* sender.
# Add receivers that receive *any* signal from *this* sender. # Add receivers that receive *any* signal from *any* sender.
getReceivers( sender, Any ), l = []
# Add receivers that receive *this* signal from *any* sender. i = id(sender)
getReceivers( Any, signal ), if i in connections:
# Add receivers that receive *any* signal from *any* sender. sender_receivers = connections[i]
getReceivers( Any, Any ), if signal in sender_receivers:
): l.extend(sender_receivers[signal])
for receiver in set: if signal is not Any and Any in sender_receivers:
if receiver: # filter out dead instance-method weakrefs l.extend(sender_receivers[Any])
try:
if not receivers.has_key( receiver ): if sender is not Any:
receivers[receiver] = 1 i = id(Any)
yield receiver if i in connections:
except TypeError: sender_receivers = connections[i]
# dead weakrefs raise TypeError on hash... if sender_receivers is not None:
pass if signal in sender_receivers:
l.extend(sender_receivers[signal])
if signal is not Any and Any in sender_receivers:
l.extend(sender_receivers[Any])
for receiver in l:
try:
if not receiver in receivers:
if isinstance(receiver, WEAKREF_TYPES):
receiver = receiver()
# this should only (rough guess) be possible if somehow, deref'ing
# triggered a wipe.
if receiver is None:
continue
receivers[receiver] = 1
yield receiver
except TypeError:
# dead weakrefs raise TypeError on hash...
pass
def send(signal=Any, sender=Anonymous, *arguments, **named): def send(signal=Any, sender=Anonymous, *arguments, **named):
"""Send signal from sender to all connected receivers. """Send signal from sender to all connected receivers.
@ -340,7 +351,7 @@ def send(signal=Any, sender=Anonymous, *arguments, **named):
# Call each receiver with whatever arguments it can accept. # Call each receiver with whatever arguments it can accept.
# Return a list of tuple pairs [(receiver, response), ... ]. # Return a list of tuple pairs [(receiver, response), ... ].
responses = [] responses = []
for receiver in liveReceivers(getAllReceivers(sender, signal)): for receiver in getAllReceivers(sender, signal):
response = robustapply.robustApply( response = robustapply.robustApply(
receiver, receiver,
signal=signal, signal=signal,
@ -350,6 +361,8 @@ def send(signal=Any, sender=Anonymous, *arguments, **named):
) )
responses.append((receiver, response)) responses.append((receiver, response))
return responses return responses
def sendExact( signal=Any, sender=Anonymous, *arguments, **named ): def sendExact( signal=Any, sender=Anonymous, *arguments, **named ):
"""Send signal only to those receivers registered for exact message """Send signal only to those receivers registered for exact message
@ -421,33 +434,18 @@ def _cleanupConnections(senderkey, signal):
def _removeSender(senderkey): def _removeSender(senderkey):
"""Remove senderkey from connections.""" """Remove senderkey from connections."""
_removeBackrefs(senderkey) _removeBackrefs(senderkey)
try:
del connections[senderkey] connections.pop(senderkey, None)
except KeyError: senders.pop(senderkey, None)
pass
# Senderkey will only be in senders dictionary if sender
# could be weakly referenced.
try:
del senders[senderkey]
except:
pass
def _removeBackrefs( senderkey): def _removeBackrefs( senderkey):
"""Remove all back-references to this senderkey""" """Remove all back-references to this senderkey"""
try: for receiver_list in connections.pop(senderkey, {}).values():
signals = connections[senderkey] for receiver in receiver_list:
except KeyError:
signals = None
else:
items = signals.items()
def allReceivers( ):
for signal,set in items:
for item in set:
yield item
for receiver in allReceivers():
_killBackref( receiver, senderkey ) _killBackref( receiver, senderkey )
def _removeOldBackRefs(senderkey, signal, receiver, receivers): def _removeOldBackRefs(senderkey, signal, receiver, receivers):
"""Kill old sendersBack references from receiver """Kill old sendersBack references from receiver
@ -483,13 +481,13 @@ def _removeOldBackRefs(senderkey, signal, receiver, receivers):
def _killBackref( receiver, senderkey ): def _killBackref( receiver, senderkey ):
"""Do the actual removal of back reference from receiver to senderkey""" """Do the actual removal of back reference from receiver to senderkey"""
receiverkey = id(receiver) receiverkey = id(receiver)
set = sendersBack.get( receiverkey, () ) receivers_list = sendersBack.get( receiverkey, () )
while senderkey in set: while senderkey in receivers_list:
try: try:
set.remove( senderkey ) receivers_list.remove( senderkey )
except: except:
break break
if not set: if not receivers_list:
try: try:
del sendersBack[ receiverkey ] del sendersBack[ receiverkey ]
except KeyError: except KeyError:

View File

@ -0,0 +1,2 @@
"""Unit-tests for the dispatch project
"""

View File

View File

@ -0,0 +1,7 @@
"""
Unit-tests for the dispatch project
"""
from test_dispatcher import *
from test_robustapply import *
from test_saferef import *

View File

@ -0,0 +1,147 @@
from django.dispatch.dispatcher import *
from django.dispatch import dispatcher, robust
import unittest, pprint, copy
def x(a):
return a
class Dummy( object ):
pass
class Callable(object):
def __call__( self, a ):
return a
def a( self, a ):
return a
class DispatcherTests(unittest.TestCase):
"""Test suite for dispatcher (barely started)"""
def setUp(self):
# track the initial state, since it's possible that others have bleed receivers in
self.sendersBack = copy.copy(dispatcher.sendersBack)
self.connections = copy.copy(dispatcher.connections)
self.senders = copy.copy(dispatcher.senders)
def _isclean( self ):
"""Assert that everything has been cleaned up automatically"""
self.assertEqual(dispatcher.sendersBack, self.sendersBack)
self.assertEqual(dispatcher.connections, self.connections)
self.assertEqual(dispatcher.senders, self.senders)
def testExact (self):
a = Dummy()
signal = 'this'
connect( x, signal, a )
expected = [(x,a)]
result = send('this',a, a=a)
self.assertEqual(result, expected)
disconnect( x, signal, a )
self.assertEqual(len(list(getAllReceivers(a,signal))), 0)
self._isclean()
def testAnonymousSend(self):
a = Dummy()
signal = 'this'
connect( x, signal )
expected = [(x,a)]
result = send(signal,None, a=a)
assert result == expected,"""Send didn't return expected result:\n\texpected:%s\n\tgot:%s"""% (expected, result)
disconnect( x, signal )
assert len(list(getAllReceivers(None,signal))) == 0
self._isclean()
def testAnyRegistration(self):
a = Dummy()
signal = 'this'
connect( x, signal, Any )
expected = [(x,a)]
result = send('this',object(), a=a)
assert result == expected,"""Send didn't return expected result:\n\texpected:%s\n\tgot:%s"""% (expected, result)
disconnect( x, signal, Any )
expected = []
result = send('this',object(), a=a)
assert result == expected,"""Send didn't return expected result:\n\texpected:%s\n\tgot:%s"""% (expected, result)
assert len(list(getAllReceivers(Any,signal))) == 0
self._isclean()
def testAnyRegistration2(self):
a = Dummy()
signal = 'this'
connect( x, Any, a )
expected = [(x,a)]
result = send('this',a, a=a)
assert result == expected,"""Send didn't return expected result:\n\texpected:%s\n\tgot:%s"""% (expected, result)
disconnect( x, Any, a )
assert len(list(getAllReceivers(a,Any))) == 0
self._isclean()
def testGarbageCollected(self):
a = Callable()
b = Dummy()
signal = 'this'
connect( a.a, signal, b )
expected = []
del a
result = send('this',b, a=b)
assert result == expected,"""Send didn't return expected result:\n\texpected:%s\n\tgot:%s"""% (expected, result)
assert len(list(getAllReceivers(b,signal))) == 0, """Remaining handlers: %s"""%(getAllReceivers(b,signal),)
self._isclean()
def testGarbageCollectedObj(self):
class x:
def __call__( self, a ):
return a
a = Callable()
b = Dummy()
signal = 'this'
connect( a, signal, b )
expected = []
del a
result = send('this',b, a=b)
self.assertEqual(result, expected)
self.assertEqual(len(list(getAllReceivers(b,signal))), 0)
self._isclean()
def testMultipleRegistration(self):
a = Callable()
b = Dummy()
signal = 'this'
connect( a, signal, b )
connect( a, signal, b )
connect( a, signal, b )
connect( a, signal, b )
connect( a, signal, b )
connect( a, signal, b )
result = send('this',b, a=b)
assert len( result ) == 1, result
assert len(list(getAllReceivers(b,signal))) == 1, """Remaining handlers: %s"""%(getAllReceivers(b,signal),)
del a
del b
del result
self._isclean()
def testRobust( self ):
"""Test the sendRobust function"""
def fails( ):
raise ValueError( 'this' )
a = object()
signal = 'this'
connect( fails, Any, a )
result = robust.sendRobust('this',a, a=a)
err = result[0][1]
assert isinstance( err, ValueError )
assert err.args == ('this',)
def getSuite():
return unittest.makeSuite(DispatcherTests,'test')
if __name__ == "__main__":
unittest.main ()

View File

@ -0,0 +1,27 @@
from django.dispatch.robustapply import *
import unittest
def noArgument():
pass
def oneArgument (blah):
pass
def twoArgument(blah, other):
pass
class TestCases( unittest.TestCase ):
def test01( self ):
robustApply(noArgument )
def test02( self ):
self.assertRaises( TypeError, robustApply, noArgument, "this" )
def test03( self ):
self.assertRaises( TypeError, robustApply, oneArgument )
def test04( self ):
"""Raise error on duplication of a particular argument"""
self.assertRaises( TypeError, robustApply, oneArgument, "this", blah = "that" )
def getSuite():
return unittest.makeSuite(TestCases,'test')
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,76 @@
from django.dispatch.saferef import *
import unittest
class Test1( object):
def x( self ):
pass
def test2(obj):
pass
class Test2( object ):
def __call__( self, obj ):
pass
class Tester (unittest.TestCase):
def setUp (self):
ts = []
ss = []
for x in xrange( 5000 ):
t = Test1()
ts.append( t)
s = safeRef(t.x, self._closure )
ss.append( s)
ts.append( test2 )
ss.append( safeRef(test2, self._closure) )
for x in xrange( 30 ):
t = Test2()
ts.append( t)
s = safeRef(t, self._closure )
ss.append( s)
self.ts = ts
self.ss = ss
self.closureCount = 0
def tearDown( self ):
del self.ts
del self.ss
def testIn(self):
"""Test the "in" operator for safe references (cmp)"""
for t in self.ts[:50]:
assert safeRef(t.x) in self.ss
def testValid(self):
"""Test that the references are valid (return instance methods)"""
for s in self.ss:
assert s()
def testShortCircuit (self):
"""Test that creation short-circuits to reuse existing references"""
sd = {}
for s in self.ss:
sd[s] = 1
for t in self.ts:
if hasattr( t, 'x'):
assert sd.has_key( safeRef(t.x))
else:
assert sd.has_key( safeRef(t))
def testRepresentation (self):
"""Test that the reference object's representation works
XXX Doesn't currently check the results, just that no error
is raised
"""
repr( self.ss[-1] )
def test(self):
self.closureCount = 0
wholeI = len(self.ts)
for i in xrange( len(self.ts)-1, -1, -1):
del self.ts[i]
if wholeI-i != self.closureCount:
"""Unexpected number of items closed, expected %s, got %s closed"""%( wholeI-i,self.closureCount)
def _closure(self, ref):
"""Dumb utility mechanism to increment deletion counter"""
self.closureCount +=1
def getSuite():
return unittest.makeSuite(Tester,'test')
if __name__ == "__main__":
unittest.main ()