Type annotate _pytest.doctest

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent 387d9d04f7
commit 32dd0e87cb
1 changed files with 77 additions and 34 deletions

View File

@ -4,12 +4,17 @@ import inspect
import platform
import sys
import traceback
import types
import warnings
from contextlib import contextmanager
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Optional
from typing import Pattern
from typing import Sequence
from typing import Tuple
from typing import Union
@ -24,6 +29,7 @@ from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.compat import safe_getattr
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.outcomes import OutcomeException
@ -131,7 +137,7 @@ def _is_setup_py(path: py.path.local) -> bool:
return b"setuptools" in contents or b"distutils" in contents
def _is_doctest(config, path, parent):
def _is_doctest(config: Config, path: py.path.local, parent) -> bool:
if path.ext in (".txt", ".rst") and parent.session.isinitpath(path):
return True
globs = config.getoption("doctestglob") or ["test*.txt"]
@ -144,7 +150,7 @@ def _is_doctest(config, path, parent):
class ReprFailDoctest(TerminalRepr):
def __init__(
self, reprlocation_lines: Sequence[Tuple[ReprFileLocation, Sequence[str]]]
):
) -> None:
self.reprlocation_lines = reprlocation_lines
def toterminal(self, tw: TerminalWriter) -> None:
@ -155,7 +161,7 @@ class ReprFailDoctest(TerminalRepr):
class MultipleDoctestFailures(Exception):
def __init__(self, failures):
def __init__(self, failures: "Sequence[doctest.DocTestFailure]") -> None:
super().__init__()
self.failures = failures
@ -170,21 +176,33 @@ def _init_runner_class() -> "Type[doctest.DocTestRunner]":
"""
def __init__(
self, checker=None, verbose=None, optionflags=0, continue_on_failure=True
):
self,
checker: Optional[doctest.OutputChecker] = None,
verbose: Optional[bool] = None,
optionflags: int = 0,
continue_on_failure: bool = True,
) -> None:
doctest.DebugRunner.__init__(
self, checker=checker, verbose=verbose, optionflags=optionflags
)
self.continue_on_failure = continue_on_failure
def report_failure(self, out, test, example, got):
def report_failure(
self, out, test: "doctest.DocTest", example: "doctest.Example", got: str,
) -> None:
failure = doctest.DocTestFailure(test, example, got)
if self.continue_on_failure:
out.append(failure)
else:
raise failure
def report_unexpected_exception(self, out, test, example, exc_info):
def report_unexpected_exception(
self,
out,
test: "doctest.DocTest",
example: "doctest.Example",
exc_info: "Tuple[Type[BaseException], BaseException, types.TracebackType]",
) -> None:
if isinstance(exc_info[1], OutcomeException):
raise exc_info[1]
if isinstance(exc_info[1], bdb.BdbQuit):
@ -219,16 +237,27 @@ def _get_runner(
class DoctestItem(pytest.Item):
def __init__(self, name, parent, runner=None, dtest=None):
def __init__(
self,
name: str,
parent: "Union[DoctestTextfile, DoctestModule]",
runner: Optional["doctest.DocTestRunner"] = None,
dtest: Optional["doctest.DocTest"] = None,
) -> None:
super().__init__(name, parent)
self.runner = runner
self.dtest = dtest
self.obj = None
self.fixture_request = None
self.fixture_request = None # type: Optional[FixtureRequest]
@classmethod
def from_parent( # type: ignore
cls, parent: "Union[DoctestTextfile, DoctestModule]", *, name, runner, dtest
cls,
parent: "Union[DoctestTextfile, DoctestModule]",
*,
name: str,
runner: "doctest.DocTestRunner",
dtest: "doctest.DocTest"
):
# incompatible signature due to to imposed limits on sublcass
"""
@ -236,7 +265,7 @@ class DoctestItem(pytest.Item):
"""
return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest)
def setup(self):
def setup(self) -> None:
if self.dtest is not None:
self.fixture_request = _setup_fixtures(self)
globs = dict(getfixture=self.fixture_request.getfixturevalue)
@ -247,14 +276,18 @@ class DoctestItem(pytest.Item):
self.dtest.globs.update(globs)
def runtest(self) -> None:
assert self.dtest is not None
assert self.runner is not None
_check_all_skipped(self.dtest)
self._disable_output_capturing_for_darwin()
failures = [] # type: List[doctest.DocTestFailure]
self.runner.run(self.dtest, out=failures)
# Type ignored because we change the type of `out` from what
# doctest expects.
self.runner.run(self.dtest, out=failures) # type: ignore[arg-type] # noqa: F821
if failures:
raise MultipleDoctestFailures(failures)
def _disable_output_capturing_for_darwin(self):
def _disable_output_capturing_for_darwin(self) -> None:
"""
Disable output capturing. Otherwise, stdout is lost to doctest (#985)
"""
@ -272,10 +305,12 @@ class DoctestItem(pytest.Item):
failures = (
None
) # type: Optional[List[Union[doctest.DocTestFailure, doctest.UnexpectedException]]]
if excinfo.errisinstance((doctest.DocTestFailure, doctest.UnexpectedException)):
) # type: Optional[Sequence[Union[doctest.DocTestFailure, doctest.UnexpectedException]]]
if isinstance(
excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException)
):
failures = [excinfo.value]
elif excinfo.errisinstance(MultipleDoctestFailures):
elif isinstance(excinfo.value, MultipleDoctestFailures):
failures = excinfo.value.failures
if failures is not None:
@ -289,7 +324,8 @@ class DoctestItem(pytest.Item):
else:
lineno = test.lineno + example.lineno + 1
message = type(failure).__name__
reprlocation = ReprFileLocation(filename, lineno, message)
# TODO: ReprFileLocation doesn't expect a None lineno.
reprlocation = ReprFileLocation(filename, lineno, message) # type: ignore[arg-type] # noqa: F821
checker = _get_checker()
report_choice = _get_report_choice(
self.config.getoption("doctestreport")
@ -329,7 +365,8 @@ class DoctestItem(pytest.Item):
else:
return super().repr_failure(excinfo)
def reportinfo(self) -> Tuple[py.path.local, int, str]:
def reportinfo(self):
assert self.dtest is not None
return self.fspath, self.dtest.lineno, "[doctest] %s" % self.name
@ -399,7 +436,7 @@ class DoctestTextfile(pytest.Module):
)
def _check_all_skipped(test):
def _check_all_skipped(test: "doctest.DocTest") -> None:
"""raises pytest.skip() if all examples in the given DocTest have the SKIP
option set.
"""
@ -410,7 +447,7 @@ def _check_all_skipped(test):
pytest.skip("all tests skipped by +SKIP option")
def _is_mocked(obj):
def _is_mocked(obj: object) -> bool:
"""
returns if a object is possibly a mock object by checking the existence of a highly improbable attribute
"""
@ -421,23 +458,26 @@ def _is_mocked(obj):
@contextmanager
def _patch_unwrap_mock_aware():
def _patch_unwrap_mock_aware() -> Generator[None, None, None]:
"""
contextmanager which replaces ``inspect.unwrap`` with a version
that's aware of mock objects and doesn't recurse on them
"""
real_unwrap = inspect.unwrap
def _mock_aware_unwrap(obj, stop=None):
def _mock_aware_unwrap(
func: Callable[..., Any], *, stop: Optional[Callable[[Any], Any]] = None
) -> Any:
try:
if stop is None or stop is _is_mocked:
return real_unwrap(obj, stop=_is_mocked)
return real_unwrap(obj, stop=lambda obj: _is_mocked(obj) or stop(obj))
return real_unwrap(func, stop=_is_mocked)
_stop = stop
return real_unwrap(func, stop=lambda obj: _is_mocked(obj) or _stop(func))
except Exception as e:
warnings.warn(
"Got %r when unwrapping %r. This is usually caused "
"by a violation of Python's object protocol; see e.g. "
"https://github.com/pytest-dev/pytest/issues/5080" % (e, obj),
"https://github.com/pytest-dev/pytest/issues/5080" % (e, func),
PytestWarning,
)
raise
@ -469,7 +509,10 @@ class DoctestModule(pytest.Module):
"""
if isinstance(obj, property):
obj = getattr(obj, "fget", obj)
return doctest.DocTestFinder._find_lineno(self, obj, source_lines)
# Type ignored because this is a private function.
return doctest.DocTestFinder._find_lineno( # type: ignore
self, obj, source_lines,
)
def _find(
self, tests, obj, name, module, source_lines, globs, seen
@ -510,17 +553,17 @@ class DoctestModule(pytest.Module):
)
def _setup_fixtures(doctest_item):
def _setup_fixtures(doctest_item: DoctestItem) -> FixtureRequest:
"""
Used by DoctestTextfile and DoctestItem to setup fixture information.
"""
def func():
def func() -> None:
pass
doctest_item.funcargs = {}
doctest_item.funcargs = {} # type: ignore[attr-defined] # noqa: F821
fm = doctest_item.session._fixturemanager
doctest_item._fixtureinfo = fm.getfixtureinfo(
doctest_item._fixtureinfo = fm.getfixtureinfo( # type: ignore[attr-defined] # noqa: F821
node=doctest_item, func=func, cls=None, funcargs=False
)
fixture_request = FixtureRequest(doctest_item)
@ -564,7 +607,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
re.VERBOSE,
)
def check_output(self, want, got, optionflags):
def check_output(self, want: str, got: str, optionflags: int) -> bool:
if doctest.OutputChecker.check_output(self, want, got, optionflags):
return True
@ -575,7 +618,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
if not allow_unicode and not allow_bytes and not allow_number:
return False
def remove_prefixes(regex, txt):
def remove_prefixes(regex: Pattern[str], txt: str) -> str:
return re.sub(regex, r"\1\2", txt)
if allow_unicode:
@ -591,7 +634,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
return doctest.OutputChecker.check_output(self, want, got, optionflags)
def _remove_unwanted_precision(self, want, got):
def _remove_unwanted_precision(self, want: str, got: str) -> str:
wants = list(self._number_re.finditer(want))
gots = list(self._number_re.finditer(got))
if len(wants) != len(gots):
@ -686,7 +729,7 @@ def _get_report_choice(key: str) -> int:
@pytest.fixture(scope="session")
def doctest_namespace():
def doctest_namespace() -> Dict[str, Any]:
"""
Fixture that returns a :py:class:`dict` that will be injected into the namespace of doctests.
"""