Refactor warns() exit logic

This commit is contained in:
Zac Hatfield-Dodds 2023-06-30 15:29:02 -07:00
parent 9279ea2882
commit a1b37022af
2 changed files with 44 additions and 55 deletions

View File

@ -281,6 +281,12 @@ class WarningsChecker(WarningsRecorder):
self.expected_warning = expected_warning_tup self.expected_warning = expected_warning_tup
self.match_expr = match_expr self.match_expr = match_expr
def matches(self, warning: warnings.WarningMessage) -> bool:
assert self.expected_warning is not None
return issubclass(warning.category, self.expected_warning) and bool(
self.match_expr is None or re.search(self.match_expr, str(warning.message))
)
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[Type[BaseException]],
@ -291,56 +297,39 @@ class WarningsChecker(WarningsRecorder):
__tracebackhide__ = True __tracebackhide__ = True
if self.expected_warning is None:
# nothing to do in this deprecated case, see WARNS_NONE_ARG above
return
if not (exc_type is None and exc_val is None and exc_tb is None):
# We currently ignore missing warnings if an exception is active.
# TODO: fix this, because it means things are surprisingly order-sensitive.
return
def found_str(): def found_str():
return pformat([record.message for record in self], indent=2) return pformat([record.message for record in self], indent=2)
def re_emit() -> None: try:
for r in self: if not any(issubclass(w.category, self.expected_warning) for w in self):
if matches(r): fail(
continue f"DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n"
f" Emitted warnings: {found_str()}."
assert issubclass(r.message.__class__, Warning)
warnings.warn_explicit(
str(r.message),
r.message.__class__,
r.filename,
r.lineno,
module=r.__module__,
source=r.source,
) )
elif not any(self.matches(w) for w in self):
def matches(warning) -> bool: fail(
if self.expected_warning is not None: f"DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.\n"
if issubclass(warning.category, self.expected_warning): f" Regex: {self.match_expr}\n"
if self.match_expr is not None: f" Emitted warnings: {found_str()}."
if re.compile(self.match_expr).search(str(warning.message)): )
return True finally:
return False # Whether or not any warnings matched, we want to re-emit all unmatched warnings.
return True for w in self:
return False if not self.matches(w):
warnings.warn_explicit(
# only check if we're not currently handling an exception str(w.message),
if exc_type is None and exc_val is None and exc_tb is None: w.message.__class__,
if self.expected_warning is not None: w.filename,
if not any(issubclass(r.category, self.expected_warning) for r in self): w.lineno,
__tracebackhide__ = True module=w.__module__,
fail( source=w.source,
f"DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n"
f"The list of emitted warnings is: {found_str()}."
) )
elif self.match_expr is not None:
for r in self:
if issubclass(r.category, self.expected_warning):
if re.compile(self.match_expr).search(str(r.message)):
re_emit()
break
else:
fail(
f"""\
DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.
Regex: {self.match_expr}
Emitted warnings: {found_str()}"""
)
else:
re_emit()

View File

@ -376,7 +376,7 @@ class TestWarns:
warnings.warn("value must be 42", UserWarning) warnings.warn("value must be 42", UserWarning)
def test_one_from_multiple_warns(self) -> None: def test_one_from_multiple_warns(self) -> None:
with pytest.raises(pytest.fail.Exception): with pytest.raises(pytest.fail.Exception, match="DID NOT WARN"):
with pytest.warns(UserWarning, match=r"aaa"): with pytest.warns(UserWarning, match=r"aaa"):
with pytest.warns(UserWarning, match=r"aaa"): with pytest.warns(UserWarning, match=r"aaa"):
warnings.warn("cccccccccc", UserWarning) warnings.warn("cccccccccc", UserWarning)
@ -384,7 +384,7 @@ class TestWarns:
warnings.warn("aaaaaaaaaa", UserWarning) warnings.warn("aaaaaaaaaa", UserWarning)
def test_none_of_multiple_warns(self) -> None: def test_none_of_multiple_warns(self) -> None:
with pytest.raises(pytest.fail.Exception): with pytest.raises(pytest.fail.Exception, match="DID NOT WARN"):
with pytest.warns(UserWarning, match=r"aaa"): with pytest.warns(UserWarning, match=r"aaa"):
warnings.warn("bbbbbbbbbb", UserWarning) warnings.warn("bbbbbbbbbb", UserWarning)
warnings.warn("cccccccccc", UserWarning) warnings.warn("cccccccccc", UserWarning)
@ -424,13 +424,13 @@ class TestWarns:
warnings.warn("some deprecation warning", DeprecationWarning) warnings.warn("some deprecation warning", DeprecationWarning)
def test_re_emit_match_multiple(self) -> None: def test_re_emit_match_multiple(self) -> None:
# with pytest.warns(UserWarning): with warnings.catch_warnings():
with pytest.warns(UserWarning, match="user warning"): warnings.simplefilter("error") # if anything is re-emitted
warnings.warn("first user warning", UserWarning) with pytest.warns(UserWarning, match="user warning"):
warnings.warn("second user warning", UserWarning) warnings.warn("first user warning", UserWarning)
warnings.warn("second user warning", UserWarning)
def test_re_emit_non_match_single(self) -> None: def test_re_emit_non_match_single(self) -> None:
# with pytest.warns(UserWarning):
with pytest.warns(UserWarning, match="v2 warning"): with pytest.warns(UserWarning, match="v2 warning"):
with pytest.warns(UserWarning, match="v1 warning"): with pytest.warns(UserWarning, match="v1 warning"):
warnings.warn("v1 warning", UserWarning) warnings.warn("v1 warning", UserWarning)