From 56dcc9e1f884dc9f5f699c975a303cb0a97ccfa9 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 10 Jul 2019 11:28:43 +0300 Subject: [PATCH] Type-annotate pytest.raises --- src/_pytest/python_api.py | 63 ++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 7 deletions(-) diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index aae5ced33..9ede24df6 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -7,6 +7,13 @@ from collections.abc import Sized from decimal import Decimal from itertools import filterfalse from numbers import Number +from types import TracebackType +from typing import Any +from typing import Callable +from typing import Optional +from typing import overload +from typing import Pattern +from typing import Tuple from typing import Union from more_itertools.more import always_iterable @@ -15,6 +22,9 @@ import _pytest._code from _pytest.compat import STRING_TYPES from _pytest.outcomes import fail +if False: # TYPE_CHECKING + from typing import Type # noqa: F401 (used in type string) + BASE_TYPE = (type, STRING_TYPES) @@ -528,7 +538,32 @@ def _is_numpy_array(obj): # builtin pytest.raises helper -def raises(expected_exception, *args, match=None, **kwargs): +@overload +def raises( + expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + *, + match: Optional[Union[str, Pattern]] = ... +) -> "RaisesContext": + ... # pragma: no cover + + +@overload +def raises( + expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + func: Callable, + *args: Any, + match: Optional[str] = ..., + **kwargs: Any +) -> Optional[_pytest._code.ExceptionInfo]: + ... # pragma: no cover + + +def raises( + expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + *args: Any, + match: Optional[Union[str, Pattern]] = None, + **kwargs: Any +) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]: r""" Assert that a code block/function call raises ``expected_exception`` or raise a failure exception otherwise. @@ -676,21 +711,35 @@ raises.Exception = fail.Exception # type: ignore class RaisesContext: - def __init__(self, expected_exception, message, match_expr): + def __init__( + self, + expected_exception: Union[ + "Type[BaseException]", Tuple["Type[BaseException]", ...] + ], + message: str, + match_expr: Optional[Union[str, Pattern]] = None, + ) -> None: self.expected_exception = expected_exception self.message = message self.match_expr = match_expr - self.excinfo = None + self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo] - def __enter__(self): + def __enter__(self) -> _pytest._code.ExceptionInfo: self.excinfo = _pytest._code.ExceptionInfo.for_later() return self.excinfo - def __exit__(self, *tp): + def __exit__( + self, + exc_type: Optional["Type[BaseException]"], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: __tracebackhide__ = True - if tp[0] is None: + if exc_type is None: fail(self.message) - self.excinfo.__init__(tp) + assert self.excinfo is not None + # Type ignored because mypy doesn't like calling __init__ directly like this. + self.excinfo.__init__((exc_type, exc_val, exc_tb)) # type: ignore suppress_exception = issubclass(self.excinfo.type, self.expected_exception) if self.match_expr is not None and suppress_exception: self.excinfo.match(self.match_expr)