From ab8f5ce7f46593b0d7217ea75c1c919fa51b1ea7 Mon Sep 17 00:00:00 2001 From: Mihail Milushev Date: Sun, 10 Sep 2023 12:24:18 +0100 Subject: [PATCH 1/4] Add new `ExceptionInfo.group_contains` assertion helper method Tests if a captured exception group contains an expected exception. Will raise `AssertionError` if the wrapped exception is not an exception group. Supports recursive search into nested exception groups. --- AUTHORS | 1 + changelog/10441.feature.rst | 2 ++ src/_pytest/_code/code.py | 65 ++++++++++++++++++++++++++++++++---- testing/code/test_excinfo.py | 59 ++++++++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 6 deletions(-) create mode 100644 changelog/10441.feature.rst diff --git a/AUTHORS b/AUTHORS index e9e033c73..16918b40c 100644 --- a/AUTHORS +++ b/AUTHORS @@ -266,6 +266,7 @@ Michal Wajszczuk Michał Zięba Mickey Pashov Mihai Capotă +Mihail Milushev Mike Hoyle (hoylemd) Mike Lundy Milan Lesnek diff --git a/changelog/10441.feature.rst b/changelog/10441.feature.rst new file mode 100644 index 000000000..0019926ac --- /dev/null +++ b/changelog/10441.feature.rst @@ -0,0 +1,2 @@ +Added :func:`ExceptionInfo.group_contains() `, an assertion +helper that tests if an `ExceptionGroup` contains a matching exception. diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index b73c8bbb3..48a8685bd 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -697,6 +697,14 @@ class ExceptionInfo(Generic[E]): ) return fmt.repr_excinfo(self) + def _stringify_exception(self, exc: BaseException) -> str: + return "\n".join( + [ + str(exc), + *getattr(exc, "__notes__", []), + ] + ) + def match(self, regexp: Union[str, Pattern[str]]) -> "Literal[True]": """Check whether the regular expression `regexp` matches the string representation of the exception using :func:`python:re.search`. @@ -704,12 +712,7 @@ class ExceptionInfo(Generic[E]): If it matches `True` is returned, otherwise an `AssertionError` is raised. """ __tracebackhide__ = True - value = "\n".join( - [ - str(self.value), - *getattr(self.value, "__notes__", []), - ] - ) + value = self._stringify_exception(self.value) msg = f"Regex pattern did not match.\n Regex: {regexp!r}\n Input: {value!r}" if regexp == value: msg += "\n Did you mean to `re.escape()` the regex?" @@ -717,6 +720,56 @@ class ExceptionInfo(Generic[E]): # Return True to allow for "assert excinfo.match()". return True + def _group_contains( + self, + exc_group: BaseExceptionGroup[BaseException], + expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + match: Union[str, Pattern[str], None], + recursive: bool = False, + ) -> bool: + """Return `True` if a `BaseExceptionGroup` contains a matching exception.""" + for exc in exc_group.exceptions: + if recursive and isinstance(exc, BaseExceptionGroup): + if self._group_contains(exc, expected_exception, match, recursive): + return True + if not isinstance(exc, expected_exception): + continue + if match is not None: + value = self._stringify_exception(exc) + if not re.search(match, value): + continue + return True + return False + + def group_contains( + self, + expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + match: Union[str, Pattern[str], None] = None, + recursive: bool = False, + ) -> bool: + """Check whether a captured exception group contains a matching exception. + + :param Type[BaseException] | Tuple[Type[BaseException]] expected_exception: + The expected exception type, or a tuple if one of multiple possible + exception types are expected. + + :param str | Pattern[str] | None match: + If specified, a string containing a regular expression, + or a regular expression object, that is tested against the string + representation of the exception and its `PEP-678 ` `__notes__` + using :func:`re.search`. + + To match a literal string that may contain :ref:`special characters + `, the pattern can first be escaped with :func:`re.escape`. + + :param bool recursive: + If `True`, search will descend recursively into any nested exception groups. + If `False`, only the top exception group will be searched. + """ + msg = "Captured exception is not an instance of `BaseExceptionGroup`" + assert isinstance(self.value, BaseExceptionGroup), msg + return self._group_contains(self.value, expected_exception, match, recursive) + @dataclasses.dataclass class FormattedExcinfo: diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index 90f81123e..4e6e89a20 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -27,6 +27,9 @@ from _pytest.pytester import Pytester if TYPE_CHECKING: from _pytest._code.code import _TracebackStyle +if sys.version_info[:2] < (3, 11): + from exceptiongroup import ExceptionGroup + @pytest.fixture def limited_recursion_depth(): @@ -444,6 +447,62 @@ def test_match_raises_error(pytester: Pytester) -> None: result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match]) +class TestGroupContains: + def test_contains_exception_type(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError) + + def test_doesnt_contain_exception_type(self) -> None: + exc_group = ExceptionGroup("", [ValueError()]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains(RuntimeError) + + def test_contains_exception_match(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError("exception message")]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, match=r"^exception message$") + + def test_doesnt_contain_exception_match(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError("message that will not match")]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains(RuntimeError, match=r"^exception message$") + + def test_contains_exception_type_recursive(self) -> None: + exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, recursive=True) + + def test_doesnt_contain_exception_type_nonrecursive(self) -> None: + exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains(RuntimeError) + + def test_contains_exception_match_recursive(self) -> None: + exc_group = ExceptionGroup( + "", [ExceptionGroup("", [RuntimeError("exception message")])] + ) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains( + RuntimeError, match=r"^exception message$", recursive=True + ) + + def test_doesnt_contain_exception_match_nonrecursive(self) -> None: + exc_group = ExceptionGroup( + "", [ExceptionGroup("", [RuntimeError("message that will not match")])] + ) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains(RuntimeError, match=r"^exception message$") + + class TestFormattedExcinfo: @pytest.fixture def importasmod(self, tmp_path: Path, _sys_snapshot): From a47fcb48733e586b4e674ca6d21a392678c0f85c Mon Sep 17 00:00:00 2001 From: Mihail Milushev Date: Fri, 15 Sep 2023 13:36:04 +0100 Subject: [PATCH 2/4] code review: kwarg-only `match`, replace `recursive` with `depth` --- src/_pytest/_code/code.py | 29 +++++++++++++----- testing/code/test_excinfo.py | 58 +++++++++++++++++++++++++++--------- 2 files changed, 65 insertions(+), 22 deletions(-) diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index 48a8685bd..0288d7a54 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -725,13 +725,22 @@ class ExceptionInfo(Generic[E]): exc_group: BaseExceptionGroup[BaseException], expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], match: Union[str, Pattern[str], None], - recursive: bool = False, + target_depth: Optional[int] = None, + current_depth: int = 1, ) -> bool: """Return `True` if a `BaseExceptionGroup` contains a matching exception.""" + if (target_depth is not None) and (current_depth > target_depth): + # already descended past the target depth + return False for exc in exc_group.exceptions: - if recursive and isinstance(exc, BaseExceptionGroup): - if self._group_contains(exc, expected_exception, match, recursive): + if isinstance(exc, BaseExceptionGroup): + if self._group_contains( + exc, expected_exception, match, target_depth, current_depth + 1 + ): return True + if (target_depth is not None) and (current_depth != target_depth): + # not at the target depth, no match + continue if not isinstance(exc, expected_exception): continue if match is not None: @@ -744,8 +753,9 @@ class ExceptionInfo(Generic[E]): def group_contains( self, expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + *, match: Union[str, Pattern[str], None] = None, - recursive: bool = False, + depth: Optional[int] = None, ) -> bool: """Check whether a captured exception group contains a matching exception. @@ -762,13 +772,16 @@ class ExceptionInfo(Generic[E]): To match a literal string that may contain :ref:`special characters `, the pattern can first be escaped with :func:`re.escape`. - :param bool recursive: - If `True`, search will descend recursively into any nested exception groups. - If `False`, only the top exception group will be searched. + :param Optional[int] depth: + If `None`, will search for a matching exception at any nesting depth. + If >= 1, will only match an exception if it's at the specified depth (depth = 1 being + the exceptions contained within the topmost exception group). """ msg = "Captured exception is not an instance of `BaseExceptionGroup`" assert isinstance(self.value, BaseExceptionGroup), msg - return self._group_contains(self.value, expected_exception, match, recursive) + msg = "`depth` must be >= 1 if specified" + assert (depth is None) or (depth >= 1), msg + return self._group_contains(self.value, expected_exception, match, depth) @dataclasses.dataclass diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index 4e6e89a20..89beefce5 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -472,36 +472,66 @@ class TestGroupContains: raise exc_group assert not exc_info.group_contains(RuntimeError, match=r"^exception message$") - def test_contains_exception_type_recursive(self) -> None: + def test_contains_exception_type_unlimited_depth(self) -> None: exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) with pytest.raises(ExceptionGroup) as exc_info: raise exc_group - assert exc_info.group_contains(RuntimeError, recursive=True) + assert exc_info.group_contains(RuntimeError) - def test_doesnt_contain_exception_type_nonrecursive(self) -> None: + def test_contains_exception_type_at_depth_1(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, depth=1) + + def test_doesnt_contain_exception_type_past_depth(self) -> None: exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) with pytest.raises(ExceptionGroup) as exc_info: raise exc_group - assert not exc_info.group_contains(RuntimeError) + assert not exc_info.group_contains(RuntimeError, depth=1) - def test_contains_exception_match_recursive(self) -> None: + def test_contains_exception_type_specific_depth(self) -> None: + exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, depth=2) + + def test_contains_exception_match_unlimited_depth(self) -> None: + exc_group = ExceptionGroup( + "", [ExceptionGroup("", [RuntimeError("exception message")])] + ) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, match=r"^exception message$") + + def test_contains_exception_match_at_depth_1(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError("exception message")]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains( + RuntimeError, match=r"^exception message$", depth=1 + ) + + def test_doesnt_contain_exception_match_past_depth(self) -> None: + exc_group = ExceptionGroup( + "", [ExceptionGroup("", [RuntimeError("exception message")])] + ) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains( + RuntimeError, match=r"^exception message$", depth=1 + ) + + def test_contains_exception_match_specific_depth(self) -> None: exc_group = ExceptionGroup( "", [ExceptionGroup("", [RuntimeError("exception message")])] ) with pytest.raises(ExceptionGroup) as exc_info: raise exc_group assert exc_info.group_contains( - RuntimeError, match=r"^exception message$", recursive=True + RuntimeError, match=r"^exception message$", depth=2 ) - def test_doesnt_contain_exception_match_nonrecursive(self) -> None: - exc_group = ExceptionGroup( - "", [ExceptionGroup("", [RuntimeError("message that will not match")])] - ) - with pytest.raises(ExceptionGroup) as exc_info: - raise exc_group - assert not exc_info.group_contains(RuntimeError, match=r"^exception message$") - class TestFormattedExcinfo: @pytest.fixture From e7caaa0b3ee60ebb4aa446156f080053dd6a2d03 Mon Sep 17 00:00:00 2001 From: Mihail Milushev Date: Sun, 17 Sep 2023 22:26:58 +0100 Subject: [PATCH 3/4] Document the new `ExceptionInfo.group_contains()` method --- doc/en/getting-started.rst | 24 ++++++++++++++++++++ doc/en/how-to/assert.rst | 46 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/doc/en/getting-started.rst b/doc/en/getting-started.rst index f8f994473..8d37612df 100644 --- a/doc/en/getting-started.rst +++ b/doc/en/getting-started.rst @@ -97,6 +97,30 @@ Use the :ref:`raises ` helper to assert that some code raises an e with pytest.raises(SystemExit): f() +You can also use the context provided by :ref:`raises ` to +assert that an expected exception is part of a raised ``ExceptionGroup``: + +.. code-block:: python + + # content of test_exceptiongroup.py + import pytest + + + def f(): + raise ExceptionGroup( + "Group message", + [ + RuntimeError(), + ], + ) + + + def test_exception_in_group(): + with pytest.raises(ExceptionGroup) as excinfo: + f() + assert excinfo.group_contains(RuntimeError) + assert not excinfo.group_contains(TypeError) + Execute the test function with “quiet” reporting mode: .. code-block:: pytest diff --git a/doc/en/how-to/assert.rst b/doc/en/how-to/assert.rst index d99a1ce5c..f52ddc278 100644 --- a/doc/en/how-to/assert.rst +++ b/doc/en/how-to/assert.rst @@ -119,6 +119,52 @@ The regexp parameter of the ``match`` method is matched with the ``re.search`` function, so in the above example ``match='123'`` would have worked as well. +You can also use the :func:`excinfo.group_contains() ` +method to test for exceptions returned as part of an ``ExceptionGroup``: + +.. code-block:: python + + def test_exception_in_group(): + with pytest.raises(RuntimeError) as excinfo: + raise ExceptionGroup( + "Group message", + [ + RuntimeError("Exception 123 raised"), + ], + ) + assert excinfo.group_contains(RuntimeError, match=r".* 123 .*") + assert not excinfo.group_contains(TypeError) + +The optional ``match`` keyword parameter works the same way as for +:func:`pytest.raises`. + +By default ``group_contains()`` will recursively search for a matching +exception at any level of nested ``ExceptionGroup`` instances. You can +specify a ``depth`` keyword parameter if you only want to match an +exception at a specific level; exceptions contained directly in the top +``ExceptionGroup`` would match ``depth=1``. + +.. code-block:: python + + def test_exception_in_group_at_given_depth(): + with pytest.raises(RuntimeError) as excinfo: + raise ExceptionGroup( + "Group message", + [ + RuntimeError(), + ExceptionGroup( + "Nested group", + [ + TypeError(), + ], + ), + ], + ) + assert excinfo.group_contains(RuntimeError, depth=1) + assert excinfo.group_contains(TypeError, depth=2) + assert not excinfo.group_contains(RuntimeError, depth=2) + assert not excinfo.group_contains(TypeError, depth=1) + There's an alternate form of the :func:`pytest.raises` function where you pass a function that will be executed with the given ``*args`` and ``**kwargs`` and assert that the given exception is raised: From 5ace48ca5bc701d01cb50d30ca945234c26d5f17 Mon Sep 17 00:00:00 2001 From: Mihail Milushev Date: Sun, 17 Sep 2023 22:27:36 +0100 Subject: [PATCH 4/4] Fix a minor mistake in docs ("``match`` method" is actually talking about the `match` keyword parameter) --- doc/en/how-to/assert.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/en/how-to/assert.rst b/doc/en/how-to/assert.rst index f52ddc278..cc53d001f 100644 --- a/doc/en/how-to/assert.rst +++ b/doc/en/how-to/assert.rst @@ -115,7 +115,7 @@ that a regular expression matches on the string representation of an exception with pytest.raises(ValueError, match=r".* 123 .*"): myfunc() -The regexp parameter of the ``match`` method is matched with the ``re.search`` +The regexp parameter of the ``match`` parameter is matched with the ``re.search`` function, so in the above example ``match='123'`` would have worked as well.