Use a callable __tracebackhide__ for filtering

While this leads to slightly more complicated user code for the common
case (checking if the exception is of a given type) it's easier to
implement and more flexible.
This commit is contained in:
Florian Bruhin 2016-04-20 11:07:34 +02:00
parent 4c552d4ef7
commit 75160547f2
4 changed files with 27 additions and 32 deletions

View File

@ -23,8 +23,9 @@
Thanks `@omarkohl`_ for the complete PR (`#1502`_) and `@nicoddemus`_ for the Thanks `@omarkohl`_ for the complete PR (`#1502`_) and `@nicoddemus`_ for the
implementation tips. implementation tips.
* ``__tracebackhide__`` can now also be set to an exception type (or a list of * ``__tracebackhide__`` can now also be set to a callable which then can decide
exception types) to only filter exceptions of that type. whether to filter the traceback based on the ``ExceptionInfo`` object passed
to it.
* *

View File

@ -140,8 +140,8 @@ class TracebackEntry(object):
_repr_style = None _repr_style = None
exprinfo = None exprinfo = None
def __init__(self, rawentry, exctype=None): def __init__(self, rawentry, excinfo=None):
self._exctype = exctype self._excinfo = excinfo
self._rawentry = rawentry self._rawentry = rawentry
self.lineno = rawentry.tb_lineno - 1 self.lineno = rawentry.tb_lineno - 1
@ -218,16 +218,12 @@ class TracebackEntry(object):
source = property(getsource) source = property(getsource)
def _is_exception_type(self, obj):
return isinstance(obj, type) and issubclass(obj, Exception)
def ishidden(self): def ishidden(self):
""" return True if the current frame has a var __tracebackhide__ """ return True if the current frame has a var __tracebackhide__
resolving to True resolving to True
If __tracebackhide__ is set to an exception type, or a list/tuple, If __tracebackhide__ is a callable, it gets called with the
the traceback is only hidden if the exception which happened is of ExceptionInfo instance and can decide whether to hide the traceback.
the given type(s).
mostly for internal use mostly for internal use
""" """
@ -239,13 +235,8 @@ class TracebackEntry(object):
except KeyError: except KeyError:
return False return False
if self._is_exception_type(tbh): if callable(tbh):
assert self._exctype is not None return tbh(self._excinfo)
return issubclass(self._exctype, tbh)
elif (isinstance(tbh, (list, tuple)) and
all(self._is_exception_type(e) for e in tbh)):
assert self._exctype is not None
return issubclass(self._exctype, tuple(tbh))
else: else:
return tbh return tbh
@ -272,13 +263,13 @@ class Traceback(list):
access to Traceback entries. access to Traceback entries.
""" """
Entry = TracebackEntry Entry = TracebackEntry
def __init__(self, tb, exctype=None): def __init__(self, tb, excinfo=None):
""" initialize from given python traceback object and exc type. """ """ initialize from given python traceback object and ExceptionInfo """
self._exctype = exctype self._excinfo = excinfo
if hasattr(tb, 'tb_next'): if hasattr(tb, 'tb_next'):
def f(cur): def f(cur):
while cur is not None: while cur is not None:
yield self.Entry(cur, exctype=exctype) yield self.Entry(cur, excinfo=excinfo)
cur = cur.tb_next cur = cur.tb_next
list.__init__(self, f(tb)) list.__init__(self, f(tb))
else: else:
@ -302,7 +293,7 @@ class Traceback(list):
not codepath.relto(excludepath)) and not codepath.relto(excludepath)) and
(lineno is None or x.lineno == lineno) and (lineno is None or x.lineno == lineno) and
(firstlineno is None or x.frame.code.firstlineno == firstlineno)): (firstlineno is None or x.frame.code.firstlineno == firstlineno)):
return Traceback(x._rawentry, self._exctype) return Traceback(x._rawentry, self._excinfo)
return self return self
def __getitem__(self, key): def __getitem__(self, key):
@ -321,7 +312,7 @@ class Traceback(list):
by default this removes all the TracebackItems which are hidden by default this removes all the TracebackItems which are hidden
(see ishidden() above) (see ishidden() above)
""" """
return Traceback(filter(fn, self), self._exctype) return Traceback(filter(fn, self), self._excinfo)
def getcrashentry(self): def getcrashentry(self):
""" return last non-hidden traceback entry that lead """ return last non-hidden traceback entry that lead
@ -385,7 +376,7 @@ class ExceptionInfo(object):
#: the exception type name #: the exception type name
self.typename = self.type.__name__ self.typename = self.type.__name__
#: the exception traceback (_pytest._code.Traceback instance) #: the exception traceback (_pytest._code.Traceback instance)
self.traceback = _pytest._code.Traceback(self.tb, exctype=self.type) self.traceback = _pytest._code.Traceback(self.tb, excinfo=self)
def __repr__(self): def __repr__(self):
return "<ExceptionInfo %s tblen=%d>" % (self.typename, len(self.traceback)) return "<ExceptionInfo %s tblen=%d>" % (self.typename, len(self.traceback))

View File

@ -216,16 +216,18 @@ Let's run our little function::
test_checkconfig.py:8: Failed test_checkconfig.py:8: Failed
1 failed in 0.12 seconds 1 failed in 0.12 seconds
If you only want to hide certain exception classes, you can also set If you only want to hide certain exceptions, you can set ``__tracebackhide__``
``__tracebackhide__`` to an exception type or a list of exception types:: to a callable which gets the ``ExceptionInfo`` object. You can for example use
this to make sure unexpected exception types aren't hidden::
import operator
import pytest import pytest
class ConfigException(Exception): class ConfigException(Exception):
pass pass
def checkconfig(x): def checkconfig(x):
__tracebackhide__ = ConfigException __tracebackhide__ = operator.methodcaller('errisinstance', ConfigException)
if not hasattr(x, "config"): if not hasattr(x, "config"):
raise ConfigException("not configured: %s" %(x,)) raise ConfigException("not configured: %s" %(x,))

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import operator
import _pytest import _pytest
import py import py
import pytest import pytest
@ -145,10 +146,10 @@ class TestTraceback_f_g_h:
assert len(ntraceback) == len(traceback) - 1 assert len(ntraceback) == len(traceback) - 1
@pytest.mark.parametrize('tracebackhide, matching', [ @pytest.mark.parametrize('tracebackhide, matching', [
(ValueError, True), (lambda info: True, True),
(IndexError, False), (lambda info: False, False),
([ValueError, IndexError], True), (operator.methodcaller('errisinstance', ValueError), True),
((ValueError, IndexError), True), (operator.methodcaller('errisinstance', IndexError), False),
]) ])
def test_traceback_filter_selective(self, tracebackhide, matching): def test_traceback_filter_selective(self, tracebackhide, matching):
def f(): def f():
@ -475,7 +476,7 @@ raise ValueError()
f_globals = {} f_globals = {}
class FakeTracebackEntry(_pytest._code.Traceback.Entry): class FakeTracebackEntry(_pytest._code.Traceback.Entry):
def __init__(self, tb, exctype=None): def __init__(self, tb, excinfo=None):
self.lineno = 5+3 self.lineno = 5+3
@property @property