Make ExceptionInfo generic in the exception type

This way, in

    with pytest.raises(ValueError) as cm:
        ...

cm.value is a ValueError and not a BaseException.
This commit is contained in:
Ran Benita 2019-07-10 20:12:41 +03:00
parent 56dcc9e1f8
commit 14bf4cdf44
2 changed files with 33 additions and 21 deletions

View File

@ -6,9 +6,11 @@ from inspect import CO_VARARGS
from inspect import CO_VARKEYWORDS
from traceback import format_exception_only
from types import TracebackType
from typing import Generic
from typing import Optional
from typing import Pattern
from typing import Tuple
from typing import TypeVar
from typing import Union
from weakref import ref
@ -379,22 +381,25 @@ co_equal = compile(
)
_E = TypeVar("_E", bound=BaseException)
@attr.s(repr=False)
class ExceptionInfo:
class ExceptionInfo(Generic[_E]):
""" wraps sys.exc_info() objects and offers
help for navigating the traceback.
"""
_assert_start_repr = "AssertionError('assert "
_excinfo = attr.ib(
type=Optional[Tuple["Type[BaseException]", BaseException, TracebackType]]
)
_excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]])
_striptext = attr.ib(type=str, default="")
_traceback = attr.ib(type=Optional[Traceback], default=None)
@classmethod
def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo":
def from_current(
cls, exprinfo: Optional[str] = None
) -> "ExceptionInfo[BaseException]":
"""returns an ExceptionInfo matching the current traceback
.. warning::
@ -422,13 +427,13 @@ class ExceptionInfo:
return cls(tup, _striptext)
@classmethod
def for_later(cls) -> "ExceptionInfo":
def for_later(cls) -> "ExceptionInfo[_E]":
"""return an unfilled ExceptionInfo
"""
return cls(None)
@property
def type(self) -> "Type[BaseException]":
def type(self) -> "Type[_E]":
"""the exception class"""
assert (
self._excinfo is not None
@ -436,7 +441,7 @@ class ExceptionInfo:
return self._excinfo[0]
@property
def value(self) -> BaseException:
def value(self) -> _E:
"""the exception value"""
assert (
self._excinfo is not None

View File

@ -10,10 +10,13 @@ from numbers import Number
from types import TracebackType
from typing import Any
from typing import Callable
from typing import cast
from typing import Generic
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
from typing import TypeVar
from typing import Union
from more_itertools.more import always_iterable
@ -537,33 +540,35 @@ def _is_numpy_array(obj):
# builtin pytest.raises helper
_E = TypeVar("_E", bound=BaseException)
@overload
def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*,
match: Optional[Union[str, Pattern]] = ...
) -> "RaisesContext":
) -> "RaisesContext[_E]":
... # pragma: no cover
@overload
def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
func: Callable,
*args: Any,
match: Optional[str] = ...,
**kwargs: Any
) -> Optional[_pytest._code.ExceptionInfo]:
) -> Optional[_pytest._code.ExceptionInfo[_E]]:
... # pragma: no cover
def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*args: Any,
match: Optional[Union[str, Pattern]] = None,
**kwargs: Any
) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]:
) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
r"""
Assert that a code block/function call raises ``expected_exception``
or raise a failure exception otherwise.
@ -703,28 +708,30 @@ def raises(
try:
func(*args[1:], **kwargs)
except expected_exception:
return _pytest._code.ExceptionInfo.from_current()
# Cast to narrow the type to expected_exception (_E).
return cast(
_pytest._code.ExceptionInfo[_E],
_pytest._code.ExceptionInfo.from_current(),
)
fail(message)
raises.Exception = fail.Exception # type: ignore
class RaisesContext:
class RaisesContext(Generic[_E]):
def __init__(
self,
expected_exception: Union[
"Type[BaseException]", Tuple["Type[BaseException]", ...]
],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
message: str,
match_expr: Optional[Union[str, Pattern]] = None,
) -> None:
self.expected_exception = expected_exception
self.message = message
self.match_expr = match_expr
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
def __enter__(self) -> _pytest._code.ExceptionInfo:
def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later()
return self.excinfo