allow funcarg factories to receive funcargs

This commit is contained in:
holger krekel 2012-07-23 10:55:09 +02:00
parent 76584b53a1
commit d4a487c725
4 changed files with 144 additions and 32 deletions

View File

@ -444,17 +444,26 @@ class FuncargManager:
self.pytest_plugin_registered(plugin) self.pytest_plugin_registered(plugin)
def pytest_generate_tests(self, metafunc): def pytest_generate_tests(self, metafunc):
for argname in metafunc.funcargnames: funcargnames = list(metafunc.funcargnames)
seen = set()
while funcargnames:
argname = funcargnames.pop(0)
if argname in seen:
continue
seen.add(argname)
faclist = self.getfactorylist(argname, metafunc.parentid, faclist = self.getfactorylist(argname, metafunc.parentid,
metafunc.function, raising=False) metafunc.function, raising=False)
if faclist is None: if faclist is None:
continue # will raise at setup time continue # will raise FuncargLookupError at setup time
for fac in faclist: for fac in faclist:
marker = getattr(fac, "funcarg", None) marker = getattr(fac, "funcarg", None)
if marker is not None: if marker is not None:
params = marker.kwargs.get("params") params = marker.kwargs.get("params")
if params is not None: if params is not None:
metafunc.parametrize(argname, params, indirect=True) metafunc.parametrize(argname, params, indirect=True)
newfuncargnames = getfuncargnames(fac)
newfuncargnames.remove("request")
funcargnames.extend(newfuncargnames)
def _parsefactories(self, holderobj, nodeid): def _parsefactories(self, holderobj, nodeid):
if holderobj in self._holderobjseen: if holderobj in self._holderobjseen:
@ -773,3 +782,15 @@ class FuncargLookupErrorRepr(TerminalRepr):
tw.line(" " + line.strip(), red=True) tw.line(" " + line.strip(), red=True)
tw.line() tw.line()
tw.line("%s:%d" % (self.filename, self.firstlineno+1)) tw.line("%s:%d" % (self.filename, self.firstlineno+1))
def getfuncargnames(function, startindex=None):
# XXX merge with main.py's varnames
argnames = py.std.inspect.getargs(py.code.getrawcode(function))[0]
if startindex is None:
startindex = py.std.inspect.ismethod(function) and 1 or 0
defaults = getattr(function, 'func_defaults',
getattr(function, '__defaults__', None)) or ()
numdefaults = len(defaults)
if numdefaults:
return argnames[startindex:-numdefaults]
return argnames[startindex:]

View File

