diff --git a/AUTHORS b/AUTHORS index e9e033c73..16918b40c 100644 --- a/AUTHORS +++ b/AUTHORS @@ -266,6 +266,7 @@ Michal Wajszczuk Michał Zięba Mickey Pashov Mihai Capotă +Mihail Milushev Mike Hoyle (hoylemd) Mike Lundy Milan Lesnek diff --git a/changelog/10441.feature.rst b/changelog/10441.feature.rst new file mode 100644 index 000000000..0019926ac --- /dev/null +++ b/changelog/10441.feature.rst @@ -0,0 +1,2 @@ +Added :func:`ExceptionInfo.group_contains() `, an assertion +helper that tests if an `ExceptionGroup` contains a matching exception. diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index b73c8bbb3..48a8685bd 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -697,6 +697,14 @@ class ExceptionInfo(Generic[E]): ) return fmt.repr_excinfo(self) + def _stringify_exception(self, exc: BaseException) -> str: + return "\n".join( + [ + str(exc), + *getattr(exc, "__notes__", []), + ] + ) + def match(self, regexp: Union[str, Pattern[str]]) -> "Literal[True]": """Check whether the regular expression `regexp` matches the string representation of the exception using :func:`python:re.search`. @@ -704,12 +712,7 @@ class ExceptionInfo(Generic[E]): If it matches `True` is returned, otherwise an `AssertionError` is raised. """ __tracebackhide__ = True - value = "\n".join( - [ - str(self.value), - *getattr(self.value, "__notes__", []), - ] - ) + value = self._stringify_exception(self.value) msg = f"Regex pattern did not match.\n Regex: {regexp!r}\n Input: {value!r}" if regexp == value: msg += "\n Did you mean to `re.escape()` the regex?" @@ -717,6 +720,56 @@ class ExceptionInfo(Generic[E]): # Return True to allow for "assert excinfo.match()". return True + def _group_contains( + self, + exc_group: BaseExceptionGroup[BaseException], + expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + match: Union[str, Pattern[str], None], + recursive: bool = False, + ) -> bool: + """Return `True` if a `BaseExceptionGroup` contains a matching exception.""" + for exc in exc_group.exceptions: + if recursive and isinstance(exc, BaseExceptionGroup): + if self._group_contains(exc, expected_exception, match, recursive): + return True + if not isinstance(exc, expected_exception): + continue + if match is not None: + value = self._stringify_exception(exc) + if not re.search(match, value): + continue + return True + return False + + def group_contains( + self, + expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + match: Union[str, Pattern[str], None] = None, + recursive: bool = False, + ) -> bool: + """Check whether a captured exception group contains a matching exception. + + :param Type[BaseException] | Tuple[Type[BaseException]] expected_exception: + The expected exception type, or a tuple if one of multiple possible + exception types are expected. + + :param str | Pattern[str] | None match: + If specified, a string containing a regular expression, + or a regular expression object, that is tested against the string + representation of the exception and its `PEP-678 ` `__notes__` + using :func:`re.search`. + + To match a literal string that may contain :ref:`special characters + `, the pattern can first be escaped with :func:`re.escape`. + + :param bool recursive: + If `True`, search will descend recursively into any nested exception groups. + If `False`, only the top exception group will be searched. + """ + msg = "Captured exception is not an instance of `BaseExceptionGroup`" + assert isinstance(self.value, BaseExceptionGroup), msg + return self._group_contains(self.value, expected_exception, match, recursive) + @dataclasses.dataclass class FormattedExcinfo: diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index 90f81123e..4e6e89a20 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -27,6 +27,9 @@ from _pytest.pytester import Pytester if TYPE_CHECKING: from _pytest._code.code import _TracebackStyle +if sys.version_info[:2] < (3, 11): + from exceptiongroup import ExceptionGroup + @pytest.fixture def limited_recursion_depth(): @@ -444,6 +447,62 @@ def test_match_raises_error(pytester: Pytester) -> None: result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match]) +class TestGroupContains: + def test_contains_exception_type(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError) + + def test_doesnt_contain_exception_type(self) -> None: + exc_group = ExceptionGroup("", [ValueError()]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains(RuntimeError) + + def test_contains_exception_match(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError("exception message")]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, match=r"^exception message$") + + def test_doesnt_contain_exception_match(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError("message that will not match")]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains(RuntimeError, match=r"^exception message$") + + def test_contains_exception_type_recursive(self) -> None: + exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, recursive=True) + + def test_doesnt_contain_exception_type_nonrecursive(self) -> None: + exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains(RuntimeError) + + def test_contains_exception_match_recursive(self) -> None: + exc_group = ExceptionGroup( + "", [ExceptionGroup("", [RuntimeError("exception message")])] + ) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains( + RuntimeError, match=r"^exception message$", recursive=True + ) + + def test_doesnt_contain_exception_match_nonrecursive(self) -> None: + exc_group = ExceptionGroup( + "", [ExceptionGroup("", [RuntimeError("message that will not match")])] + ) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains(RuntimeError, match=r"^exception message$") + + class TestFormattedExcinfo: @pytest.fixture def importasmod(self, tmp_path: Path, _sys_snapshot):