Type-annotate pytest.raises

This commit is contained in:
Ran Benita 2019-07-10 11:28:43 +03:00
parent 55a570e513
commit 56dcc9e1f8
1 changed files with 56 additions and 7 deletions

View File

@ -7,6 +7,13 @@ from collections.abc import Sized
from decimal import Decimal from decimal import Decimal
from itertools import filterfalse from itertools import filterfalse
from numbers import Number from numbers import Number
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
from typing import Union from typing import Union
from more_itertools.more import always_iterable from more_itertools.more import always_iterable
@ -15,6 +22,9 @@ import _pytest._code
from _pytest.compat import STRING_TYPES from _pytest.compat import STRING_TYPES
from _pytest.outcomes import fail from _pytest.outcomes import fail
if False: # TYPE_CHECKING
from typing import Type # noqa: F401 (used in type string)
BASE_TYPE = (type, STRING_TYPES) BASE_TYPE = (type, STRING_TYPES)
@ -528,7 +538,32 @@ def _is_numpy_array(obj):
# builtin pytest.raises helper # builtin pytest.raises helper
def raises(expected_exception, *args, match=None, **kwargs): @overload
def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
*,
match: Optional[Union[str, Pattern]] = ...
) -> "RaisesContext":
... # pragma: no cover
@overload
def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
func: Callable,
*args: Any,
match: Optional[str] = ...,
**kwargs: Any
) -> Optional[_pytest._code.ExceptionInfo]:
... # pragma: no cover
def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
*args: Any,
match: Optional[Union[str, Pattern]] = None,
**kwargs: Any
) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]:
r""" r"""
Assert that a code block/function call raises ``expected_exception`` Assert that a code block/function call raises ``expected_exception``
or raise a failure exception otherwise. or raise a failure exception otherwise.
@ -676,21 +711,35 @@ raises.Exception = fail.Exception # type: ignore
class RaisesContext: class RaisesContext:
def __init__(self, expected_exception, message, match_expr): def __init__(
self,
expected_exception: Union[
"Type[BaseException]", Tuple["Type[BaseException]", ...]
],
message: str,
match_expr: Optional[Union[str, Pattern]] = None,
) -> None:
self.expected_exception = expected_exception self.expected_exception = expected_exception
self.message = message self.message = message
self.match_expr = match_expr self.match_expr = match_expr
self.excinfo = None self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
def __enter__(self): def __enter__(self) -> _pytest._code.ExceptionInfo:
self.excinfo = _pytest._code.ExceptionInfo.for_later() self.excinfo = _pytest._code.ExceptionInfo.for_later()
return self.excinfo return self.excinfo
def __exit__(self, *tp): def __exit__(
self,
exc_type: Optional["Type[BaseException]"],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
__tracebackhide__ = True __tracebackhide__ = True
if tp[0] is None: if exc_type is None:
fail(self.message) fail(self.message)
self.excinfo.__init__(tp) assert self.excinfo is not None
# Type ignored because mypy doesn't like calling __init__ directly like this.
self.excinfo.__init__((exc_type, exc_val, exc_tb)) # type: ignore
suppress_exception = issubclass(self.excinfo.type, self.expected_exception) suppress_exception = issubclass(self.excinfo.type, self.expected_exception)
if self.match_expr is not None and suppress_exception: if self.match_expr is not None and suppress_exception:
self.excinfo.match(self.match_expr) self.excinfo.match(self.match_expr)