Allow creating ExceptionInfo from existing exc_info for better typing

This way the ExceptionInfo generic parameter can be inferred from the
passed-in exc_info. See for example the replaced cast().
This commit is contained in:
Ran Benita 2019-07-14 11:36:33 +03:00
parent 3f1fb62584
commit 11f1f79222
3 changed files with 47 additions and 20 deletions

View File

@ -396,6 +396,33 @@ class ExceptionInfo(Generic[_E]):
_striptext = attr.ib(type=str, default="") _striptext = attr.ib(type=str, default="")
_traceback = attr.ib(type=Optional[Traceback], default=None) _traceback = attr.ib(type=Optional[Traceback], default=None)
@classmethod
def from_exc_info(
cls,
exc_info: Tuple["Type[_E]", "_E", TracebackType],
exprinfo: Optional[str] = None,
) -> "ExceptionInfo[_E]":
"""returns an ExceptionInfo for an existing exc_info tuple.
.. warning::
Experimental API
:param exprinfo: a text string helping to determine if we should
strip ``AssertionError`` from the output, defaults
to the exception message/``__str__()``
"""
_striptext = ""
if exprinfo is None and isinstance(exc_info[1], AssertionError):
exprinfo = getattr(exc_info[1], "msg", None)
if exprinfo is None:
exprinfo = saferepr(exc_info[1])
if exprinfo and exprinfo.startswith(cls._assert_start_repr):
_striptext = "AssertionError: "
return cls(exc_info, _striptext)
@classmethod @classmethod
def from_current( def from_current(
cls, exprinfo: Optional[str] = None cls, exprinfo: Optional[str] = None
@ -411,20 +438,12 @@ class ExceptionInfo(Generic[_E]):
strip ``AssertionError`` from the output, defaults strip ``AssertionError`` from the output, defaults
to the exception message/``__str__()`` to the exception message/``__str__()``
""" """
tup_ = sys.exc_info() tup = sys.exc_info()
assert tup_[0] is not None, "no current exception" assert tup[0] is not None, "no current exception"
assert tup_[1] is not None, "no current exception" assert tup[1] is not None, "no current exception"
assert tup_[2] is not None, "no current exception" assert tup[2] is not None, "no current exception"
tup = (tup_[0], tup_[1], tup_[2]) exc_info = (tup[0], tup[1], tup[2])
_striptext = "" return cls.from_exc_info(exc_info)
if exprinfo is None and isinstance(tup[1], AssertionError):
exprinfo = getattr(tup[1], "msg", None)
if exprinfo is None:
exprinfo = saferepr(tup[1])
if exprinfo and exprinfo.startswith(cls._assert_start_repr):
_striptext = "AssertionError: "
return cls(tup, _striptext)
@classmethod @classmethod
def for_later(cls) -> "ExceptionInfo[_E]": def for_later(cls) -> "ExceptionInfo[_E]":

View File

@ -707,11 +707,11 @@ def raises(
) )
try: try:
func(*args[1:], **kwargs) func(*args[1:], **kwargs)
except expected_exception: except expected_exception as e:
# Cast to narrow the type to expected_exception (_E). # We just caught the exception - there is a traceback.
return cast( assert e.__traceback__ is not None
_pytest._code.ExceptionInfo[_E], return _pytest._code.ExceptionInfo.from_exc_info(
_pytest._code.ExceptionInfo.from_current(), (type(e), e, e.__traceback__)
) )
fail(message) fail(message)

View File

@ -58,7 +58,7 @@ class TWMock:
fullwidth = 80 fullwidth = 80
def test_excinfo_simple(): def test_excinfo_simple() -> None:
try: try:
raise ValueError raise ValueError
except ValueError: except ValueError:
@ -66,6 +66,14 @@ def test_excinfo_simple():
assert info.type == ValueError assert info.type == ValueError
def test_excinfo_from_exc_info_simple():
try:
raise ValueError
except ValueError as e:
info = _pytest._code.ExceptionInfo.from_exc_info((type(e), e, e.__traceback__))
assert info.type == ValueError
def test_excinfo_getstatement(): def test_excinfo_getstatement():
def g(): def g():
raise ValueError raise ValueError