diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index d9b06ffd9..203e90287 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -6,9 +6,11 @@ from inspect import CO_VARARGS from inspect import CO_VARKEYWORDS from traceback import format_exception_only from types import TracebackType +from typing import Generic from typing import Optional from typing import Pattern from typing import Tuple +from typing import TypeVar from typing import Union from weakref import ref @@ -379,22 +381,25 @@ co_equal = compile( ) +_E = TypeVar("_E", bound=BaseException) + + @attr.s(repr=False) -class ExceptionInfo: +class ExceptionInfo(Generic[_E]): """ wraps sys.exc_info() objects and offers help for navigating the traceback. """ _assert_start_repr = "AssertionError('assert " - _excinfo = attr.ib( - type=Optional[Tuple["Type[BaseException]", BaseException, TracebackType]] - ) + _excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]]) _striptext = attr.ib(type=str, default="") _traceback = attr.ib(type=Optional[Traceback], default=None) @classmethod - def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo": + def from_current( + cls, exprinfo: Optional[str] = None + ) -> "ExceptionInfo[BaseException]": """returns an ExceptionInfo matching the current traceback .. warning:: @@ -422,13 +427,13 @@ class ExceptionInfo: return cls(tup, _striptext) @classmethod - def for_later(cls) -> "ExceptionInfo": + def for_later(cls) -> "ExceptionInfo[_E]": """return an unfilled ExceptionInfo """ return cls(None) @property - def type(self) -> "Type[BaseException]": + def type(self) -> "Type[_E]": """the exception class""" assert ( self._excinfo is not None @@ -436,7 +441,7 @@ class ExceptionInfo: return self._excinfo[0] @property - def value(self) -> BaseException: + def value(self) -> _E: """the exception value""" assert ( self._excinfo is not None diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 9ede24df6..7ca545878 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -10,10 +10,13 @@ from numbers import Number from types import TracebackType from typing import Any from typing import Callable +from typing import cast +from typing import Generic from typing import Optional from typing import overload from typing import Pattern from typing import Tuple +from typing import TypeVar from typing import Union from more_itertools.more import always_iterable @@ -537,33 +540,35 @@ def _is_numpy_array(obj): # builtin pytest.raises helper +_E = TypeVar("_E", bound=BaseException) + @overload def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], *, match: Optional[Union[str, Pattern]] = ... -) -> "RaisesContext": +) -> "RaisesContext[_E]": ... # pragma: no cover @overload def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], func: Callable, *args: Any, match: Optional[str] = ..., **kwargs: Any -) -> Optional[_pytest._code.ExceptionInfo]: +) -> Optional[_pytest._code.ExceptionInfo[_E]]: ... # pragma: no cover def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], *args: Any, match: Optional[Union[str, Pattern]] = None, **kwargs: Any -) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]: +) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]: r""" Assert that a code block/function call raises ``expected_exception`` or raise a failure exception otherwise. @@ -703,28 +708,30 @@ def raises( try: func(*args[1:], **kwargs) except expected_exception: - return _pytest._code.ExceptionInfo.from_current() + # Cast to narrow the type to expected_exception (_E). + return cast( + _pytest._code.ExceptionInfo[_E], + _pytest._code.ExceptionInfo.from_current(), + ) fail(message) raises.Exception = fail.Exception # type: ignore -class RaisesContext: +class RaisesContext(Generic[_E]): def __init__( self, - expected_exception: Union[ - "Type[BaseException]", Tuple["Type[BaseException]", ...] - ], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], message: str, match_expr: Optional[Union[str, Pattern]] = None, ) -> None: self.expected_exception = expected_exception self.message = message self.match_expr = match_expr - self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo] + self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]] - def __enter__(self) -> _pytest._code.ExceptionInfo: + def __enter__(self) -> _pytest._code.ExceptionInfo[_E]: self.excinfo = _pytest._code.ExceptionInfo.for_later() return self.excinfo