Derive pytest.raises from AbstractContextManager

Makes `AbstractContextManager` the shared base class between "raises" and other context managers.

The motivation is for type checkers to narrow `pytest.raises(...) if x else nullcontext()` to a `ContextManager` rather than `object`.
This commit is contained in:
Ilya Konstantinov 2023-01-13 09:29:38 -05:00
parent 3ad4344656
commit 1a96f16401
3 changed files with 15 additions and 2 deletions

View File

@ -0,0 +1,2 @@
Fix :py:func:`pytest.raises` to return a 'ContextManager' so that type-checkers could narrow
:code:`pytest.raises(...) if ... else nullcontext()` down to 'ContextManager' rather than 'object'.

View File

@ -8,7 +8,7 @@ from types import TracebackType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import cast from typing import cast
from typing import Generic from typing import ContextManager
from typing import List from typing import List
from typing import Mapping from typing import Mapping
from typing import Optional from typing import Optional
@ -957,7 +957,7 @@ raises.Exception = fail.Exception # type: ignore
@final @final
class RaisesContext(Generic[E]): class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __init__( def __init__(
self, self,
expected_exception: Union[Type[E], Tuple[Type[E], ...]], expected_exception: Union[Type[E], Tuple[Type[E], ...]],

View File

@ -3,6 +3,11 @@
This file is not executed, it is only checked by mypy to ensure that This file is not executed, it is only checked by mypy to ensure that
none of the code triggers any mypy errors. none of the code triggers any mypy errors.
""" """
import contextlib
from typing import Optional
from typing_extensions import assert_type
import pytest import pytest
@ -22,3 +27,9 @@ def check_fixture_ids_callable() -> None:
@pytest.mark.parametrize("func", [str, int], ids=lambda x: str(x.__name__)) @pytest.mark.parametrize("func", [str, int], ids=lambda x: str(x.__name__))
def check_parametrize_ids_callable(func) -> None: def check_parametrize_ids_callable(func) -> None:
pass pass
def check_raises_is_a_context_manager(val: bool) -> None:
with pytest.raises(RuntimeError) if val else contextlib.nullcontext() as excinfo:
pass
assert_type(excinfo, Optional[pytest.ExceptionInfo[RuntimeError]])