Use a hack to make typing of pytest.fail.Exception & co work

Mypy currently is unable to handle assigning attributes on function:
https://github.com/python/mypy/issues/2087.
pytest uses this for the outcome exceptions -- `pytest.fail.Exception`,
`pytest.exit.Exception` etc, and this is the canonical name by which they
are referred.

Initially we started working around this with type: ignores, and later
by switching e.g. `pytest.fail.Exception` with the direct exception
`Failed`. But this causes a lot of churn and is not as nice. And I also
found that some code relies on it, in skipping.py:

    def pytest_configure(config):
        if config.option.runxfail:
            # yay a hack
            import pytest

            old = pytest.xfail
            config._cleanup.append(lambda: setattr(pytest, "xfail", old))

            def nop(*args, **kwargs):
                pass

            nop.Exception = xfail.Exception
            setattr(pytest, "xfail", nop)
        ...

So it seems better to support it. Use a hack to make it work. The rest
of the commit rolls back all of the workarounds we added up to now.

`pytest.raises.Exception` also exists, but it's not used much so I kept
it as-is for now.

Hopefully in the future mypy supports this and this ugliness can be
removed.
This commit is contained in:
Ran Benita 2020-02-16 21:46:11 +02:00
parent d18c75baa3
commit 24dcc76495
9 changed files with 86 additions and 70 deletions

View File

