diff --git a/_pytest/python.py b/_pytest/python.py index 8a3fcd51a..2ddf83ef9 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -1861,6 +1861,27 @@ class FixtureDef: return ("" % (self.argname, self.scope, self.baseid)) +def handle_mock_module_patching(function, startindex): + """ + Special treatment when test_function is decorated + by mock.patch + """ + for candidate_module_name in ('unittest.mock', 'mock'): + # stdlib comes first + try: + mock = sys.modules[candidate_module_name] + except KeyError: + pass + else: + for patching in getattr(function, "patchings", []): + if (not patching.attribute_name + and patching.new is mock.DEFAULT): + startindex += 1 + break + else: + startindex += len(getattr(function, "patchings", [])) + return startindex + def getfuncargnames(function, startindex=None): # XXX merge with main.py's varnames #assert not inspect.isclass(function) @@ -1870,13 +1891,7 @@ def getfuncargnames(function, startindex=None): if startindex is None: startindex = inspect.ismethod(function) and 1 or 0 if realfunction != function: - mock = sys.modules.get('mock') - if mock is not None: - for patching in getattr(function, "patchings", []): - if not patching.attribute_name and patching.new is mock.DEFAULT: - startindex += 1 - else: - startindex += len(getattr(function, "patchings", [])) + startindex = handle_mock_module_patching(function, startindex) function = realfunction argnames = inspect.getargs(py.code.getrawcode(function))[0] defaults = getattr(function, 'func_defaults',