Make ExceptionInfo generic in the exception type

This way, in

    with pytest.raises(ValueError) as cm:
        ...

cm.value is a ValueError and not a BaseException.
This commit is contained in:
Ran Benita 2019-07-10 20:12:41 +03:00
parent 56dcc9e1f8
commit 14bf4cdf44
2 changed files with 33 additions and 21 deletions

View File

@ -6,9 +6,11 @@ from inspect import CO_VARARGS
from inspect import CO_VARKEYWORDS from inspect import CO_VARKEYWORDS
from traceback import format_exception_only from traceback import format_exception_only
from types import TracebackType from types import TracebackType
from typing import Generic
from typing import Optional from typing import Optional
from typing import Pattern from typing import Pattern
from typing import Tuple from typing import Tuple
from typing import TypeVar
from typing import Union from typing import Union
from weakref import ref from weakref import ref
@ -379,22 +381,25 @@ co_equal = compile(
) )
_E = TypeVar("_E", bound=BaseException)
@attr.s(repr=False) @attr.s(repr=False)
class ExceptionInfo: class ExceptionInfo(Generic[_E]):
""" wraps sys.exc_info() objects and offers """ wraps sys.exc_info() objects and offers
help for navigating the traceback. help for navigating the traceback.
""" """
_assert_start_repr = "AssertionError('assert " _assert_start_repr = "AssertionError('assert "
_excinfo = attr.ib( _excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]])
type=Optional[Tuple["Type[BaseException]", BaseException, TracebackType]]
)
_striptext = attr.ib(type=str, default="") _striptext = attr.ib(type=str, default="")
_traceback = attr.ib(type=Optional[Traceback], default=None) _traceback = attr.ib(type=Optional[Traceback], default=None)
@classmethod @classmethod
def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo": def from_current(
cls, exprinfo: Optional[str] = None
) -> "ExceptionInfo[BaseException]":
"""returns an ExceptionInfo matching the current traceback """returns an ExceptionInfo matching the current traceback
.. warning:: .. warning::
@ -422,13 +427,13 @@ class ExceptionInfo:
return cls(tup, _striptext) return cls(tup, _striptext)
@classmethod @classmethod
def for_later(cls) -> "ExceptionInfo": def for_later(cls) -> "ExceptionInfo[_E]":
"""return an unfilled ExceptionInfo """return an unfilled ExceptionInfo
""" """
return cls(None) return cls(None)
@property @property
def type(self) -> "Type[BaseException]": def type(self) -> "Type[_E]":
"""the exception class""" """the exception class"""
assert ( assert (
self._excinfo is not None self._excinfo is not None
@ -436,7 +441,7 @@ class ExceptionInfo:
return self._excinfo[0] return self._excinfo[0]
@property @property
def value(self) -> BaseException: def value(self) -> _E:
"""the exception value""" """the exception value"""
assert ( assert (
self._excinfo is not None self._excinfo is not None

View File

@ -10,10 +10,13 @@ from numbers import Number
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import cast
from typing import Generic
from typing import Optional from typing import Optional
from typing import overload from typing import overload
from typing import Pattern from typing import Pattern
from typing import Tuple from typing import Tuple
from typing import TypeVar
from typing import Union from typing import Union
from more_itertools.more import always_iterable from more_itertools.more import always_iterable
@ -537,33 +540,35 @@ def _is_numpy_array(obj):
# builtin pytest.raises helper # builtin pytest.raises helper
_E = TypeVar("_E", bound=BaseException)
@overload @overload
def raises( def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*, *,
match: Optional[Union[str, Pattern]] = ... match: Optional[Union[str, Pattern]] = ...
) -> "RaisesContext": ) -> "RaisesContext[_E]":
... # pragma: no cover ... # pragma: no cover
@overload @overload
def raises( def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
func: Callable, func: Callable,
*args: Any, *args: Any,
match: Optional[str] = ..., match: Optional[str] = ...,
**kwargs: Any **kwargs: Any
) -> Optional[_pytest._code.ExceptionInfo]: ) -> Optional[_pytest._code.ExceptionInfo[_E]]:
... # pragma: no cover ... # pragma: no cover
def raises( def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*args: Any, *args: Any,
match: Optional[Union[str, Pattern]] = None, match: Optional[Union[str, Pattern]] = None,
**kwargs: Any **kwargs: Any
) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]: ) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
r""" r"""
Assert that a code block/function call raises ``expected_exception`` Assert that a code block/function call raises ``expected_exception``
or raise a failure exception otherwise. or raise a failure exception otherwise.
@ -703,28 +708,30 @@ def raises(
try: try:
func(*args[1:], **kwargs) func(*args[1:], **kwargs)
except expected_exception: except expected_exception:
return _pytest._code.ExceptionInfo.from_current() # Cast to narrow the type to expected_exception (_E).
return cast(
_pytest._code.ExceptionInfo[_E],
_pytest._code.ExceptionInfo.from_current(),
)
fail(message) fail(message)
raises.Exception = fail.Exception # type: ignore raises.Exception = fail.Exception # type: ignore
class RaisesContext: class RaisesContext(Generic[_E]):
def __init__( def __init__(
self, self,
expected_exception: Union[ expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
"Type[BaseException]", Tuple["Type[BaseException]", ...]
],
message: str, message: str,
match_expr: Optional[Union[str, Pattern]] = None, match_expr: Optional[Union[str, Pattern]] = None,
) -> None: ) -> None:
self.expected_exception = expected_exception self.expected_exception = expected_exception
self.message = message self.message = message
self.match_expr = match_expr self.match_expr = match_expr
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo] self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
def __enter__(self) -> _pytest._code.ExceptionInfo: def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later() self.excinfo = _pytest._code.ExceptionInfo.for_later()
return self.excinfo return self.excinfo