Merge pull request #5593 from bluetech/type-annotations-1
Type-annotate pytest.{exit,skip,fail,xfail,importorskip,warns,raises}
This commit is contained in:
commit
faf222f8fb
|
@ -5,6 +5,13 @@ import traceback
|
||||||
from inspect import CO_VARARGS
|
from inspect import CO_VARARGS
|
||||||
from inspect import CO_VARKEYWORDS
|
from inspect import CO_VARKEYWORDS
|
||||||
from traceback import format_exception_only
|
from traceback import format_exception_only
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Generic
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Pattern
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import TypeVar
|
||||||
|
from typing import Union
|
||||||
from weakref import ref
|
from weakref import ref
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -15,6 +22,9 @@ import _pytest
|
||||||
from _pytest._io.saferepr import safeformat
|
from _pytest._io.saferepr import safeformat
|
||||||
from _pytest._io.saferepr import saferepr
|
from _pytest._io.saferepr import saferepr
|
||||||
|
|
||||||
|
if False: # TYPE_CHECKING
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
|
||||||
class Code:
|
class Code:
|
||||||
""" wrapper around Python code objects """
|
""" wrapper around Python code objects """
|
||||||
|
@ -371,20 +381,52 @@ co_equal = compile(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_E = TypeVar("_E", bound=BaseException)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(repr=False)
|
@attr.s(repr=False)
|
||||||
class ExceptionInfo:
|
class ExceptionInfo(Generic[_E]):
|
||||||
""" wraps sys.exc_info() objects and offers
|
""" wraps sys.exc_info() objects and offers
|
||||||
help for navigating the traceback.
|
help for navigating the traceback.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_assert_start_repr = "AssertionError('assert "
|
_assert_start_repr = "AssertionError('assert "
|
||||||
|
|
||||||
_excinfo = attr.ib()
|
_excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]])
|
||||||
_striptext = attr.ib(default="")
|
_striptext = attr.ib(type=str, default="")
|
||||||
_traceback = attr.ib(default=None)
|
_traceback = attr.ib(type=Optional[Traceback], default=None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_current(cls, exprinfo=None):
|
def from_exc_info(
|
||||||
|
cls,
|
||||||
|
exc_info: Tuple["Type[_E]", "_E", TracebackType],
|
||||||
|
exprinfo: Optional[str] = None,
|
||||||
|
) -> "ExceptionInfo[_E]":
|
||||||
|
"""returns an ExceptionInfo for an existing exc_info tuple.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Experimental API
|
||||||
|
|
||||||
|
|
||||||
|
:param exprinfo: a text string helping to determine if we should
|
||||||
|
strip ``AssertionError`` from the output, defaults
|
||||||
|
to the exception message/``__str__()``
|
||||||
|
"""
|
||||||
|
_striptext = ""
|
||||||
|
if exprinfo is None and isinstance(exc_info[1], AssertionError):
|
||||||
|
exprinfo = getattr(exc_info[1], "msg", None)
|
||||||
|
if exprinfo is None:
|
||||||
|
exprinfo = saferepr(exc_info[1])
|
||||||
|
if exprinfo and exprinfo.startswith(cls._assert_start_repr):
|
||||||
|
_striptext = "AssertionError: "
|
||||||
|
|
||||||
|
return cls(exc_info, _striptext)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_current(
|
||||||
|
cls, exprinfo: Optional[str] = None
|
||||||
|
) -> "ExceptionInfo[BaseException]":
|
||||||
"""returns an ExceptionInfo matching the current traceback
|
"""returns an ExceptionInfo matching the current traceback
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
@ -398,59 +440,71 @@ class ExceptionInfo:
|
||||||
"""
|
"""
|
||||||
tup = sys.exc_info()
|
tup = sys.exc_info()
|
||||||
assert tup[0] is not None, "no current exception"
|
assert tup[0] is not None, "no current exception"
|
||||||
_striptext = ""
|
assert tup[1] is not None, "no current exception"
|
||||||
if exprinfo is None and isinstance(tup[1], AssertionError):
|
assert tup[2] is not None, "no current exception"
|
||||||
exprinfo = getattr(tup[1], "msg", None)
|
exc_info = (tup[0], tup[1], tup[2])
|
||||||
if exprinfo is None:
|
return cls.from_exc_info(exc_info)
|
||||||
exprinfo = saferepr(tup[1])
|
|
||||||
if exprinfo and exprinfo.startswith(cls._assert_start_repr):
|
|
||||||
_striptext = "AssertionError: "
|
|
||||||
|
|
||||||
return cls(tup, _striptext)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_later(cls):
|
def for_later(cls) -> "ExceptionInfo[_E]":
|
||||||
"""return an unfilled ExceptionInfo
|
"""return an unfilled ExceptionInfo
|
||||||
"""
|
"""
|
||||||
return cls(None)
|
return cls(None)
|
||||||
|
|
||||||
|
def fill_unfilled(self, exc_info: Tuple["Type[_E]", _E, TracebackType]) -> None:
|
||||||
|
"""fill an unfilled ExceptionInfo created with for_later()"""
|
||||||
|
assert self._excinfo is None, "ExceptionInfo was already filled"
|
||||||
|
self._excinfo = exc_info
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self):
|
def type(self) -> "Type[_E]":
|
||||||
"""the exception class"""
|
"""the exception class"""
|
||||||
|
assert (
|
||||||
|
self._excinfo is not None
|
||||||
|
), ".type can only be used after the context manager exits"
|
||||||
return self._excinfo[0]
|
return self._excinfo[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value(self):
|
def value(self) -> _E:
|
||||||
"""the exception value"""
|
"""the exception value"""
|
||||||
|
assert (
|
||||||
|
self._excinfo is not None
|
||||||
|
), ".value can only be used after the context manager exits"
|
||||||
return self._excinfo[1]
|
return self._excinfo[1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tb(self):
|
def tb(self) -> TracebackType:
|
||||||
"""the exception raw traceback"""
|
"""the exception raw traceback"""
|
||||||
|
assert (
|
||||||
|
self._excinfo is not None
|
||||||
|
), ".tb can only be used after the context manager exits"
|
||||||
return self._excinfo[2]
|
return self._excinfo[2]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def typename(self):
|
def typename(self) -> str:
|
||||||
"""the type name of the exception"""
|
"""the type name of the exception"""
|
||||||
|
assert (
|
||||||
|
self._excinfo is not None
|
||||||
|
), ".typename can only be used after the context manager exits"
|
||||||
return self.type.__name__
|
return self.type.__name__
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def traceback(self):
|
def traceback(self) -> Traceback:
|
||||||
"""the traceback"""
|
"""the traceback"""
|
||||||
if self._traceback is None:
|
if self._traceback is None:
|
||||||
self._traceback = Traceback(self.tb, excinfo=ref(self))
|
self._traceback = Traceback(self.tb, excinfo=ref(self))
|
||||||
return self._traceback
|
return self._traceback
|
||||||
|
|
||||||
@traceback.setter
|
@traceback.setter
|
||||||
def traceback(self, value):
|
def traceback(self, value: Traceback) -> None:
|
||||||
self._traceback = value
|
self._traceback = value
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
if self._excinfo is None:
|
if self._excinfo is None:
|
||||||
return "<ExceptionInfo for raises contextmanager>"
|
return "<ExceptionInfo for raises contextmanager>"
|
||||||
return "<ExceptionInfo %s tblen=%d>" % (self.typename, len(self.traceback))
|
return "<ExceptionInfo %s tblen=%d>" % (self.typename, len(self.traceback))
|
||||||
|
|
||||||
def exconly(self, tryshort=False):
|
def exconly(self, tryshort: bool = False) -> str:
|
||||||
""" return the exception as a string
|
""" return the exception as a string
|
||||||
|
|
||||||
when 'tryshort' resolves to True, and the exception is a
|
when 'tryshort' resolves to True, and the exception is a
|
||||||
|
@ -466,11 +520,11 @@ class ExceptionInfo:
|
||||||
text = text[len(self._striptext) :]
|
text = text[len(self._striptext) :]
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def errisinstance(self, exc):
|
def errisinstance(self, exc: "Type[BaseException]") -> bool:
|
||||||
""" return True if the exception is an instance of exc """
|
""" return True if the exception is an instance of exc """
|
||||||
return isinstance(self.value, exc)
|
return isinstance(self.value, exc)
|
||||||
|
|
||||||
def _getreprcrash(self):
|
def _getreprcrash(self) -> "ReprFileLocation":
|
||||||
exconly = self.exconly(tryshort=True)
|
exconly = self.exconly(tryshort=True)
|
||||||
entry = self.traceback.getcrashentry()
|
entry = self.traceback.getcrashentry()
|
||||||
path, lineno = entry.frame.code.raw.co_filename, entry.lineno
|
path, lineno = entry.frame.code.raw.co_filename, entry.lineno
|
||||||
|
@ -478,13 +532,13 @@ class ExceptionInfo:
|
||||||
|
|
||||||
def getrepr(
|
def getrepr(
|
||||||
self,
|
self,
|
||||||
showlocals=False,
|
showlocals: bool = False,
|
||||||
style="long",
|
style: str = "long",
|
||||||
abspath=False,
|
abspath: bool = False,
|
||||||
tbfilter=True,
|
tbfilter: bool = True,
|
||||||
funcargs=False,
|
funcargs: bool = False,
|
||||||
truncate_locals=True,
|
truncate_locals: bool = True,
|
||||||
chain=True,
|
chain: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Return str()able representation of this exception info.
|
Return str()able representation of this exception info.
|
||||||
|
@ -535,7 +589,7 @@ class ExceptionInfo:
|
||||||
)
|
)
|
||||||
return fmt.repr_excinfo(self)
|
return fmt.repr_excinfo(self)
|
||||||
|
|
||||||
def match(self, regexp):
|
def match(self, regexp: Union[str, Pattern]) -> bool:
|
||||||
"""
|
"""
|
||||||
Check whether the regular expression 'regexp' is found in the string
|
Check whether the regular expression 'regexp' is found in the string
|
||||||
representation of the exception using ``re.search``. If it matches
|
representation of the exception using ``re.search``. If it matches
|
||||||
|
|
|
@ -3,21 +3,26 @@ exception classes and constants handling test outcomes
|
||||||
as well as functions creating them
|
as well as functions creating them
|
||||||
"""
|
"""
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
|
if False: # TYPE_CHECKING
|
||||||
|
from typing import NoReturn
|
||||||
|
|
||||||
|
|
||||||
class OutcomeException(BaseException):
|
class OutcomeException(BaseException):
|
||||||
""" OutcomeException and its subclass instances indicate and
|
""" OutcomeException and its subclass instances indicate and
|
||||||
contain info about test and collection outcomes.
|
contain info about test and collection outcomes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, msg=None, pytrace=True):
|
def __init__(self, msg: Optional[str] = None, pytrace: bool = True) -> None:
|
||||||
BaseException.__init__(self, msg)
|
BaseException.__init__(self, msg)
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
self.pytrace = pytrace
|
self.pytrace = pytrace
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
if self.msg:
|
if self.msg:
|
||||||
val = self.msg
|
val = self.msg
|
||||||
if isinstance(val, bytes):
|
if isinstance(val, bytes):
|
||||||
|
@ -36,7 +41,12 @@ class Skipped(OutcomeException):
|
||||||
# in order to have Skipped exception printing shorter/nicer
|
# in order to have Skipped exception printing shorter/nicer
|
||||||
__module__ = "builtins"
|
__module__ = "builtins"
|
||||||
|
|
||||||
def __init__(self, msg=None, pytrace=True, allow_module_level=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
msg: Optional[str] = None,
|
||||||
|
pytrace: bool = True,
|
||||||
|
allow_module_level: bool = False,
|
||||||
|
) -> None:
|
||||||
OutcomeException.__init__(self, msg=msg, pytrace=pytrace)
|
OutcomeException.__init__(self, msg=msg, pytrace=pytrace)
|
||||||
self.allow_module_level = allow_module_level
|
self.allow_module_level = allow_module_level
|
||||||
|
|
||||||
|
@ -50,7 +60,9 @@ class Failed(OutcomeException):
|
||||||
class Exit(Exception):
|
class Exit(Exception):
|
||||||
""" raised for immediate program exits (no tracebacks/summaries)"""
|
""" raised for immediate program exits (no tracebacks/summaries)"""
|
||||||
|
|
||||||
def __init__(self, msg="unknown reason", returncode=None):
|
def __init__(
|
||||||
|
self, msg: str = "unknown reason", returncode: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
self.returncode = returncode
|
self.returncode = returncode
|
||||||
super().__init__(msg)
|
super().__init__(msg)
|
||||||
|
@ -59,7 +71,7 @@ class Exit(Exception):
|
||||||
# exposed helper methods
|
# exposed helper methods
|
||||||
|
|
||||||
|
|
||||||
def exit(msg, returncode=None):
|
def exit(msg: str, returncode: Optional[int] = None) -> "NoReturn":
|
||||||
"""
|
"""
|
||||||
Exit testing process.
|
Exit testing process.
|
||||||
|
|
||||||
|
@ -74,7 +86,7 @@ def exit(msg, returncode=None):
|
||||||
exit.Exception = Exit # type: ignore
|
exit.Exception = Exit # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def skip(msg="", *, allow_module_level=False):
|
def skip(msg: str = "", *, allow_module_level: bool = False) -> "NoReturn":
|
||||||
"""
|
"""
|
||||||
Skip an executing test with the given message.
|
Skip an executing test with the given message.
|
||||||
|
|
||||||
|
@ -101,7 +113,7 @@ def skip(msg="", *, allow_module_level=False):
|
||||||
skip.Exception = Skipped # type: ignore
|
skip.Exception = Skipped # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def fail(msg="", pytrace=True):
|
def fail(msg: str = "", pytrace: bool = True) -> "NoReturn":
|
||||||
"""
|
"""
|
||||||
Explicitly fail an executing test with the given message.
|
Explicitly fail an executing test with the given message.
|
||||||
|
|
||||||
|
@ -121,7 +133,7 @@ class XFailed(Failed):
|
||||||
""" raised from an explicit call to pytest.xfail() """
|
""" raised from an explicit call to pytest.xfail() """
|
||||||
|
|
||||||
|
|
||||||
def xfail(reason=""):
|
def xfail(reason: str = "") -> "NoReturn":
|
||||||
"""
|
"""
|
||||||
Imperatively xfail an executing test or setup functions with the given reason.
|
Imperatively xfail an executing test or setup functions with the given reason.
|
||||||
|
|
||||||
|
@ -139,7 +151,9 @@ def xfail(reason=""):
|
||||||
xfail.Exception = XFailed # type: ignore
|
xfail.Exception = XFailed # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def importorskip(modname, minversion=None, reason=None):
|
def importorskip(
|
||||||
|
modname: str, minversion: Optional[str] = None, reason: Optional[str] = None
|
||||||
|
) -> Any:
|
||||||
"""Imports and returns the requested module ``modname``, or skip the current test
|
"""Imports and returns the requested module ``modname``, or skip the current test
|
||||||
if the module cannot be imported.
|
if the module cannot be imported.
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,16 @@ from collections.abc import Sized
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from itertools import filterfalse
|
from itertools import filterfalse
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
from typing import cast
|
||||||
|
from typing import Generic
|
||||||
|
from typing import Optional
|
||||||
|
from typing import overload
|
||||||
|
from typing import Pattern
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import TypeVar
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from more_itertools.more import always_iterable
|
from more_itertools.more import always_iterable
|
||||||
|
@ -15,6 +25,9 @@ import _pytest._code
|
||||||
from _pytest.compat import STRING_TYPES
|
from _pytest.compat import STRING_TYPES
|
||||||
from _pytest.outcomes import fail
|
from _pytest.outcomes import fail
|
||||||
|
|
||||||
|
if False: # TYPE_CHECKING
|
||||||
|
from typing import Type # noqa: F401 (used in type string)
|
||||||
|
|
||||||
BASE_TYPE = (type, STRING_TYPES)
|
BASE_TYPE = (type, STRING_TYPES)
|
||||||
|
|
||||||
|
|
||||||
|
@ -527,8 +540,35 @@ def _is_numpy_array(obj):
|
||||||
|
|
||||||
# builtin pytest.raises helper
|
# builtin pytest.raises helper
|
||||||
|
|
||||||
|
_E = TypeVar("_E", bound=BaseException)
|
||||||
|
|
||||||
def raises(expected_exception, *args, match=None, **kwargs):
|
|
||||||
|
@overload
|
||||||
|
def raises(
|
||||||
|
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||||
|
*,
|
||||||
|
match: Optional[Union[str, Pattern]] = ...
|
||||||
|
) -> "RaisesContext[_E]":
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def raises(
|
||||||
|
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||||
|
func: Callable,
|
||||||
|
*args: Any,
|
||||||
|
match: Optional[str] = ...,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Optional[_pytest._code.ExceptionInfo[_E]]:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
def raises(
|
||||||
|
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||||
|
*args: Any,
|
||||||
|
match: Optional[Union[str, Pattern]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
|
||||||
r"""
|
r"""
|
||||||
Assert that a code block/function call raises ``expected_exception``
|
Assert that a code block/function call raises ``expected_exception``
|
||||||
or raise a failure exception otherwise.
|
or raise a failure exception otherwise.
|
||||||
|
@ -647,18 +687,18 @@ def raises(expected_exception, *args, match=None, **kwargs):
|
||||||
for exc in filterfalse(
|
for exc in filterfalse(
|
||||||
inspect.isclass, always_iterable(expected_exception, BASE_TYPE)
|
inspect.isclass, always_iterable(expected_exception, BASE_TYPE)
|
||||||
):
|
):
|
||||||
msg = (
|
msg = "exceptions must be derived from BaseException, not %s"
|
||||||
"exceptions must be old-style classes or"
|
|
||||||
" derived from BaseException, not %s"
|
|
||||||
)
|
|
||||||
raise TypeError(msg % type(exc))
|
raise TypeError(msg % type(exc))
|
||||||
|
|
||||||
message = "DID NOT RAISE {}".format(expected_exception)
|
message = "DID NOT RAISE {}".format(expected_exception)
|
||||||
|
|
||||||
if not args:
|
if not args:
|
||||||
return RaisesContext(
|
if kwargs:
|
||||||
expected_exception, message=message, match_expr=match, **kwargs
|
msg = "Unexpected keyword arguments passed to pytest.raises: "
|
||||||
)
|
msg += ", ".join(sorted(kwargs))
|
||||||
|
msg += "\nUse context-manager form instead?"
|
||||||
|
raise TypeError(msg)
|
||||||
|
return RaisesContext(expected_exception, message, match)
|
||||||
else:
|
else:
|
||||||
func = args[0]
|
func = args[0]
|
||||||
if not callable(func):
|
if not callable(func):
|
||||||
|
@ -667,31 +707,51 @@ def raises(expected_exception, *args, match=None, **kwargs):
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
func(*args[1:], **kwargs)
|
func(*args[1:], **kwargs)
|
||||||
except expected_exception:
|
except expected_exception as e:
|
||||||
return _pytest._code.ExceptionInfo.from_current()
|
# We just caught the exception - there is a traceback.
|
||||||
|
assert e.__traceback__ is not None
|
||||||
|
return _pytest._code.ExceptionInfo.from_exc_info(
|
||||||
|
(type(e), e, e.__traceback__)
|
||||||
|
)
|
||||||
fail(message)
|
fail(message)
|
||||||
|
|
||||||
|
|
||||||
raises.Exception = fail.Exception # type: ignore
|
raises.Exception = fail.Exception # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class RaisesContext:
|
class RaisesContext(Generic[_E]):
|
||||||
def __init__(self, expected_exception, message, match_expr):
|
def __init__(
|
||||||
|
self,
|
||||||
|
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||||
|
message: str,
|
||||||
|
match_expr: Optional[Union[str, Pattern]] = None,
|
||||||
|
) -> None:
|
||||||
self.expected_exception = expected_exception
|
self.expected_exception = expected_exception
|
||||||
self.message = message
|
self.message = message
|
||||||
self.match_expr = match_expr
|
self.match_expr = match_expr
|
||||||
self.excinfo = None
|
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
|
||||||
self.excinfo = _pytest._code.ExceptionInfo.for_later()
|
self.excinfo = _pytest._code.ExceptionInfo.for_later()
|
||||||
return self.excinfo
|
return self.excinfo
|
||||||
|
|
||||||
def __exit__(self, *tp):
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional["Type[BaseException]"],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> bool:
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
if tp[0] is None:
|
if exc_type is None:
|
||||||
fail(self.message)
|
fail(self.message)
|
||||||
self.excinfo.__init__(tp)
|
assert self.excinfo is not None
|
||||||
suppress_exception = issubclass(self.excinfo.type, self.expected_exception)
|
if not issubclass(exc_type, self.expected_exception):
|
||||||
if self.match_expr is not None and suppress_exception:
|
return False
|
||||||
|
# Cast to narrow the exception type now that it's verified.
|
||||||
|
exc_info = cast(
|
||||||
|
Tuple["Type[_E]", _E, TracebackType], (exc_type, exc_val, exc_tb)
|
||||||
|
)
|
||||||
|
self.excinfo.fill_unfilled(exc_info)
|
||||||
|
if self.match_expr is not None:
|
||||||
self.excinfo.match(self.match_expr)
|
self.excinfo.match(self.match_expr)
|
||||||
return suppress_exception
|
return True
|
||||||
|
|
|
@ -1,11 +1,23 @@
|
||||||
""" recording warnings during test function execution. """
|
""" recording warnings during test function execution. """
|
||||||
import inspect
|
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
from typing import Iterator
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import overload
|
||||||
|
from typing import Pattern
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from _pytest.fixtures import yield_fixture
|
from _pytest.fixtures import yield_fixture
|
||||||
from _pytest.outcomes import fail
|
from _pytest.outcomes import fail
|
||||||
|
|
||||||
|
if False: # TYPE_CHECKING
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
|
||||||
@yield_fixture
|
@yield_fixture
|
||||||
def recwarn():
|
def recwarn():
|
||||||
|
@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs):
|
||||||
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)
|
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def warns(expected_warning, *args, match=None, **kwargs):
|
@overload
|
||||||
|
def warns(
|
||||||
|
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||||
|
*,
|
||||||
|
match: Optional[Union[str, Pattern]] = ...
|
||||||
|
) -> "WarningsChecker":
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def warns(
|
||||||
|
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||||
|
func: Callable,
|
||||||
|
*args: Any,
|
||||||
|
match: Optional[Union[str, Pattern]] = ...,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union[Any]:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
def warns(
|
||||||
|
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||||
|
*args: Any,
|
||||||
|
match: Optional[Union[str, Pattern]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union["WarningsChecker", Any]:
|
||||||
r"""Assert that code raises a particular class of warning.
|
r"""Assert that code raises a particular class of warning.
|
||||||
|
|
||||||
Specifically, the parameter ``expected_warning`` can be a warning class or
|
Specifically, the parameter ``expected_warning`` can be a warning class or
|
||||||
|
@ -76,7 +113,12 @@ def warns(expected_warning, *args, match=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
if not args:
|
if not args:
|
||||||
return WarningsChecker(expected_warning, match_expr=match, **kwargs)
|
if kwargs:
|
||||||
|
msg = "Unexpected keyword arguments passed to pytest.warns: "
|
||||||
|
msg += ", ".join(sorted(kwargs))
|
||||||
|
msg += "\nUse context-manager form instead?"
|
||||||
|
raise TypeError(msg)
|
||||||
|
return WarningsChecker(expected_warning, match_expr=match)
|
||||||
else:
|
else:
|
||||||
func = args[0]
|
func = args[0]
|
||||||
if not callable(func):
|
if not callable(func):
|
||||||
|
@ -96,26 +138,26 @@ class WarningsRecorder(warnings.catch_warnings):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(record=True)
|
super().__init__(record=True)
|
||||||
self._entered = False
|
self._entered = False
|
||||||
self._list = []
|
self._list = [] # type: List[warnings._Record]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def list(self):
|
def list(self) -> List["warnings._Record"]:
|
||||||
"""The list of recorded warnings."""
|
"""The list of recorded warnings."""
|
||||||
return self._list
|
return self._list
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i: int) -> "warnings._Record":
|
||||||
"""Get a recorded warning by index."""
|
"""Get a recorded warning by index."""
|
||||||
return self._list[i]
|
return self._list[i]
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator["warnings._Record"]:
|
||||||
"""Iterate through the recorded warnings."""
|
"""Iterate through the recorded warnings."""
|
||||||
return iter(self._list)
|
return iter(self._list)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
"""The number of recorded warnings."""
|
"""The number of recorded warnings."""
|
||||||
return len(self._list)
|
return len(self._list)
|
||||||
|
|
||||||
def pop(self, cls=Warning):
|
def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record":
|
||||||
"""Pop the first recorded warning, raise exception if not exists."""
|
"""Pop the first recorded warning, raise exception if not exists."""
|
||||||
for i, w in enumerate(self._list):
|
for i, w in enumerate(self._list):
|
||||||
if issubclass(w.category, cls):
|
if issubclass(w.category, cls):
|
||||||
|
@ -123,54 +165,80 @@ class WarningsRecorder(warnings.catch_warnings):
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
raise AssertionError("%r not found in warning list" % cls)
|
raise AssertionError("%r not found in warning list" % cls)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self) -> None:
|
||||||
"""Clear the list of recorded warnings."""
|
"""Clear the list of recorded warnings."""
|
||||||
self._list[:] = []
|
self._list[:] = []
|
||||||
|
|
||||||
def __enter__(self):
|
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
|
||||||
|
# -- it returns a List but we only emulate one.
|
||||||
|
def __enter__(self) -> "WarningsRecorder": # type: ignore
|
||||||
if self._entered:
|
if self._entered:
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
raise RuntimeError("Cannot enter %r twice" % self)
|
raise RuntimeError("Cannot enter %r twice" % self)
|
||||||
self._list = super().__enter__()
|
_list = super().__enter__()
|
||||||
|
# record=True means it's None.
|
||||||
|
assert _list is not None
|
||||||
|
self._list = _list
|
||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *exc_info):
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional["Type[BaseException]"],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> bool:
|
||||||
if not self._entered:
|
if not self._entered:
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
raise RuntimeError("Cannot exit %r without entering first" % self)
|
raise RuntimeError("Cannot exit %r without entering first" % self)
|
||||||
|
|
||||||
super().__exit__(*exc_info)
|
super().__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
# Built-in catch_warnings does not reset entered state so we do it
|
# Built-in catch_warnings does not reset entered state so we do it
|
||||||
# manually here for this context manager to become reusable.
|
# manually here for this context manager to become reusable.
|
||||||
self._entered = False
|
self._entered = False
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class WarningsChecker(WarningsRecorder):
|
class WarningsChecker(WarningsRecorder):
|
||||||
def __init__(self, expected_warning=None, match_expr=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
expected_warning: Optional[
|
||||||
|
Union["Type[Warning]", Tuple["Type[Warning]", ...]]
|
||||||
|
] = None,
|
||||||
|
match_expr: Optional[Union[str, Pattern]] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
msg = "exceptions must be old-style classes or derived from Warning, not %s"
|
msg = "exceptions must be derived from Warning, not %s"
|
||||||
if isinstance(expected_warning, tuple):
|
if expected_warning is None:
|
||||||
|
expected_warning_tup = None
|
||||||
|
elif isinstance(expected_warning, tuple):
|
||||||
for exc in expected_warning:
|
for exc in expected_warning:
|
||||||
if not inspect.isclass(exc):
|
if not issubclass(exc, Warning):
|
||||||
raise TypeError(msg % type(exc))
|
raise TypeError(msg % type(exc))
|
||||||
elif inspect.isclass(expected_warning):
|
expected_warning_tup = expected_warning
|
||||||
expected_warning = (expected_warning,)
|
elif issubclass(expected_warning, Warning):
|
||||||
elif expected_warning is not None:
|
expected_warning_tup = (expected_warning,)
|
||||||
|
else:
|
||||||
raise TypeError(msg % type(expected_warning))
|
raise TypeError(msg % type(expected_warning))
|
||||||
|
|
||||||
self.expected_warning = expected_warning
|
self.expected_warning = expected_warning_tup
|
||||||
self.match_expr = match_expr
|
self.match_expr = match_expr
|
||||||
|
|
||||||
def __exit__(self, *exc_info):
|
def __exit__(
|
||||||
super().__exit__(*exc_info)
|
self,
|
||||||
|
exc_type: Optional["Type[BaseException]"],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> bool:
|
||||||
|
super().__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
|
|
||||||
# only check if we're not currently handling an exception
|
# only check if we're not currently handling an exception
|
||||||
if all(a is None for a in exc_info):
|
if exc_type is None and exc_val is None and exc_tb is None:
|
||||||
if self.expected_warning is not None:
|
if self.expected_warning is not None:
|
||||||
if not any(issubclass(r.category, self.expected_warning) for r in self):
|
if not any(issubclass(r.category, self.expected_warning) for r in self):
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
|
@ -195,3 +263,4 @@ class WarningsChecker(WarningsRecorder):
|
||||||
[each.message for each in self],
|
[each.message for each in self],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
return False
|
||||||
|
|
|
@ -58,7 +58,7 @@ class TWMock:
|
||||||
fullwidth = 80
|
fullwidth = 80
|
||||||
|
|
||||||
|
|
||||||
def test_excinfo_simple():
|
def test_excinfo_simple() -> None:
|
||||||
try:
|
try:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -66,6 +66,14 @@ def test_excinfo_simple():
|
||||||
assert info.type == ValueError
|
assert info.type == ValueError
|
||||||
|
|
||||||
|
|
||||||
|
def test_excinfo_from_exc_info_simple():
|
||||||
|
try:
|
||||||
|
raise ValueError
|
||||||
|
except ValueError as e:
|
||||||
|
info = _pytest._code.ExceptionInfo.from_exc_info((type(e), e, e.__traceback__))
|
||||||
|
assert info.type == ValueError
|
||||||
|
|
||||||
|
|
||||||
def test_excinfo_getstatement():
|
def test_excinfo_getstatement():
|
||||||
def g():
|
def g():
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
|
@ -248,3 +248,9 @@ class TestRaises:
|
||||||
with pytest.raises(CrappyClass()):
|
with pytest.raises(CrappyClass()):
|
||||||
pass
|
pass
|
||||||
assert "via __class__" in excinfo.value.args[0]
|
assert "via __class__" in excinfo.value.args[0]
|
||||||
|
|
||||||
|
def test_raises_context_manager_with_kwargs(self):
|
||||||
|
with pytest.raises(TypeError) as excinfo:
|
||||||
|
with pytest.raises(Exception, foo="bar"):
|
||||||
|
pass
|
||||||
|
assert "Unexpected keyword arguments" in str(excinfo.value)
|
||||||
|
|
|
@ -374,3 +374,9 @@ class TestWarns:
|
||||||
assert f() == 10
|
assert f() == 10
|
||||||
assert pytest.warns(UserWarning, f) == 10
|
assert pytest.warns(UserWarning, f) == 10
|
||||||
assert pytest.warns(UserWarning, f) == 10
|
assert pytest.warns(UserWarning, f) == 10
|
||||||
|
|
||||||
|
def test_warns_context_manager_with_kwargs(self):
|
||||||
|
with pytest.raises(TypeError) as excinfo:
|
||||||
|
with pytest.warns(UserWarning, foo="bar"):
|
||||||
|
pass
|
||||||
|
assert "Unexpected keyword arguments" in str(excinfo.value)
|
||||||
|
|
Loading…
Reference in New Issue