pytester: improve type annotations

This commit is contained in:
Ran Benita 2020-08-01 10:49:51 +03:00
parent 62ddf7a0e5
commit f0eb82f7d4
1 changed files with 86 additions and 25 deletions

View File

@ -28,6 +28,7 @@ import pytest
from _pytest import timing
from _pytest._code import Source
from _pytest.capture import _get_multicapture
from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING
from _pytest.config import _PluggyPlugin
from _pytest.config import Config
@ -42,11 +43,13 @@ from _pytest.nodes import Item
from _pytest.pathlib import make_numbered_dir
from _pytest.pathlib import Path
from _pytest.python import Module
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
from _pytest.tmpdir import TempdirFactory
if TYPE_CHECKING:
from typing import Type
from typing_extensions import Literal
import pexpect
@ -180,24 +183,24 @@ class PytestArg:
return hookrecorder
def get_public_names(values):
def get_public_names(values: Iterable[str]) -> List[str]:
"""Only return names from iterator values without a leading underscore."""
return [x for x in values if x[0] != "_"]
class ParsedCall:
def __init__(self, name, kwargs):
def __init__(self, name: str, kwargs) -> None:
self.__dict__.update(kwargs)
self._name = name
def __repr__(self):
def __repr__(self) -> str:
d = self.__dict__.copy()
del d["_name"]
return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
if TYPE_CHECKING:
# The class has undetermined attributes, this tells mypy about it.
def __getattr__(self, key):
def __getattr__(self, key: str):
raise NotImplementedError()
@ -211,6 +214,7 @@ class HookRecorder:
def __init__(self, pluginmanager: PytestPluginManager) -> None:
self._pluginmanager = pluginmanager
self.calls = [] # type: List[ParsedCall]
self.ret = None # type: Optional[Union[int, ExitCode]]
def before(hook_name: str, hook_impls, kwargs) -> None:
self.calls.append(ParsedCall(hook_name, kwargs))
@ -228,7 +232,7 @@ class HookRecorder:
names = names.split()
return [call for call in self.calls if call._name in names]
def assert_contains(self, entries) -> None:
def assert_contains(self, entries: Sequence[Tuple[str, str]]) -> None:
__tracebackhide__ = True
i = 0
entries = list(entries)
@ -266,22 +270,46 @@ class HookRecorder:
# functionality for test reports
@overload
def getreports(
self, names: "Literal['pytest_collectreport']",
) -> Sequence[CollectReport]:
raise NotImplementedError()
@overload # noqa: F811
def getreports( # noqa: F811
self, names: "Literal['pytest_runtest_logreport']",
) -> Sequence[TestReport]:
raise NotImplementedError()
@overload # noqa: F811
def getreports( # noqa: F811
self,
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
) -> List[TestReport]:
names: Union[str, Iterable[str]] = (
"pytest_collectreport",
"pytest_runtest_logreport",
),
) -> Sequence[Union[CollectReport, TestReport]]:
raise NotImplementedError()
def getreports( # noqa: F811
self,
names: Union[str, Iterable[str]] = (
"pytest_collectreport",
"pytest_runtest_logreport",
),
) -> Sequence[Union[CollectReport, TestReport]]:
return [x.report for x in self.getcalls(names)]
def matchreport(
self,
inamepart: str = "",
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
when=None,
):
names: Union[str, Iterable[str]] = (
"pytest_runtest_logreport",
"pytest_collectreport",
),
when: Optional[str] = None,
) -> Union[CollectReport, TestReport]:
"""Return a testreport whose dotted import path matches."""
values = []
for rep in self.getreports(names=names):
@ -305,26 +333,56 @@ class HookRecorder:
)
return values[0]
@overload
def getfailures(
self, names: "Literal['pytest_collectreport']",
) -> Sequence[CollectReport]:
raise NotImplementedError()
@overload # noqa: F811
def getfailures( # noqa: F811
self, names: "Literal['pytest_runtest_logreport']",
) -> Sequence[TestReport]:
raise NotImplementedError()
@overload # noqa: F811
def getfailures( # noqa: F811
self,
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
) -> List[TestReport]:
names: Union[str, Iterable[str]] = (
"pytest_collectreport",
"pytest_runtest_logreport",
),
) -> Sequence[Union[CollectReport, TestReport]]:
raise NotImplementedError()
def getfailures( # noqa: F811
self,
names: Union[str, Iterable[str]] = (
"pytest_collectreport",
"pytest_runtest_logreport",
),
) -> Sequence[Union[CollectReport, TestReport]]:
return [rep for rep in self.getreports(names) if rep.failed]
def getfailedcollections(self) -> List[TestReport]:
def getfailedcollections(self) -> Sequence[CollectReport]:
return self.getfailures("pytest_collectreport")
def listoutcomes(
self,
) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]:
) -> Tuple[
Sequence[TestReport],
Sequence[Union[CollectReport, TestReport]],
Sequence[Union[CollectReport, TestReport]],
]:
passed = []
skipped = []
failed = []
for rep in self.getreports("pytest_collectreport pytest_runtest_logreport"):
for rep in self.getreports(
("pytest_collectreport", "pytest_runtest_logreport")
):
if rep.passed:
if rep.when == "call":
assert isinstance(rep, TestReport)
passed.append(rep)
elif rep.skipped:
skipped.append(rep)
@ -879,7 +937,7 @@ class Testdir:
runner = testclassinstance.getrunner()
return runner(item)
def inline_runsource(self, source, *cmdlineargs):
def inline_runsource(self, source, *cmdlineargs) -> HookRecorder:
"""Run a test module in process using ``pytest.main()``.
This run writes "source" into a temporary file and runs
@ -896,7 +954,7 @@ class Testdir:
values = list(cmdlineargs) + [p]
return self.inline_run(*values)
def inline_genitems(self, *args):
def inline_genitems(self, *args) -> Tuple[List[Item], HookRecorder]:
"""Run ``pytest.main(['--collectonly'])`` in-process.
Runs the :py:func:`pytest.main` function to run all of pytest inside
@ -907,7 +965,9 @@ class Testdir:
items = [x.item for x in rec.getcalls("pytest_itemcollected")]
return items, rec
def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False):
def inline_run(
self, *args, plugins=(), no_reraise_ctrlc: bool = False
) -> HookRecorder:
"""Run ``pytest.main()`` in-process, returning a HookRecorder.
Runs the :py:func:`pytest.main` function to run all of pytest inside
@ -962,7 +1022,7 @@ class Testdir:
class reprec: # type: ignore
pass
reprec.ret = ret # type: ignore[attr-defined]
reprec.ret = ret
# Typically we reraise keyboard interrupts from the child run
# because it's our user requesting interruption of the testing.
@ -1010,6 +1070,7 @@ class Testdir:
sys.stdout.write(out)
sys.stderr.write(err)
assert reprec.ret is not None
res = RunResult(
reprec.ret, out.splitlines(), err.splitlines(), timing.time() - now
)