@ -3,7 +3,7 @@ import py
import inspect import inspect
import sys import sys
import pytest import pytest
from _pytest.main import getfslineno from _pytest.main import getfslineno, getfuncargnames
from _pytest.monkeypatch import monkeypatch from _pytest.monkeypatch import monkeypatch
import _pytest import _pytest
@ -475,17 +475,6 @@ def hasinit(obj):
return True return True
def getfuncargnames(function, startindex=None):
# XXX merge with main.py's varnames
argnames = py.std.inspect.getargs(py.code.getrawcode(function))[0]
if startindex is None:
startindex = py.std.inspect.ismethod(function) and 1 or 0
defaults = getattr(function, 'func_defaults',
getattr(function, '__defaults__', None)) or ()
numdefaults = len(defaults)
if numdefaults:
return argnames[startindex:-numdefaults]
return argnames[startindex:]
def fillfuncargs(function): def fillfuncargs(function):
""" fill missing funcargs. """ """ fill missing funcargs. """
@ -887,6 +876,7 @@ class FuncargRequest:
self.funcargnames = getfuncargnames(self.function) self.funcargnames = getfuncargnames(self.function)
self.parentid = pyfuncitem.parent.nodeid self.parentid = pyfuncitem.parent.nodeid
self.scope = "function" self.scope = "function"
self._factorystack = []
def _getfaclist(self, argname): def _getfaclist(self, argname):
faclist = self._name2factory.get(argname, None) faclist = self._name2factory.get(argname, None)
@ -947,7 +937,8 @@ class FuncargRequest:
if self.funcargnames: if self.funcargnames:
assert not getattr(self._pyfuncitem, '_args', None), ( assert not getattr(self._pyfuncitem, '_args', None), (
"yielded functions cannot have funcargs") "yielded functions cannot have funcargs")
for argname in self.funcargnames: while self.funcargnames:
argname = self.funcargnames.pop(0)
if argname not in self._pyfuncitem.funcargs: if argname not in self._pyfuncitem.funcargs:
self._pyfuncitem.funcargs[argname] = \ self._pyfuncitem.funcargs[argname] = \
self.getfuncargvalue(argname) self.getfuncargvalue(argname)
@ -984,7 +975,10 @@ class FuncargRequest:
val = cache[cachekey] val = cache[cachekey]
except KeyError: except KeyError:
__tracebackhide__ = True __tracebackhide__ = True
check_scope(self.scope, scope) if scopemismatch(self.scope, scope):
raise ScopeMismatchError("You tried to access a %r scoped "
"resource with a %r scoped request object" %(
(scope, self.scope)))
__tracebackhide__ = False __tracebackhide__ = False
val = setup() val = setup()
cache[cachekey] = val cache[cachekey] = val
@ -1010,6 +1004,21 @@ class FuncargRequest:
pass pass
factorylist = self._getfaclist(argname) factorylist = self._getfaclist(argname)
funcargfactory = factorylist.pop() funcargfactory = factorylist.pop()
self._factorystack.append(funcargfactory)
try:
return self._getfuncargvalue(funcargfactory, argname)
finally:
self._factorystack.pop()
def _getfuncargvalue(self, funcargfactory, argname):
# collect funcargs from the factory
newnames = getfuncargnames(funcargfactory)
newnames.remove("request")
factory_kwargs = {"request": self}
def fillfactoryargs():
for newname in newnames:
factory_kwargs[newname] = self.getfuncargvalue(newname)
node = self._pyfuncitem node = self._pyfuncitem
mp = monkeypatch() mp = monkeypatch()
mp.setattr(self, '_currentarg', argname) mp.setattr(self, '_currentarg', argname)
@ -1027,16 +1036,30 @@ class FuncargRequest:
scope = marker.kwargs.get("scope") scope = marker.kwargs.get("scope")
if scope is not None: if scope is not None:
__tracebackhide__ = True __tracebackhide__ = True
check_scope(self.scope, scope) if scopemismatch(self.scope, scope):
# try to report something helpful
lines = []
for factory in self._factorystack:
fs, lineno = getfslineno(factory)
p = self._pyfuncitem.session.fspath.bestrelpath(fs)
args = inspect.formatargspec(*inspect.getargspec(factory))
lines.append("%s:%d\n def %s%s" %(
p, lineno, factory.__name__, args))
raise ScopeMismatchError("You tried to access the %r scoped "
"funcarg %r with a %r scoped request object, "
"involved factories\n%s" %(
(scope, argname, self.scope, "\n".join(lines))))
__tracebackhide__ = False __tracebackhide__ = False
mp.setattr(self, "scope", scope) mp.setattr(self, "scope", scope)
kwargs = {} kwargs = {}
if hasattr(self, "param"): if hasattr(self, "param"):
kwargs["extrakey"] = param kwargs["extrakey"] = param
val = self.cached_setup(lambda: funcargfactory(request=self), fillfactoryargs()
val = self.cached_setup(lambda: funcargfactory(**factory_kwargs),
scope=scope, **kwargs) scope=scope, **kwargs)
else: else:
val = funcargfactory(request=self) fillfactoryargs()
val = funcargfactory(**factory_kwargs)
mp.undo() mp.undo()
self._funcargs[argname] = val self._funcargs[argname] = val
return val return val
@ -1072,11 +1095,14 @@ class ScopeMismatchError(Exception):
""" A funcarg factory tries to access a funcargvalue/factory """ A funcarg factory tries to access a funcargvalue/factory
which has a lower scope (e.g. a Session one calls a function one) which has a lower scope (e.g. a Session one calls a function one)
""" """
scopes = "session module class function".split() scopes = "session module class function".split()
def check_scope(currentscope, newscope): def scopemismatch(currentscope, newscope):
__tracebackhide__ = True return scopes.index(newscope) > scopes.index(currentscope)
i_currentscope = scopes.index(currentscope)
i_newscope = scopes.index(newscope) def slice_kwargs(names, kwargs):
if i_newscope > i_currentscope: new_kwargs = {}
raise ScopeMismatchError("You tried to access a %r scoped funcarg " for name in names:
"from a %r scoped one." % (newscope, currentscope)) new_kwargs[name] = kwargs[name]
return new_kwargs

