Type-annotate pytest.warns
This commit is contained in:
parent
d7ee3dac2c
commit
2dca68b863
|
@ -1,11 +1,23 @@
|
|||
""" recording warnings during test function execution. """
|
||||
import inspect
|
||||
import re
|
||||
import warnings
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Pattern
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from _pytest.fixtures import yield_fixture
|
||||
from _pytest.outcomes import fail
|
||||
|
||||
if False: # TYPE_CHECKING
|
||||
from typing import Type
|
||||
|
||||
|
||||
@yield_fixture
|
||||
def recwarn():
|
||||
|
@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs):
|
|||
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)
|
||||
|
||||
|
||||
def warns(expected_warning, *args, match=None, **kwargs):
|
||||
@overload
|
||||
def warns(
|
||||
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||
*,
|
||||
match: Optional[Union[str, Pattern]] = ...
|
||||
) -> "WarningsChecker":
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
@overload
|
||||
def warns(
|
||||
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
match: Optional[Union[str, Pattern]] = ...,
|
||||
**kwargs: Any
|
||||
) -> Union[Any]:
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
def warns(
|
||||
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||
*args: Any,
|
||||
match: Optional[Union[str, Pattern]] = None,
|
||||
**kwargs: Any
|
||||
) -> Union["WarningsChecker", Any]:
|
||||
r"""Assert that code raises a particular class of warning.
|
||||
|
||||
Specifically, the parameter ``expected_warning`` can be a warning class or
|
||||
|
@ -101,26 +138,26 @@ class WarningsRecorder(warnings.catch_warnings):
|
|||
def __init__(self):
|
||||
super().__init__(record=True)
|
||||
self._entered = False
|
||||
self._list = []
|
||||
self._list = [] # type: List[warnings._Record]
|
||||
|
||||
@property
|
||||
def list(self):
|
||||
def list(self) -> List["warnings._Record"]:
|
||||
"""The list of recorded warnings."""
|
||||
return self._list
|
||||
|
||||
def __getitem__(self, i):
|
||||
def __getitem__(self, i: int) -> "warnings._Record":
|
||||
"""Get a recorded warning by index."""
|
||||
return self._list[i]
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator["warnings._Record"]:
|
||||
"""Iterate through the recorded warnings."""
|
||||
return iter(self._list)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""The number of recorded warnings."""
|
||||
return len(self._list)
|
||||
|
||||
def pop(self, cls=Warning):
|
||||
def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record":
|
||||
"""Pop the first recorded warning, raise exception if not exists."""
|
||||
for i, w in enumerate(self._list):
|
||||
if issubclass(w.category, cls):
|
||||
|
@ -128,54 +165,80 @@ class WarningsRecorder(warnings.catch_warnings):
|
|||
__tracebackhide__ = True
|
||||
raise AssertionError("%r not found in warning list" % cls)
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
"""Clear the list of recorded warnings."""
|
||||
self._list[:] = []
|
||||
|
||||
def __enter__(self):
|
||||
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
|
||||
# -- it returns a List but we only emulate one.
|
||||
def __enter__(self) -> "WarningsRecorder": # type: ignore
|
||||
if self._entered:
|
||||
__tracebackhide__ = True
|
||||
raise RuntimeError("Cannot enter %r twice" % self)
|
||||
self._list = super().__enter__()
|
||||
_list = super().__enter__()
|
||||
# record=True means it's None.
|
||||
assert _list is not None
|
||||
self._list = _list
|
||||
warnings.simplefilter("always")
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional["Type[BaseException]"],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> bool:
|
||||
if not self._entered:
|
||||
__tracebackhide__ = True
|
||||
raise RuntimeError("Cannot exit %r without entering first" % self)
|
||||
|
||||
super().__exit__(*exc_info)
|
||||
super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
# Built-in catch_warnings does not reset entered state so we do it
|
||||
# manually here for this context manager to become reusable.
|
||||
self._entered = False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class WarningsChecker(WarningsRecorder):
|
||||
def __init__(self, expected_warning=None, match_expr=None):
|
||||
def __init__(
|
||||
self,
|
||||
expected_warning: Optional[
|
||||
Union["Type[Warning]", Tuple["Type[Warning]", ...]]
|
||||
] = None,
|
||||
match_expr: Optional[Union[str, Pattern]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
msg = "exceptions must be derived from Warning, not %s"
|
||||
if isinstance(expected_warning, tuple):
|
||||
if expected_warning is None:
|
||||
expected_warning_tup = None
|
||||
elif isinstance(expected_warning, tuple):
|
||||
for exc in expected_warning:
|
||||
if not inspect.isclass(exc):
|
||||
if not issubclass(exc, Warning):
|
||||
raise TypeError(msg % type(exc))
|
||||
elif inspect.isclass(expected_warning):
|
||||
expected_warning = (expected_warning,)
|
||||
elif expected_warning is not None:
|
||||
expected_warning_tup = expected_warning
|
||||
elif issubclass(expected_warning, Warning):
|
||||
expected_warning_tup = (expected_warning,)
|
||||
else:
|
||||
raise TypeError(msg % type(expected_warning))
|
||||
|
||||
self.expected_warning = expected_warning
|
||||
self.expected_warning = expected_warning_tup
|
||||
self.match_expr = match_expr
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
super().__exit__(*exc_info)
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional["Type[BaseException]"],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> bool:
|
||||
super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
__tracebackhide__ = True
|
||||
|
||||
# only check if we're not currently handling an exception
|
||||
if all(a is None for a in exc_info):
|
||||
if exc_type is None and exc_val is None and exc_tb is None:
|
||||
if self.expected_warning is not None:
|
||||
if not any(issubclass(r.category, self.expected_warning) for r in self):
|
||||
__tracebackhide__ = True
|
||||
|
@ -200,3 +263,4 @@ class WarningsChecker(WarningsRecorder):
|
|||
[each.message for each in self],
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
|
Loading…
Reference in New Issue