From 32dd0e87cb2e6750c1fc2356eb451c9811bdb065 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Fri, 1 May 2020 14:40:16 +0300 Subject: [PATCH] Type annotate _pytest.doctest --- src/_pytest/doctest.py | 111 ++++++++++++++++++++++++++++------------- 1 file changed, 77 insertions(+), 34 deletions(-) diff --git a/src/_pytest/doctest.py b/src/_pytest/doctest.py index 026476b8a..ab8085982 100644 --- a/src/_pytest/doctest.py +++ b/src/_pytest/doctest.py @@ -4,12 +4,17 @@ import inspect import platform import sys import traceback +import types import warnings from contextlib import contextmanager +from typing import Any +from typing import Callable from typing import Dict +from typing import Generator from typing import Iterable from typing import List from typing import Optional +from typing import Pattern from typing import Sequence from typing import Tuple from typing import Union @@ -24,6 +29,7 @@ from _pytest._code.code import TerminalRepr from _pytest._io import TerminalWriter from _pytest.compat import safe_getattr from _pytest.compat import TYPE_CHECKING +from _pytest.config import Config from _pytest.config.argparsing import Parser from _pytest.fixtures import FixtureRequest from _pytest.outcomes import OutcomeException @@ -131,7 +137,7 @@ def _is_setup_py(path: py.path.local) -> bool: return b"setuptools" in contents or b"distutils" in contents -def _is_doctest(config, path, parent): +def _is_doctest(config: Config, path: py.path.local, parent) -> bool: if path.ext in (".txt", ".rst") and parent.session.isinitpath(path): return True globs = config.getoption("doctestglob") or ["test*.txt"] @@ -144,7 +150,7 @@ def _is_doctest(config, path, parent): class ReprFailDoctest(TerminalRepr): def __init__( self, reprlocation_lines: Sequence[Tuple[ReprFileLocation, Sequence[str]]] - ): + ) -> None: self.reprlocation_lines = reprlocation_lines def toterminal(self, tw: TerminalWriter) -> None: @@ -155,7 +161,7 @@ class ReprFailDoctest(TerminalRepr): class MultipleDoctestFailures(Exception): - def __init__(self, failures): + def __init__(self, failures: "Sequence[doctest.DocTestFailure]") -> None: super().__init__() self.failures = failures @@ -170,21 +176,33 @@ def _init_runner_class() -> "Type[doctest.DocTestRunner]": """ def __init__( - self, checker=None, verbose=None, optionflags=0, continue_on_failure=True - ): + self, + checker: Optional[doctest.OutputChecker] = None, + verbose: Optional[bool] = None, + optionflags: int = 0, + continue_on_failure: bool = True, + ) -> None: doctest.DebugRunner.__init__( self, checker=checker, verbose=verbose, optionflags=optionflags ) self.continue_on_failure = continue_on_failure - def report_failure(self, out, test, example, got): + def report_failure( + self, out, test: "doctest.DocTest", example: "doctest.Example", got: str, + ) -> None: failure = doctest.DocTestFailure(test, example, got) if self.continue_on_failure: out.append(failure) else: raise failure - def report_unexpected_exception(self, out, test, example, exc_info): + def report_unexpected_exception( + self, + out, + test: "doctest.DocTest", + example: "doctest.Example", + exc_info: "Tuple[Type[BaseException], BaseException, types.TracebackType]", + ) -> None: if isinstance(exc_info[1], OutcomeException): raise exc_info[1] if isinstance(exc_info[1], bdb.BdbQuit): @@ -219,16 +237,27 @@ def _get_runner( class DoctestItem(pytest.Item): - def __init__(self, name, parent, runner=None, dtest=None): + def __init__( + self, + name: str, + parent: "Union[DoctestTextfile, DoctestModule]", + runner: Optional["doctest.DocTestRunner"] = None, + dtest: Optional["doctest.DocTest"] = None, + ) -> None: super().__init__(name, parent) self.runner = runner self.dtest = dtest self.obj = None - self.fixture_request = None + self.fixture_request = None # type: Optional[FixtureRequest] @classmethod def from_parent( # type: ignore - cls, parent: "Union[DoctestTextfile, DoctestModule]", *, name, runner, dtest + cls, + parent: "Union[DoctestTextfile, DoctestModule]", + *, + name: str, + runner: "doctest.DocTestRunner", + dtest: "doctest.DocTest" ): # incompatible signature due to to imposed limits on sublcass """ @@ -236,7 +265,7 @@ class DoctestItem(pytest.Item): """ return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest) - def setup(self): + def setup(self) -> None: if self.dtest is not None: self.fixture_request = _setup_fixtures(self) globs = dict(getfixture=self.fixture_request.getfixturevalue) @@ -247,14 +276,18 @@ class DoctestItem(pytest.Item): self.dtest.globs.update(globs) def runtest(self) -> None: + assert self.dtest is not None + assert self.runner is not None _check_all_skipped(self.dtest) self._disable_output_capturing_for_darwin() failures = [] # type: List[doctest.DocTestFailure] - self.runner.run(self.dtest, out=failures) + # Type ignored because we change the type of `out` from what + # doctest expects. + self.runner.run(self.dtest, out=failures) # type: ignore[arg-type] # noqa: F821 if failures: raise MultipleDoctestFailures(failures) - def _disable_output_capturing_for_darwin(self): + def _disable_output_capturing_for_darwin(self) -> None: """ Disable output capturing. Otherwise, stdout is lost to doctest (#985) """ @@ -272,10 +305,12 @@ class DoctestItem(pytest.Item): failures = ( None - ) # type: Optional[List[Union[doctest.DocTestFailure, doctest.UnexpectedException]]] - if excinfo.errisinstance((doctest.DocTestFailure, doctest.UnexpectedException)): + ) # type: Optional[Sequence[Union[doctest.DocTestFailure, doctest.UnexpectedException]]] + if isinstance( + excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException) + ): failures = [excinfo.value] - elif excinfo.errisinstance(MultipleDoctestFailures): + elif isinstance(excinfo.value, MultipleDoctestFailures): failures = excinfo.value.failures if failures is not None: @@ -289,7 +324,8 @@ class DoctestItem(pytest.Item): else: lineno = test.lineno + example.lineno + 1 message = type(failure).__name__ - reprlocation = ReprFileLocation(filename, lineno, message) + # TODO: ReprFileLocation doesn't expect a None lineno. + reprlocation = ReprFileLocation(filename, lineno, message) # type: ignore[arg-type] # noqa: F821 checker = _get_checker() report_choice = _get_report_choice( self.config.getoption("doctestreport") @@ -329,7 +365,8 @@ class DoctestItem(pytest.Item): else: return super().repr_failure(excinfo) - def reportinfo(self) -> Tuple[py.path.local, int, str]: + def reportinfo(self): + assert self.dtest is not None return self.fspath, self.dtest.lineno, "[doctest] %s" % self.name @@ -399,7 +436,7 @@ class DoctestTextfile(pytest.Module): ) -def _check_all_skipped(test): +def _check_all_skipped(test: "doctest.DocTest") -> None: """raises pytest.skip() if all examples in the given DocTest have the SKIP option set. """ @@ -410,7 +447,7 @@ def _check_all_skipped(test): pytest.skip("all tests skipped by +SKIP option") -def _is_mocked(obj): +def _is_mocked(obj: object) -> bool: """ returns if a object is possibly a mock object by checking the existence of a highly improbable attribute """ @@ -421,23 +458,26 @@ def _is_mocked(obj): @contextmanager -def _patch_unwrap_mock_aware(): +def _patch_unwrap_mock_aware() -> Generator[None, None, None]: """ contextmanager which replaces ``inspect.unwrap`` with a version that's aware of mock objects and doesn't recurse on them """ real_unwrap = inspect.unwrap - def _mock_aware_unwrap(obj, stop=None): + def _mock_aware_unwrap( + func: Callable[..., Any], *, stop: Optional[Callable[[Any], Any]] = None + ) -> Any: try: if stop is None or stop is _is_mocked: - return real_unwrap(obj, stop=_is_mocked) - return real_unwrap(obj, stop=lambda obj: _is_mocked(obj) or stop(obj)) + return real_unwrap(func, stop=_is_mocked) + _stop = stop + return real_unwrap(func, stop=lambda obj: _is_mocked(obj) or _stop(func)) except Exception as e: warnings.warn( "Got %r when unwrapping %r. This is usually caused " "by a violation of Python's object protocol; see e.g. " - "https://github.com/pytest-dev/pytest/issues/5080" % (e, obj), + "https://github.com/pytest-dev/pytest/issues/5080" % (e, func), PytestWarning, ) raise @@ -469,7 +509,10 @@ class DoctestModule(pytest.Module): """ if isinstance(obj, property): obj = getattr(obj, "fget", obj) - return doctest.DocTestFinder._find_lineno(self, obj, source_lines) + # Type ignored because this is a private function. + return doctest.DocTestFinder._find_lineno( # type: ignore + self, obj, source_lines, + ) def _find( self, tests, obj, name, module, source_lines, globs, seen @@ -510,17 +553,17 @@ class DoctestModule(pytest.Module): ) -def _setup_fixtures(doctest_item): +def _setup_fixtures(doctest_item: DoctestItem) -> FixtureRequest: """ Used by DoctestTextfile and DoctestItem to setup fixture information. """ - def func(): + def func() -> None: pass - doctest_item.funcargs = {} + doctest_item.funcargs = {} # type: ignore[attr-defined] # noqa: F821 fm = doctest_item.session._fixturemanager - doctest_item._fixtureinfo = fm.getfixtureinfo( + doctest_item._fixtureinfo = fm.getfixtureinfo( # type: ignore[attr-defined] # noqa: F821 node=doctest_item, func=func, cls=None, funcargs=False ) fixture_request = FixtureRequest(doctest_item) @@ -564,7 +607,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]": re.VERBOSE, ) - def check_output(self, want, got, optionflags): + def check_output(self, want: str, got: str, optionflags: int) -> bool: if doctest.OutputChecker.check_output(self, want, got, optionflags): return True @@ -575,7 +618,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]": if not allow_unicode and not allow_bytes and not allow_number: return False - def remove_prefixes(regex, txt): + def remove_prefixes(regex: Pattern[str], txt: str) -> str: return re.sub(regex, r"\1\2", txt) if allow_unicode: @@ -591,7 +634,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]": return doctest.OutputChecker.check_output(self, want, got, optionflags) - def _remove_unwanted_precision(self, want, got): + def _remove_unwanted_precision(self, want: str, got: str) -> str: wants = list(self._number_re.finditer(want)) gots = list(self._number_re.finditer(got)) if len(wants) != len(gots): @@ -686,7 +729,7 @@ def _get_report_choice(key: str) -> int: @pytest.fixture(scope="session") -def doctest_namespace(): +def doctest_namespace() -> Dict[str, Any]: """ Fixture that returns a :py:class:`dict` that will be injected into the namespace of doctests. """