View File

@ -248,10 +248,11 @@ You can also use ``@setup`` inside a test module or class::
def modes(tmpdir, request): def modes(tmpdir, request):
# ... # ...
This would execute the ``modes`` function once for each parameter. This would execute the ``modes`` function once for each parameter
In addition to normal funcargs you can also receive the "request" which will be put at ``request.param``. This request object offers
funcarg which represents a takes on each of the values in the the ``addfinalizer(func)`` helper which allows to register a function
``params=[1,2,3]`` decorator argument. which will be executed when test functions within the specified scope
finished execution.
.. note:: .. note::

View File

@ -619,7 +619,7 @@ class TestRequest:
def test_request_attributes_method(self, testdir): def test_request_attributes_method(self, testdir):
item, = testdir.getitems(""" item, = testdir.getitems("""
class TestB: class TestB:
def pytest_funcarg__something(request): def pytest_funcarg__something(self, request):
return 1 return 1
def test_func(self, something): def test_func(self, something):
pass pass
@ -1616,6 +1616,70 @@ class TestRequestAPI:
"*1 passed*", "*1 passed*",
]) ])
class TestFuncargFactory:
def test_receives_funcargs(self, testdir):
testdir.makepyfile("""
import pytest
@pytest.mark.funcarg
def arg1(request):
return 1
@pytest.mark.funcarg
def arg2(request, arg1):
return arg1 + 1
def test_add(arg2):
assert arg2 == 2
def test_all(arg1, arg2):
assert arg1 == 1
assert arg2 == 2
""")
reprec = testdir.inline_run()
reprec.assertoutcome(passed=2)
def test_receives_funcargs_scope_mismatch(self, testdir):
testdir.makepyfile("""
import pytest
@pytest.mark.funcarg(scope="function")
def arg1(request):
return 1
@pytest.mark.funcarg(scope="module")
def arg2(request, arg1):
return arg1 + 1
def test_add(arg2):
assert arg2 == 2
""")
result = testdir.runpytest()
result.stdout.fnmatch_lines([
"*ScopeMismatch*involved factories*",
"* def arg2*",
"* def arg1*",
"*1 error*"
])
def test_funcarg_parametrized_and_used_twice(self, testdir):
testdir.makepyfile("""
import pytest
l = []
@pytest.mark.funcarg(params=[1,2])
def arg1(request):
l.append(1)
return request.param
@pytest.mark.funcarg
def arg2(request, arg1):
return arg1 + 1
def test_add(arg1, arg2):
assert arg2 == arg1 + 1
assert len(l) == arg1
""")
result = testdir.runpytest()
result.stdout.fnmatch_lines([
"*2 passed*"
])
class TestResourceIntegrationFunctional: class TestResourceIntegrationFunctional:
@ -1802,7 +1866,7 @@ class TestFuncargMarker:
result = testdir.runpytest() result = testdir.runpytest()
assert result.ret != 0 assert result.ret != 0
result.stdout.fnmatch_lines([ result.stdout.fnmatch_lines([
"*ScopeMismatch*You tried*function*from*session*", "*ScopeMismatch*You tried*function*session*request*",
]) ])
def test_register_only_with_mark(self, testdir): def test_register_only_with_mark(self, testdir):