simplify _scan_plugin implementation and store argnames on HookCaller

This commit is contained in:
holger krekel 2014-10-01 13:57:35 +02:00
parent 28c785a0d1
commit f250e912eb
2 changed files with 50 additions and 54 deletions

View File

@ -68,7 +68,7 @@ class TagTracerSub:
return self.__class__(self.root, self.tags + (name,))
class PluginManager(object):
def __init__(self, hookspecs=None):
def __init__(self, hookspecs=None, prefix="pytest_"):
self._name2plugin = {}
self._listattrcache = {}
self._plugins = []
@ -77,7 +77,7 @@ class PluginManager(object):
self.trace = TagTracer().get("pluginmanage")
self._plugin_distinfo = []
self._shutdown = []
self.hook = HookRelay(hookspecs or [], pm=self)
self.hook = HookRelay(hookspecs or [], pm=self, prefix=prefix)
def do_configure(self, config):
# backward compatibility
@ -384,22 +384,25 @@ class HookRelay:
self._hookspecs = []
self._pm = pm
self.trace = pm.trace.root.get("hook")
self.prefix = prefix
for hookspec in hookspecs:
self._addhooks(hookspec, prefix)
def _addhooks(self, hookspecs, prefix):
self._hookspecs.append(hookspecs)
def _addhooks(self, hookspec, prefix):
self._hookspecs.append(hookspec)
added = False
for name, method in vars(hookspecs).items():
for name in dir(hookspec):
if name.startswith(prefix):
method = getattr(hookspec, name)
firstresult = getattr(method, 'firstresult', False)
hc = HookCaller(self, name, firstresult=firstresult)
hc = HookCaller(self, name, firstresult=firstresult,
argnames=varnames(method))
setattr(self, name, hc)
added = True
#print ("setting new hook", name)
if not added:
raise ValueError("did not find new %r hooks in %r" %(
prefix, hookspecs,))
prefix, hookspec,))
def _getcaller(self, name, plugins):
caller = getattr(self, name)
@ -409,62 +412,44 @@ class HookRelay:
return caller
def _scan_plugin(self, plugin):
methods = collectattr(plugin)
hooks = {}
for hookspec in self._hookspecs:
hooks.update(collectattr(hookspec))
def fail(msg, *args):
name = getattr(plugin, '__name__', plugin)
raise PluginValidationError("plugin %r\n%s" %(name, msg % args))
stringio = py.io.TextIO()
def Print(*args):
if args:
stringio.write(" ".join(map(str, args)))
stringio.write("\n")
fail = False
while methods:
name, method = methods.popitem()
#print "checking", name
if isgenerichook(name):
for name in dir(plugin):
if not name.startswith(self.prefix):
continue
if name not in hooks:
if not getattr(method, 'optionalhook', False):
Print("found unknown hook:", name)
fail = True
else:
#print "checking", method
method_args = list(varnames(method))
if '__multicall__' in method_args:
method_args.remove('__multicall__')
hook = hooks[name]
hookargs = varnames(hook)
for arg in method_args:
if arg not in hookargs:
Print("argument %r not available" %(arg, ))
Print("actual definition: %s" %(formatdef(method)))
Print("available hook arguments: %s" %
", ".join(hookargs))
fail = True
break
#if not fail:
# print "matching hook:", formatdef(method)
getattr(self, name).clear_method_cache()
if fail:
name = getattr(plugin, '__name__', plugin)
raise PluginValidationError("%s:\n%s" % (name, stringio.getvalue()))
hook = getattr(self, name, None)
method = getattr(plugin, name)
if hook is None:
is_optional = getattr(method, 'optionalhook', False)
if not isgenerichook(name) and not is_optional:
fail("found unknown hook: %r", name)
continue
for arg in varnames(method):
if arg not in hook.argnames:
fail("argument %r not available\n"
"actual definition: %s\n"
"available hookargs: %s",
arg, formatdef(method),
", ".join(hook.argnames))
getattr(self, name).clear_method_cache()
class HookCaller:
def __init__(self, hookrelay, name, firstresult, methods=None):
def __init__(self, hookrelay, name, firstresult, argnames, methods=None):
self.hookrelay = hookrelay
self.name = name
self.firstresult = firstresult
self.trace = self.hookrelay.trace
self.methods = methods
self.argnames = ["__multicall__"]
self.argnames.extend(argnames)
assert "self" not in argnames
def new_cached_caller(self, methods):
return HookCaller(self.hookrelay, self.name, self.firstresult,
methods=methods)
argnames=self.argnames, methods=methods)
def __repr__(self):
return "<HookCaller %r>" %(self.name,)
@ -474,14 +459,13 @@ class HookCaller:
def __call__(self, **kwargs):
methods = self.methods
if self.methods is None:
if methods is None:
self.methods = methods = self.hookrelay._pm.listattr(self.name)
methods = self.methods
return self._docall(methods, kwargs)
def callextra(self, methods, **kwargs):
if self.methods is None:
self.reload_methods()
#if self.methods is None:
# self.reload_methods()
return self._docall(self.methods + methods, kwargs)
def _docall(self, methods, kwargs):

View File

@ -630,6 +630,18 @@ class TestHookRelay:
assert l == [4]
assert not hasattr(mcm, 'world')
def test_argmismatch(self):
class Api:
def hello(self, arg):
"api hook 1"
pm = PluginManager(Api, prefix="he")
class Plugin:
def hello(self, argwrong):
return arg + 1
with pytest.raises(PluginValidationError) as exc:
pm.register(Plugin())
assert "argwrong" in str(exc.value)
def test_only_kwargs(self):
pm = PluginManager()
class Api: