From b607f6728f2eed51c9a7ce713d3b601dcd04f8db Mon Sep 17 00:00:00 2001 From: Florian Bruhin Date: Wed, 20 Apr 2016 10:25:33 +0200 Subject: [PATCH] Filter selectively with __tracebackhide__ When __tracebackhide__ gets set to an exception type or list/tuple of exception types, only those exceptions get filtered, while the full traceback is shown if another exception (e.g. a bug in a assertion helper) happens. --- CHANGELOG.rst | 3 +++ _pytest/_code/code.py | 37 +++++++++++++++++++++++++++--------- doc/en/example/simple.rst | 20 +++++++++++++++++++ testing/code/test_excinfo.py | 35 +++++++++++++++++++++++++++++++++- 4 files changed, 85 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index dfd5c2236..a58e1dc99 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -23,6 +23,9 @@ Thanks `@omarkohl`_ for the complete PR (`#1502`_) and `@nicoddemus`_ for the implementation tips. +* ``__tracebackhide__`` can now also be set to an exception type (or a list of + exception types) to only filter exceptions of that type. + * **Changes** diff --git a/_pytest/_code/code.py b/_pytest/_code/code.py index 79fcf9f1c..5fe81bc13 100644 --- a/_pytest/_code/code.py +++ b/_pytest/_code/code.py @@ -140,7 +140,8 @@ class TracebackEntry(object): _repr_style = None exprinfo = None - def __init__(self, rawentry): + def __init__(self, rawentry, exctype=None): + self._exctype = exctype self._rawentry = rawentry self.lineno = rawentry.tb_lineno - 1 @@ -217,20 +218,37 @@ class TracebackEntry(object): source = property(getsource) + def _is_exception_type(self, obj): + return isinstance(obj, type) and issubclass(obj, Exception) + def ishidden(self): """ return True if the current frame has a var __tracebackhide__ resolving to True + If __tracebackhide__ is set to an exception type, or a list/tuple, + the traceback is only hidden if the exception which happened is of + the given type(s). + mostly for internal use """ try: - return self.frame.f_locals['__tracebackhide__'] + tbh = self.frame.f_locals['__tracebackhide__'] except KeyError: try: - return self.frame.f_globals['__tracebackhide__'] + tbh = self.frame.f_globals['__tracebackhide__'] except KeyError: return False + if self._is_exception_type(tbh): + assert self._exctype is not None + return issubclass(self._exctype, tbh) + elif (isinstance(tbh, (list, tuple)) and + all(self._is_exception_type(e) for e in tbh)): + assert self._exctype is not None + return issubclass(self._exctype, tuple(tbh)) + else: + return tbh + def __str__(self): try: fn = str(self.path) @@ -254,12 +272,13 @@ class Traceback(list): access to Traceback entries. """ Entry = TracebackEntry - def __init__(self, tb): - """ initialize from given python traceback object. """ + def __init__(self, tb, exctype=None): + """ initialize from given python traceback object and exc type. """ + self._exctype = exctype if hasattr(tb, 'tb_next'): def f(cur): while cur is not None: - yield self.Entry(cur) + yield self.Entry(cur, exctype=exctype) cur = cur.tb_next list.__init__(self, f(tb)) else: @@ -283,7 +302,7 @@ class Traceback(list): not codepath.relto(excludepath)) and (lineno is None or x.lineno == lineno) and (firstlineno is None or x.frame.code.firstlineno == firstlineno)): - return Traceback(x._rawentry) + return Traceback(x._rawentry, self._exctype) return self def __getitem__(self, key): @@ -302,7 +321,7 @@ class Traceback(list): by default this removes all the TracebackItems which are hidden (see ishidden() above) """ - return Traceback(filter(fn, self)) + return Traceback(filter(fn, self), self._exctype) def getcrashentry(self): """ return last non-hidden traceback entry that lead @@ -366,7 +385,7 @@ class ExceptionInfo(object): #: the exception type name self.typename = self.type.__name__ #: the exception traceback (_pytest._code.Traceback instance) - self.traceback = _pytest._code.Traceback(self.tb) + self.traceback = _pytest._code.Traceback(self.tb, exctype=self.type) def __repr__(self): return "" % (self.typename, len(self.traceback)) diff --git a/doc/en/example/simple.rst b/doc/en/example/simple.rst index be12d2afe..ff06c7672 100644 --- a/doc/en/example/simple.rst +++ b/doc/en/example/simple.rst @@ -216,6 +216,26 @@ Let's run our little function:: test_checkconfig.py:8: Failed 1 failed in 0.12 seconds +If you only want to hide certain exception classes, you can also set +``__tracebackhide__`` to an exception type or a list of exception types:: + + import pytest + + class ConfigException(Exception): + pass + + def checkconfig(x): + __tracebackhide__ = ConfigException + if not hasattr(x, "config"): + raise ConfigException("not configured: %s" %(x,)) + + def test_something(): + checkconfig(42) + +This will avoid hiding the exception traceback on unrelated exceptions (i.e. +bugs in assertion helpers). + + Detect if running from within a pytest run -------------------------------------------------------------- diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index 0280d1aa3..8cf79166b 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -144,6 +144,39 @@ class TestTraceback_f_g_h: ntraceback = traceback.filter() assert len(ntraceback) == len(traceback) - 1 + @pytest.mark.parametrize('tracebackhide, matching', [ + (ValueError, True), + (IndexError, False), + ([ValueError, IndexError], True), + ((ValueError, IndexError), True), + ]) + def test_traceback_filter_selective(self, tracebackhide, matching): + def f(): + # + raise ValueError + # + def g(): + # + __tracebackhide__ = tracebackhide + f() + # + def h(): + # + g() + # + + excinfo = pytest.raises(ValueError, h) + traceback = excinfo.traceback + ntraceback = traceback.filter() + print('old: {!r}'.format(traceback)) + print('new: {!r}'.format(ntraceback)) + + if matching: + assert len(ntraceback) == len(traceback) - 2 + else: + # -1 because of the __tracebackhide__ in pytest.raises + assert len(ntraceback) == len(traceback) - 1 + def test_traceback_recursion_index(self): def f(n): if n < 10: @@ -442,7 +475,7 @@ raise ValueError() f_globals = {} class FakeTracebackEntry(_pytest._code.Traceback.Entry): - def __init__(self, tb): + def __init__(self, tb, exctype=None): self.lineno = 5+3 @property