Add new `ExceptionInfo.group_contains` assertion helper method

Tests if a captured exception group contains an expected exception.
Will raise `AssertionError` if the wrapped exception is not an exception group.
Supports recursive search into nested exception groups.
This commit is contained in:
Mihail Milushev 2023-09-10 12:24:18 +01:00
parent 6c2feb75d2
commit ab8f5ce7f4
4 changed files with 121 additions and 6 deletions

View File

@ -266,6 +266,7 @@ Michal Wajszczuk
Michał Zięba
Mickey Pashov
Mihai Capotă
Mihail Milushev
Mike Hoyle (hoylemd)
Mike Lundy
Milan Lesnek

View File

@ -0,0 +1,2 @@
Added :func:`ExceptionInfo.group_contains() <pytest.ExceptionInfo.group_contains>`, an assertion
helper that tests if an `ExceptionGroup` contains a matching exception.

View File

@ -697,6 +697,14 @@ class ExceptionInfo(Generic[E]):
)
return fmt.repr_excinfo(self)
def _stringify_exception(self, exc: BaseException) -> str:
return "\n".join(
[
str(exc),
*getattr(exc, "__notes__", []),
]
)
def match(self, regexp: Union[str, Pattern[str]]) -> "Literal[True]":
"""Check whether the regular expression `regexp` matches the string
representation of the exception using :func:`python:re.search`.
@ -704,12 +712,7 @@ class ExceptionInfo(Generic[E]):
If it matches `True` is returned, otherwise an `AssertionError` is raised.
"""
__tracebackhide__ = True
value = "\n".join(
[
str(self.value),
*getattr(self.value, "__notes__", []),
]
)
value = self._stringify_exception(self.value)
msg = f"Regex pattern did not match.\n Regex: {regexp!r}\n Input: {value!r}"
if regexp == value:
msg += "\n Did you mean to `re.escape()` the regex?"
@ -717,6 +720,56 @@ class ExceptionInfo(Generic[E]):
# Return True to allow for "assert excinfo.match()".
return True
def _group_contains(
self,
exc_group: BaseExceptionGroup[BaseException],
expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
match: Union[str, Pattern[str], None],
recursive: bool = False,
) -> bool:
"""Return `True` if a `BaseExceptionGroup` contains a matching exception."""
for exc in exc_group.exceptions:
if recursive and isinstance(exc, BaseExceptionGroup):
if self._group_contains(exc, expected_exception, match, recursive):
return True
if not isinstance(exc, expected_exception):
continue
if match is not None:
value = self._stringify_exception(exc)
if not re.search(match, value):
continue
return True
return False
def group_contains(
self,
expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
match: Union[str, Pattern[str], None] = None,
recursive: bool = False,
) -> bool:
"""Check whether a captured exception group contains a matching exception.
:param Type[BaseException] | Tuple[Type[BaseException]] expected_exception:
The expected exception type, or a tuple if one of multiple possible
exception types are expected.
:param str | Pattern[str] | None match:
If specified, a string containing a regular expression,
or a regular expression object, that is tested against the string
representation of the exception and its `PEP-678 <https://peps.python.org/pep-0678/>` `__notes__`
using :func:`re.search`.
To match a literal string that may contain :ref:`special characters
<re-syntax>`, the pattern can first be escaped with :func:`re.escape`.
:param bool recursive:
If `True`, search will descend recursively into any nested exception groups.
If `False`, only the top exception group will be searched.
"""
msg = "Captured exception is not an instance of `BaseExceptionGroup`"
assert isinstance(self.value, BaseExceptionGroup), msg
return self._group_contains(self.value, expected_exception, match, recursive)
@dataclasses.dataclass
class FormattedExcinfo:

View File

@ -27,6 +27,9 @@ from _pytest.pytester import Pytester
if TYPE_CHECKING:
from _pytest._code.code import _TracebackStyle
if sys.version_info[:2] < (3, 11):
from exceptiongroup import ExceptionGroup
@pytest.fixture
def limited_recursion_depth():
@ -444,6 +447,62 @@ def test_match_raises_error(pytester: Pytester) -> None:
result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match])
class TestGroupContains:
def test_contains_exception_type(self) -> None:
exc_group = ExceptionGroup("", [RuntimeError()])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert exc_info.group_contains(RuntimeError)
def test_doesnt_contain_exception_type(self) -> None:
exc_group = ExceptionGroup("", [ValueError()])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert not exc_info.group_contains(RuntimeError)
def test_contains_exception_match(self) -> None:
exc_group = ExceptionGroup("", [RuntimeError("exception message")])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert exc_info.group_contains(RuntimeError, match=r"^exception message$")
def test_doesnt_contain_exception_match(self) -> None:
exc_group = ExceptionGroup("", [RuntimeError("message that will not match")])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert not exc_info.group_contains(RuntimeError, match=r"^exception message$")
def test_contains_exception_type_recursive(self) -> None:
exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert exc_info.group_contains(RuntimeError, recursive=True)
def test_doesnt_contain_exception_type_nonrecursive(self) -> None:
exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert not exc_info.group_contains(RuntimeError)
def test_contains_exception_match_recursive(self) -> None:
exc_group = ExceptionGroup(
"", [ExceptionGroup("", [RuntimeError("exception message")])]
)
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert exc_info.group_contains(
RuntimeError, match=r"^exception message$", recursive=True
)
def test_doesnt_contain_exception_match_nonrecursive(self) -> None:
exc_group = ExceptionGroup(
"", [ExceptionGroup("", [RuntimeError("message that will not match")])]
)
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert not exc_info.group_contains(RuntimeError, match=r"^exception message$")
class TestFormattedExcinfo:
@pytest.fixture
def importasmod(self, tmp_path: Path, _sys_snapshot):