From 14bf4cdf44be4d8e2482b1f2b9cafeba06c03550 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 10 Jul 2019 20:12:41 +0300 Subject: [PATCH] Make ExceptionInfo generic in the exception type This way, in with pytest.raises(ValueError) as cm: ... cm.value is a ValueError and not a BaseException. --- src/_pytest/_code/code.py | 21 +++++++++++++-------- src/_pytest/python_api.py | 33 ++++++++++++++++++++------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index d9b06ffd9..203e90287 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -6,9 +6,11 @@ from inspect import CO_VARARGS from inspect import CO_VARKEYWORDS from traceback import format_exception_only from types import TracebackType +from typing import Generic from typing import Optional from typing import Pattern from typing import Tuple +from typing import TypeVar from typing import Union from weakref import ref @@ -379,22 +381,25 @@ co_equal = compile( ) +_E = TypeVar("_E", bound=BaseException) + + @attr.s(repr=False) -class ExceptionInfo: +class ExceptionInfo(Generic[_E]): """ wraps sys.exc_info() objects and offers help for navigating the traceback. """ _assert_start_repr = "AssertionError('assert " - _excinfo = attr.ib( - type=Optional[Tuple["Type[BaseException]", BaseException, TracebackType]] - ) + _excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]]) _striptext = attr.ib(type=str, default="") _traceback = attr.ib(type=Optional[Traceback], default=None) @classmethod - def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo": + def from_current( + cls, exprinfo: Optional[str] = None + ) -> "ExceptionInfo[BaseException]": """returns an ExceptionInfo matching the current traceback .. warning:: @@ -422,13 +427,13 @@ class ExceptionInfo: return cls(tup, _striptext) @classmethod - def for_later(cls) -> "ExceptionInfo": + def for_later(cls) -> "ExceptionInfo[_E]": """return an unfilled ExceptionInfo """ return cls(None) @property - def type(self) -> "Type[BaseException]": + def type(self) -> "Type[_E]": """the exception class""" assert ( self._excinfo is not None @@ -436,7 +441,7 @@ class ExceptionInfo: return self._excinfo[0] @property - def value(self) -> BaseException: + def value(self) -> _E: """the exception value""" assert ( self._excinfo is not None diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 9ede24df6..7ca545878 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -10,10 +10,13 @@ from numbers import Number from types import TracebackType from typing import Any from typing import Callable +from typing import cast +from typing import Generic from typing import Optional from typing import overload from typing import Pattern from typing import Tuple +from typing import TypeVar from typing import Union from more_itertools.more import always_iterable @@ -537,33 +540,35 @@ def _is_numpy_array(obj): # builtin pytest.raises helper +_E = TypeVar("_E", bound=BaseException) + @overload def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], *, match: Optional[Union[str, Pattern]] = ... -) -> "RaisesContext": +) -> "RaisesContext[_E]": ... # pragma: no cover @overload def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], func: Callable, *args: Any, match: Optional[str] = ..., **kwargs: Any -) -> Optional[_pytest._code.ExceptionInfo]: +) -> Optional[_pytest._code.ExceptionInfo[_E]]: ... # pragma: no cover def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], *args: Any, match: Optional[Union[str, Pattern]] = None, **kwargs: Any -) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]: +) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]: r""" Assert that a code block/function call raises ``expected_exception`` or raise a failure exception otherwise. @@ -703,28 +708,30 @@ def raises( try: func(*args[1:], **kwargs) except expected_exception: - return _pytest._code.ExceptionInfo.from_current() + # Cast to narrow the type to expected_exception (_E). + return cast( + _pytest._code.ExceptionInfo[_E], + _pytest._code.ExceptionInfo.from_current(), + ) fail(message) raises.Exception = fail.Exception # type: ignore -class RaisesContext: +class RaisesContext(Generic[_E]): def __init__( self, - expected_exception: Union[ - "Type[BaseException]", Tuple["Type[BaseException]", ...] - ], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], message: str, match_expr: Optional[Union[str, Pattern]] = None, ) -> None: self.expected_exception = expected_exception self.message = message self.match_expr = match_expr - self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo] + self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]] - def __enter__(self) -> _pytest._code.ExceptionInfo: + def __enter__(self) -> _pytest._code.ExceptionInfo[_E]: self.excinfo = _pytest._code.ExceptionInfo.for_later() return self.excinfo