add changelog entry and refactor unittest.mock.patch fix a bit

This commit is contained in:
holger krekel 2014-07-28 10:34:01 +02:00
parent 8792261df1
commit ba878c6d9d
3 changed files with 16 additions and 23 deletions

View File

@ -44,3 +44,4 @@ Daniel Grana
Andy Freeland Andy Freeland
Trevor Bekolay Trevor Bekolay
David Mohr David Mohr
Nicolas Delaby

View File

@ -6,6 +6,9 @@ NEXT
or a tuple of exception classes. Thanks David Mohr for the complete or a tuple of exception classes. Thanks David Mohr for the complete
PR. PR.
- fix integration of pytest with unittest.mock.patch decorator when
it uses the "new" argument. Thanks Nicolas Delaby for test and PR.
2.6 2.6
----------------------------------- -----------------------------------

View File

@ -1861,28 +1861,17 @@ class FixtureDef:
return ("<FixtureDef name=%r scope=%r baseid=%r >" % return ("<FixtureDef name=%r scope=%r baseid=%r >" %
(self.argname, self.scope, self.baseid)) (self.argname, self.scope, self.baseid))
def handle_mock_module_patching(function, startindex): def num_mock_patch_args(function):
""" """ return number of arguments used up by mock arguments (if any) """
Special treatment when test_function is decorated patchings = getattr(function, "patchings", None)
by mock.patch if not patchings:
""" return 0
for candidate_module_name in ('mock', 'unittest.mock'): mock = sys.modules.get("mock", sys.modules.get("unittest.mock", None))
# stdlib comes last, because mock might be also installed if mock is not None:
# as a third party with upgraded version compare to return len([p for p in patchings
# unittest.mock if not p.attribute_name and p.new is mock.DEFAULT])
try: return len(patchings)
mock = sys.modules[candidate_module_name]
except KeyError:
pass
else:
for patching in getattr(function, "patchings", []):
if (not patching.attribute_name
and patching.new is mock.DEFAULT):
startindex += 1
break
else:
startindex += len(getattr(function, "patchings", []))
return startindex
def getfuncargnames(function, startindex=None): def getfuncargnames(function, startindex=None):
# XXX merge with main.py's varnames # XXX merge with main.py's varnames
@ -1893,7 +1882,7 @@ def getfuncargnames(function, startindex=None):
if startindex is None: if startindex is None:
startindex = inspect.ismethod(function) and 1 or 0 startindex = inspect.ismethod(function) and 1 or 0
if realfunction != function: if realfunction != function:
startindex = handle_mock_module_patching(function, startindex) startindex += num_mock_patch_args(function)
function = realfunction function = realfunction
argnames = inspect.getargs(py.code.getrawcode(function))[0] argnames = inspect.getargs(py.code.getrawcode(function))[0]
defaults = getattr(function, 'func_defaults', defaults = getattr(function, 'func_defaults',