From 3d794b6b3891c673ba6947243dee0e986581b3ec Mon Sep 17 00:00:00 2001 From: holger krekel Date: Sat, 4 Oct 2014 15:49:31 +0200 Subject: [PATCH] factor out a small "wrapping" helper --- _pytest/core.py | 39 +++++++++++++++++++++++++++++---------- testing/test_core.py | 20 ++++++++++++++++++++ 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/_pytest/core.py b/_pytest/core.py index 937703654..cd54072bb 100644 --- a/_pytest/core.py +++ b/_pytest/core.py @@ -67,6 +67,24 @@ class TagTracerSub: def get(self, name): return self.__class__(self.root, self.tags + (name,)) +def add_method_controller(cls, func): + name = func.__name__ + oldcall = getattr(cls, name) + def wrap_exec(*args, **kwargs): + gen = func(*args, **kwargs) + gen.next() # first yield + res = oldcall(*args, **kwargs) + try: + gen.send(res) + except StopIteration: + pass + else: + raise ValueError("expected StopIteration") + return res + setattr(cls, name, wrap_exec) + return lambda: setattr(cls, name, oldcall) + + class PluginManager(object): def __init__(self, hookspecs=None, prefix="pytest_"): self._name2plugin = {} @@ -80,23 +98,24 @@ class PluginManager(object): def set_tracing(self, writer): self.trace.root.setwriter(writer) - # we reconfigure HookCalling to perform tracing - # and we avoid doing the "do we need to trace" check dynamically - # for speed reasons - assert HookCaller._docall.__name__ == "_docall" - real_docall = HookCaller._docall - def docall_tracing(self, methods, kwargs): + # reconfigure HookCalling to perform tracing + assert not hasattr(self, "_wrapping") + self._wrapping = True + + def _docall(self, methods, kwargs): trace = self.hookrelay.trace trace.root.indent += 1 trace(self.name, kwargs) + res = None try: - res = real_docall(self, methods, kwargs) + res = yield + finally: if res: trace("finish", self.name, "-->", res) - finally: trace.root.indent -= 1 - return res - HookCaller._docall = docall_tracing + + undo = add_method_controller(HookCaller, _docall) + self.add_shutdown(undo) def do_configure(self, config): # backward compatibility diff --git a/testing/test_core.py b/testing/test_core.py index cd8fa8982..6cafc8812 100644 --- a/testing/test_core.py +++ b/testing/test_core.py @@ -765,3 +765,23 @@ def test_importplugin_issue375(testdir): assert "qwe" not in str(excinfo.value) assert "aaaa" in str(excinfo.value) + +def test_wrapping(): + class A: + def f(self): + return "A.f" + + shutdown = [] + l = [] + def f(self): + l.append(1) + x = yield + l.append(2) + undo = add_method_controller(A, f) + + assert A().f() == "A.f" + assert l == [1,2] + undo() + l[:] = [] + assert A().f() == "A.f" + assert l == []