code: stop storing weakref to ExceptionInfo on Traceback and TracebackEntry

TracebackEntry needs the excinfo for the `__tracebackhide__ = callback`
functionality, where `callback` accepts the excinfo.

Currently it achieves this by storing a weakref to the excinfo which
created it. I think this is not great, mixing layers and bloating the
objects.

Instead, have `ishidden` (and transitively, `Traceback.filter()`) take
the excinfo as a parameter.
This commit is contained in:
Ran Benita 2023-04-12 23:17:54 +03:00
parent 11965d1c27
commit cc23ec91d0
6 changed files with 46 additions and 41 deletions

View File

@ -31,7 +31,6 @@ from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
from typing import Union from typing import Union
from weakref import ref
import pluggy import pluggy
@ -52,7 +51,6 @@ from _pytest.pathlib import bestrelpath
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Literal from typing_extensions import Literal
from typing_extensions import SupportsIndex from typing_extensions import SupportsIndex
from weakref import ReferenceType
_TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"] _TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"]
@ -194,15 +192,13 @@ class Frame:
class TracebackEntry: class TracebackEntry:
"""A single entry in a Traceback.""" """A single entry in a Traceback."""
__slots__ = ("_rawentry", "_excinfo", "_repr_style") __slots__ = ("_rawentry", "_repr_style")
def __init__( def __init__(
self, self,
rawentry: TracebackType, rawentry: TracebackType,
excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None,
) -> None: ) -> None:
self._rawentry = rawentry self._rawentry = rawentry
self._excinfo = excinfo
self._repr_style: Optional['Literal["short", "long"]'] = None self._repr_style: Optional['Literal["short", "long"]'] = None
@property @property
@ -272,7 +268,7 @@ class TracebackEntry:
source = property(getsource) source = property(getsource)
def ishidden(self) -> bool: def ishidden(self, excinfo: Optional["ExceptionInfo[BaseException]"]) -> bool:
"""Return True if the current frame has a var __tracebackhide__ """Return True if the current frame has a var __tracebackhide__
resolving to True. resolving to True.
@ -296,7 +292,7 @@ class TracebackEntry:
else: else:
break break
if tbh and callable(tbh): if tbh and callable(tbh):
return tbh(None if self._excinfo is None else self._excinfo()) return tbh(excinfo)
return tbh return tbh
def __str__(self) -> str: def __str__(self) -> str:
@ -329,16 +325,14 @@ class Traceback(List[TracebackEntry]):
def __init__( def __init__(
self, self,
tb: Union[TracebackType, Iterable[TracebackEntry]], tb: Union[TracebackType, Iterable[TracebackEntry]],
excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None,
) -> None: ) -> None:
"""Initialize from given python traceback object and ExceptionInfo.""" """Initialize from given python traceback object and ExceptionInfo."""
self._excinfo = excinfo
if isinstance(tb, TracebackType): if isinstance(tb, TracebackType):
def f(cur: TracebackType) -> Iterable[TracebackEntry]: def f(cur: TracebackType) -> Iterable[TracebackEntry]:
cur_: Optional[TracebackType] = cur cur_: Optional[TracebackType] = cur
while cur_ is not None: while cur_ is not None:
yield TracebackEntry(cur_, excinfo=excinfo) yield TracebackEntry(cur_)
cur_ = cur_.tb_next cur_ = cur_.tb_next
super().__init__(f(tb)) super().__init__(f(tb))
@ -378,7 +372,7 @@ class Traceback(List[TracebackEntry]):
continue continue
if firstlineno is not None and x.frame.code.firstlineno != firstlineno: if firstlineno is not None and x.frame.code.firstlineno != firstlineno:
continue continue
return Traceback(x._rawentry, self._excinfo) return Traceback(x._rawentry)
return self return self
@overload @overload
@ -398,25 +392,36 @@ class Traceback(List[TracebackEntry]):
return super().__getitem__(key) return super().__getitem__(key)
def filter( def filter(
self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden() self,
# TODO(py38): change to positional only.
_excinfo_or_fn: Union[
"ExceptionInfo[BaseException]",
Callable[[TracebackEntry], bool],
],
) -> "Traceback": ) -> "Traceback":
"""Return a Traceback instance with certain items removed """Return a Traceback instance with certain items removed.
fn is a function that gets a single argument, a TracebackEntry If the filter is an `ExceptionInfo`, removes all the ``TracebackEntry``s
instance, and should return True when the item should be added which are hidden (see ishidden() above).
to the Traceback, False when not.
By default this removes all the TracebackEntries which are hidden Otherwise, the filter is a function that gets a single argument, a
(see ishidden() above). ``TracebackEntry`` instance, and should return True when the item should
be added to the ``Traceback``, False when not.
""" """
return Traceback(filter(fn, self), self._excinfo) if isinstance(_excinfo_or_fn, ExceptionInfo):
fn = lambda x: not x.ishidden(_excinfo_or_fn) # noqa: E731
else:
fn = _excinfo_or_fn
return Traceback(filter(fn, self))
def getcrashentry(self) -> Optional[TracebackEntry]: def getcrashentry(
self, excinfo: Optional["ExceptionInfo[BaseException]"]
) -> Optional[TracebackEntry]:
"""Return last non-hidden traceback entry that lead to the exception of """Return last non-hidden traceback entry that lead to the exception of
a traceback, or None if all hidden.""" a traceback, or None if all hidden."""
for i in range(-1, -len(self) - 1, -1): for i in range(-1, -len(self) - 1, -1):
entry = self[i] entry = self[i]
if not entry.ishidden(): if not entry.ishidden(excinfo):
return entry return entry
return None return None
@ -583,7 +588,7 @@ class ExceptionInfo(Generic[E]):
def traceback(self) -> Traceback: def traceback(self) -> Traceback:
"""The traceback.""" """The traceback."""
if self._traceback is None: if self._traceback is None:
self._traceback = Traceback(self.tb, excinfo=ref(self)) self._traceback = Traceback(self.tb)
return self._traceback return self._traceback
@traceback.setter @traceback.setter
@ -624,7 +629,7 @@ class ExceptionInfo(Generic[E]):
def _getreprcrash(self) -> Optional["ReprFileLocation"]: def _getreprcrash(self) -> Optional["ReprFileLocation"]:
exconly = self.exconly(tryshort=True) exconly = self.exconly(tryshort=True)
entry = self.traceback.getcrashentry() entry = self.traceback.getcrashentry(self)
if entry is None: if entry is None:
return None return None
path, lineno = entry.frame.code.raw.co_filename, entry.lineno path, lineno = entry.frame.code.raw.co_filename, entry.lineno
@ -882,7 +887,7 @@ class FormattedExcinfo:
def repr_traceback(self, excinfo: ExceptionInfo[BaseException]) -> "ReprTraceback": def repr_traceback(self, excinfo: ExceptionInfo[BaseException]) -> "ReprTraceback":
traceback = excinfo.traceback traceback = excinfo.traceback
if self.tbfilter: if self.tbfilter:
traceback = traceback.filter() traceback = traceback.filter(excinfo)
if isinstance(excinfo.value, RecursionError): if isinstance(excinfo.value, RecursionError):
traceback, extraline = self._truncate_recursive_traceback(traceback) traceback, extraline = self._truncate_recursive_traceback(traceback)

