Support for tests created with functools.partial

Fix #811
This commit is contained in:
Bruno Oliveira 2015-07-02 23:13:59 -03:00
parent 330de0a93d
commit dcdc823dd2
2 changed files with 69 additions and 5 deletions

View File

@ -18,10 +18,19 @@ callable = py.builtin.callable
# used to work around a python2 exception info leak
exc_clear = getattr(sys, 'exc_clear', lambda: None)
def getfslineno(obj):
# xxx let decorators etc specify a sane ordering
def get_real_func(obj):
"""gets the real function object of the (possibly) wrapped object by
functools.wraps or functools.partial.
"""
while hasattr(obj, "__wrapped__"):
obj = obj.__wrapped__
if isinstance(obj, py.std.functools.partial):
obj = obj.func
return obj
def getfslineno(obj):
# xxx let decorators etc specify a sane ordering
obj = get_real_func(obj)
if hasattr(obj, 'place_as'):
obj = obj.place_as
fslineno = py.code.getfslineno(obj)
@ -594,7 +603,10 @@ class FunctionMixin(PyobjMixin):
def _prunetraceback(self, excinfo):
if hasattr(self, '_obj') and not self.config.option.fulltrace:
code = py.code.Code(self.obj)
if isinstance(self.obj, py.std.functools.partial):
code = py.code.Code(self.obj.func)
else:
code = py.code.Code(self.obj)
path, firstlineno = code.path, code.firstlineno
traceback = excinfo.traceback
ntraceback = traceback.cut(path=path, firstlineno=firstlineno)
@ -1537,7 +1549,7 @@ class FixtureLookupError(LookupError):
for function in stack:
fspath, lineno = getfslineno(function)
try:
lines, _ = inspect.getsourcelines(function)
lines, _ = inspect.getsourcelines(get_real_func(function))
except IOError:
error_msg = "file %s, line %s: source code not available"
addline(error_msg % (fspath, lineno+1))
@ -1937,7 +1949,15 @@ def getfuncargnames(function, startindex=None):
if realfunction != function:
startindex += num_mock_patch_args(function)
function = realfunction
argnames = inspect.getargs(py.code.getrawcode(function))[0]
if isinstance(function, py.std.functools.partial):
argnames = inspect.getargs(py.code.getrawcode(function.func))[0]
partial = function
argnames = argnames[len(partial.args):]
if partial.keywords:
for kw in partial.keywords:
argnames.remove(kw)
else:
argnames = inspect.getargs(py.code.getrawcode(function))[0]
defaults = getattr(function, 'func_defaults',
getattr(function, '__defaults__', None)) or ()
numdefaults = len(defaults)

View File

@ -851,3 +851,47 @@ def test_unorderable_types(testdir):
result = testdir.runpytest()
assert "TypeError" not in result.stdout.str()
assert result.ret == 0
def test_collect_functools_partial(testdir):
"""
Test that collection of functools.partial object works, and arguments
to the wrapped functions are dealt correctly (see #811).
"""
testdir.makepyfile("""
import functools
import pytest
@pytest.fixture
def fix1():
return 'fix1'
@pytest.fixture
def fix2():
return 'fix2'
def check1(i, fix1):
assert i == 2
assert fix1 == 'fix1'
def check2(fix1, i):
assert i == 2
assert fix1 == 'fix1'
def check3(fix1, i, fix2):
assert i == 2
assert fix1 == 'fix1'
assert fix2 == 'fix2'
test_ok_1 = functools.partial(check1, i=2)
test_ok_2 = functools.partial(check1, i=2, fix1='fix1')
test_ok_3 = functools.partial(check1, 2)
test_ok_4 = functools.partial(check2, i=2)
test_ok_5 = functools.partial(check3, i=2)
test_ok_6 = functools.partial(check3, i=2, fix1='fix1')
test_fail_1 = functools.partial(check2, 2)
test_fail_2 = functools.partial(check3, 2)
""")
result = testdir.inline_run()
result.assertoutcome(passed=6, failed=2)