From d4a487c7253e371280b1aad533901fc9989b789c Mon Sep 17 00:00:00 2001 From: holger krekel Date: Mon, 23 Jul 2012 10:55:09 +0200 Subject: [PATCH] allow funcarg factories to receive funcargs --- _pytest/main.py | 25 ++++++++++++-- _pytest/python.py | 74 ++++++++++++++++++++++++++++-------------- doc/en/resources.txt | 9 ++--- testing/test_python.py | 68 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 144 insertions(+), 32 deletions(-) diff --git a/_pytest/main.py b/_pytest/main.py index ad4910afc..fbd7fd2f3 100644 --- a/_pytest/main.py +++ b/_pytest/main.py @@ -444,17 +444,26 @@ class FuncargManager: self.pytest_plugin_registered(plugin) def pytest_generate_tests(self, metafunc): - for argname in metafunc.funcargnames: + funcargnames = list(metafunc.funcargnames) + seen = set() + while funcargnames: + argname = funcargnames.pop(0) + if argname in seen: + continue + seen.add(argname) faclist = self.getfactorylist(argname, metafunc.parentid, metafunc.function, raising=False) if faclist is None: - continue # will raise at setup time + continue # will raise FuncargLookupError at setup time for fac in faclist: marker = getattr(fac, "funcarg", None) if marker is not None: params = marker.kwargs.get("params") if params is not None: metafunc.parametrize(argname, params, indirect=True) + newfuncargnames = getfuncargnames(fac) + newfuncargnames.remove("request") + funcargnames.extend(newfuncargnames) def _parsefactories(self, holderobj, nodeid): if holderobj in self._holderobjseen: @@ -773,3 +782,15 @@ class FuncargLookupErrorRepr(TerminalRepr): tw.line(" " + line.strip(), red=True) tw.line() tw.line("%s:%d" % (self.filename, self.firstlineno+1)) + +def getfuncargnames(function, startindex=None): + # XXX merge with main.py's varnames + argnames = py.std.inspect.getargs(py.code.getrawcode(function))[0] + if startindex is None: + startindex = py.std.inspect.ismethod(function) and 1 or 0 + defaults = getattr(function, 'func_defaults', + getattr(function, '__defaults__', None)) or () + numdefaults = len(defaults) + if numdefaults: + return argnames[startindex:-numdefaults] + return argnames[startindex:] diff --git a/_pytest/python.py b/_pytest/python.py index 6baacb8cf..7821ff47b 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -3,7 +3,7 @@ import py import inspect import sys import pytest -from _pytest.main import getfslineno +from _pytest.main import getfslineno, getfuncargnames from _pytest.monkeypatch import monkeypatch import _pytest @@ -475,17 +475,6 @@ def hasinit(obj): return True -def getfuncargnames(function, startindex=None): - # XXX merge with main.py's varnames - argnames = py.std.inspect.getargs(py.code.getrawcode(function))[0] - if startindex is None: - startindex = py.std.inspect.ismethod(function) and 1 or 0 - defaults = getattr(function, 'func_defaults', - getattr(function, '__defaults__', None)) or () - numdefaults = len(defaults) - if numdefaults: - return argnames[startindex:-numdefaults] - return argnames[startindex:] def fillfuncargs(function): """ fill missing funcargs. """ @@ -887,6 +876,7 @@ class FuncargRequest: self.funcargnames = getfuncargnames(self.function) self.parentid = pyfuncitem.parent.nodeid self.scope = "function" + self._factorystack = [] def _getfaclist(self, argname): faclist = self._name2factory.get(argname, None) @@ -947,7 +937,8 @@ class FuncargRequest: if self.funcargnames: assert not getattr(self._pyfuncitem, '_args', None), ( "yielded functions cannot have funcargs") - for argname in self.funcargnames: + while self.funcargnames: + argname = self.funcargnames.pop(0) if argname not in self._pyfuncitem.funcargs: self._pyfuncitem.funcargs[argname] = \ self.getfuncargvalue(argname) @@ -984,7 +975,10 @@ class FuncargRequest: val = cache[cachekey] except KeyError: __tracebackhide__ = True - check_scope(self.scope, scope) + if scopemismatch(self.scope, scope): + raise ScopeMismatchError("You tried to access a %r scoped " + "resource with a %r scoped request object" %( + (scope, self.scope))) __tracebackhide__ = False val = setup() cache[cachekey] = val @@ -1010,6 +1004,21 @@ class FuncargRequest: pass factorylist = self._getfaclist(argname) funcargfactory = factorylist.pop() + self._factorystack.append(funcargfactory) + try: + return self._getfuncargvalue(funcargfactory, argname) + finally: + self._factorystack.pop() + + def _getfuncargvalue(self, funcargfactory, argname): + # collect funcargs from the factory + newnames = getfuncargnames(funcargfactory) + newnames.remove("request") + factory_kwargs = {"request": self} + def fillfactoryargs(): + for newname in newnames: + factory_kwargs[newname] = self.getfuncargvalue(newname) + node = self._pyfuncitem mp = monkeypatch() mp.setattr(self, '_currentarg', argname) @@ -1027,16 +1036,30 @@ class FuncargRequest: scope = marker.kwargs.get("scope") if scope is not None: __tracebackhide__ = True - check_scope(self.scope, scope) + if scopemismatch(self.scope, scope): + # try to report something helpful + lines = [] + for factory in self._factorystack: + fs, lineno = getfslineno(factory) + p = self._pyfuncitem.session.fspath.bestrelpath(fs) + args = inspect.formatargspec(*inspect.getargspec(factory)) + lines.append("%s:%d\n def %s%s" %( + p, lineno, factory.__name__, args)) + raise ScopeMismatchError("You tried to access the %r scoped " + "funcarg %r with a %r scoped request object, " + "involved factories\n%s" %( + (scope, argname, self.scope, "\n".join(lines)))) __tracebackhide__ = False mp.setattr(self, "scope", scope) kwargs = {} if hasattr(self, "param"): kwargs["extrakey"] = param - val = self.cached_setup(lambda: funcargfactory(request=self), + fillfactoryargs() + val = self.cached_setup(lambda: funcargfactory(**factory_kwargs), scope=scope, **kwargs) else: - val = funcargfactory(request=self) + fillfactoryargs() + val = funcargfactory(**factory_kwargs) mp.undo() self._funcargs[argname] = val return val @@ -1072,11 +1095,14 @@ class ScopeMismatchError(Exception): """ A funcarg factory tries to access a funcargvalue/factory which has a lower scope (e.g. a Session one calls a function one) """ + scopes = "session module class function".split() -def check_scope(currentscope, newscope): - __tracebackhide__ = True - i_currentscope = scopes.index(currentscope) - i_newscope = scopes.index(newscope) - if i_newscope > i_currentscope: - raise ScopeMismatchError("You tried to access a %r scoped funcarg " - "from a %r scoped one." % (newscope, currentscope)) +def scopemismatch(currentscope, newscope): + return scopes.index(newscope) > scopes.index(currentscope) + +def slice_kwargs(names, kwargs): + new_kwargs = {} + for name in names: + new_kwargs[name] = kwargs[name] + return new_kwargs + diff --git a/doc/en/resources.txt b/doc/en/resources.txt index 9a43341e2..a9f043e90 100644 --- a/doc/en/resources.txt +++ b/doc/en/resources.txt @@ -248,10 +248,11 @@ You can also use ``@setup`` inside a test module or class:: def modes(tmpdir, request): # ... -This would execute the ``modes`` function once for each parameter. -In addition to normal funcargs you can also receive the "request" -funcarg which represents a takes on each of the values in the -``params=[1,2,3]`` decorator argument. +This would execute the ``modes`` function once for each parameter +which will be put at ``request.param``. This request object offers +the ``addfinalizer(func)`` helper which allows to register a function +which will be executed when test functions within the specified scope +finished execution. .. note:: diff --git a/testing/test_python.py b/testing/test_python.py index 2a5eb8a98..2e681630f 100644 --- a/testing/test_python.py +++ b/testing/test_python.py @@ -619,7 +619,7 @@ class TestRequest: def test_request_attributes_method(self, testdir): item, = testdir.getitems(""" class TestB: - def pytest_funcarg__something(request): + def pytest_funcarg__something(self, request): return 1 def test_func(self, something): pass @@ -1616,6 +1616,70 @@ class TestRequestAPI: "*1 passed*", ]) +class TestFuncargFactory: + def test_receives_funcargs(self, testdir): + testdir.makepyfile(""" + import pytest + @pytest.mark.funcarg + def arg1(request): + return 1 + + @pytest.mark.funcarg + def arg2(request, arg1): + return arg1 + 1 + + def test_add(arg2): + assert arg2 == 2 + def test_all(arg1, arg2): + assert arg1 == 1 + assert arg2 == 2 + """) + reprec = testdir.inline_run() + reprec.assertoutcome(passed=2) + + def test_receives_funcargs_scope_mismatch(self, testdir): + testdir.makepyfile(""" + import pytest + @pytest.mark.funcarg(scope="function") + def arg1(request): + return 1 + + @pytest.mark.funcarg(scope="module") + def arg2(request, arg1): + return arg1 + 1 + + def test_add(arg2): + assert arg2 == 2 + """) + result = testdir.runpytest() + result.stdout.fnmatch_lines([ + "*ScopeMismatch*involved factories*", + "* def arg2*", + "* def arg1*", + "*1 error*" + ]) + + def test_funcarg_parametrized_and_used_twice(self, testdir): + testdir.makepyfile(""" + import pytest + l = [] + @pytest.mark.funcarg(params=[1,2]) + def arg1(request): + l.append(1) + return request.param + + @pytest.mark.funcarg + def arg2(request, arg1): + return arg1 + 1 + + def test_add(arg1, arg2): + assert arg2 == arg1 + 1 + assert len(l) == arg1 + """) + result = testdir.runpytest() + result.stdout.fnmatch_lines([ + "*2 passed*" + ]) class TestResourceIntegrationFunctional: @@ -1802,7 +1866,7 @@ class TestFuncargMarker: result = testdir.runpytest() assert result.ret != 0 result.stdout.fnmatch_lines([ - "*ScopeMismatch*You tried*function*from*session*", + "*ScopeMismatch*You tried*function*session*request*", ]) def test_register_only_with_mark(self, testdir):