View File

@ -560,7 +560,7 @@ class Collector(Node):
ntraceback = traceback.cut(path=self.path) ntraceback = traceback.cut(path=self.path)
if ntraceback == traceback: if ntraceback == traceback:
ntraceback = ntraceback.cut(excludepath=tracebackcutdir) ntraceback = ntraceback.cut(excludepath=tracebackcutdir)
excinfo.traceback = ntraceback.filter() excinfo.traceback = ntraceback.filter(excinfo)
def _check_initialpaths_for_relpath(session: "Session", path: Path) -> Optional[str]: def _check_initialpaths_for_relpath(session: "Session", path: Path) -> Optional[str]:

View File

@ -1814,7 +1814,7 @@ class Function(PyobjMixin, nodes.Item):
if not ntraceback: if not ntraceback:
ntraceback = traceback ntraceback = traceback
excinfo.traceback = ntraceback.filter() excinfo.traceback = ntraceback.filter(excinfo)
# issue364: mark all but first and last frames to # issue364: mark all but first and last frames to
# only show a single-line message for each frame. # only show a single-line message for each frame.
if self.config.getoption("tbstyle", "auto") == "auto": if self.config.getoption("tbstyle", "auto") == "auto":

View File

@ -339,7 +339,7 @@ class TestCaseFunction(Function):
) -> None: ) -> None:
super()._prunetraceback(excinfo) super()._prunetraceback(excinfo)
traceback = excinfo.traceback.filter( traceback = excinfo.traceback.filter(
lambda x: not x.frame.f_globals.get("__unittest") lambda x: not x.frame.f_globals.get("__unittest"),
) )
if traceback: if traceback:
excinfo.traceback = traceback excinfo.traceback = traceback

