diff --git a/_pytest/core.py b/_pytest/core.py index d1fdf1dba..e7a441634 100644 --- a/_pytest/core.py +++ b/_pytest/core.py @@ -95,10 +95,11 @@ class PluginManager(object): raise ValueError("Plugin already registered: %s=%s\n%s" %( name, plugin, self._name2plugin)) #self.trace("registering", name, plugin) - self._name2plugin[name] = plugin reg = getattr(self, "_registercallback", None) if reg is not None: - reg(plugin, name) + reg(plugin, name) # may call addhooks + self.hook._scan_plugin(plugin) + self._name2plugin[name] = plugin if conftest: self._conftestplugins.append(plugin) else: @@ -403,41 +404,86 @@ class HookRelay: def _getcaller(self, name, plugins): caller = getattr(self, name) methods = self._pm.listattr(name, plugins=plugins) - return CachedHookCaller(caller, methods) - - -class CachedHookCaller: - def __init__(self, hookmethod, methods): - self.hookmethod = hookmethod - self.methods = methods - - def __call__(self, **kwargs): - return self.hookmethod._docall(self.methods, kwargs) - - def callextra(self, methods, **kwargs): - # XXX in theory we should respect "tryfirst/trylast" if set - # on the added methods but we currently only use it for - # pytest_generate_tests and it doesn't make sense there i'd think - all = self.methods if methods: - all = all + methods - return self.hookmethod._docall(all, kwargs) + return caller.new_cached_caller(methods) + return caller + + def _scan_plugin(self, plugin): + methods = collectattr(plugin) + hooks = {} + for hookspec in self._hookspecs: + hooks.update(collectattr(hookspec)) + + stringio = py.io.TextIO() + def Print(*args): + if args: + stringio.write(" ".join(map(str, args))) + stringio.write("\n") + + fail = False + while methods: + name, method = methods.popitem() + #print "checking", name + if isgenerichook(name): + continue + if name not in hooks: + if not getattr(method, 'optionalhook', False): + Print("found unknown hook:", name) + fail = True + else: + #print "checking", method + method_args = list(varnames(method)) + if '__multicall__' in method_args: + method_args.remove('__multicall__') + hook = hooks[name] + hookargs = varnames(hook) + for arg in method_args: + if arg not in hookargs: + Print("argument %r not available" %(arg, )) + Print("actual definition: %s" %(formatdef(method))) + Print("available hook arguments: %s" % + ", ".join(hookargs)) + fail = True + break + #if not fail: + # print "matching hook:", formatdef(method) + getattr(self, name).clear_method_cache() + + if fail: + name = getattr(plugin, '__name__', plugin) + raise PluginValidationError("%s:\n%s" % (name, stringio.getvalue())) class HookCaller: - def __init__(self, hookrelay, name, firstresult): + def __init__(self, hookrelay, name, firstresult, methods=None): self.hookrelay = hookrelay self.name = name self.firstresult = firstresult self.trace = self.hookrelay.trace + self.methods = methods + + def new_cached_caller(self, methods): + return HookCaller(self.hookrelay, self.name, self.firstresult, + methods=methods) def __repr__(self): return "" %(self.name,) + def clear_method_cache(self): + self.methods = None + def __call__(self, **kwargs): - methods = self.hookrelay._pm.listattr(self.name) + methods = self.methods + if self.methods is None: + self.methods = methods = self.hookrelay._pm.listattr(self.name) + methods = self.methods return self._docall(methods, kwargs) + def callextra(self, methods, **kwargs): + if self.methods is None: + self.reload_methods() + return self._docall(self.methods + methods, kwargs) + def _docall(self, methods, kwargs): self.trace(self.name, kwargs) self.trace.root.indent += 1 @@ -450,3 +496,25 @@ class HookCaller: self.trace.root.indent -= 1 return res + + +class PluginValidationError(Exception): + """ plugin failed validation. """ + +def isgenerichook(name): + return name == "pytest_plugins" or \ + name.startswith("pytest_funcarg__") + +def collectattr(obj): + methods = {} + for apiname in dir(obj): + if apiname.startswith("pytest_"): + methods[apiname] = getattr(obj, apiname) + return methods + +def formatdef(func): + return "%s%s" % ( + func.__name__, + inspect.formatargspec(*inspect.getargspec(func)) + ) + diff --git a/_pytest/helpconfig.py b/_pytest/helpconfig.py index 79c331145..028b3e3f1 100644 --- a/_pytest/helpconfig.py +++ b/_pytest/helpconfig.py @@ -127,70 +127,3 @@ def pytest_report_header(config): return lines -# ===================================================== -# validate plugin syntax and hooks -# ===================================================== - -def pytest_plugin_registered(manager, plugin): - methods = collectattr(plugin) - hooks = {} - for hookspec in manager.hook._hookspecs: - hooks.update(collectattr(hookspec)) - - stringio = py.io.TextIO() - def Print(*args): - if args: - stringio.write(" ".join(map(str, args))) - stringio.write("\n") - - fail = False - while methods: - name, method = methods.popitem() - #print "checking", name - if isgenerichook(name): - continue - if name not in hooks: - if not getattr(method, 'optionalhook', False): - Print("found unknown hook:", name) - fail = True - else: - #print "checking", method - method_args = list(varnames(method)) - if '__multicall__' in method_args: - method_args.remove('__multicall__') - hook = hooks[name] - hookargs = varnames(hook) - for arg in method_args: - if arg not in hookargs: - Print("argument %r not available" %(arg, )) - Print("actual definition: %s" %(formatdef(method))) - Print("available hook arguments: %s" % - ", ".join(hookargs)) - fail = True - break - #if not fail: - # print "matching hook:", formatdef(method) - if fail: - name = getattr(plugin, '__name__', plugin) - raise PluginValidationError("%s:\n%s" % (name, stringio.getvalue())) - -class PluginValidationError(Exception): - """ plugin failed validation. """ - -def isgenerichook(name): - return name == "pytest_plugins" or \ - name.startswith("pytest_funcarg__") - -def collectattr(obj): - methods = {} - for apiname in dir(obj): - if apiname.startswith("pytest_"): - methods[apiname] = getattr(obj, apiname) - return methods - -def formatdef(func): - return "%s%s" % ( - func.__name__, - inspect.formatargspec(*inspect.getargspec(func)) - ) - diff --git a/_pytest/main.py b/_pytest/main.py index 5377764de..025573bdb 100644 --- a/_pytest/main.py +++ b/_pytest/main.py @@ -164,28 +164,6 @@ class FSHookProxy(object): self.__dict__[name] = x return x - hookmethod = getattr(self.config.hook, name) - methods = self.config.pluginmanager.listattr(name, plugins=plugins) - self.__dict__[name] = x = HookCaller(hookmethod, methods) - return x - - -class HookCaller: - def __init__(self, hookmethod, methods): - self.hookmethod = hookmethod - self.methods = methods - - def __call__(self, **kwargs): - return self.hookmethod._docall(self.methods, kwargs) - - def callextra(self, methods, **kwargs): - # XXX in theory we should respect "tryfirst/trylast" if set - # on the added methods but we currently only use it for - # pytest_generate_tests and it doesn't make sense there i'd think - all = self.methods - if methods: - all = all + methods - return self.hookmethod._docall(all, kwargs) def compatproperty(name): def fget(self): diff --git a/_pytest/pytester.py b/_pytest/pytester.py index 9e987ae03..63e223af2 100644 --- a/_pytest/pytester.py +++ b/_pytest/pytester.py @@ -47,7 +47,6 @@ class PytestArg: def gethookrecorder(self, hook): hookrecorder = HookRecorder(hook._pm) - hookrecorder.start_recording(hook._hookspecs) self.request.addfinalizer(hookrecorder.finish_recording) return hookrecorder @@ -69,9 +68,7 @@ class HookRecorder: self.calls = [] self._recorders = {} - def start_recording(self, hookspecs): - if not isinstance(hookspecs, (list, tuple)): - hookspecs = [hookspecs] + hookspecs = self._pluginmanager.hook._hookspecs for hookspec in hookspecs: assert hookspec not in self._recorders class RecordCalls: diff --git a/testing/test_helpconfig.py b/testing/test_helpconfig.py index df78ccecc..7d4c7cab1 100644 --- a/testing/test_helpconfig.py +++ b/testing/test_helpconfig.py @@ -1,5 +1,5 @@ import py, pytest -from _pytest.helpconfig import collectattr +from _pytest.core import collectattr def test_version(testdir, pytestconfig): result = testdir.runpytest("--version") diff --git a/testing/test_pytester.py b/testing/test_pytester.py index b3c6cf795..59a294674 100644 --- a/testing/test_pytester.py +++ b/testing/test_pytester.py @@ -71,28 +71,38 @@ def test_testdir_runs_with_plugin(testdir): "*1 passed*" ]) -def test_hookrecorder_basic(): - rec = HookRecorder(PluginManager()) - class ApiClass: + +def make_holder(): + class apiclass: def pytest_xyz(self, arg): "x" - rec.start_recording(ApiClass) + def pytest_xyz_noarg(self): + "x" + + apimod = type(os)('api') + def pytest_xyz(arg): + "x" + def pytest_xyz_noarg(): + "x" + apimod.pytest_xyz = pytest_xyz + apimod.pytest_xyz_noarg = pytest_xyz_noarg + return apiclass, apimod + + +@pytest.mark.parametrize("holder", make_holder()) +def test_hookrecorder_basic(holder): + pm = PluginManager() + pm.hook._addhooks(holder, "pytest_") + rec = HookRecorder(pm) rec.hook.pytest_xyz(arg=123) call = rec.popcall("pytest_xyz") assert call.arg == 123 assert call._name == "pytest_xyz" pytest.raises(pytest.fail.Exception, "rec.popcall('abc')") + rec.hook.pytest_xyz_noarg() + call = rec.popcall("pytest_xyz_noarg") + assert call._name == "pytest_xyz_noarg" -def test_hookrecorder_basic_no_args_hook(): - rec = HookRecorder(PluginManager()) - apimod = type(os)('api') - def pytest_xyz(): - "x" - apimod.pytest_xyz = pytest_xyz - rec.start_recording(apimod) - rec.hook.pytest_xyz() - call = rec.popcall("pytest_xyz") - assert call._name == "pytest_xyz" def test_functional(testdir, linecomp): reprec = testdir.inline_runsource(""" @@ -102,8 +112,9 @@ def test_functional(testdir, linecomp): def test_func(_pytest): class ApiClass: def pytest_xyz(self, arg): "x" - hook = HookRelay([ApiClass], PluginManager()) - rec = _pytest.gethookrecorder(hook) + pm = PluginManager() + pm.hook._addhooks(ApiClass, "pytest_") + rec = _pytest.gethookrecorder(pm.hook) class Plugin: def pytest_xyz(self, arg): return arg + 1