diff --git a/AUTHORS b/AUTHORS index e9e033c73..16918b40c 100644 --- a/AUTHORS +++ b/AUTHORS @@ -266,6 +266,7 @@ Michal Wajszczuk Michał Zięba Mickey Pashov Mihai Capotă +Mihail Milushev Mike Hoyle (hoylemd) Mike Lundy Milan Lesnek diff --git a/changelog/10441.feature.rst b/changelog/10441.feature.rst new file mode 100644 index 000000000..0019926ac --- /dev/null +++ b/changelog/10441.feature.rst @@ -0,0 +1,2 @@ +Added :func:`ExceptionInfo.group_contains() `, an assertion +helper that tests if an `ExceptionGroup` contains a matching exception. diff --git a/doc/en/getting-started.rst b/doc/en/getting-started.rst index f8f994473..8d37612df 100644 --- a/doc/en/getting-started.rst +++ b/doc/en/getting-started.rst @@ -97,6 +97,30 @@ Use the :ref:`raises ` helper to assert that some code raises an e with pytest.raises(SystemExit): f() +You can also use the context provided by :ref:`raises ` to +assert that an expected exception is part of a raised ``ExceptionGroup``: + +.. code-block:: python + + # content of test_exceptiongroup.py + import pytest + + + def f(): + raise ExceptionGroup( + "Group message", + [ + RuntimeError(), + ], + ) + + + def test_exception_in_group(): + with pytest.raises(ExceptionGroup) as excinfo: + f() + assert excinfo.group_contains(RuntimeError) + assert not excinfo.group_contains(TypeError) + Execute the test function with “quiet” reporting mode: .. code-block:: pytest diff --git a/doc/en/how-to/assert.rst b/doc/en/how-to/assert.rst index d99a1ce5c..cc53d001f 100644 --- a/doc/en/how-to/assert.rst +++ b/doc/en/how-to/assert.rst @@ -115,10 +115,56 @@ that a regular expression matches on the string representation of an exception with pytest.raises(ValueError, match=r".* 123 .*"): myfunc() -The regexp parameter of the ``match`` method is matched with the ``re.search`` +The regexp parameter of the ``match`` parameter is matched with the ``re.search`` function, so in the above example ``match='123'`` would have worked as well. +You can also use the :func:`excinfo.group_contains() ` +method to test for exceptions returned as part of an ``ExceptionGroup``: + +.. code-block:: python + + def test_exception_in_group(): + with pytest.raises(RuntimeError) as excinfo: + raise ExceptionGroup( + "Group message", + [ + RuntimeError("Exception 123 raised"), + ], + ) + assert excinfo.group_contains(RuntimeError, match=r".* 123 .*") + assert not excinfo.group_contains(TypeError) + +The optional ``match`` keyword parameter works the same way as for +:func:`pytest.raises`. + +By default ``group_contains()`` will recursively search for a matching +exception at any level of nested ``ExceptionGroup`` instances. You can +specify a ``depth`` keyword parameter if you only want to match an +exception at a specific level; exceptions contained directly in the top +``ExceptionGroup`` would match ``depth=1``. + +.. code-block:: python + + def test_exception_in_group_at_given_depth(): + with pytest.raises(RuntimeError) as excinfo: + raise ExceptionGroup( + "Group message", + [ + RuntimeError(), + ExceptionGroup( + "Nested group", + [ + TypeError(), + ], + ), + ], + ) + assert excinfo.group_contains(RuntimeError, depth=1) + assert excinfo.group_contains(TypeError, depth=2) + assert not excinfo.group_contains(RuntimeError, depth=2) + assert not excinfo.group_contains(TypeError, depth=1) + There's an alternate form of the :func:`pytest.raises` function where you pass a function that will be executed with the given ``*args`` and ``**kwargs`` and assert that the given exception is raised: diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index b73c8bbb3..0288d7a54 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -697,6 +697,14 @@ class ExceptionInfo(Generic[E]): ) return fmt.repr_excinfo(self) + def _stringify_exception(self, exc: BaseException) -> str: + return "\n".join( + [ + str(exc), + *getattr(exc, "__notes__", []), + ] + ) + def match(self, regexp: Union[str, Pattern[str]]) -> "Literal[True]": """Check whether the regular expression `regexp` matches the string representation of the exception using :func:`python:re.search`. @@ -704,12 +712,7 @@ class ExceptionInfo(Generic[E]): If it matches `True` is returned, otherwise an `AssertionError` is raised. """ __tracebackhide__ = True - value = "\n".join( - [ - str(self.value), - *getattr(self.value, "__notes__", []), - ] - ) + value = self._stringify_exception(self.value) msg = f"Regex pattern did not match.\n Regex: {regexp!r}\n Input: {value!r}" if regexp == value: msg += "\n Did you mean to `re.escape()` the regex?" @@ -717,6 +720,69 @@ class ExceptionInfo(Generic[E]): # Return True to allow for "assert excinfo.match()". return True + def _group_contains( + self, + exc_group: BaseExceptionGroup[BaseException], + expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + match: Union[str, Pattern[str], None], + target_depth: Optional[int] = None, + current_depth: int = 1, + ) -> bool: + """Return `True` if a `BaseExceptionGroup` contains a matching exception.""" + if (target_depth is not None) and (current_depth > target_depth): + # already descended past the target depth + return False + for exc in exc_group.exceptions: + if isinstance(exc, BaseExceptionGroup): + if self._group_contains( + exc, expected_exception, match, target_depth, current_depth + 1 + ): + return True + if (target_depth is not None) and (current_depth != target_depth): + # not at the target depth, no match + continue + if not isinstance(exc, expected_exception): + continue + if match is not None: + value = self._stringify_exception(exc) + if not re.search(match, value): + continue + return True + return False + + def group_contains( + self, + expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + *, + match: Union[str, Pattern[str], None] = None, + depth: Optional[int] = None, + ) -> bool: + """Check whether a captured exception group contains a matching exception. + + :param Type[BaseException] | Tuple[Type[BaseException]] expected_exception: + The expected exception type, or a tuple if one of multiple possible + exception types are expected. + + :param str | Pattern[str] | None match: + If specified, a string containing a regular expression, + or a regular expression object, that is tested against the string + representation of the exception and its `PEP-678 ` `__notes__` + using :func:`re.search`. + + To match a literal string that may contain :ref:`special characters + `, the pattern can first be escaped with :func:`re.escape`. + + :param Optional[int] depth: + If `None`, will search for a matching exception at any nesting depth. + If >= 1, will only match an exception if it's at the specified depth (depth = 1 being + the exceptions contained within the topmost exception group). + """ + msg = "Captured exception is not an instance of `BaseExceptionGroup`" + assert isinstance(self.value, BaseExceptionGroup), msg + msg = "`depth` must be >= 1 if specified" + assert (depth is None) or (depth >= 1), msg + return self._group_contains(self.value, expected_exception, match, depth) + @dataclasses.dataclass class FormattedExcinfo: diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index 90f81123e..89beefce5 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -27,6 +27,9 @@ from _pytest.pytester import Pytester if TYPE_CHECKING: from _pytest._code.code import _TracebackStyle +if sys.version_info[:2] < (3, 11): + from exceptiongroup import ExceptionGroup + @pytest.fixture def limited_recursion_depth(): @@ -444,6 +447,92 @@ def test_match_raises_error(pytester: Pytester) -> None: result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match]) +class TestGroupContains: + def test_contains_exception_type(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError) + + def test_doesnt_contain_exception_type(self) -> None: + exc_group = ExceptionGroup("", [ValueError()]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains(RuntimeError) + + def test_contains_exception_match(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError("exception message")]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, match=r"^exception message$") + + def test_doesnt_contain_exception_match(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError("message that will not match")]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains(RuntimeError, match=r"^exception message$") + + def test_contains_exception_type_unlimited_depth(self) -> None: + exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError) + + def test_contains_exception_type_at_depth_1(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, depth=1) + + def test_doesnt_contain_exception_type_past_depth(self) -> None: + exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains(RuntimeError, depth=1) + + def test_contains_exception_type_specific_depth(self) -> None: + exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, depth=2) + + def test_contains_exception_match_unlimited_depth(self) -> None: + exc_group = ExceptionGroup( + "", [ExceptionGroup("", [RuntimeError("exception message")])] + ) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, match=r"^exception message$") + + def test_contains_exception_match_at_depth_1(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError("exception message")]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains( + RuntimeError, match=r"^exception message$", depth=1 + ) + + def test_doesnt_contain_exception_match_past_depth(self) -> None: + exc_group = ExceptionGroup( + "", [ExceptionGroup("", [RuntimeError("exception message")])] + ) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains( + RuntimeError, match=r"^exception message$", depth=1 + ) + + def test_contains_exception_match_specific_depth(self) -> None: + exc_group = ExceptionGroup( + "", [ExceptionGroup("", [RuntimeError("exception message")])] + ) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains( + RuntimeError, match=r"^exception message$", depth=2 + ) + + class TestFormattedExcinfo: @pytest.fixture def importasmod(self, tmp_path: Path, _sys_snapshot):