Refactor warns() exit logic

This commit is contained in:
Zac Hatfield-Dodds 2023-06-30 15:29:02 -07:00
parent 9279ea2882
commit a1b37022af
2 changed files with 44 additions and 55 deletions

View File

@ -281,6 +281,12 @@ class WarningsChecker(WarningsRecorder):
self.expected_warning = expected_warning_tup
self.match_expr = match_expr
def matches(self, warning: warnings.WarningMessage) -> bool:
assert self.expected_warning is not None
return issubclass(warning.category, self.expected_warning) and bool(
self.match_expr is None or re.search(self.match_expr, str(warning.message))
)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
@ -291,56 +297,39 @@ class WarningsChecker(WarningsRecorder):
__tracebackhide__ = True
if self.expected_warning is None:
# nothing to do in this deprecated case, see WARNS_NONE_ARG above
return
if not (exc_type is None and exc_val is None and exc_tb is None):
# We currently ignore missing warnings if an exception is active.
# TODO: fix this, because it means things are surprisingly order-sensitive.
return
def found_str():
return pformat([record.message for record in self], indent=2)
def re_emit() -> None:
for r in self:
if matches(r):
continue
assert issubclass(r.message.__class__, Warning)
warnings.warn_explicit(
str(r.message),
r.message.__class__,
r.filename,
r.lineno,
module=r.__module__,
source=r.source,
try:
if not any(issubclass(w.category, self.expected_warning) for w in self):
fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n"
f" Emitted warnings: {found_str()}."
)
def matches(warning) -> bool:
if self.expected_warning is not None:
if issubclass(warning.category, self.expected_warning):
if self.match_expr is not None:
if re.compile(self.match_expr).search(str(warning.message)):
return True
return False
return True
return False
# only check if we're not currently handling an exception
if exc_type is None and exc_val is None and exc_tb is None:
if self.expected_warning is not None:
if not any(issubclass(r.category, self.expected_warning) for r in self):
__tracebackhide__ = True
fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n"
f"The list of emitted warnings is: {found_str()}."
elif not any(self.matches(w) for w in self):
fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.\n"
f" Regex: {self.match_expr}\n"
f" Emitted warnings: {found_str()}."
)
finally:
# Whether or not any warnings matched, we want to re-emit all unmatched warnings.
for w in self:
if not self.matches(w):
warnings.warn_explicit(
str(w.message),
w.message.__class__,
w.filename,
w.lineno,
module=w.__module__,
source=w.source,
)
elif self.match_expr is not None:
for r in self:
if issubclass(r.category, self.expected_warning):
if re.compile(self.match_expr).search(str(r.message)):
re_emit()
break
else:
fail(
f"""\
DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.
Regex: {self.match_expr}
Emitted warnings: {found_str()}"""
)
else:
re_emit()

View File

@ -376,7 +376,7 @@ class TestWarns:
warnings.warn("value must be 42", UserWarning)
def test_one_from_multiple_warns(self) -> None:
with pytest.raises(pytest.fail.Exception):
with pytest.raises(pytest.fail.Exception, match="DID NOT WARN"):
with pytest.warns(UserWarning, match=r"aaa"):
with pytest.warns(UserWarning, match=r"aaa"):
warnings.warn("cccccccccc", UserWarning)
@ -384,7 +384,7 @@ class TestWarns:
warnings.warn("aaaaaaaaaa", UserWarning)
def test_none_of_multiple_warns(self) -> None:
with pytest.raises(pytest.fail.Exception):
with pytest.raises(pytest.fail.Exception, match="DID NOT WARN"):
with pytest.warns(UserWarning, match=r"aaa"):
warnings.warn("bbbbbbbbbb", UserWarning)
warnings.warn("cccccccccc", UserWarning)
@ -424,13 +424,13 @@ class TestWarns:
warnings.warn("some deprecation warning", DeprecationWarning)
def test_re_emit_match_multiple(self) -> None:
# with pytest.warns(UserWarning):
with pytest.warns(UserWarning, match="user warning"):
warnings.warn("first user warning", UserWarning)
warnings.warn("second user warning", UserWarning)
with warnings.catch_warnings():
warnings.simplefilter("error") # if anything is re-emitted
with pytest.warns(UserWarning, match="user warning"):
warnings.warn("first user warning", UserWarning)
warnings.warn("second user warning", UserWarning)
def test_re_emit_non_match_single(self) -> None:
# with pytest.warns(UserWarning):
with pytest.warns(UserWarning, match="v2 warning"):
with pytest.warns(UserWarning, match="v1 warning"):
warnings.warn("v1 warning", UserWarning)