diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index 48a8685bd..0288d7a54 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -725,13 +725,22 @@ class ExceptionInfo(Generic[E]): exc_group: BaseExceptionGroup[BaseException], expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], match: Union[str, Pattern[str], None], - recursive: bool = False, + target_depth: Optional[int] = None, + current_depth: int = 1, ) -> bool: """Return `True` if a `BaseExceptionGroup` contains a matching exception.""" + if (target_depth is not None) and (current_depth > target_depth): + # already descended past the target depth + return False for exc in exc_group.exceptions: - if recursive and isinstance(exc, BaseExceptionGroup): - if self._group_contains(exc, expected_exception, match, recursive): + if isinstance(exc, BaseExceptionGroup): + if self._group_contains( + exc, expected_exception, match, target_depth, current_depth + 1 + ): return True + if (target_depth is not None) and (current_depth != target_depth): + # not at the target depth, no match + continue if not isinstance(exc, expected_exception): continue if match is not None: @@ -744,8 +753,9 @@ class ExceptionInfo(Generic[E]): def group_contains( self, expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + *, match: Union[str, Pattern[str], None] = None, - recursive: bool = False, + depth: Optional[int] = None, ) -> bool: """Check whether a captured exception group contains a matching exception. @@ -762,13 +772,16 @@ class ExceptionInfo(Generic[E]): To match a literal string that may contain :ref:`special characters `, the pattern can first be escaped with :func:`re.escape`. - :param bool recursive: - If `True`, search will descend recursively into any nested exception groups. - If `False`, only the top exception group will be searched. + :param Optional[int] depth: + If `None`, will search for a matching exception at any nesting depth. + If >= 1, will only match an exception if it's at the specified depth (depth = 1 being + the exceptions contained within the topmost exception group). """ msg = "Captured exception is not an instance of `BaseExceptionGroup`" assert isinstance(self.value, BaseExceptionGroup), msg - return self._group_contains(self.value, expected_exception, match, recursive) + msg = "`depth` must be >= 1 if specified" + assert (depth is None) or (depth >= 1), msg + return self._group_contains(self.value, expected_exception, match, depth) @dataclasses.dataclass diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index 4e6e89a20..89beefce5 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -472,36 +472,66 @@ class TestGroupContains: raise exc_group assert not exc_info.group_contains(RuntimeError, match=r"^exception message$") - def test_contains_exception_type_recursive(self) -> None: + def test_contains_exception_type_unlimited_depth(self) -> None: exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) with pytest.raises(ExceptionGroup) as exc_info: raise exc_group - assert exc_info.group_contains(RuntimeError, recursive=True) + assert exc_info.group_contains(RuntimeError) - def test_doesnt_contain_exception_type_nonrecursive(self) -> None: + def test_contains_exception_type_at_depth_1(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, depth=1) + + def test_doesnt_contain_exception_type_past_depth(self) -> None: exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) with pytest.raises(ExceptionGroup) as exc_info: raise exc_group - assert not exc_info.group_contains(RuntimeError) + assert not exc_info.group_contains(RuntimeError, depth=1) - def test_contains_exception_match_recursive(self) -> None: + def test_contains_exception_type_specific_depth(self) -> None: + exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, depth=2) + + def test_contains_exception_match_unlimited_depth(self) -> None: + exc_group = ExceptionGroup( + "", [ExceptionGroup("", [RuntimeError("exception message")])] + ) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError, match=r"^exception message$") + + def test_contains_exception_match_at_depth_1(self) -> None: + exc_group = ExceptionGroup("", [RuntimeError("exception message")]) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert exc_info.group_contains( + RuntimeError, match=r"^exception message$", depth=1 + ) + + def test_doesnt_contain_exception_match_past_depth(self) -> None: + exc_group = ExceptionGroup( + "", [ExceptionGroup("", [RuntimeError("exception message")])] + ) + with pytest.raises(ExceptionGroup) as exc_info: + raise exc_group + assert not exc_info.group_contains( + RuntimeError, match=r"^exception message$", depth=1 + ) + + def test_contains_exception_match_specific_depth(self) -> None: exc_group = ExceptionGroup( "", [ExceptionGroup("", [RuntimeError("exception message")])] ) with pytest.raises(ExceptionGroup) as exc_info: raise exc_group assert exc_info.group_contains( - RuntimeError, match=r"^exception message$", recursive=True + RuntimeError, match=r"^exception message$", depth=2 ) - def test_doesnt_contain_exception_match_nonrecursive(self) -> None: - exc_group = ExceptionGroup( - "", [ExceptionGroup("", [RuntimeError("message that will not match")])] - ) - with pytest.raises(ExceptionGroup) as exc_info: - raise exc_group - assert not exc_info.group_contains(RuntimeError, match=r"^exception message$") - class TestFormattedExcinfo: @pytest.fixture