Fix type of pytest.warns, and fix check_untyped_defs in test_recwarn

The expected_warning is optional.
This commit is contained in:
Ran Benita 2019-12-03 14:34:41 +02:00
parent 0b603156b9
commit 3d2680b31b
2 changed files with 64 additions and 60 deletions

View File

@ -57,7 +57,7 @@ def deprecated_call(func=None, *args, **kwargs):
@overload @overload
def warns( def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
*, *,
match: "Optional[Union[str, Pattern]]" = ... match: "Optional[Union[str, Pattern]]" = ...
) -> "WarningsChecker": ) -> "WarningsChecker":
@ -66,7 +66,7 @@ def warns(
@overload # noqa: F811 @overload # noqa: F811
def warns( # noqa: F811 def warns( # noqa: F811
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
func: Callable, func: Callable,
*args: Any, *args: Any,
match: Optional[Union[str, "Pattern"]] = ..., match: Optional[Union[str, "Pattern"]] = ...,
@ -76,7 +76,7 @@ def warns( # noqa: F811
def warns( # noqa: F811 def warns( # noqa: F811
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
*args: Any, *args: Any,
match: Optional[Union[str, "Pattern"]] = None, match: Optional[Union[str, "Pattern"]] = None,
**kwargs: Any **kwargs: Any

View File

@ -1,17 +1,19 @@
import re import re
import warnings import warnings
from typing import Optional
import pytest import pytest
from _pytest.outcomes import Failed
from _pytest.recwarn import WarningsRecorder from _pytest.recwarn import WarningsRecorder
def test_recwarn_stacklevel(recwarn): def test_recwarn_stacklevel(recwarn: WarningsRecorder) -> None:
warnings.warn("hello") warnings.warn("hello")
warn = recwarn.pop() warn = recwarn.pop()
assert warn.filename == __file__ assert warn.filename == __file__
def test_recwarn_functional(testdir): def test_recwarn_functional(testdir) -> None:
testdir.makepyfile( testdir.makepyfile(
""" """
import warnings import warnings
@ -26,7 +28,7 @@ def test_recwarn_functional(testdir):
class TestWarningsRecorderChecker: class TestWarningsRecorderChecker:
def test_recording(self): def test_recording(self) -> None:
rec = WarningsRecorder() rec = WarningsRecorder()
with rec: with rec:
assert not rec.list assert not rec.list
@ -42,23 +44,23 @@ class TestWarningsRecorderChecker:
assert values is rec.list assert values is rec.list
pytest.raises(AssertionError, rec.pop) pytest.raises(AssertionError, rec.pop)
def test_warn_stacklevel(self): def test_warn_stacklevel(self) -> None:
"""#4243""" """#4243"""
rec = WarningsRecorder() rec = WarningsRecorder()
with rec: with rec:
warnings.warn("test", DeprecationWarning, 2) warnings.warn("test", DeprecationWarning, 2)
def test_typechecking(self): def test_typechecking(self) -> None:
from _pytest.recwarn import WarningsChecker from _pytest.recwarn import WarningsChecker
with pytest.raises(TypeError): with pytest.raises(TypeError):
WarningsChecker(5) WarningsChecker(5) # type: ignore
with pytest.raises(TypeError): with pytest.raises(TypeError):
WarningsChecker(("hi", RuntimeWarning)) WarningsChecker(("hi", RuntimeWarning)) # type: ignore
with pytest.raises(TypeError): with pytest.raises(TypeError):
WarningsChecker([DeprecationWarning, RuntimeWarning]) WarningsChecker([DeprecationWarning, RuntimeWarning]) # type: ignore
def test_invalid_enter_exit(self): def test_invalid_enter_exit(self) -> None:
# wrap this test in WarningsRecorder to ensure warning state gets reset # wrap this test in WarningsRecorder to ensure warning state gets reset
with WarningsRecorder(): with WarningsRecorder():
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@ -75,50 +77,52 @@ class TestWarningsRecorderChecker:
class TestDeprecatedCall: class TestDeprecatedCall:
"""test pytest.deprecated_call()""" """test pytest.deprecated_call()"""
def dep(self, i, j=None): def dep(self, i: int, j: Optional[int] = None) -> int:
if i == 0: if i == 0:
warnings.warn("is deprecated", DeprecationWarning, stacklevel=1) warnings.warn("is deprecated", DeprecationWarning, stacklevel=1)
return 42 return 42
def dep_explicit(self, i): def dep_explicit(self, i: int) -> None:
if i == 0: if i == 0:
warnings.warn_explicit( warnings.warn_explicit(
"dep_explicit", category=DeprecationWarning, filename="hello", lineno=3 "dep_explicit", category=DeprecationWarning, filename="hello", lineno=3
) )
def test_deprecated_call_raises(self): def test_deprecated_call_raises(self) -> None:
with pytest.raises(pytest.fail.Exception, match="No warnings of type"): with pytest.raises(Failed, match="No warnings of type"):
pytest.deprecated_call(self.dep, 3, 5) pytest.deprecated_call(self.dep, 3, 5)
def test_deprecated_call(self): def test_deprecated_call(self) -> None:
pytest.deprecated_call(self.dep, 0, 5) pytest.deprecated_call(self.dep, 0, 5)
def test_deprecated_call_ret(self): def test_deprecated_call_ret(self) -> None:
ret = pytest.deprecated_call(self.dep, 0) ret = pytest.deprecated_call(self.dep, 0)
assert ret == 42 assert ret == 42
def test_deprecated_call_preserves(self): def test_deprecated_call_preserves(self) -> None:
onceregistry = warnings.onceregistry.copy() # Type ignored because `onceregistry` and `filters` are not
filters = warnings.filters[:] # documented API.
onceregistry = warnings.onceregistry.copy() # type: ignore
filters = warnings.filters[:] # type: ignore
warn = warnings.warn warn = warnings.warn
warn_explicit = warnings.warn_explicit warn_explicit = warnings.warn_explicit
self.test_deprecated_call_raises() self.test_deprecated_call_raises()
self.test_deprecated_call() self.test_deprecated_call()
assert onceregistry == warnings.onceregistry assert onceregistry == warnings.onceregistry # type: ignore
assert filters == warnings.filters assert filters == warnings.filters # type: ignore
assert warn is warnings.warn assert warn is warnings.warn
assert warn_explicit is warnings.warn_explicit assert warn_explicit is warnings.warn_explicit
def test_deprecated_explicit_call_raises(self): def test_deprecated_explicit_call_raises(self) -> None:
with pytest.raises(pytest.fail.Exception): with pytest.raises(Failed):
pytest.deprecated_call(self.dep_explicit, 3) pytest.deprecated_call(self.dep_explicit, 3)
def test_deprecated_explicit_call(self): def test_deprecated_explicit_call(self) -> None:
pytest.deprecated_call(self.dep_explicit, 0) pytest.deprecated_call(self.dep_explicit, 0)
pytest.deprecated_call(self.dep_explicit, 0) pytest.deprecated_call(self.dep_explicit, 0)
@pytest.mark.parametrize("mode", ["context_manager", "call"]) @pytest.mark.parametrize("mode", ["context_manager", "call"])
def test_deprecated_call_no_warning(self, mode): def test_deprecated_call_no_warning(self, mode) -> None:
"""Ensure deprecated_call() raises the expected failure when its block/function does """Ensure deprecated_call() raises the expected failure when its block/function does
not raise a deprecation warning. not raise a deprecation warning.
""" """
@ -127,7 +131,7 @@ class TestDeprecatedCall:
pass pass
msg = "No warnings of type (.*DeprecationWarning.*, .*PendingDeprecationWarning.*)" msg = "No warnings of type (.*DeprecationWarning.*, .*PendingDeprecationWarning.*)"
with pytest.raises(pytest.fail.Exception, match=msg): with pytest.raises(Failed, match=msg):
if mode == "call": if mode == "call":
pytest.deprecated_call(f) pytest.deprecated_call(f)
else: else:
@ -140,7 +144,7 @@ class TestDeprecatedCall:
@pytest.mark.parametrize("mode", ["context_manager", "call"]) @pytest.mark.parametrize("mode", ["context_manager", "call"])
@pytest.mark.parametrize("call_f_first", [True, False]) @pytest.mark.parametrize("call_f_first", [True, False])
@pytest.mark.filterwarnings("ignore") @pytest.mark.filterwarnings("ignore")
def test_deprecated_call_modes(self, warning_type, mode, call_f_first): def test_deprecated_call_modes(self, warning_type, mode, call_f_first) -> None:
"""Ensure deprecated_call() captures a deprecation warning as expected inside its """Ensure deprecated_call() captures a deprecation warning as expected inside its
block/function. block/function.
""" """
@ -159,7 +163,7 @@ class TestDeprecatedCall:
assert f() == 10 assert f() == 10
@pytest.mark.parametrize("mode", ["context_manager", "call"]) @pytest.mark.parametrize("mode", ["context_manager", "call"])
def test_deprecated_call_exception_is_raised(self, mode): def test_deprecated_call_exception_is_raised(self, mode) -> None:
"""If the block of the code being tested by deprecated_call() raises an exception, """If the block of the code being tested by deprecated_call() raises an exception,
it must raise the exception undisturbed. it must raise the exception undisturbed.
""" """
@ -174,7 +178,7 @@ class TestDeprecatedCall:
with pytest.deprecated_call(): with pytest.deprecated_call():
f() f()
def test_deprecated_call_specificity(self): def test_deprecated_call_specificity(self) -> None:
other_warnings = [ other_warnings = [
Warning, Warning,
UserWarning, UserWarning,
@ -189,40 +193,40 @@ class TestDeprecatedCall:
def f(): def f():
warnings.warn(warning("hi")) warnings.warn(warning("hi"))
with pytest.raises(pytest.fail.Exception): with pytest.raises(Failed):
pytest.deprecated_call(f) pytest.deprecated_call(f)
with pytest.raises(pytest.fail.Exception): with pytest.raises(Failed):
with pytest.deprecated_call(): with pytest.deprecated_call():
f() f()
def test_deprecated_call_supports_match(self): def test_deprecated_call_supports_match(self) -> None:
with pytest.deprecated_call(match=r"must be \d+$"): with pytest.deprecated_call(match=r"must be \d+$"):
warnings.warn("value must be 42", DeprecationWarning) warnings.warn("value must be 42", DeprecationWarning)
with pytest.raises(pytest.fail.Exception): with pytest.raises(Failed):
with pytest.deprecated_call(match=r"must be \d+$"): with pytest.deprecated_call(match=r"must be \d+$"):
warnings.warn("this is not here", DeprecationWarning) warnings.warn("this is not here", DeprecationWarning)
class TestWarns: class TestWarns:
def test_check_callable(self): def test_check_callable(self) -> None:
source = "warnings.warn('w1', RuntimeWarning)" source = "warnings.warn('w1', RuntimeWarning)"
with pytest.raises(TypeError, match=r".* must be callable"): with pytest.raises(TypeError, match=r".* must be callable"):
pytest.warns(RuntimeWarning, source) pytest.warns(RuntimeWarning, source) # type: ignore
def test_several_messages(self): def test_several_messages(self) -> None:
# different messages, b/c Python suppresses multiple identical warnings # different messages, b/c Python suppresses multiple identical warnings
pytest.warns(RuntimeWarning, lambda: warnings.warn("w1", RuntimeWarning)) pytest.warns(RuntimeWarning, lambda: warnings.warn("w1", RuntimeWarning))
with pytest.raises(pytest.fail.Exception): with pytest.raises(Failed):
pytest.warns(UserWarning, lambda: warnings.warn("w2", RuntimeWarning)) pytest.warns(UserWarning, lambda: warnings.warn("w2", RuntimeWarning))
pytest.warns(RuntimeWarning, lambda: warnings.warn("w3", RuntimeWarning)) pytest.warns(RuntimeWarning, lambda: warnings.warn("w3", RuntimeWarning))
def test_function(self): def test_function(self) -> None:
pytest.warns( pytest.warns(
SyntaxWarning, lambda msg: warnings.warn(msg, SyntaxWarning), "syntax" SyntaxWarning, lambda msg: warnings.warn(msg, SyntaxWarning), "syntax"
) )
def test_warning_tuple(self): def test_warning_tuple(self) -> None:
pytest.warns( pytest.warns(
(RuntimeWarning, SyntaxWarning), lambda: warnings.warn("w1", RuntimeWarning) (RuntimeWarning, SyntaxWarning), lambda: warnings.warn("w1", RuntimeWarning)
) )
@ -230,21 +234,21 @@ class TestWarns:
(RuntimeWarning, SyntaxWarning), lambda: warnings.warn("w2", SyntaxWarning) (RuntimeWarning, SyntaxWarning), lambda: warnings.warn("w2", SyntaxWarning)
) )
pytest.raises( pytest.raises(
pytest.fail.Exception, Failed,
lambda: pytest.warns( lambda: pytest.warns(
(RuntimeWarning, SyntaxWarning), (RuntimeWarning, SyntaxWarning),
lambda: warnings.warn("w3", UserWarning), lambda: warnings.warn("w3", UserWarning),
), ),
) )
def test_as_contextmanager(self): def test_as_contextmanager(self) -> None:
with pytest.warns(RuntimeWarning): with pytest.warns(RuntimeWarning):
warnings.warn("runtime", RuntimeWarning) warnings.warn("runtime", RuntimeWarning)
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
warnings.warn("user", UserWarning) warnings.warn("user", UserWarning)
with pytest.raises(pytest.fail.Exception) as excinfo: with pytest.raises(Failed) as excinfo:
with pytest.warns(RuntimeWarning): with pytest.warns(RuntimeWarning):
warnings.warn("user", UserWarning) warnings.warn("user", UserWarning)
excinfo.match( excinfo.match(
@ -252,7 +256,7 @@ class TestWarns:
r"The list of emitted warnings is: \[UserWarning\('user',?\)\]." r"The list of emitted warnings is: \[UserWarning\('user',?\)\]."
) )
with pytest.raises(pytest.fail.Exception) as excinfo: with pytest.raises(Failed) as excinfo:
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
warnings.warn("runtime", RuntimeWarning) warnings.warn("runtime", RuntimeWarning)
excinfo.match( excinfo.match(
@ -260,7 +264,7 @@ class TestWarns:
r"The list of emitted warnings is: \[RuntimeWarning\('runtime',?\)\]." r"The list of emitted warnings is: \[RuntimeWarning\('runtime',?\)\]."
) )
with pytest.raises(pytest.fail.Exception) as excinfo: with pytest.raises(Failed) as excinfo:
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
pass pass
excinfo.match( excinfo.match(
@ -269,7 +273,7 @@ class TestWarns:
) )
warning_classes = (UserWarning, FutureWarning) warning_classes = (UserWarning, FutureWarning)
with pytest.raises(pytest.fail.Exception) as excinfo: with pytest.raises(Failed) as excinfo:
with pytest.warns(warning_classes) as warninfo: with pytest.warns(warning_classes) as warninfo:
warnings.warn("runtime", RuntimeWarning) warnings.warn("runtime", RuntimeWarning)
warnings.warn("import", ImportWarning) warnings.warn("import", ImportWarning)
@ -286,14 +290,14 @@ class TestWarns:
) )
) )
def test_record(self): def test_record(self) -> None:
with pytest.warns(UserWarning) as record: with pytest.warns(UserWarning) as record:
warnings.warn("user", UserWarning) warnings.warn("user", UserWarning)
assert len(record) == 1 assert len(record) == 1
assert str(record[0].message) == "user" assert str(record[0].message) == "user"
def test_record_only(self): def test_record_only(self) -> None:
with pytest.warns(None) as record: with pytest.warns(None) as record:
warnings.warn("user", UserWarning) warnings.warn("user", UserWarning)
warnings.warn("runtime", RuntimeWarning) warnings.warn("runtime", RuntimeWarning)
@ -302,7 +306,7 @@ class TestWarns:
assert str(record[0].message) == "user" assert str(record[0].message) == "user"
assert str(record[1].message) == "runtime" assert str(record[1].message) == "runtime"
def test_record_by_subclass(self): def test_record_by_subclass(self) -> None:
with pytest.warns(Warning) as record: with pytest.warns(Warning) as record:
warnings.warn("user", UserWarning) warnings.warn("user", UserWarning)
warnings.warn("runtime", RuntimeWarning) warnings.warn("runtime", RuntimeWarning)
@ -325,7 +329,7 @@ class TestWarns:
assert str(record[0].message) == "user" assert str(record[0].message) == "user"
assert str(record[1].message) == "runtime" assert str(record[1].message) == "runtime"
def test_double_test(self, testdir): def test_double_test(self, testdir) -> None:
"""If a test is run again, the warning should still be raised""" """If a test is run again, the warning should still be raised"""
testdir.makepyfile( testdir.makepyfile(
""" """
@ -341,32 +345,32 @@ class TestWarns:
result = testdir.runpytest() result = testdir.runpytest()
result.stdout.fnmatch_lines(["*2 passed in*"]) result.stdout.fnmatch_lines(["*2 passed in*"])
def test_match_regex(self): def test_match_regex(self) -> None:
with pytest.warns(UserWarning, match=r"must be \d+$"): with pytest.warns(UserWarning, match=r"must be \d+$"):
warnings.warn("value must be 42", UserWarning) warnings.warn("value must be 42", UserWarning)
with pytest.raises(pytest.fail.Exception): with pytest.raises(Failed):
with pytest.warns(UserWarning, match=r"must be \d+$"): with pytest.warns(UserWarning, match=r"must be \d+$"):
warnings.warn("this is not here", UserWarning) warnings.warn("this is not here", UserWarning)
with pytest.raises(pytest.fail.Exception): with pytest.raises(Failed):
with pytest.warns(FutureWarning, match=r"must be \d+$"): with pytest.warns(FutureWarning, match=r"must be \d+$"):
warnings.warn("value must be 42", UserWarning) warnings.warn("value must be 42", UserWarning)
def test_one_from_multiple_warns(self): def test_one_from_multiple_warns(self) -> None:
with pytest.warns(UserWarning, match=r"aaa"): with pytest.warns(UserWarning, match=r"aaa"):
warnings.warn("cccccccccc", UserWarning) warnings.warn("cccccccccc", UserWarning)
warnings.warn("bbbbbbbbbb", UserWarning) warnings.warn("bbbbbbbbbb", UserWarning)
warnings.warn("aaaaaaaaaa", UserWarning) warnings.warn("aaaaaaaaaa", UserWarning)
def test_none_of_multiple_warns(self): def test_none_of_multiple_warns(self) -> None:
with pytest.raises(pytest.fail.Exception): with pytest.raises(Failed):
with pytest.warns(UserWarning, match=r"aaa"): with pytest.warns(UserWarning, match=r"aaa"):
warnings.warn("bbbbbbbbbb", UserWarning) warnings.warn("bbbbbbbbbb", UserWarning)
warnings.warn("cccccccccc", UserWarning) warnings.warn("cccccccccc", UserWarning)
@pytest.mark.filterwarnings("ignore") @pytest.mark.filterwarnings("ignore")
def test_can_capture_previously_warned(self): def test_can_capture_previously_warned(self) -> None:
def f(): def f():
warnings.warn(UserWarning("ohai")) warnings.warn(UserWarning("ohai"))
return 10 return 10
@ -375,8 +379,8 @@ class TestWarns:
assert pytest.warns(UserWarning, f) == 10 assert pytest.warns(UserWarning, f) == 10
assert pytest.warns(UserWarning, f) == 10 assert pytest.warns(UserWarning, f) == 10
def test_warns_context_manager_with_kwargs(self): def test_warns_context_manager_with_kwargs(self) -> None:
with pytest.raises(TypeError) as excinfo: with pytest.raises(TypeError) as excinfo:
with pytest.warns(UserWarning, foo="bar"): with pytest.warns(UserWarning, foo="bar"): # type: ignore
pass pass
assert "Unexpected keyword arguments" in str(excinfo.value) assert "Unexpected keyword arguments" in str(excinfo.value)