View File

@ -186,7 +186,7 @@ class TestTraceback_f_g_h:
def test_traceback_filter(self): def test_traceback_filter(self):
traceback = self.excinfo.traceback traceback = self.excinfo.traceback
ntraceback = traceback.filter() ntraceback = traceback.filter(self.excinfo)
assert len(ntraceback) == len(traceback) - 1 assert len(ntraceback) == len(traceback) - 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -217,7 +217,7 @@ class TestTraceback_f_g_h:
excinfo = pytest.raises(ValueError, h) excinfo = pytest.raises(ValueError, h)
traceback = excinfo.traceback traceback = excinfo.traceback
ntraceback = traceback.filter() ntraceback = traceback.filter(excinfo)
print(f"old: {traceback!r}") print(f"old: {traceback!r}")
print(f"new: {ntraceback!r}") print(f"new: {ntraceback!r}")
@ -307,7 +307,7 @@ class TestTraceback_f_g_h:
excinfo = pytest.raises(ValueError, f) excinfo = pytest.raises(ValueError, f)
tb = excinfo.traceback tb = excinfo.traceback
entry = tb.getcrashentry() entry = tb.getcrashentry(excinfo)
assert entry is not None assert entry is not None
co = _pytest._code.Code.from_function(h) co = _pytest._code.Code.from_function(h)
assert entry.frame.code.path == co.path assert entry.frame.code.path == co.path
@ -324,7 +324,7 @@ class TestTraceback_f_g_h:
g() g()
excinfo = pytest.raises(ValueError, f) excinfo = pytest.raises(ValueError, f)
assert excinfo.traceback.getcrashentry() is None assert excinfo.traceback.getcrashentry(excinfo) is None
def test_excinfo_exconly(): def test_excinfo_exconly():
@ -626,7 +626,7 @@ raise ValueError()
""" """
) )
excinfo = pytest.raises(ValueError, mod.func1) excinfo = pytest.raises(ValueError, mod.func1)
excinfo.traceback = excinfo.traceback.filter() excinfo.traceback = excinfo.traceback.filter(excinfo)
p = FormattedExcinfo() p = FormattedExcinfo()
reprtb = p.repr_traceback_entry(excinfo.traceback[-1]) reprtb = p.repr_traceback_entry(excinfo.traceback[-1])
@ -659,7 +659,7 @@ raise ValueError()
""" """
) )
excinfo = pytest.raises(ValueError, mod.func1, "m" * 90, 5, 13, "z" * 120) excinfo = pytest.raises(ValueError, mod.func1, "m" * 90, 5, 13, "z" * 120)
excinfo.traceback = excinfo.traceback.filter() excinfo.traceback = excinfo.traceback.filter(excinfo)
entry = excinfo.traceback[-1] entry = excinfo.traceback[-1]
p = FormattedExcinfo(funcargs=True) p = FormattedExcinfo(funcargs=True)
reprfuncargs = p.repr_args(entry) reprfuncargs = p.repr_args(entry)
@ -686,7 +686,7 @@ raise ValueError()
""" """
) )
excinfo = pytest.raises(ValueError, mod.func1, "a", "b", c="d") excinfo = pytest.raises(ValueError, mod.func1, "a", "b", c="d")
excinfo.traceback = excinfo.traceback.filter() excinfo.traceback = excinfo.traceback.filter(excinfo)
entry = excinfo.traceback[-1] entry = excinfo.traceback[-1]
p = FormattedExcinfo(funcargs=True) p = FormattedExcinfo(funcargs=True)
reprfuncargs = p.repr_args(entry) reprfuncargs = p.repr_args(entry)
@ -960,7 +960,7 @@ raise ValueError()
""" """
) )
excinfo = pytest.raises(ValueError, mod.f) excinfo = pytest.raises(ValueError, mod.f)
excinfo.traceback = excinfo.traceback.filter() excinfo.traceback = excinfo.traceback.filter(excinfo)
repr = excinfo.getrepr() repr = excinfo.getrepr()
repr.toterminal(tw_mock) repr.toterminal(tw_mock)
assert tw_mock.lines[0] == "" assert tw_mock.lines[0] == ""
@ -994,7 +994,7 @@ raise ValueError()
) )
excinfo = pytest.raises(ValueError, mod.f) excinfo = pytest.raises(ValueError, mod.f)
tmp_path.joinpath("mod.py").unlink() tmp_path.joinpath("mod.py").unlink()
excinfo.traceback = excinfo.traceback.filter() excinfo.traceback = excinfo.traceback.filter(excinfo)
repr = excinfo.getrepr() repr = excinfo.getrepr()
repr.toterminal(tw_mock) repr.toterminal(tw_mock)
assert tw_mock.lines[0] == "" assert tw_mock.lines[0] == ""
@ -1026,7 +1026,7 @@ raise ValueError()
) )
excinfo = pytest.raises(ValueError, mod.f) excinfo = pytest.raises(ValueError, mod.f)
tmp_path.joinpath("mod.py").write_text("asdf") tmp_path.joinpath("mod.py").write_text("asdf")
excinfo.traceback = excinfo.traceback.filter() excinfo.traceback = excinfo.traceback.filter(excinfo)
repr = excinfo.getrepr() repr = excinfo.getrepr()
repr.toterminal(tw_mock) repr.toterminal(tw_mock)
assert tw_mock.lines[0] == "" assert tw_mock.lines[0] == ""
@ -1123,7 +1123,7 @@ raise ValueError()
""" """
) )
excinfo = pytest.raises(ValueError, mod.f) excinfo = pytest.raises(ValueError, mod.f)
excinfo.traceback = excinfo.traceback.filter() excinfo.traceback = excinfo.traceback.filter(excinfo)
excinfo.traceback[1].set_repr_style("short") excinfo.traceback[1].set_repr_style("short")
excinfo.traceback[2].set_repr_style("short") excinfo.traceback[2].set_repr_style("short")
r = excinfo.getrepr(style="long") r = excinfo.getrepr(style="long")
@ -1391,7 +1391,7 @@ raise ValueError()
with pytest.raises(TypeError) as excinfo: with pytest.raises(TypeError) as excinfo:
mod.f() mod.f()
# previously crashed with `AttributeError: list has no attribute get` # previously crashed with `AttributeError: list has no attribute get`
excinfo.traceback.filter() excinfo.traceback.filter(excinfo)
@pytest.mark.parametrize("style", ["short", "long"]) @pytest.mark.parametrize("style", ["short", "long"])

View File

@ -1003,9 +1003,9 @@ class TestTracebackCutting:
with pytest.raises(pytest.skip.Exception) as excinfo: with pytest.raises(pytest.skip.Exception) as excinfo:
pytest.skip("xxx") pytest.skip("xxx")
assert excinfo.traceback[-1].frame.code.name == "skip" assert excinfo.traceback[-1].frame.code.name == "skip"
assert excinfo.traceback[-1].ishidden() assert excinfo.traceback[-1].ishidden(excinfo)
assert excinfo.traceback[-2].frame.code.name == "test_skip_simple" assert excinfo.traceback[-2].frame.code.name == "test_skip_simple"
assert not excinfo.traceback[-2].ishidden() assert not excinfo.traceback[-2].ishidden(excinfo)
def test_traceback_argsetup(self, pytester: Pytester) -> None: def test_traceback_argsetup(self, pytester: Pytester) -> None:
pytester.makeconftest( pytester.makeconftest(