diff --git a/src/_pytest/recwarn.py b/src/_pytest/recwarn.py index 7e772aa35..19e3938c3 100644 --- a/src/_pytest/recwarn.py +++ b/src/_pytest/recwarn.py @@ -1,11 +1,23 @@ """ recording warnings during test function execution. """ -import inspect import re import warnings +from types import TracebackType +from typing import Any +from typing import Callable +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload +from typing import Pattern +from typing import Tuple +from typing import Union from _pytest.fixtures import yield_fixture from _pytest.outcomes import fail +if False: # TYPE_CHECKING + from typing import Type + @yield_fixture def recwarn(): @@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs): return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs) -def warns(expected_warning, *args, match=None, **kwargs): +@overload +def warns( + expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], + *, + match: Optional[Union[str, Pattern]] = ... +) -> "WarningsChecker": + ... # pragma: no cover + + +@overload +def warns( + expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], + func: Callable, + *args: Any, + match: Optional[Union[str, Pattern]] = ..., + **kwargs: Any +) -> Union[Any]: + ... # pragma: no cover + + +def warns( + expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], + *args: Any, + match: Optional[Union[str, Pattern]] = None, + **kwargs: Any +) -> Union["WarningsChecker", Any]: r"""Assert that code raises a particular class of warning. Specifically, the parameter ``expected_warning`` can be a warning class or @@ -101,26 +138,26 @@ class WarningsRecorder(warnings.catch_warnings): def __init__(self): super().__init__(record=True) self._entered = False - self._list = [] + self._list = [] # type: List[warnings._Record] @property - def list(self): + def list(self) -> List["warnings._Record"]: """The list of recorded warnings.""" return self._list - def __getitem__(self, i): + def __getitem__(self, i: int) -> "warnings._Record": """Get a recorded warning by index.""" return self._list[i] - def __iter__(self): + def __iter__(self) -> Iterator["warnings._Record"]: """Iterate through the recorded warnings.""" return iter(self._list) - def __len__(self): + def __len__(self) -> int: """The number of recorded warnings.""" return len(self._list) - def pop(self, cls=Warning): + def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record": """Pop the first recorded warning, raise exception if not exists.""" for i, w in enumerate(self._list): if issubclass(w.category, cls): @@ -128,54 +165,80 @@ class WarningsRecorder(warnings.catch_warnings): __tracebackhide__ = True raise AssertionError("%r not found in warning list" % cls) - def clear(self): + def clear(self) -> None: """Clear the list of recorded warnings.""" self._list[:] = [] - def __enter__(self): + # Type ignored because it doesn't exactly warnings.catch_warnings.__enter__ + # -- it returns a List but we only emulate one. + def __enter__(self) -> "WarningsRecorder": # type: ignore if self._entered: __tracebackhide__ = True raise RuntimeError("Cannot enter %r twice" % self) - self._list = super().__enter__() + _list = super().__enter__() + # record=True means it's None. + assert _list is not None + self._list = _list warnings.simplefilter("always") return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: Optional["Type[BaseException]"], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: if not self._entered: __tracebackhide__ = True raise RuntimeError("Cannot exit %r without entering first" % self) - super().__exit__(*exc_info) + super().__exit__(exc_type, exc_val, exc_tb) # Built-in catch_warnings does not reset entered state so we do it # manually here for this context manager to become reusable. self._entered = False + return False + class WarningsChecker(WarningsRecorder): - def __init__(self, expected_warning=None, match_expr=None): + def __init__( + self, + expected_warning: Optional[ + Union["Type[Warning]", Tuple["Type[Warning]", ...]] + ] = None, + match_expr: Optional[Union[str, Pattern]] = None, + ) -> None: super().__init__() msg = "exceptions must be derived from Warning, not %s" - if isinstance(expected_warning, tuple): + if expected_warning is None: + expected_warning_tup = None + elif isinstance(expected_warning, tuple): for exc in expected_warning: - if not inspect.isclass(exc): + if not issubclass(exc, Warning): raise TypeError(msg % type(exc)) - elif inspect.isclass(expected_warning): - expected_warning = (expected_warning,) - elif expected_warning is not None: + expected_warning_tup = expected_warning + elif issubclass(expected_warning, Warning): + expected_warning_tup = (expected_warning,) + else: raise TypeError(msg % type(expected_warning)) - self.expected_warning = expected_warning + self.expected_warning = expected_warning_tup self.match_expr = match_expr - def __exit__(self, *exc_info): - super().__exit__(*exc_info) + def __exit__( + self, + exc_type: Optional["Type[BaseException]"], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + super().__exit__(exc_type, exc_val, exc_tb) __tracebackhide__ = True # only check if we're not currently handling an exception - if all(a is None for a in exc_info): + if exc_type is None and exc_val is None and exc_tb is None: if self.expected_warning is not None: if not any(issubclass(r.category, self.expected_warning) for r in self): __tracebackhide__ = True @@ -200,3 +263,4 @@ class WarningsChecker(WarningsRecorder): [each.message for each in self], ) ) + return False