Type annotate _pytest.doctest

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent 387d9d04f7
commit 32dd0e87cb
1 changed files with 77 additions and 34 deletions

View File

@ -4,12 +4,17 @@ import inspect
import platform import platform
import sys import sys
import traceback import traceback
import types
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any
from typing import Callable
from typing import Dict from typing import Dict
from typing import Generator
from typing import Iterable from typing import Iterable
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Pattern
from typing import Sequence from typing import Sequence
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
@ -24,6 +29,7 @@ from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter from _pytest._io import TerminalWriter
from _pytest.compat import safe_getattr from _pytest.compat import safe_getattr
from _pytest.compat import TYPE_CHECKING from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest from _pytest.fixtures import FixtureRequest
from _pytest.outcomes import OutcomeException from _pytest.outcomes import OutcomeException
@ -131,7 +137,7 @@ def _is_setup_py(path: py.path.local) -> bool:
return b"setuptools" in contents or b"distutils" in contents return b"setuptools" in contents or b"distutils" in contents
def _is_doctest(config, path, parent): def _is_doctest(config: Config, path: py.path.local, parent) -> bool:
if path.ext in (".txt", ".rst") and parent.session.isinitpath(path): if path.ext in (".txt", ".rst") and parent.session.isinitpath(path):
return True return True
globs = config.getoption("doctestglob") or ["test*.txt"] globs = config.getoption("doctestglob") or ["test*.txt"]
@ -144,7 +150,7 @@ def _is_doctest(config, path, parent):
class ReprFailDoctest(TerminalRepr): class ReprFailDoctest(TerminalRepr):
def __init__( def __init__(
self, reprlocation_lines: Sequence[Tuple[ReprFileLocation, Sequence[str]]] self, reprlocation_lines: Sequence[Tuple[ReprFileLocation, Sequence[str]]]
): ) -> None:
self.reprlocation_lines = reprlocation_lines self.reprlocation_lines = reprlocation_lines
def toterminal(self, tw: TerminalWriter) -> None: def toterminal(self, tw: TerminalWriter) -> None:
@ -155,7 +161,7 @@ class ReprFailDoctest(TerminalRepr):
class MultipleDoctestFailures(Exception): class MultipleDoctestFailures(Exception):
def __init__(self, failures): def __init__(self, failures: "Sequence[doctest.DocTestFailure]") -> None:
super().__init__() super().__init__()
self.failures = failures self.failures = failures
@ -170,21 +176,33 @@ def _init_runner_class() -> "Type[doctest.DocTestRunner]":
""" """
def __init__( def __init__(
self, checker=None, verbose=None, optionflags=0, continue_on_failure=True self,
): checker: Optional[doctest.OutputChecker] = None,
verbose: Optional[bool] = None,
optionflags: int = 0,
continue_on_failure: bool = True,
) -> None:
doctest.DebugRunner.__init__( doctest.DebugRunner.__init__(
self, checker=checker, verbose=verbose, optionflags=optionflags self, checker=checker, verbose=verbose, optionflags=optionflags
) )
self.continue_on_failure = continue_on_failure self.continue_on_failure = continue_on_failure
def report_failure(self, out, test, example, got): def report_failure(
self, out, test: "doctest.DocTest", example: "doctest.Example", got: str,
) -> None:
failure = doctest.DocTestFailure(test, example, got) failure = doctest.DocTestFailure(test, example, got)
if self.continue_on_failure: if self.continue_on_failure:
out.append(failure) out.append(failure)
else: else:
raise failure raise failure
def report_unexpected_exception(self, out, test, example, exc_info): def report_unexpected_exception(
self,
out,
test: "doctest.DocTest",
example: "doctest.Example",
exc_info: "Tuple[Type[BaseException], BaseException, types.TracebackType]",
) -> None:
if isinstance(exc_info[1], OutcomeException): if isinstance(exc_info[1], OutcomeException):
raise exc_info[1] raise exc_info[1]
if isinstance(exc_info[1], bdb.BdbQuit): if isinstance(exc_info[1], bdb.BdbQuit):
@ -219,16 +237,27 @@ def _get_runner(
class DoctestItem(pytest.Item): class DoctestItem(pytest.Item):
def __init__(self, name, parent, runner=None, dtest=None): def __init__(
self,
name: str,
parent: "Union[DoctestTextfile, DoctestModule]",
runner: Optional["doctest.DocTestRunner"] = None,
dtest: Optional["doctest.DocTest"] = None,
) -> None:
super().__init__(name, parent) super().__init__(name, parent)
self.runner = runner self.runner = runner
self.dtest = dtest self.dtest = dtest
self.obj = None self.obj = None
self.fixture_request = None self.fixture_request = None # type: Optional[FixtureRequest]
@classmethod @classmethod
def from_parent( # type: ignore def from_parent( # type: ignore
cls, parent: "Union[DoctestTextfile, DoctestModule]", *, name, runner, dtest cls,
parent: "Union[DoctestTextfile, DoctestModule]",
*,
name: str,
runner: "doctest.DocTestRunner",
dtest: "doctest.DocTest"
): ):
# incompatible signature due to to imposed limits on sublcass # incompatible signature due to to imposed limits on sublcass
""" """
@ -236,7 +265,7 @@ class DoctestItem(pytest.Item):
""" """
return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest) return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest)
def setup(self): def setup(self) -> None:
if self.dtest is not None: if self.dtest is not None:
self.fixture_request = _setup_fixtures(self) self.fixture_request = _setup_fixtures(self)
globs = dict(getfixture=self.fixture_request.getfixturevalue) globs = dict(getfixture=self.fixture_request.getfixturevalue)
@ -247,14 +276,18 @@ class DoctestItem(pytest.Item):
self.dtest.globs.update(globs) self.dtest.globs.update(globs)
def runtest(self) -> None: def runtest(self) -> None:
assert self.dtest is not None
assert self.runner is not None
_check_all_skipped(self.dtest) _check_all_skipped(self.dtest)
self._disable_output_capturing_for_darwin() self._disable_output_capturing_for_darwin()
failures = [] # type: List[doctest.DocTestFailure] failures = [] # type: List[doctest.DocTestFailure]
self.runner.run(self.dtest, out=failures) # Type ignored because we change the type of `out` from what
# doctest expects.
self.runner.run(self.dtest, out=failures) # type: ignore[arg-type] # noqa: F821
if failures: if failures:
raise MultipleDoctestFailures(failures) raise MultipleDoctestFailures(failures)
def _disable_output_capturing_for_darwin(self): def _disable_output_capturing_for_darwin(self) -> None:
""" """
Disable output capturing. Otherwise, stdout is lost to doctest (#985) Disable output capturing. Otherwise, stdout is lost to doctest (#985)
""" """
@ -272,10 +305,12 @@ class DoctestItem(pytest.Item):
failures = ( failures = (
None None
) # type: Optional[List[Union[doctest.DocTestFailure, doctest.UnexpectedException]]] ) # type: Optional[Sequence[Union[doctest.DocTestFailure, doctest.UnexpectedException]]]
if excinfo.errisinstance((doctest.DocTestFailure, doctest.UnexpectedException)): if isinstance(
excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException)
):
failures = [excinfo.value] failures = [excinfo.value]
elif excinfo.errisinstance(MultipleDoctestFailures): elif isinstance(excinfo.value, MultipleDoctestFailures):
failures = excinfo.value.failures failures = excinfo.value.failures
if failures is not None: if failures is not None:
@ -289,7 +324,8 @@ class DoctestItem(pytest.Item):
else: else:
lineno = test.lineno + example.lineno + 1 lineno = test.lineno + example.lineno + 1
message = type(failure).__name__ message = type(failure).__name__
reprlocation = ReprFileLocation(filename, lineno, message) # TODO: ReprFileLocation doesn't expect a None lineno.
reprlocation = ReprFileLocation(filename, lineno, message) # type: ignore[arg-type] # noqa: F821
checker = _get_checker() checker = _get_checker()
report_choice = _get_report_choice( report_choice = _get_report_choice(
self.config.getoption("doctestreport") self.config.getoption("doctestreport")
@ -329,7 +365,8 @@ class DoctestItem(pytest.Item):
else: else:
return super().repr_failure(excinfo) return super().repr_failure(excinfo)
def reportinfo(self) -> Tuple[py.path.local, int, str]: def reportinfo(self):
assert self.dtest is not None
return self.fspath, self.dtest.lineno, "[doctest] %s" % self.name return self.fspath, self.dtest.lineno, "[doctest] %s" % self.name
@ -399,7 +436,7 @@ class DoctestTextfile(pytest.Module):
) )
def _check_all_skipped(test): def _check_all_skipped(test: "doctest.DocTest") -> None:
"""raises pytest.skip() if all examples in the given DocTest have the SKIP """raises pytest.skip() if all examples in the given DocTest have the SKIP
option set. option set.
""" """
@ -410,7 +447,7 @@ def _check_all_skipped(test):
pytest.skip("all tests skipped by +SKIP option") pytest.skip("all tests skipped by +SKIP option")
def _is_mocked(obj): def _is_mocked(obj: object) -> bool:
""" """
returns if a object is possibly a mock object by checking the existence of a highly improbable attribute returns if a object is possibly a mock object by checking the existence of a highly improbable attribute
""" """
@ -421,23 +458,26 @@ def _is_mocked(obj):
@contextmanager @contextmanager
def _patch_unwrap_mock_aware(): def _patch_unwrap_mock_aware() -> Generator[None, None, None]:
""" """
contextmanager which replaces ``inspect.unwrap`` with a version contextmanager which replaces ``inspect.unwrap`` with a version
that's aware of mock objects and doesn't recurse on them that's aware of mock objects and doesn't recurse on them
""" """
real_unwrap = inspect.unwrap real_unwrap = inspect.unwrap
def _mock_aware_unwrap(obj, stop=None): def _mock_aware_unwrap(
func: Callable[..., Any], *, stop: Optional[Callable[[Any], Any]] = None
) -> Any:
try: try:
if stop is None or stop is _is_mocked: if stop is None or stop is _is_mocked:
return real_unwrap(obj, stop=_is_mocked) return real_unwrap(func, stop=_is_mocked)
return real_unwrap(obj, stop=lambda obj: _is_mocked(obj) or stop(obj)) _stop = stop
return real_unwrap(func, stop=lambda obj: _is_mocked(obj) or _stop(func))
except Exception as e: except Exception as e:
warnings.warn( warnings.warn(
"Got %r when unwrapping %r. This is usually caused " "Got %r when unwrapping %r. This is usually caused "
"by a violation of Python's object protocol; see e.g. " "by a violation of Python's object protocol; see e.g. "
"https://github.com/pytest-dev/pytest/issues/5080" % (e, obj), "https://github.com/pytest-dev/pytest/issues/5080" % (e, func),
PytestWarning, PytestWarning,
) )
raise raise
@ -469,7 +509,10 @@ class DoctestModule(pytest.Module):
""" """
if isinstance(obj, property): if isinstance(obj, property):
obj = getattr(obj, "fget", obj) obj = getattr(obj, "fget", obj)
return doctest.DocTestFinder._find_lineno(self, obj, source_lines) # Type ignored because this is a private function.
return doctest.DocTestFinder._find_lineno( # type: ignore
self, obj, source_lines,
)
def _find( def _find(
self, tests, obj, name, module, source_lines, globs, seen self, tests, obj, name, module, source_lines, globs, seen
@ -510,17 +553,17 @@ class DoctestModule(pytest.Module):
) )
def _setup_fixtures(doctest_item): def _setup_fixtures(doctest_item: DoctestItem) -> FixtureRequest:
""" """
Used by DoctestTextfile and DoctestItem to setup fixture information. Used by DoctestTextfile and DoctestItem to setup fixture information.
""" """
def func(): def func() -> None:
pass pass
doctest_item.funcargs = {} doctest_item.funcargs = {} # type: ignore[attr-defined] # noqa: F821
fm = doctest_item.session._fixturemanager fm = doctest_item.session._fixturemanager
doctest_item._fixtureinfo = fm.getfixtureinfo( doctest_item._fixtureinfo = fm.getfixtureinfo( # type: ignore[attr-defined] # noqa: F821
node=doctest_item, func=func, cls=None, funcargs=False node=doctest_item, func=func, cls=None, funcargs=False
) )
fixture_request = FixtureRequest(doctest_item) fixture_request = FixtureRequest(doctest_item)
@ -564,7 +607,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
re.VERBOSE, re.VERBOSE,
) )
def check_output(self, want, got, optionflags): def check_output(self, want: str, got: str, optionflags: int) -> bool:
if doctest.OutputChecker.check_output(self, want, got, optionflags): if doctest.OutputChecker.check_output(self, want, got, optionflags):
return True return True
@ -575,7 +618,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
if not allow_unicode and not allow_bytes and not allow_number: if not allow_unicode and not allow_bytes and not allow_number:
return False return False
def remove_prefixes(regex, txt): def remove_prefixes(regex: Pattern[str], txt: str) -> str:
return re.sub(regex, r"\1\2", txt) return re.sub(regex, r"\1\2", txt)
if allow_unicode: if allow_unicode:
@ -591,7 +634,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
return doctest.OutputChecker.check_output(self, want, got, optionflags) return doctest.OutputChecker.check_output(self, want, got, optionflags)
def _remove_unwanted_precision(self, want, got): def _remove_unwanted_precision(self, want: str, got: str) -> str:
wants = list(self._number_re.finditer(want)) wants = list(self._number_re.finditer(want))
gots = list(self._number_re.finditer(got)) gots = list(self._number_re.finditer(got))
if len(wants) != len(gots): if len(wants) != len(gots):
@ -686,7 +729,7 @@ def _get_report_choice(key: str) -> int:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def doctest_namespace(): def doctest_namespace() -> Dict[str, Any]:
""" """
Fixture that returns a :py:class:`dict` that will be injected into the namespace of doctests. Fixture that returns a :py:class:`dict` that will be injected into the namespace of doctests.
""" """