Type-annotate pytest.raises
This commit is contained in:
parent
55a570e513
commit
56dcc9e1f8
|
@ -7,6 +7,13 @@ from collections.abc import Sized
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from itertools import filterfalse
|
from itertools import filterfalse
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
from typing import Optional
|
||||||
|
from typing import overload
|
||||||
|
from typing import Pattern
|
||||||
|
from typing import Tuple
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from more_itertools.more import always_iterable
|
from more_itertools.more import always_iterable
|
||||||
|
@ -15,6 +22,9 @@ import _pytest._code
|
||||||
from _pytest.compat import STRING_TYPES
|
from _pytest.compat import STRING_TYPES
|
||||||
from _pytest.outcomes import fail
|
from _pytest.outcomes import fail
|
||||||
|
|
||||||
|
if False: # TYPE_CHECKING
|
||||||
|
from typing import Type # noqa: F401 (used in type string)
|
||||||
|
|
||||||
BASE_TYPE = (type, STRING_TYPES)
|
BASE_TYPE = (type, STRING_TYPES)
|
||||||
|
|
||||||
|
|
||||||
|
@ -528,7 +538,32 @@ def _is_numpy_array(obj):
|
||||||
# builtin pytest.raises helper
|
# builtin pytest.raises helper
|
||||||
|
|
||||||
|
|
||||||
def raises(expected_exception, *args, match=None, **kwargs):
|
@overload
|
||||||
|
def raises(
|
||||||
|
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
|
||||||
|
*,
|
||||||
|
match: Optional[Union[str, Pattern]] = ...
|
||||||
|
) -> "RaisesContext":
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def raises(
|
||||||
|
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
|
||||||
|
func: Callable,
|
||||||
|
*args: Any,
|
||||||
|
match: Optional[str] = ...,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Optional[_pytest._code.ExceptionInfo]:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
def raises(
|
||||||
|
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
|
||||||
|
*args: Any,
|
||||||
|
match: Optional[Union[str, Pattern]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]:
|
||||||
r"""
|
r"""
|
||||||
Assert that a code block/function call raises ``expected_exception``
|
Assert that a code block/function call raises ``expected_exception``
|
||||||
or raise a failure exception otherwise.
|
or raise a failure exception otherwise.
|
||||||
|
@ -676,21 +711,35 @@ raises.Exception = fail.Exception # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class RaisesContext:
|
class RaisesContext:
|
||||||
def __init__(self, expected_exception, message, match_expr):
|
def __init__(
|
||||||
|
self,
|
||||||
|
expected_exception: Union[
|
||||||
|
"Type[BaseException]", Tuple["Type[BaseException]", ...]
|
||||||
|
],
|
||||||
|
message: str,
|
||||||
|
match_expr: Optional[Union[str, Pattern]] = None,
|
||||||
|
) -> None:
|
||||||
self.expected_exception = expected_exception
|
self.expected_exception = expected_exception
|
||||||
self.message = message
|
self.message = message
|
||||||
self.match_expr = match_expr
|
self.match_expr = match_expr
|
||||||
self.excinfo = None
|
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> _pytest._code.ExceptionInfo:
|
||||||
self.excinfo = _pytest._code.ExceptionInfo.for_later()
|
self.excinfo = _pytest._code.ExceptionInfo.for_later()
|
||||||
return self.excinfo
|
return self.excinfo
|
||||||
|
|
||||||
def __exit__(self, *tp):
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional["Type[BaseException]"],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> bool:
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
if tp[0] is None:
|
if exc_type is None:
|
||||||
fail(self.message)
|
fail(self.message)
|
||||||
self.excinfo.__init__(tp)
|
assert self.excinfo is not None
|
||||||
|
# Type ignored because mypy doesn't like calling __init__ directly like this.
|
||||||
|
self.excinfo.__init__((exc_type, exc_val, exc_tb)) # type: ignore
|
||||||
suppress_exception = issubclass(self.excinfo.type, self.expected_exception)
|
suppress_exception = issubclass(self.excinfo.type, self.expected_exception)
|
||||||
if self.match_expr is not None and suppress_exception:
|
if self.match_expr is not None and suppress_exception:
|
||||||
self.excinfo.match(self.match_expr)
|
self.excinfo.match(self.match_expr)
|
||||||
|
|
Loading…
Reference in New Issue