diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index 5bc78f478..26db750e6 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -31,7 +31,6 @@ from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from weakref import ref import pluggy @@ -52,7 +51,6 @@ from _pytest.pathlib import bestrelpath if TYPE_CHECKING: from typing_extensions import Literal from typing_extensions import SupportsIndex - from weakref import ReferenceType _TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"] @@ -194,15 +192,13 @@ class Frame: class TracebackEntry: """A single entry in a Traceback.""" - __slots__ = ("_rawentry", "_excinfo", "_repr_style") + __slots__ = ("_rawentry", "_repr_style") def __init__( self, rawentry: TracebackType, - excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None, ) -> None: self._rawentry = rawentry - self._excinfo = excinfo self._repr_style: Optional['Literal["short", "long"]'] = None @property @@ -272,7 +268,7 @@ class TracebackEntry: source = property(getsource) - def ishidden(self) -> bool: + def ishidden(self, excinfo: Optional["ExceptionInfo[BaseException]"]) -> bool: """Return True if the current frame has a var __tracebackhide__ resolving to True. @@ -296,7 +292,7 @@ class TracebackEntry: else: break if tbh and callable(tbh): - return tbh(None if self._excinfo is None else self._excinfo()) + return tbh(excinfo) return tbh def __str__(self) -> str: @@ -329,16 +325,14 @@ class Traceback(List[TracebackEntry]): def __init__( self, tb: Union[TracebackType, Iterable[TracebackEntry]], - excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None, ) -> None: """Initialize from given python traceback object and ExceptionInfo.""" - self._excinfo = excinfo if isinstance(tb, TracebackType): def f(cur: TracebackType) -> Iterable[TracebackEntry]: cur_: Optional[TracebackType] = cur while cur_ is not None: - yield TracebackEntry(cur_, excinfo=excinfo) + yield TracebackEntry(cur_) cur_ = cur_.tb_next super().__init__(f(tb)) @@ -378,7 +372,7 @@ class Traceback(List[TracebackEntry]): continue if firstlineno is not None and x.frame.code.firstlineno != firstlineno: continue - return Traceback(x._rawentry, self._excinfo) + return Traceback(x._rawentry) return self @overload @@ -398,25 +392,36 @@ class Traceback(List[TracebackEntry]): return super().__getitem__(key) def filter( - self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden() + self, + # TODO(py38): change to positional only. + _excinfo_or_fn: Union[ + "ExceptionInfo[BaseException]", + Callable[[TracebackEntry], bool], + ], ) -> "Traceback": - """Return a Traceback instance with certain items removed + """Return a Traceback instance with certain items removed. - fn is a function that gets a single argument, a TracebackEntry - instance, and should return True when the item should be added - to the Traceback, False when not. + If the filter is an `ExceptionInfo`, removes all the ``TracebackEntry``s + which are hidden (see ishidden() above). - By default this removes all the TracebackEntries which are hidden - (see ishidden() above). + Otherwise, the filter is a function that gets a single argument, a + ``TracebackEntry`` instance, and should return True when the item should + be added to the ``Traceback``, False when not. """ - return Traceback(filter(fn, self), self._excinfo) + if isinstance(_excinfo_or_fn, ExceptionInfo): + fn = lambda x: not x.ishidden(_excinfo_or_fn) # noqa: E731 + else: + fn = _excinfo_or_fn + return Traceback(filter(fn, self)) - def getcrashentry(self) -> Optional[TracebackEntry]: + def getcrashentry( + self, excinfo: Optional["ExceptionInfo[BaseException]"] + ) -> Optional[TracebackEntry]: """Return last non-hidden traceback entry that lead to the exception of a traceback, or None if all hidden.""" for i in range(-1, -len(self) - 1, -1): entry = self[i] - if not entry.ishidden(): + if not entry.ishidden(excinfo): return entry return None @@ -583,7 +588,7 @@ class ExceptionInfo(Generic[E]): def traceback(self) -> Traceback: """The traceback.""" if self._traceback is None: - self._traceback = Traceback(self.tb, excinfo=ref(self)) + self._traceback = Traceback(self.tb) return self._traceback @traceback.setter @@ -624,7 +629,7 @@ class ExceptionInfo(Generic[E]): def _getreprcrash(self) -> Optional["ReprFileLocation"]: exconly = self.exconly(tryshort=True) - entry = self.traceback.getcrashentry() + entry = self.traceback.getcrashentry(self) if entry is None: return None path, lineno = entry.frame.code.raw.co_filename, entry.lineno @@ -882,7 +887,7 @@ class FormattedExcinfo: def repr_traceback(self, excinfo: ExceptionInfo[BaseException]) -> "ReprTraceback": traceback = excinfo.traceback if self.tbfilter: - traceback = traceback.filter() + traceback = traceback.filter(excinfo) if isinstance(excinfo.value, RecursionError): traceback, extraline = self._truncate_recursive_traceback(traceback) diff --git a/src/_pytest/nodes.py b/src/_pytest/nodes.py index ea016786e..738ab97e9 100644 --- a/src/_pytest/nodes.py +++ b/src/_pytest/nodes.py @@ -560,7 +560,7 @@ class Collector(Node): ntraceback = traceback.cut(path=self.path) if ntraceback == traceback: ntraceback = ntraceback.cut(excludepath=tracebackcutdir) - excinfo.traceback = ntraceback.filter() + excinfo.traceback = ntraceback.filter(excinfo) def _check_initialpaths_for_relpath(session: "Session", path: Path) -> Optional[str]: diff --git a/src/_pytest/python.py b/src/_pytest/python.py index d04b6fa4d..c65c41b97 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -1814,7 +1814,7 @@ class Function(PyobjMixin, nodes.Item): if not ntraceback: ntraceback = traceback - excinfo.traceback = ntraceback.filter() + excinfo.traceback = ntraceback.filter(excinfo) # issue364: mark all but first and last frames to # only show a single-line message for each frame. if self.config.getoption("tbstyle", "auto") == "auto": diff --git a/src/_pytest/unittest.py b/src/_pytest/unittest.py index c660aa75d..7a5e73661 100644 --- a/src/_pytest/unittest.py +++ b/src/_pytest/unittest.py @@ -339,7 +339,7 @@ class TestCaseFunction(Function): ) -> None: super()._prunetraceback(excinfo) traceback = excinfo.traceback.filter( - lambda x: not x.frame.f_globals.get("__unittest") + lambda x: not x.frame.f_globals.get("__unittest"), ) if traceback: excinfo.traceback = traceback diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index 6a720e64c..d0eb5e646 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -186,7 +186,7 @@ class TestTraceback_f_g_h: def test_traceback_filter(self): traceback = self.excinfo.traceback - ntraceback = traceback.filter() + ntraceback = traceback.filter(self.excinfo) assert len(ntraceback) == len(traceback) - 1 @pytest.mark.parametrize( @@ -217,7 +217,7 @@ class TestTraceback_f_g_h: excinfo = pytest.raises(ValueError, h) traceback = excinfo.traceback - ntraceback = traceback.filter() + ntraceback = traceback.filter(excinfo) print(f"old: {traceback!r}") print(f"new: {ntraceback!r}") @@ -307,7 +307,7 @@ class TestTraceback_f_g_h: excinfo = pytest.raises(ValueError, f) tb = excinfo.traceback - entry = tb.getcrashentry() + entry = tb.getcrashentry(excinfo) assert entry is not None co = _pytest._code.Code.from_function(h) assert entry.frame.code.path == co.path @@ -324,7 +324,7 @@ class TestTraceback_f_g_h: g() excinfo = pytest.raises(ValueError, f) - assert excinfo.traceback.getcrashentry() is None + assert excinfo.traceback.getcrashentry(excinfo) is None def test_excinfo_exconly(): @@ -626,7 +626,7 @@ raise ValueError() """ ) excinfo = pytest.raises(ValueError, mod.func1) - excinfo.traceback = excinfo.traceback.filter() + excinfo.traceback = excinfo.traceback.filter(excinfo) p = FormattedExcinfo() reprtb = p.repr_traceback_entry(excinfo.traceback[-1]) @@ -659,7 +659,7 @@ raise ValueError() """ ) excinfo = pytest.raises(ValueError, mod.func1, "m" * 90, 5, 13, "z" * 120) - excinfo.traceback = excinfo.traceback.filter() + excinfo.traceback = excinfo.traceback.filter(excinfo) entry = excinfo.traceback[-1] p = FormattedExcinfo(funcargs=True) reprfuncargs = p.repr_args(entry) @@ -686,7 +686,7 @@ raise ValueError() """ ) excinfo = pytest.raises(ValueError, mod.func1, "a", "b", c="d") - excinfo.traceback = excinfo.traceback.filter() + excinfo.traceback = excinfo.traceback.filter(excinfo) entry = excinfo.traceback[-1] p = FormattedExcinfo(funcargs=True) reprfuncargs = p.repr_args(entry) @@ -960,7 +960,7 @@ raise ValueError() """ ) excinfo = pytest.raises(ValueError, mod.f) - excinfo.traceback = excinfo.traceback.filter() + excinfo.traceback = excinfo.traceback.filter(excinfo) repr = excinfo.getrepr() repr.toterminal(tw_mock) assert tw_mock.lines[0] == "" @@ -994,7 +994,7 @@ raise ValueError() ) excinfo = pytest.raises(ValueError, mod.f) tmp_path.joinpath("mod.py").unlink() - excinfo.traceback = excinfo.traceback.filter() + excinfo.traceback = excinfo.traceback.filter(excinfo) repr = excinfo.getrepr() repr.toterminal(tw_mock) assert tw_mock.lines[0] == "" @@ -1026,7 +1026,7 @@ raise ValueError() ) excinfo = pytest.raises(ValueError, mod.f) tmp_path.joinpath("mod.py").write_text("asdf") - excinfo.traceback = excinfo.traceback.filter() + excinfo.traceback = excinfo.traceback.filter(excinfo) repr = excinfo.getrepr() repr.toterminal(tw_mock) assert tw_mock.lines[0] == "" @@ -1123,7 +1123,7 @@ raise ValueError() """ ) excinfo = pytest.raises(ValueError, mod.f) - excinfo.traceback = excinfo.traceback.filter() + excinfo.traceback = excinfo.traceback.filter(excinfo) excinfo.traceback[1].set_repr_style("short") excinfo.traceback[2].set_repr_style("short") r = excinfo.getrepr(style="long") @@ -1391,7 +1391,7 @@ raise ValueError() with pytest.raises(TypeError) as excinfo: mod.f() # previously crashed with `AttributeError: list has no attribute get` - excinfo.traceback.filter() + excinfo.traceback.filter(excinfo) @pytest.mark.parametrize("style", ["short", "long"]) diff --git a/testing/python/collect.py b/testing/python/collect.py index ac3edd395..41845517b 100644 --- a/testing/python/collect.py +++ b/testing/python/collect.py @@ -1003,9 +1003,9 @@ class TestTracebackCutting: with pytest.raises(pytest.skip.Exception) as excinfo: pytest.skip("xxx") assert excinfo.traceback[-1].frame.code.name == "skip" - assert excinfo.traceback[-1].ishidden() + assert excinfo.traceback[-1].ishidden(excinfo) assert excinfo.traceback[-2].frame.code.name == "test_skip_simple" - assert not excinfo.traceback[-2].ishidden() + assert not excinfo.traceback[-2].ishidden(excinfo) def test_traceback_argsetup(self, pytester: Pytester) -> None: pytester.makeconftest(