@ -25,7 +25,7 @@ from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.fixtures import FixtureManager
from _pytest.outcomes import Exit
from _pytest.outcomes import exit
from _pytest.reports import CollectReport
from _pytest.runner import collect_one_node
from _pytest.runner import SetupState
@ -195,10 +195,10 @@ def wrap_session(
raise
except Failed:
session.exitstatus = ExitCode.TESTS_FAILED
except (KeyboardInterrupt, Exit):
except (KeyboardInterrupt, exit.Exception):
excinfo = _pytest._code.ExceptionInfo.from_current()
exitstatus = ExitCode.INTERRUPTED # type: Union[int, ExitCode]
if isinstance(excinfo.value, Exit):
if isinstance(excinfo.value, exit.Exception):
if excinfo.value.returncode is not None:
exitstatus = excinfo.value.returncode
if initstate < 2:
@ -212,7 +212,7 @@ def wrap_session(
excinfo = _pytest._code.ExceptionInfo.from_current()
try:
config.notify_exception(excinfo, config.option)
except Exit as exc:
except exit.Exception as exc:
if exc.returncode is not None:
session.exitstatus = exc.returncode
sys.stderr.write("{}: {}\n".format(type(exc).__name__, exc))
@ -229,7 +229,7 @@ def wrap_session(
config.hook.pytest_sessionfinish(
session=session, exitstatus=session.exitstatus
)
except Exit as exc:
except exit.Exception as exc:
if exc.returncode is not None:
session.exitstatus = exc.returncode
sys.stderr.write("{}: {}\n".format(type(exc).__name__, exc))

View File

@ -27,6 +27,7 @@ from _pytest.fixtures import FixtureLookupErrorRepr
from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator
from _pytest.mark.structures import NodeKeywords
from _pytest.outcomes import fail
from _pytest.outcomes import Failed
if TYPE_CHECKING:
@ -314,7 +315,7 @@ class Node(metaclass=NodeMeta):
def _repr_failure_py(
self, excinfo: ExceptionInfo[Union[Failed, FixtureLookupError]], style=None
) -> Union[str, ReprExceptionInfo, ExceptionChainRepr, FixtureLookupErrorRepr]:
if isinstance(excinfo.value, Failed):
if isinstance(excinfo.value, fail.Exception):
if not excinfo.value.pytrace:
return str(excinfo.value)
if isinstance(excinfo.value, FixtureLookupError):

View File

@ -4,7 +4,10 @@ as well as functions creating them
"""
import sys
from typing import Any
from typing import Callable
from typing import cast
from typing import Optional
from typing import TypeVar
from packaging.version import Version
@ -12,6 +15,15 @@ TYPE_CHECKING = False # avoid circular import through compat
if TYPE_CHECKING:
from typing import NoReturn
from typing import Type # noqa: F401 (Used in string type annotation.)
from typing_extensions import Protocol
else:
# typing.Protocol is only available starting from Python 3.8. It is also
# available from typing_extensions, but we don't want a runtime dependency
# on that. So use a dummy runtime implementation.
from typing import Generic
Protocol = Generic
class OutcomeException(BaseException):
@ -73,9 +85,31 @@ class Exit(Exception):
super().__init__(msg)
# Elaborate hack to work around https://github.com/python/mypy/issues/2087.
# Ideally would just be `exit.Exception = Exit` etc.
_F = TypeVar("_F", bound=Callable)
_ET = TypeVar("_ET", bound="Type[BaseException]")
class _WithException(Protocol[_F, _ET]):
Exception = None # type: _ET
__call__ = None # type: _F
def _with_exception(exception_type: _ET) -> Callable[[_F], _WithException[_F, _ET]]:
def decorate(func: _F) -> _WithException[_F, _ET]:
func_with_exception = cast(_WithException[_F, _ET], func)
func_with_exception.Exception = exception_type
return func_with_exception
return decorate
# exposed helper methods
@_with_exception(Exit)
def exit(msg: str, returncode: Optional[int] = None) -> "NoReturn":
"""
Exit testing process.
@ -87,10 +121,7 @@ def exit(msg: str, returncode: Optional[int] = None) -> "NoReturn":
raise Exit(msg, returncode)
# Ignore type because of https://github.com/python/mypy/issues/2087.
exit.Exception = Exit # type: ignore
@_with_exception(Skipped)
def skip(msg: str = "", *, allow_module_level: bool = False) -> "NoReturn":
"""
Skip an executing test with the given message.
@ -114,10 +145,7 @@ def skip(msg: str = "", *, allow_module_level: bool = False) -> "NoReturn":
raise Skipped(msg=msg, allow_module_level=allow_module_level)
# Ignore type because of https://github.com/python/mypy/issues/2087.
skip.Exception = Skipped # type: ignore
@_with_exception(Failed)
def fail(msg: str = "", pytrace: bool = True) -> "NoReturn":
"""
Explicitly fail an executing test with the given message.
@ -130,14 +158,11 @@ def fail(msg: str = "", pytrace: bool = True) -> "NoReturn":
raise Failed(msg=msg, pytrace=pytrace)
# Ignore type because of https://github.com/python/mypy/issues/2087.
fail.Exception = Failed # type: ignore
class XFailed(Failed):
""" raised from an explicit call to pytest.xfail() """
@_with_exception(XFailed)
def xfail(reason: str = "") -> "NoReturn":
"""
Imperatively xfail an executing test or setup functions with the given reason.
@ -152,10 +177,6 @@ def xfail(reason: str = "") -> "NoReturn":
raise XFailed(reason)
# Ignore type because of https://github.com/python/mypy/issues/2087.
xfail.Exception = XFailed # type: ignore
def importorskip(
modname: str, minversion: Optional[str] = None, reason: Optional[str] = None
) -> Any:

View File

@ -284,8 +284,7 @@ class TestReport(BaseReport):
if not isinstance(excinfo, ExceptionInfo):
outcome = "failed"
longrepr = excinfo
# Type ignored -- see comment where skip.Exception is defined.
elif excinfo.errisinstance(skip.Exception): # type: ignore
elif excinfo.errisinstance(skip.Exception):
outcome = "skipped"
r = excinfo._getreprcrash()
longrepr = (str(r.path), r.lineno, r.message)

View File

@ -904,6 +904,8 @@ class TestTracebackCutting:
pytest.skip("xxx")
assert excinfo.traceback[-1].frame.code.name == "skip"
assert excinfo.traceback[-1].ishidden()
assert excinfo.traceback[-2].frame.code.name == "test_skip_simple"
assert not excinfo.traceback[-2].ishidden()
def test_traceback_argsetup(self, testdir):
testdir.makeconftest(

View File

@ -16,7 +16,7 @@ from hypothesis import strategies
import pytest
from _pytest import fixtures
from _pytest import python
from _pytest.outcomes import Failed
from _pytest.outcomes import fail
from _pytest.pytester import Testdir
from _pytest.python import _idval
@ -99,7 +99,7 @@ class TestMetafunc:
({"x": 2}, "2"),
]
with pytest.raises(
Failed,
fail.Exception,
match=(
r"In func: ids must be list of string/float/int/bool, found:"
r" Exc\(from_gen\) \(type: <class .*Exc'>\) at index 2"
@ -113,7 +113,7 @@ class TestMetafunc:
metafunc = self.Metafunc(func)
with pytest.raises(
Failed,
fail.Exception,
match=r"parametrize\(\) call in func got an unexpected scope value 'doggy'",
):
metafunc.parametrize("x", [1], scope="doggy")
@ -126,7 +126,7 @@ class TestMetafunc:
metafunc = self.Metafunc(func)
with pytest.raises(
Failed,
fail.Exception,
match=r"'request' is a reserved name and cannot be used in @pytest.mark.parametrize",
):
metafunc.parametrize("request", [1])
@ -205,10 +205,10 @@ class TestMetafunc:
metafunc = self.Metafunc(func)
with pytest.raises(Failed):
with pytest.raises(fail.Exception):
metafunc.parametrize("x", [1, 2], ids=["basic"])
with pytest.raises(Failed):
with pytest.raises(fail.Exception):
metafunc.parametrize(
("x", "y"), [("abc", "def"), ("ghi", "jkl")], ids=["one"]
)
@ -689,7 +689,7 @@ class TestMetafunc:
metafunc = self.Metafunc(func)
with pytest.raises(
Failed,
fail.Exception,
match="In func: expected Sequence or boolean for indirect, got dict",
):
metafunc.parametrize("x, y", [("a", "b")], indirect={}) # type: ignore[arg-type] # noqa: F821
@ -730,7 +730,7 @@ class TestMetafunc:
pass
metafunc = self.Metafunc(func)
with pytest.raises(Failed):
with pytest.raises(fail.Exception):
metafunc.parametrize("x, y", [("a", "b")], indirect=["x", "z"])
def test_parametrize_uses_no_fixture_error_indirect_false(

View File

@ -10,7 +10,6 @@ import _pytest.pytester as pytester
import pytest
from _pytest.config import ExitCode
from _pytest.config import PytestPluginManager
from _pytest.outcomes import Failed
from _pytest.pytester import CwdSnapshot
from _pytest.pytester import HookRecorder
from _pytest.pytester import LineMatcher
@ -171,7 +170,7 @@ def test_hookrecorder_basic(holder) -> None:
call = rec.popcall("pytest_xyz")
assert call.arg == 123
assert call._name == "pytest_xyz"
pytest.raises(Failed, rec.popcall, "abc")
pytest.raises(pytest.fail.Exception, rec.popcall, "abc")
pm.hook.pytest_xyz_noarg()
call = rec.popcall("pytest_xyz_noarg")
assert call._name == "pytest_xyz_noarg"
@ -482,7 +481,7 @@ def test_linematcher_with_nonlist() -> None:
def test_linematcher_match_failure() -> None:
lm = LineMatcher(["foo", "foo", "bar"])
with pytest.raises(Failed) as e:
with pytest.raises(pytest.fail.Exception) as e:
lm.fnmatch_lines(["foo", "f*", "baz"])
assert e.value.msg is not None
assert e.value.msg.splitlines() == [
@ -495,7 +494,7 @@ def test_linematcher_match_failure() -> None:
]
lm = LineMatcher(["foo", "foo", "bar"])
with pytest.raises(Failed) as e:
with pytest.raises(pytest.fail.Exception) as e:
lm.re_match_lines(["foo", "^f.*", "baz"])
assert e.value.msg is not None
assert e.value.msg.splitlines() == [
@ -550,7 +549,7 @@ def test_linematcher_no_matching(function) -> None:
# check the function twice to ensure we don't accumulate the internal buffer
for i in range(2):
with pytest.raises(Failed) as e:
with pytest.raises(pytest.fail.Exception) as e:
func = getattr(lm, function)
func(good_pattern)
obtained = str(e.value).splitlines()
@ -580,7 +579,7 @@ def test_linematcher_no_matching(function) -> None:
def test_linematcher_no_matching_after_match() -> None:
lm = LineMatcher(["1", "2", "3"])
lm.fnmatch_lines(["1", "3"])
with pytest.raises(Failed) as e:
with pytest.raises(pytest.fail.Exception) as e:
lm.no_fnmatch_line("*")
assert str(e.value).splitlines() == ["fnmatch: '*'", " with: '1'"]

View File

@ -3,7 +3,6 @@ import warnings
from typing import Optional
import pytest
from _pytest.outcomes import Failed
from _pytest.recwarn import WarningsRecorder
@ -89,7 +88,7 @@ class TestDeprecatedCall:
)
def test_deprecated_call_raises(self) -> None:
with pytest.raises(Failed, match="No warnings of type"):
with pytest.raises(pytest.fail.Exception, match="No warnings of type"):
pytest.deprecated_call(self.dep, 3, 5)
def test_deprecated_call(self) -> None:
@ -114,7 +113,7 @@ class TestDeprecatedCall:
assert warn_explicit is warnings.warn_explicit
def test_deprecated_explicit_call_raises(self) -> None:
with pytest.raises(Failed):
with pytest.raises(pytest.fail.Exception):
pytest.deprecated_call(self.dep_explicit, 3)
def test_deprecated_explicit_call(self) -> None:
@ -131,7 +130,7 @@ class TestDeprecatedCall:
pass
msg = "No warnings of type (.*DeprecationWarning.*, .*PendingDeprecationWarning.*)"
with pytest.raises(Failed, match=msg):
with pytest.raises(pytest.fail.Exception, match=msg):
if mode == "call":
pytest.deprecated_call(f)
else:
@ -193,9 +192,9 @@ class TestDeprecatedCall:
def f():
warnings.warn(warning("hi"))
with pytest.raises(Failed):
with pytest.raises(pytest.fail.Exception):
pytest.deprecated_call(f)
with pytest.raises(Failed):
with pytest.raises(pytest.fail.Exception):
with pytest.deprecated_call():
f()
@ -203,7 +202,7 @@ class TestDeprecatedCall:
with pytest.deprecated_call(match=r"must be \d+$"):
warnings.warn("value must be 42", DeprecationWarning)
with pytest.raises(Failed):
with pytest.raises(pytest.fail.Exception):
with pytest.deprecated_call(match=r"must be \d+$"):
warnings.warn("this is not here", DeprecationWarning)
@ -217,7 +216,7 @@ class TestWarns:
def test_several_messages(self) -> None:
# different messages, b/c Python suppresses multiple identical warnings
pytest.warns(RuntimeWarning, lambda: warnings.warn("w1", RuntimeWarning))
with pytest.raises(Failed):
with pytest.raises(pytest.fail.Exception):
pytest.warns(UserWarning, lambda: warnings.warn("w2", RuntimeWarning))
pytest.warns(RuntimeWarning, lambda: warnings.warn("w3", RuntimeWarning))
@ -234,7 +233,7 @@ class TestWarns:
(RuntimeWarning, SyntaxWarning), lambda: warnings.warn("w2", SyntaxWarning)
)
pytest.raises(
Failed,
pytest.fail.Exception,
lambda: pytest.warns(
(RuntimeWarning, SyntaxWarning),
lambda: warnings.warn("w3", UserWarning),
@ -248,7 +247,7 @@ class TestWarns:
with pytest.warns(UserWarning):
warnings.warn("user", UserWarning)
with pytest.raises(Failed) as excinfo:
with pytest.raises(pytest.fail.Exception) as excinfo:
with pytest.warns(RuntimeWarning):
warnings.warn("user", UserWarning)
excinfo.match(
@ -256,7 +255,7 @@ class TestWarns:
r"The list of emitted warnings is: \[UserWarning\('user',?\)\]."
)
with pytest.raises(Failed) as excinfo:
with pytest.raises(pytest.fail.Exception) as excinfo:
with pytest.warns(UserWarning):
warnings.warn("runtime", RuntimeWarning)
excinfo.match(
@ -264,7 +263,7 @@ class TestWarns:
r"The list of emitted warnings is: \[RuntimeWarning\('runtime',?\)\]."
)
with pytest.raises(Failed) as excinfo:
with pytest.raises(pytest.fail.Exception) as excinfo:
with pytest.warns(UserWarning):
pass
excinfo.match(
@ -273,7 +272,7 @@ class TestWarns:
)
warning_classes = (UserWarning, FutureWarning)
with pytest.raises(Failed) as excinfo:
with pytest.raises(pytest.fail.Exception) as excinfo:
with pytest.warns(warning_classes) as warninfo:
warnings.warn("runtime", RuntimeWarning)
warnings.warn("import", ImportWarning)
@ -349,11 +348,11 @@ class TestWarns:
with pytest.warns(UserWarning, match=r"must be \d+$"):
warnings.warn("value must be 42", UserWarning)
with pytest.raises(Failed):
with pytest.raises(pytest.fail.Exception):
with pytest.warns(UserWarning, match=r"must be \d+$"):
warnings.warn("this is not here", UserWarning)
with pytest.raises(Failed):
with pytest.raises(pytest.fail.Exception):
with pytest.warns(FutureWarning, match=r"must be \d+$"):
warnings.warn("value must be 42", UserWarning)
@ -364,7 +363,7 @@ class TestWarns:
warnings.warn("aaaaaaaaaa", UserWarning)
def test_none_of_multiple_warns(self) -> None:
with pytest.raises(Failed):
with pytest.raises(pytest.fail.Exception):
with pytest.warns(UserWarning, match=r"aaa"):
warnings.warn("bbbbbbbbbb", UserWarning)
warnings.warn("cccccccccc", UserWarning)

View File

@ -14,10 +14,7 @@ from _pytest import outcomes
from _pytest import reports
from _pytest import runner
from _pytest.config import ExitCode
from _pytest.outcomes import Exit
from _pytest.outcomes import Failed
from _pytest.outcomes import OutcomeException
from _pytest.outcomes import Skipped
if False: # TYPE_CHECKING
from typing import Type
@ -398,7 +395,7 @@ class BaseFunctionalTests:
raise pytest.exit.Exception()
"""
)
except Exit:
except pytest.exit.Exception:
pass
else:
pytest.fail("did not raise")
@ -561,15 +558,13 @@ def test_outcomeexception_passes_except_Exception() -> None:
def test_pytest_exit() -> None:
assert Exit == pytest.exit.Exception # type: ignore
with pytest.raises(Exit) as excinfo:
with pytest.raises(pytest.exit.Exception) as excinfo:
pytest.exit("hello")
assert excinfo.errisinstance(Exit)
assert excinfo.errisinstance(pytest.exit.Exception)
def test_pytest_fail() -> None:
assert Failed == pytest.fail.Exception # type: ignore
with pytest.raises(Failed) as excinfo:
with pytest.raises(pytest.fail.Exception) as excinfo:
pytest.fail("hello")
s = excinfo.exconly(tryshort=True)
assert s.startswith("Failed")
@ -701,10 +696,10 @@ def test_pytest_no_tests_collected_exit_status(testdir) -> None:
def test_exception_printing_skip() -> None:
assert Skipped == pytest.skip.Exception # type: ignore
assert pytest.skip.Exception == pytest.skip.Exception
try:
pytest.skip("hello")
except Skipped:
except pytest.skip.Exception:
excinfo = _pytest._code.ExceptionInfo.from_current()
s = excinfo.exconly(tryshort=True)
assert s.startswith("Skipped")
@ -721,7 +716,7 @@ def test_importorskip(monkeypatch) -> None:
assert sysmod is sys
# path = pytest.importorskip("os.path")
# assert path == os.path
excinfo = pytest.raises(Skipped, f)
excinfo = pytest.raises(pytest.skip.Exception, f)
assert excinfo is not None
excrepr = excinfo.getrepr()
assert excrepr is not None
@ -735,11 +730,11 @@ def test_importorskip(monkeypatch) -> None:
mod = types.ModuleType("hello123")
mod.__version__ = "1.3" # type: ignore
monkeypatch.setitem(sys.modules, "hello123", mod)
with pytest.raises(Skipped):
with pytest.raises(pytest.skip.Exception):
pytest.importorskip("hello123", minversion="1.3.1")
mod2 = pytest.importorskip("hello123", minversion="1.3")
assert mod2 == mod
except Skipped:
except pytest.skip.Exception:
raise NotImplementedError(
"spurious skip: {}".format(_pytest._code.ExceptionInfo.from_current())
)
@ -757,9 +752,9 @@ def test_importorskip_dev_module(monkeypatch) -> None:
monkeypatch.setitem(sys.modules, "mockmodule", mod)
mod2 = pytest.importorskip("mockmodule", minversion="0.12.0")
assert mod2 == mod
with pytest.raises(Skipped):
with pytest.raises(pytest.skip.Exception):
pytest.importorskip("mockmodule1", minversion="0.14.0")
except Skipped:
except pytest.skip.Exception:
raise NotImplementedError(
"spurious skip: {}".format(_pytest._code.ExceptionInfo.from_current())
)