Merge pull request #5593 from bluetech/type-annotations-1

Type-annotate pytest.{exit,skip,fail,xfail,importorskip,warns,raises}
This commit is contained in:
Ran Benita 2019-07-16 22:38:20 +03:00 committed by GitHub
commit faf222f8fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 306 additions and 89 deletions

View File

@ -5,6 +5,13 @@ import traceback
from inspect import CO_VARARGS
from inspect import CO_VARKEYWORDS
from traceback import format_exception_only
from types import TracebackType
from typing import Generic
from typing import Optional
from typing import Pattern
from typing import Tuple
from typing import TypeVar
from typing import Union
from weakref import ref
import attr
@ -15,6 +22,9 @@ import _pytest
from _pytest._io.saferepr import safeformat
from _pytest._io.saferepr import saferepr
if False: # TYPE_CHECKING
from typing import Type
class Code:
""" wrapper around Python code objects """
@ -371,20 +381,52 @@ co_equal = compile(
)
_E = TypeVar("_E", bound=BaseException)
@attr.s(repr=False)
class ExceptionInfo:
class ExceptionInfo(Generic[_E]):
""" wraps sys.exc_info() objects and offers
help for navigating the traceback.
"""
_assert_start_repr = "AssertionError('assert "
_excinfo = attr.ib()
_striptext = attr.ib(default="")
_traceback = attr.ib(default=None)
_excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]])
_striptext = attr.ib(type=str, default="")
_traceback = attr.ib(type=Optional[Traceback], default=None)
@classmethod
def from_current(cls, exprinfo=None):
def from_exc_info(
cls,
exc_info: Tuple["Type[_E]", "_E", TracebackType],
exprinfo: Optional[str] = None,
) -> "ExceptionInfo[_E]":
"""returns an ExceptionInfo for an existing exc_info tuple.
.. warning::
Experimental API
:param exprinfo: a text string helping to determine if we should
strip ``AssertionError`` from the output, defaults
to the exception message/``__str__()``
"""
_striptext = ""
if exprinfo is None and isinstance(exc_info[1], AssertionError):
exprinfo = getattr(exc_info[1], "msg", None)
if exprinfo is None:
exprinfo = saferepr(exc_info[1])
if exprinfo and exprinfo.startswith(cls._assert_start_repr):
_striptext = "AssertionError: "
return cls(exc_info, _striptext)
@classmethod
def from_current(
cls, exprinfo: Optional[str] = None
) -> "ExceptionInfo[BaseException]":
"""returns an ExceptionInfo matching the current traceback
.. warning::
@ -398,59 +440,71 @@ class ExceptionInfo:
"""
tup = sys.exc_info()
assert tup[0] is not None, "no current exception"
_striptext = ""
if exprinfo is None and isinstance(tup[1], AssertionError):
exprinfo = getattr(tup[1], "msg", None)
if exprinfo is None:
exprinfo = saferepr(tup[1])
if exprinfo and exprinfo.startswith(cls._assert_start_repr):
_striptext = "AssertionError: "
return cls(tup, _striptext)
assert tup[1] is not None, "no current exception"
assert tup[2] is not None, "no current exception"
exc_info = (tup[0], tup[1], tup[2])
return cls.from_exc_info(exc_info)
@classmethod
def for_later(cls):
def for_later(cls) -> "ExceptionInfo[_E]":
"""return an unfilled ExceptionInfo
"""
return cls(None)
def fill_unfilled(self, exc_info: Tuple["Type[_E]", _E, TracebackType]) -> None:
"""fill an unfilled ExceptionInfo created with for_later()"""
assert self._excinfo is None, "ExceptionInfo was already filled"
self._excinfo = exc_info
@property
def type(self):
def type(self) -> "Type[_E]":
"""the exception class"""
assert (
self._excinfo is not None
), ".type can only be used after the context manager exits"
return self._excinfo[0]
@property
def value(self):
def value(self) -> _E:
"""the exception value"""
assert (
self._excinfo is not None
), ".value can only be used after the context manager exits"
return self._excinfo[1]
@property
def tb(self):
def tb(self) -> TracebackType:
"""the exception raw traceback"""
assert (
self._excinfo is not None
), ".tb can only be used after the context manager exits"
return self._excinfo[2]
@property
def typename(self):
def typename(self) -> str:
"""the type name of the exception"""
assert (
self._excinfo is not None
), ".typename can only be used after the context manager exits"
return self.type.__name__
@property
def traceback(self):
def traceback(self) -> Traceback:
"""the traceback"""
if self._traceback is None:
self._traceback = Traceback(self.tb, excinfo=ref(self))
return self._traceback
@traceback.setter
def traceback(self, value):
def traceback(self, value: Traceback) -> None:
self._traceback = value
def __repr__(self):
def __repr__(self) -> str:
if self._excinfo is None:
return "<ExceptionInfo for raises contextmanager>"
return "<ExceptionInfo %s tblen=%d>" % (self.typename, len(self.traceback))
def exconly(self, tryshort=False):
def exconly(self, tryshort: bool = False) -> str:
""" return the exception as a string
when 'tryshort' resolves to True, and the exception is a
@ -466,11 +520,11 @@ class ExceptionInfo:
text = text[len(self._striptext) :]
return text
def errisinstance(self, exc):
def errisinstance(self, exc: "Type[BaseException]") -> bool:
""" return True if the exception is an instance of exc """
return isinstance(self.value, exc)
def _getreprcrash(self):
def _getreprcrash(self) -> "ReprFileLocation":
exconly = self.exconly(tryshort=True)
entry = self.traceback.getcrashentry()
path, lineno = entry.frame.code.raw.co_filename, entry.lineno
@ -478,13 +532,13 @@ class ExceptionInfo:
def getrepr(
self,
showlocals=False,
style="long",
abspath=False,
tbfilter=True,
funcargs=False,
truncate_locals=True,
chain=True,
showlocals: bool = False,
style: str = "long",
abspath: bool = False,
tbfilter: bool = True,
funcargs: bool = False,
truncate_locals: bool = True,
chain: bool = True,
):
"""
Return str()able representation of this exception info.
@ -535,7 +589,7 @@ class ExceptionInfo:
)
return fmt.repr_excinfo(self)
def match(self, regexp):
def match(self, regexp: Union[str, Pattern]) -> bool:
"""
Check whether the regular expression 'regexp' is found in the string
representation of the exception using ``re.search``. If it matches

View File

@ -3,21 +3,26 @@ exception classes and constants handling test outcomes
as well as functions creating them
"""
import sys
from typing import Any
from typing import Optional
from packaging.version import Version
if False: # TYPE_CHECKING
from typing import NoReturn
class OutcomeException(BaseException):
""" OutcomeException and its subclass instances indicate and
contain info about test and collection outcomes.
"""
def __init__(self, msg=None, pytrace=True):
def __init__(self, msg: Optional[str] = None, pytrace: bool = True) -> None:
BaseException.__init__(self, msg)
self.msg = msg
self.pytrace = pytrace
def __repr__(self):
def __repr__(self) -> str:
if self.msg:
val = self.msg
if isinstance(val, bytes):
@ -36,7 +41,12 @@ class Skipped(OutcomeException):
# in order to have Skipped exception printing shorter/nicer
__module__ = "builtins"
def __init__(self, msg=None, pytrace=True, allow_module_level=False):
def __init__(
self,
msg: Optional[str] = None,
pytrace: bool = True,
allow_module_level: bool = False,
) -> None:
OutcomeException.__init__(self, msg=msg, pytrace=pytrace)
self.allow_module_level = allow_module_level
@ -50,7 +60,9 @@ class Failed(OutcomeException):
class Exit(Exception):
""" raised for immediate program exits (no tracebacks/summaries)"""
def __init__(self, msg="unknown reason", returncode=None):
def __init__(
self, msg: str = "unknown reason", returncode: Optional[int] = None
) -> None:
self.msg = msg
self.returncode = returncode
super().__init__(msg)
@ -59,7 +71,7 @@ class Exit(Exception):
# exposed helper methods
def exit(msg, returncode=None):
def exit(msg: str, returncode: Optional[int] = None) -> "NoReturn":
"""
Exit testing process.
@ -74,7 +86,7 @@ def exit(msg, returncode=None):
exit.Exception = Exit # type: ignore
def skip(msg="", *, allow_module_level=False):
def skip(msg: str = "", *, allow_module_level: bool = False) -> "NoReturn":
"""
Skip an executing test with the given message.
@ -101,7 +113,7 @@ def skip(msg="", *, allow_module_level=False):
skip.Exception = Skipped # type: ignore
def fail(msg="", pytrace=True):
def fail(msg: str = "", pytrace: bool = True) -> "NoReturn":
"""
Explicitly fail an executing test with the given message.
@ -121,7 +133,7 @@ class XFailed(Failed):
""" raised from an explicit call to pytest.xfail() """
def xfail(reason=""):
def xfail(reason: str = "") -> "NoReturn":
"""
Imperatively xfail an executing test or setup functions with the given reason.
@ -139,7 +151,9 @@ def xfail(reason=""):
xfail.Exception = XFailed # type: ignore
def importorskip(modname, minversion=None, reason=None):
def importorskip(
modname: str, minversion: Optional[str] = None, reason: Optional[str] = None
) -> Any:
"""Imports and returns the requested module ``modname``, or skip the current test
if the module cannot be imported.

View File

@ -7,6 +7,16 @@ from collections.abc import Sized
from decimal import Decimal
from itertools import filterfalse
from numbers import Number
from types import TracebackType
from typing import Any
from typing import Callable
from typing import cast
from typing import Generic
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
from typing import TypeVar
from typing import Union
from more_itertools.more import always_iterable
@ -15,6 +25,9 @@ import _pytest._code
from _pytest.compat import STRING_TYPES
from _pytest.outcomes import fail
if False: # TYPE_CHECKING
from typing import Type # noqa: F401 (used in type string)
BASE_TYPE = (type, STRING_TYPES)
@ -527,8 +540,35 @@ def _is_numpy_array(obj):
# builtin pytest.raises helper
_E = TypeVar("_E", bound=BaseException)
def raises(expected_exception, *args, match=None, **kwargs):
@overload
def raises(
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*,
match: Optional[Union[str, Pattern]] = ...
) -> "RaisesContext[_E]":
... # pragma: no cover
@overload
def raises(
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
func: Callable,
*args: Any,
match: Optional[str] = ...,
**kwargs: Any
) -> Optional[_pytest._code.ExceptionInfo[_E]]:
... # pragma: no cover
def raises(
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*args: Any,
match: Optional[Union[str, Pattern]] = None,
**kwargs: Any
) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
r"""
Assert that a code block/function call raises ``expected_exception``
or raise a failure exception otherwise.
@ -647,18 +687,18 @@ def raises(expected_exception, *args, match=None, **kwargs):
for exc in filterfalse(
inspect.isclass, always_iterable(expected_exception, BASE_TYPE)
):
msg = (
"exceptions must be old-style classes or"
" derived from BaseException, not %s"
)
msg = "exceptions must be derived from BaseException, not %s"
raise TypeError(msg % type(exc))
message = "DID NOT RAISE {}".format(expected_exception)
if not args:
return RaisesContext(
expected_exception, message=message, match_expr=match, **kwargs
)
if kwargs:
msg = "Unexpected keyword arguments passed to pytest.raises: "
msg += ", ".join(sorted(kwargs))
msg += "\nUse context-manager form instead?"
raise TypeError(msg)
return RaisesContext(expected_exception, message, match)
else:
func = args[0]
if not callable(func):
@ -667,31 +707,51 @@ def raises(expected_exception, *args, match=None, **kwargs):
)
try:
func(*args[1:], **kwargs)
except expected_exception:
return _pytest._code.ExceptionInfo.from_current()
except expected_exception as e:
# We just caught the exception - there is a traceback.
assert e.__traceback__ is not None
return _pytest._code.ExceptionInfo.from_exc_info(
(type(e), e, e.__traceback__)
)
fail(message)
raises.Exception = fail.Exception # type: ignore
class RaisesContext:
def __init__(self, expected_exception, message, match_expr):
class RaisesContext(Generic[_E]):
def __init__(
self,
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
message: str,
match_expr: Optional[Union[str, Pattern]] = None,
) -> None:
self.expected_exception = expected_exception
self.message = message
self.match_expr = match_expr
self.excinfo = None
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
def __enter__(self):
def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later()
return self.excinfo
def __exit__(self, *tp):
def __exit__(
self,
exc_type: Optional["Type[BaseException]"],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
__tracebackhide__ = True
if tp[0] is None:
if exc_type is None:
fail(self.message)
self.excinfo.__init__(tp)
suppress_exception = issubclass(self.excinfo.type, self.expected_exception)
if self.match_expr is not None and suppress_exception:
assert self.excinfo is not None
if not issubclass(exc_type, self.expected_exception):
return False
# Cast to narrow the exception type now that it's verified.
exc_info = cast(
Tuple["Type[_E]", _E, TracebackType], (exc_type, exc_val, exc_tb)
)
self.excinfo.fill_unfilled(exc_info)
if self.match_expr is not None:
self.excinfo.match(self.match_expr)
return suppress_exception
return True

View File

@ -1,11 +1,23 @@
""" recording warnings during test function execution. """
import inspect
import re
import warnings
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Iterator
from typing import List
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
from typing import Union
from _pytest.fixtures import yield_fixture
from _pytest.outcomes import fail
if False: # TYPE_CHECKING
from typing import Type
@yield_fixture
def recwarn():
@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs):
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)
def warns(expected_warning, *args, match=None, **kwargs):
@overload
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
*,
match: Optional[Union[str, Pattern]] = ...
) -> "WarningsChecker":
... # pragma: no cover
@overload
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
func: Callable,
*args: Any,
match: Optional[Union[str, Pattern]] = ...,
**kwargs: Any
) -> Union[Any]:
... # pragma: no cover
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
*args: Any,
match: Optional[Union[str, Pattern]] = None,
**kwargs: Any
) -> Union["WarningsChecker", Any]:
r"""Assert that code raises a particular class of warning.
Specifically, the parameter ``expected_warning`` can be a warning class or
@ -76,7 +113,12 @@ def warns(expected_warning, *args, match=None, **kwargs):
"""
__tracebackhide__ = True
if not args:
return WarningsChecker(expected_warning, match_expr=match, **kwargs)
if kwargs:
msg = "Unexpected keyword arguments passed to pytest.warns: "
msg += ", ".join(sorted(kwargs))
msg += "\nUse context-manager form instead?"
raise TypeError(msg)
return WarningsChecker(expected_warning, match_expr=match)
else:
func = args[0]
if not callable(func):
@ -96,26 +138,26 @@ class WarningsRecorder(warnings.catch_warnings):
def __init__(self):
super().__init__(record=True)
self._entered = False
self._list = []
self._list = [] # type: List[warnings._Record]
@property
def list(self):
def list(self) -> List["warnings._Record"]:
"""The list of recorded warnings."""
return self._list
def __getitem__(self, i):
def __getitem__(self, i: int) -> "warnings._Record":
"""Get a recorded warning by index."""
return self._list[i]
def __iter__(self):
def __iter__(self) -> Iterator["warnings._Record"]:
"""Iterate through the recorded warnings."""
return iter(self._list)
def __len__(self):
def __len__(self) -> int:
"""The number of recorded warnings."""
return len(self._list)
def pop(self, cls=Warning):
def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record":
"""Pop the first recorded warning, raise exception if not exists."""
for i, w in enumerate(self._list):
if issubclass(w.category, cls):
@ -123,54 +165,80 @@ class WarningsRecorder(warnings.catch_warnings):
__tracebackhide__ = True
raise AssertionError("%r not found in warning list" % cls)
def clear(self):
def clear(self) -> None:
"""Clear the list of recorded warnings."""
self._list[:] = []
def __enter__(self):
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
# -- it returns a List but we only emulate one.
def __enter__(self) -> "WarningsRecorder": # type: ignore
if self._entered:
__tracebackhide__ = True
raise RuntimeError("Cannot enter %r twice" % self)
self._list = super().__enter__()
_list = super().__enter__()
# record=True means it's None.
assert _list is not None
self._list = _list
warnings.simplefilter("always")
return self
def __exit__(self, *exc_info):
def __exit__(
self,
exc_type: Optional["Type[BaseException]"],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
if not self._entered:
__tracebackhide__ = True
raise RuntimeError("Cannot exit %r without entering first" % self)
super().__exit__(*exc_info)
super().__exit__(exc_type, exc_val, exc_tb)
# Built-in catch_warnings does not reset entered state so we do it
# manually here for this context manager to become reusable.
self._entered = False
return False
class WarningsChecker(WarningsRecorder):
def __init__(self, expected_warning=None, match_expr=None):
def __init__(
self,
expected_warning: Optional[
Union["Type[Warning]", Tuple["Type[Warning]", ...]]
] = None,
match_expr: Optional[Union[str, Pattern]] = None,
) -> None:
super().__init__()
msg = "exceptions must be old-style classes or derived from Warning, not %s"
if isinstance(expected_warning, tuple):
msg = "exceptions must be derived from Warning, not %s"
if expected_warning is None:
expected_warning_tup = None
elif isinstance(expected_warning, tuple):
for exc in expected_warning:
if not inspect.isclass(exc):
if not issubclass(exc, Warning):
raise TypeError(msg % type(exc))
elif inspect.isclass(expected_warning):
expected_warning = (expected_warning,)
elif expected_warning is not None:
expected_warning_tup = expected_warning
elif issubclass(expected_warning, Warning):
expected_warning_tup = (expected_warning,)
else:
raise TypeError(msg % type(expected_warning))
self.expected_warning = expected_warning
self.expected_warning = expected_warning_tup
self.match_expr = match_expr
def __exit__(self, *exc_info):
super().__exit__(*exc_info)
def __exit__(
self,
exc_type: Optional["Type[BaseException]"],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
super().__exit__(exc_type, exc_val, exc_tb)
__tracebackhide__ = True
# only check if we're not currently handling an exception
if all(a is None for a in exc_info):
if exc_type is None and exc_val is None and exc_tb is None:
if self.expected_warning is not None:
if not any(issubclass(r.category, self.expected_warning) for r in self):
__tracebackhide__ = True
@ -195,3 +263,4 @@ class WarningsChecker(WarningsRecorder):
[each.message for each in self],
)
)
return False

