Merge pull request #5593 from bluetech/type-annotations-1
Type-annotate pytest.{exit,skip,fail,xfail,importorskip,warns,raises}
This commit is contained in:
commit
faf222f8fb
|
@ -5,6 +5,13 @@ import traceback
|
|||
from inspect import CO_VARARGS
|
||||
from inspect import CO_VARKEYWORDS
|
||||
from traceback import format_exception_only
|
||||
from types import TracebackType
|
||||
from typing import Generic
|
||||
from typing import Optional
|
||||
from typing import Pattern
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
from weakref import ref
|
||||
|
||||
import attr
|
||||
|
@ -15,6 +22,9 @@ import _pytest
|
|||
from _pytest._io.saferepr import safeformat
|
||||
from _pytest._io.saferepr import saferepr
|
||||
|
||||
if False: # TYPE_CHECKING
|
||||
from typing import Type
|
||||
|
||||
|
||||
class Code:
|
||||
""" wrapper around Python code objects """
|
||||
|
@ -371,20 +381,52 @@ co_equal = compile(
|
|||
)
|
||||
|
||||
|
||||
_E = TypeVar("_E", bound=BaseException)
|
||||
|
||||
|
||||
@attr.s(repr=False)
|
||||
class ExceptionInfo:
|
||||
class ExceptionInfo(Generic[_E]):
|
||||
""" wraps sys.exc_info() objects and offers
|
||||
help for navigating the traceback.
|
||||
"""
|
||||
|
||||
_assert_start_repr = "AssertionError('assert "
|
||||
|
||||
_excinfo = attr.ib()
|
||||
_striptext = attr.ib(default="")
|
||||
_traceback = attr.ib(default=None)
|
||||
_excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]])
|
||||
_striptext = attr.ib(type=str, default="")
|
||||
_traceback = attr.ib(type=Optional[Traceback], default=None)
|
||||
|
||||
@classmethod
|
||||
def from_current(cls, exprinfo=None):
|
||||
def from_exc_info(
|
||||
cls,
|
||||
exc_info: Tuple["Type[_E]", "_E", TracebackType],
|
||||
exprinfo: Optional[str] = None,
|
||||
) -> "ExceptionInfo[_E]":
|
||||
"""returns an ExceptionInfo for an existing exc_info tuple.
|
||||
|
||||
.. warning::
|
||||
|
||||
Experimental API
|
||||
|
||||
|
||||
:param exprinfo: a text string helping to determine if we should
|
||||
strip ``AssertionError`` from the output, defaults
|
||||
to the exception message/``__str__()``
|
||||
"""
|
||||
_striptext = ""
|
||||
if exprinfo is None and isinstance(exc_info[1], AssertionError):
|
||||
exprinfo = getattr(exc_info[1], "msg", None)
|
||||
if exprinfo is None:
|
||||
exprinfo = saferepr(exc_info[1])
|
||||
if exprinfo and exprinfo.startswith(cls._assert_start_repr):
|
||||
_striptext = "AssertionError: "
|
||||
|
||||
return cls(exc_info, _striptext)
|
||||
|
||||
@classmethod
|
||||
def from_current(
|
||||
cls, exprinfo: Optional[str] = None
|
||||
) -> "ExceptionInfo[BaseException]":
|
||||
"""returns an ExceptionInfo matching the current traceback
|
||||
|
||||
.. warning::
|
||||
|
@ -398,59 +440,71 @@ class ExceptionInfo:
|
|||
"""
|
||||
tup = sys.exc_info()
|
||||
assert tup[0] is not None, "no current exception"
|
||||
_striptext = ""
|
||||
if exprinfo is None and isinstance(tup[1], AssertionError):
|
||||
exprinfo = getattr(tup[1], "msg", None)
|
||||
if exprinfo is None:
|
||||
exprinfo = saferepr(tup[1])
|
||||
if exprinfo and exprinfo.startswith(cls._assert_start_repr):
|
||||
_striptext = "AssertionError: "
|
||||
|
||||
return cls(tup, _striptext)
|
||||
assert tup[1] is not None, "no current exception"
|
||||
assert tup[2] is not None, "no current exception"
|
||||
exc_info = (tup[0], tup[1], tup[2])
|
||||
return cls.from_exc_info(exc_info)
|
||||
|
||||
@classmethod
|
||||
def for_later(cls):
|
||||
def for_later(cls) -> "ExceptionInfo[_E]":
|
||||
"""return an unfilled ExceptionInfo
|
||||
"""
|
||||
return cls(None)
|
||||
|
||||
def fill_unfilled(self, exc_info: Tuple["Type[_E]", _E, TracebackType]) -> None:
|
||||
"""fill an unfilled ExceptionInfo created with for_later()"""
|
||||
assert self._excinfo is None, "ExceptionInfo was already filled"
|
||||
self._excinfo = exc_info
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
def type(self) -> "Type[_E]":
|
||||
"""the exception class"""
|
||||
assert (
|
||||
self._excinfo is not None
|
||||
), ".type can only be used after the context manager exits"
|
||||
return self._excinfo[0]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
def value(self) -> _E:
|
||||
"""the exception value"""
|
||||
assert (
|
||||
self._excinfo is not None
|
||||
), ".value can only be used after the context manager exits"
|
||||
return self._excinfo[1]
|
||||
|
||||
@property
|
||||
def tb(self):
|
||||
def tb(self) -> TracebackType:
|
||||
"""the exception raw traceback"""
|
||||
assert (
|
||||
self._excinfo is not None
|
||||
), ".tb can only be used after the context manager exits"
|
||||
return self._excinfo[2]
|
||||
|
||||
@property
|
||||
def typename(self):
|
||||
def typename(self) -> str:
|
||||
"""the type name of the exception"""
|
||||
assert (
|
||||
self._excinfo is not None
|
||||
), ".typename can only be used after the context manager exits"
|
||||
return self.type.__name__
|
||||
|
||||
@property
|
||||
def traceback(self):
|
||||
def traceback(self) -> Traceback:
|
||||
"""the traceback"""
|
||||
if self._traceback is None:
|
||||
self._traceback = Traceback(self.tb, excinfo=ref(self))
|
||||
return self._traceback
|
||||
|
||||
@traceback.setter
|
||||
def traceback(self, value):
|
||||
def traceback(self, value: Traceback) -> None:
|
||||
self._traceback = value
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
if self._excinfo is None:
|
||||
return "<ExceptionInfo for raises contextmanager>"
|
||||
return "<ExceptionInfo %s tblen=%d>" % (self.typename, len(self.traceback))
|
||||
|
||||
def exconly(self, tryshort=False):
|
||||
def exconly(self, tryshort: bool = False) -> str:
|
||||
""" return the exception as a string
|
||||
|
||||
when 'tryshort' resolves to True, and the exception is a
|
||||
|
@ -466,11 +520,11 @@ class ExceptionInfo:
|
|||
text = text[len(self._striptext) :]
|
||||
return text
|
||||
|
||||
def errisinstance(self, exc):
|
||||
def errisinstance(self, exc: "Type[BaseException]") -> bool:
|
||||
""" return True if the exception is an instance of exc """
|
||||
return isinstance(self.value, exc)
|
||||
|
||||
def _getreprcrash(self):
|
||||
def _getreprcrash(self) -> "ReprFileLocation":
|
||||
exconly = self.exconly(tryshort=True)
|
||||
entry = self.traceback.getcrashentry()
|
||||
path, lineno = entry.frame.code.raw.co_filename, entry.lineno
|
||||
|
@ -478,13 +532,13 @@ class ExceptionInfo:
|
|||
|
||||
def getrepr(
|
||||
self,
|
||||
showlocals=False,
|
||||
style="long",
|
||||
abspath=False,
|
||||
tbfilter=True,
|
||||
funcargs=False,
|
||||
truncate_locals=True,
|
||||
chain=True,
|
||||
showlocals: bool = False,
|
||||
style: str = "long",
|
||||
abspath: bool = False,
|
||||
tbfilter: bool = True,
|
||||
funcargs: bool = False,
|
||||
truncate_locals: bool = True,
|
||||
chain: bool = True,
|
||||
):
|
||||
"""
|
||||
Return str()able representation of this exception info.
|
||||
|
@ -535,7 +589,7 @@ class ExceptionInfo:
|
|||
)
|
||||
return fmt.repr_excinfo(self)
|
||||
|
||||
def match(self, regexp):
|
||||
def match(self, regexp: Union[str, Pattern]) -> bool:
|
||||
"""
|
||||
Check whether the regular expression 'regexp' is found in the string
|
||||
representation of the exception using ``re.search``. If it matches
|
||||
|
|
|
@ -3,21 +3,26 @@ exception classes and constants handling test outcomes
|
|||
as well as functions creating them
|
||||
"""
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
if False: # TYPE_CHECKING
|
||||
from typing import NoReturn
|
||||
|
||||
|
||||
class OutcomeException(BaseException):
|
||||
""" OutcomeException and its subclass instances indicate and
|
||||
contain info about test and collection outcomes.
|
||||
"""
|
||||
|
||||
def __init__(self, msg=None, pytrace=True):
|
||||
def __init__(self, msg: Optional[str] = None, pytrace: bool = True) -> None:
|
||||
BaseException.__init__(self, msg)
|
||||
self.msg = msg
|
||||
self.pytrace = pytrace
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
if self.msg:
|
||||
val = self.msg
|
||||
if isinstance(val, bytes):
|
||||
|
@ -36,7 +41,12 @@ class Skipped(OutcomeException):
|
|||
# in order to have Skipped exception printing shorter/nicer
|
||||
__module__ = "builtins"
|
||||
|
||||
def __init__(self, msg=None, pytrace=True, allow_module_level=False):
|
||||
def __init__(
|
||||
self,
|
||||
msg: Optional[str] = None,
|
||||
pytrace: bool = True,
|
||||
allow_module_level: bool = False,
|
||||
) -> None:
|
||||
OutcomeException.__init__(self, msg=msg, pytrace=pytrace)
|
||||
self.allow_module_level = allow_module_level
|
||||
|
||||
|
@ -50,7 +60,9 @@ class Failed(OutcomeException):
|
|||
class Exit(Exception):
|
||||
""" raised for immediate program exits (no tracebacks/summaries)"""
|
||||
|
||||
def __init__(self, msg="unknown reason", returncode=None):
|
||||
def __init__(
|
||||
self, msg: str = "unknown reason", returncode: Optional[int] = None
|
||||
) -> None:
|
||||
self.msg = msg
|
||||
self.returncode = returncode
|
||||
super().__init__(msg)
|
||||
|
@ -59,7 +71,7 @@ class Exit(Exception):
|
|||
# exposed helper methods
|
||||
|
||||
|
||||
def exit(msg, returncode=None):
|
||||
def exit(msg: str, returncode: Optional[int] = None) -> "NoReturn":
|
||||
"""
|
||||
Exit testing process.
|
||||
|
||||
|
@ -74,7 +86,7 @@ def exit(msg, returncode=None):
|
|||
exit.Exception = Exit # type: ignore
|
||||
|
||||
|
||||
def skip(msg="", *, allow_module_level=False):
|
||||
def skip(msg: str = "", *, allow_module_level: bool = False) -> "NoReturn":
|
||||
"""
|
||||
Skip an executing test with the given message.
|
||||
|
||||
|
@ -101,7 +113,7 @@ def skip(msg="", *, allow_module_level=False):
|
|||
skip.Exception = Skipped # type: ignore
|
||||
|
||||
|
||||
def fail(msg="", pytrace=True):
|
||||
def fail(msg: str = "", pytrace: bool = True) -> "NoReturn":
|
||||
"""
|
||||
Explicitly fail an executing test with the given message.
|
||||
|
||||
|
@ -121,7 +133,7 @@ class XFailed(Failed):
|
|||
""" raised from an explicit call to pytest.xfail() """
|
||||
|
||||
|
||||
def xfail(reason=""):
|
||||
def xfail(reason: str = "") -> "NoReturn":
|
||||
"""
|
||||
Imperatively xfail an executing test or setup functions with the given reason.
|
||||
|
||||
|
@ -139,7 +151,9 @@ def xfail(reason=""):
|
|||
xfail.Exception = XFailed # type: ignore
|
||||
|
||||
|
||||
def importorskip(modname, minversion=None, reason=None):
|
||||
def importorskip(
|
||||
modname: str, minversion: Optional[str] = None, reason: Optional[str] = None
|
||||
) -> Any:
|
||||
"""Imports and returns the requested module ``modname``, or skip the current test
|
||||
if the module cannot be imported.
|
||||
|
||||
|
|
|
@ -7,6 +7,16 @@ from collections.abc import Sized
|
|||
from decimal import Decimal
|
||||
from itertools import filterfalse
|
||||
from numbers import Number
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import cast
|
||||
from typing import Generic
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Pattern
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from more_itertools.more import always_iterable
|
||||
|
@ -15,6 +25,9 @@ import _pytest._code
|
|||
from _pytest.compat import STRING_TYPES
|
||||
from _pytest.outcomes import fail
|
||||
|
||||
if False: # TYPE_CHECKING
|
||||
from typing import Type # noqa: F401 (used in type string)
|
||||
|
||||
BASE_TYPE = (type, STRING_TYPES)
|
||||
|
||||
|
||||
|
@ -527,8 +540,35 @@ def _is_numpy_array(obj):
|
|||
|
||||
# builtin pytest.raises helper
|
||||
|
||||
_E = TypeVar("_E", bound=BaseException)
|
||||
|
||||
def raises(expected_exception, *args, match=None, **kwargs):
|
||||
|
||||
@overload
|
||||
def raises(
|
||||
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||
*,
|
||||
match: Optional[Union[str, Pattern]] = ...
|
||||
) -> "RaisesContext[_E]":
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
@overload
|
||||
def raises(
|
||||
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
match: Optional[str] = ...,
|
||||
**kwargs: Any
|
||||
) -> Optional[_pytest._code.ExceptionInfo[_E]]:
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
def raises(
|
||||
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||
*args: Any,
|
||||
match: Optional[Union[str, Pattern]] = None,
|
||||
**kwargs: Any
|
||||
) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
|
||||
r"""
|
||||
Assert that a code block/function call raises ``expected_exception``
|
||||
or raise a failure exception otherwise.
|
||||
|
@ -647,18 +687,18 @@ def raises(expected_exception, *args, match=None, **kwargs):
|
|||
for exc in filterfalse(
|
||||
inspect.isclass, always_iterable(expected_exception, BASE_TYPE)
|
||||
):
|
||||
msg = (
|
||||
"exceptions must be old-style classes or"
|
||||
" derived from BaseException, not %s"
|
||||
)
|
||||
msg = "exceptions must be derived from BaseException, not %s"
|
||||
raise TypeError(msg % type(exc))
|
||||
|
||||
message = "DID NOT RAISE {}".format(expected_exception)
|
||||
|
||||
if not args:
|
||||
return RaisesContext(
|
||||
expected_exception, message=message, match_expr=match, **kwargs
|
||||
)
|
||||
if kwargs:
|
||||
msg = "Unexpected keyword arguments passed to pytest.raises: "
|
||||
msg += ", ".join(sorted(kwargs))
|
||||
msg += "\nUse context-manager form instead?"
|
||||
raise TypeError(msg)
|
||||
return RaisesContext(expected_exception, message, match)
|
||||
else:
|
||||
func = args[0]
|
||||
if not callable(func):
|
||||
|
@ -667,31 +707,51 @@ def raises(expected_exception, *args, match=None, **kwargs):
|
|||
)
|
||||
try:
|
||||
func(*args[1:], **kwargs)
|
||||
except expected_exception:
|
||||
return _pytest._code.ExceptionInfo.from_current()
|
||||
except expected_exception as e:
|
||||
# We just caught the exception - there is a traceback.
|
||||
assert e.__traceback__ is not None
|
||||
return _pytest._code.ExceptionInfo.from_exc_info(
|
||||
(type(e), e, e.__traceback__)
|
||||
)
|
||||
fail(message)
|
||||
|
||||
|
||||
raises.Exception = fail.Exception # type: ignore
|
||||
|
||||
|
||||
class RaisesContext:
|
||||
def __init__(self, expected_exception, message, match_expr):
|
||||
class RaisesContext(Generic[_E]):
|
||||
def __init__(
|
||||
self,
|
||||
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||
message: str,
|
||||
match_expr: Optional[Union[str, Pattern]] = None,
|
||||
) -> None:
|
||||
self.expected_exception = expected_exception
|
||||
self.message = message
|
||||
self.match_expr = match_expr
|
||||
self.excinfo = None
|
||||
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
|
||||
self.excinfo = _pytest._code.ExceptionInfo.for_later()
|
||||
return self.excinfo
|
||||
|
||||
def __exit__(self, *tp):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional["Type[BaseException]"],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> bool:
|
||||
__tracebackhide__ = True
|
||||
if tp[0] is None:
|
||||
if exc_type is None:
|
||||
fail(self.message)
|
||||
self.excinfo.__init__(tp)
|
||||
suppress_exception = issubclass(self.excinfo.type, self.expected_exception)
|
||||
if self.match_expr is not None and suppress_exception:
|
||||
assert self.excinfo is not None
|
||||
if not issubclass(exc_type, self.expected_exception):
|
||||
return False
|
||||
# Cast to narrow the exception type now that it's verified.
|
||||
exc_info = cast(
|
||||
Tuple["Type[_E]", _E, TracebackType], (exc_type, exc_val, exc_tb)
|
||||
)
|
||||
self.excinfo.fill_unfilled(exc_info)
|
||||
if self.match_expr is not None:
|
||||
self.excinfo.match(self.match_expr)
|
||||
return suppress_exception
|
||||
return True
|
||||
|
|
|
@ -1,11 +1,23 @@
|
|||
""" recording warnings during test function execution. """
|
||||
import inspect
|
||||
import re
|
||||
import warnings
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Pattern
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from _pytest.fixtures import yield_fixture
|
||||
from _pytest.outcomes import fail
|
||||
|
||||
if False: # TYPE_CHECKING
|
||||
from typing import Type
|
||||
|
||||
|
||||
@yield_fixture
|
||||
def recwarn():
|
||||
|
@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs):
|
|||
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)
|
||||
|
||||
|
||||
def warns(expected_warning, *args, match=None, **kwargs):
|
||||
@overload
|
||||
def warns(
|
||||
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||
*,
|
||||
match: Optional[Union[str, Pattern]] = ...
|
||||
) -> "WarningsChecker":
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
@overload
|
||||
def warns(
|
||||
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
match: Optional[Union[str, Pattern]] = ...,
|
||||
**kwargs: Any
|
||||
) -> Union[Any]:
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
def warns(
|
||||
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||
*args: Any,
|
||||
match: Optional[Union[str, Pattern]] = None,
|
||||
**kwargs: Any
|
||||
) -> Union["WarningsChecker", Any]:
|
||||
r"""Assert that code raises a particular class of warning.
|
||||
|
||||
Specifically, the parameter ``expected_warning`` can be a warning class or
|
||||
|
@ -76,7 +113,12 @@ def warns(expected_warning, *args, match=None, **kwargs):
|
|||
"""
|
||||
__tracebackhide__ = True
|
||||
if not args:
|
||||
return WarningsChecker(expected_warning, match_expr=match, **kwargs)
|
||||
if kwargs:
|
||||
msg = "Unexpected keyword arguments passed to pytest.warns: "
|
||||
msg += ", ".join(sorted(kwargs))
|
||||
msg += "\nUse context-manager form instead?"
|
||||
raise TypeError(msg)
|
||||
return WarningsChecker(expected_warning, match_expr=match)
|
||||
else:
|
||||
func = args[0]
|
||||
if not callable(func):
|
||||
|
@ -96,26 +138,26 @@ class WarningsRecorder(warnings.catch_warnings):
|
|||
def __init__(self):
|
||||
super().__init__(record=True)
|
||||
self._entered = False
|
||||
self._list = []
|
||||
self._list = [] # type: List[warnings._Record]
|
||||
|
||||
@property
|
||||
def list(self):
|
||||
def list(self) -> List["warnings._Record"]:
|
||||
"""The list of recorded warnings."""
|
||||
return self._list
|
||||
|
||||
def __getitem__(self, i):
|
||||
def __getitem__(self, i: int) -> "warnings._Record":
|
||||
"""Get a recorded warning by index."""
|
||||
return self._list[i]
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator["warnings._Record"]:
|
||||
"""Iterate through the recorded warnings."""
|
||||
return iter(self._list)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""The number of recorded warnings."""
|
||||
return len(self._list)
|
||||
|
||||
def pop(self, cls=Warning):
|
||||
def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record":
|
||||
"""Pop the first recorded warning, raise exception if not exists."""
|
||||
for i, w in enumerate(self._list):
|
||||
if issubclass(w.category, cls):
|
||||
|
@ -123,54 +165,80 @@ class WarningsRecorder(warnings.catch_warnings):
|
|||
__tracebackhide__ = True
|
||||
raise AssertionError("%r not found in warning list" % cls)
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
"""Clear the list of recorded warnings."""
|
||||
self._list[:] = []
|
||||
|
||||
def __enter__(self):
|
||||
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
|
||||
# -- it returns a List but we only emulate one.
|
||||
def __enter__(self) -> "WarningsRecorder": # type: ignore
|
||||
if self._entered:
|
||||
__tracebackhide__ = True
|
||||
raise RuntimeError("Cannot enter %r twice" % self)
|
||||
self._list = super().__enter__()
|
||||
_list = super().__enter__()
|
||||
# record=True means it's None.
|
||||
assert _list is not None
|
||||
self._list = _list
|
||||
warnings.simplefilter("always")
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional["Type[BaseException]"],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> bool:
|
||||
if not self._entered:
|
||||
__tracebackhide__ = True
|
||||
raise RuntimeError("Cannot exit %r without entering first" % self)
|
||||
|
||||
super().__exit__(*exc_info)
|
||||
super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
# Built-in catch_warnings does not reset entered state so we do it
|
||||
# manually here for this context manager to become reusable.
|
||||
self._entered = False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class WarningsChecker(WarningsRecorder):
|
||||
def __init__(self, expected_warning=None, match_expr=None):
|
||||
def __init__(
|
||||
self,
|
||||
expected_warning: Optional[
|
||||
Union["Type[Warning]", Tuple["Type[Warning]", ...]]
|
||||
] = None,
|
||||
match_expr: Optional[Union[str, Pattern]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
msg = "exceptions must be old-style classes or derived from Warning, not %s"
|
||||
if isinstance(expected_warning, tuple):
|
||||
msg = "exceptions must be derived from Warning, not %s"
|
||||
if expected_warning is None:
|
||||
expected_warning_tup = None
|
||||
elif isinstance(expected_warning, tuple):
|
||||
for exc in expected_warning:
|
||||
if not inspect.isclass(exc):
|
||||
if not issubclass(exc, Warning):
|
||||
raise TypeError(msg % type(exc))
|
||||
elif inspect.isclass(expected_warning):
|
||||
expected_warning = (expected_warning,)
|
||||
elif expected_warning is not None:
|
||||
expected_warning_tup = expected_warning
|
||||
elif issubclass(expected_warning, Warning):
|
||||
expected_warning_tup = (expected_warning,)
|
||||
else:
|
||||
raise TypeError(msg % type(expected_warning))
|
||||
|
||||
self.expected_warning = expected_warning
|
||||
self.expected_warning = expected_warning_tup
|
||||
self.match_expr = match_expr
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
super().__exit__(*exc_info)
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional["Type[BaseException]"],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> bool:
|
||||
super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
__tracebackhide__ = True
|
||||
|
||||
# only check if we're not currently handling an exception
|
||||
if all(a is None for a in exc_info):
|
||||
if exc_type is None and exc_val is None and exc_tb is None:
|
||||
if self.expected_warning is not None:
|
||||
if not any(issubclass(r.category, self.expected_warning) for r in self):
|
||||
__tracebackhide__ = True
|
||||
|
@ -195,3 +263,4 @@ class WarningsChecker(WarningsRecorder):
|
|||
[each.message for each in self],
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
|
|
@ -58,7 +58,7 @@ class TWMock:
|
|||
fullwidth = 80
|
||||
|
||||
|
||||
def test_excinfo_simple():
|
||||
def test_excinfo_simple() -> None:
|
||||
try:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
|
@ -66,6 +66,14 @@ def test_excinfo_simple():
|
|||
assert info.type == ValueError
|
||||
|
||||
|
||||
def test_excinfo_from_exc_info_simple():
|
||||
try:
|
||||
raise ValueError
|
||||
except ValueError as e:
|
||||
info = _pytest._code.ExceptionInfo.from_exc_info((type(e), e, e.__traceback__))
|
||||
assert info.type == ValueError
|
||||
|
||||
|
||||
def test_excinfo_getstatement():
|
||||
def g():
|
||||
raise ValueError
|
||||
|
|
|
@ -248,3 +248,9 @@ class TestRaises:
|
|||
with pytest.raises(CrappyClass()):
|
||||
pass
|
||||
assert "via __class__" in excinfo.value.args[0]
|
||||
|
||||
def test_raises_context_manager_with_kwargs(self):
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
with pytest.raises(Exception, foo="bar"):
|
||||
pass
|
||||
assert "Unexpected keyword arguments" in str(excinfo.value)
|
||||
|
|
|
@ -374,3 +374,9 @@ class TestWarns:
|
|||
assert f() == 10
|
||||
assert pytest.warns(UserWarning, f) == 10
|
||||
assert pytest.warns(UserWarning, f) == 10
|
||||
|
||||
def test_warns_context_manager_with_kwargs(self):
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
with pytest.warns(UserWarning, foo="bar"):
|
||||
pass
|
||||
assert "Unexpected keyword arguments" in str(excinfo.value)
|
||||
|
|
Loading…
Reference in New Issue