Type-annotate pytest.warns

This commit is contained in:
Ran Benita 2019-07-10 14:36:07 +03:00
parent d7ee3dac2c
commit 2dca68b863
1 changed files with 87 additions and 23 deletions

View File

@ -1,11 +1,23 @@
""" recording warnings during test function execution. """
import inspect
import re
import warnings
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Iterator
from typing import List
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
from typing import Union
from _pytest.fixtures import yield_fixture
from _pytest.outcomes import fail
if False: # TYPE_CHECKING
from typing import Type
@yield_fixture
def recwarn():
@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs):
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)
def warns(expected_warning, *args, match=None, **kwargs):
@overload
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
*,
match: Optional[Union[str, Pattern]] = ...
) -> "WarningsChecker":
... # pragma: no cover
@overload
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
func: Callable,
*args: Any,
match: Optional[Union[str, Pattern]] = ...,
**kwargs: Any
) -> Union[Any]:
... # pragma: no cover
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
*args: Any,
match: Optional[Union[str, Pattern]] = None,
**kwargs: Any
) -> Union["WarningsChecker", Any]:
r"""Assert that code raises a particular class of warning.
Specifically, the parameter ``expected_warning`` can be a warning class or
@ -101,26 +138,26 @@ class WarningsRecorder(warnings.catch_warnings):
def __init__(self):
super().__init__(record=True)
self._entered = False
self._list = []
self._list = [] # type: List[warnings._Record]
@property
def list(self):
def list(self) -> List["warnings._Record"]:
"""The list of recorded warnings."""
return self._list
def __getitem__(self, i):
def __getitem__(self, i: int) -> "warnings._Record":
"""Get a recorded warning by index."""
return self._list[i]
def __iter__(self):
def __iter__(self) -> Iterator["warnings._Record"]:
"""Iterate through the recorded warnings."""
return iter(self._list)
def __len__(self):
def __len__(self) -> int:
"""The number of recorded warnings."""
return len(self._list)
def pop(self, cls=Warning):
def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record":
"""Pop the first recorded warning, raise exception if not exists."""
for i, w in enumerate(self._list):
if issubclass(w.category, cls):
@ -128,54 +165,80 @@ class WarningsRecorder(warnings.catch_warnings):
__tracebackhide__ = True
raise AssertionError("%r not found in warning list" % cls)
def clear(self):
def clear(self) -> None:
"""Clear the list of recorded warnings."""
self._list[:] = []
def __enter__(self):
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
# -- it returns a List but we only emulate one.
def __enter__(self) -> "WarningsRecorder": # type: ignore
if self._entered:
__tracebackhide__ = True
raise RuntimeError("Cannot enter %r twice" % self)
self._list = super().__enter__()
_list = super().__enter__()
# record=True means it's None.
assert _list is not None
self._list = _list
warnings.simplefilter("always")
return self
def __exit__(self, *exc_info):
def __exit__(
self,
exc_type: Optional["Type[BaseException]"],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
if not self._entered:
__tracebackhide__ = True
raise RuntimeError("Cannot exit %r without entering first" % self)
super().__exit__(*exc_info)
super().__exit__(exc_type, exc_val, exc_tb)
# Built-in catch_warnings does not reset entered state so we do it
# manually here for this context manager to become reusable.
self._entered = False
return False
class WarningsChecker(WarningsRecorder):
def __init__(self, expected_warning=None, match_expr=None):
def __init__(
self,
expected_warning: Optional[
Union["Type[Warning]", Tuple["Type[Warning]", ...]]
] = None,
match_expr: Optional[Union[str, Pattern]] = None,
) -> None:
super().__init__()
msg = "exceptions must be derived from Warning, not %s"
if isinstance(expected_warning, tuple):
if expected_warning is None:
expected_warning_tup = None
elif isinstance(expected_warning, tuple):
for exc in expected_warning:
if not inspect.isclass(exc):
if not issubclass(exc, Warning):
raise TypeError(msg % type(exc))
elif inspect.isclass(expected_warning):
expected_warning = (expected_warning,)
elif expected_warning is not None:
expected_warning_tup = expected_warning
elif issubclass(expected_warning, Warning):
expected_warning_tup = (expected_warning,)
else:
raise TypeError(msg % type(expected_warning))
self.expected_warning = expected_warning
self.expected_warning = expected_warning_tup
self.match_expr = match_expr
def __exit__(self, *exc_info):
super().__exit__(*exc_info)
def __exit__(
self,
exc_type: Optional["Type[BaseException]"],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
super().__exit__(exc_type, exc_val, exc_tb)
__tracebackhide__ = True
# only check if we're not currently handling an exception
if all(a is None for a in exc_info):
if exc_type is None and exc_val is None and exc_tb is None:
if self.expected_warning is not None:
if not any(issubclass(r.category, self.expected_warning) for r in self):
__tracebackhide__ = True
@ -200,3 +263,4 @@ class WarningsChecker(WarningsRecorder):
[each.message for each in self],
)
)
return False