View File

@ -58,7 +58,7 @@ class TWMock:
fullwidth = 80
def test_excinfo_simple():
def test_excinfo_simple() -> None:
try:
raise ValueError
except ValueError:
@ -66,6 +66,14 @@ def test_excinfo_simple():
assert info.type == ValueError
def test_excinfo_from_exc_info_simple():
try:
raise ValueError
except ValueError as e:
info = _pytest._code.ExceptionInfo.from_exc_info((type(e), e, e.__traceback__))
assert info.type == ValueError
def test_excinfo_getstatement():
def g():
raise ValueError

View File

@ -248,3 +248,9 @@ class TestRaises:
with pytest.raises(CrappyClass()):
pass
assert "via __class__" in excinfo.value.args[0]
def test_raises_context_manager_with_kwargs(self):
with pytest.raises(TypeError) as excinfo:
with pytest.raises(Exception, foo="bar"):
pass
assert "Unexpected keyword arguments" in str(excinfo.value)

View File

@ -374,3 +374,9 @@ class TestWarns:
assert f() == 10
assert pytest.warns(UserWarning, f) == 10
assert pytest.warns(UserWarning, f) == 10
def test_warns_context_manager_with_kwargs(self):
with pytest.raises(TypeError) as excinfo:
with pytest.warns(UserWarning, foo="bar"):
pass
assert "Unexpected keyword arguments" in str(excinfo.value)

View File

@ -155,7 +155,7 @@ markers =
[flake8]
max-line-length = 120
ignore = E203,W503
extend-ignore = E203
[isort]
; This config mimics what reorder-python-imports does.