pytester: improve type annotations

This commit is contained in:
Ran Benita 2020-08-01 10:49:51 +03:00
parent 62ddf7a0e5
commit f0eb82f7d4
1 changed files with 86 additions and 25 deletions

View File

@ -28,6 +28,7 @@ import pytest
from _pytest import timing from _pytest import timing
from _pytest._code import Source from _pytest._code import Source
from _pytest.capture import _get_multicapture from _pytest.capture import _get_multicapture
from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING from _pytest.compat import TYPE_CHECKING
from _pytest.config import _PluggyPlugin from _pytest.config import _PluggyPlugin
from _pytest.config import Config from _pytest.config import Config
@ -42,11 +43,13 @@ from _pytest.nodes import Item
from _pytest.pathlib import make_numbered_dir from _pytest.pathlib import make_numbered_dir
from _pytest.pathlib import Path from _pytest.pathlib import Path
from _pytest.python import Module from _pytest.python import Module
from _pytest.reports import CollectReport
from _pytest.reports import TestReport from _pytest.reports import TestReport
from _pytest.tmpdir import TempdirFactory from _pytest.tmpdir import TempdirFactory
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Type from typing import Type
from typing_extensions import Literal
import pexpect import pexpect
@ -180,24 +183,24 @@ class PytestArg:
return hookrecorder return hookrecorder
def get_public_names(values): def get_public_names(values: Iterable[str]) -> List[str]:
"""Only return names from iterator values without a leading underscore.""" """Only return names from iterator values without a leading underscore."""
return [x for x in values if x[0] != "_"] return [x for x in values if x[0] != "_"]
class ParsedCall: class ParsedCall:
def __init__(self, name, kwargs): def __init__(self, name: str, kwargs) -> None:
self.__dict__.update(kwargs) self.__dict__.update(kwargs)
self._name = name self._name = name
def __repr__(self): def __repr__(self) -> str:
d = self.__dict__.copy() d = self.__dict__.copy()
del d["_name"] del d["_name"]
return "<ParsedCall {!r}(**{!r})>".format(self._name, d) return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
if TYPE_CHECKING: if TYPE_CHECKING:
# The class has undetermined attributes, this tells mypy about it. # The class has undetermined attributes, this tells mypy about it.
def __getattr__(self, key): def __getattr__(self, key: str):
raise NotImplementedError() raise NotImplementedError()
@ -211,6 +214,7 @@ class HookRecorder:
def __init__(self, pluginmanager: PytestPluginManager) -> None: def __init__(self, pluginmanager: PytestPluginManager) -> None:
self._pluginmanager = pluginmanager self._pluginmanager = pluginmanager
self.calls = [] # type: List[ParsedCall] self.calls = [] # type: List[ParsedCall]
self.ret = None # type: Optional[Union[int, ExitCode]]
def before(hook_name: str, hook_impls, kwargs) -> None: def before(hook_name: str, hook_impls, kwargs) -> None:
self.calls.append(ParsedCall(hook_name, kwargs)) self.calls.append(ParsedCall(hook_name, kwargs))
@ -228,7 +232,7 @@ class HookRecorder:
names = names.split() names = names.split()
return [call for call in self.calls if call._name in names] return [call for call in self.calls if call._name in names]
def assert_contains(self, entries) -> None: def assert_contains(self, entries: Sequence[Tuple[str, str]]) -> None:
__tracebackhide__ = True __tracebackhide__ = True
i = 0 i = 0
entries = list(entries) entries = list(entries)
@ -266,22 +270,46 @@ class HookRecorder:
# functionality for test reports # functionality for test reports
@overload
def getreports( def getreports(
self, names: "Literal['pytest_collectreport']",
) -> Sequence[CollectReport]:
raise NotImplementedError()
@overload # noqa: F811
def getreports( # noqa: F811
self, names: "Literal['pytest_runtest_logreport']",
) -> Sequence[TestReport]:
raise NotImplementedError()
@overload # noqa: F811
def getreports( # noqa: F811
self, self,
names: Union[ names: Union[str, Iterable[str]] = (
str, Iterable[str] "pytest_collectreport",
] = "pytest_runtest_logreport pytest_collectreport", "pytest_runtest_logreport",
) -> List[TestReport]: ),
) -> Sequence[Union[CollectReport, TestReport]]:
raise NotImplementedError()
def getreports( # noqa: F811
self,
names: Union[str, Iterable[str]] = (
"pytest_collectreport",
"pytest_runtest_logreport",
),
) -> Sequence[Union[CollectReport, TestReport]]:
return [x.report for x in self.getcalls(names)] return [x.report for x in self.getcalls(names)]
def matchreport( def matchreport(
self, self,
inamepart: str = "", inamepart: str = "",
names: Union[ names: Union[str, Iterable[str]] = (
str, Iterable[str] "pytest_runtest_logreport",
] = "pytest_runtest_logreport pytest_collectreport", "pytest_collectreport",
when=None, ),
): when: Optional[str] = None,
) -> Union[CollectReport, TestReport]:
"""Return a testreport whose dotted import path matches.""" """Return a testreport whose dotted import path matches."""
values = [] values = []
for rep in self.getreports(names=names): for rep in self.getreports(names=names):
@ -305,26 +333,56 @@ class HookRecorder:
) )
return values[0] return values[0]
@overload
def getfailures( def getfailures(
self, names: "Literal['pytest_collectreport']",
) -> Sequence[CollectReport]:
raise NotImplementedError()
@overload # noqa: F811
def getfailures( # noqa: F811
self, names: "Literal['pytest_runtest_logreport']",
) -> Sequence[TestReport]:
raise NotImplementedError()
@overload # noqa: F811
def getfailures( # noqa: F811
self, self,
names: Union[ names: Union[str, Iterable[str]] = (
str, Iterable[str] "pytest_collectreport",
] = "pytest_runtest_logreport pytest_collectreport", "pytest_runtest_logreport",
) -> List[TestReport]: ),
) -> Sequence[Union[CollectReport, TestReport]]:
raise NotImplementedError()
def getfailures( # noqa: F811
self,
names: Union[str, Iterable[str]] = (
"pytest_collectreport",
"pytest_runtest_logreport",
),
) -> Sequence[Union[CollectReport, TestReport]]:
return [rep for rep in self.getreports(names) if rep.failed] return [rep for rep in self.getreports(names) if rep.failed]
def getfailedcollections(self) -> List[TestReport]: def getfailedcollections(self) -> Sequence[CollectReport]:
return self.getfailures("pytest_collectreport") return self.getfailures("pytest_collectreport")
def listoutcomes( def listoutcomes(
self, self,
) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]: ) -> Tuple[
Sequence[TestReport],
Sequence[Union[CollectReport, TestReport]],
Sequence[Union[CollectReport, TestReport]],
]:
passed = [] passed = []
skipped = [] skipped = []
failed = [] failed = []
for rep in self.getreports("pytest_collectreport pytest_runtest_logreport"): for rep in self.getreports(
("pytest_collectreport", "pytest_runtest_logreport")
):
if rep.passed: if rep.passed:
if rep.when == "call": if rep.when == "call":
assert isinstance(rep, TestReport)
passed.append(rep) passed.append(rep)
elif rep.skipped: elif rep.skipped:
skipped.append(rep) skipped.append(rep)
@ -879,7 +937,7 @@ class Testdir:
runner = testclassinstance.getrunner() runner = testclassinstance.getrunner()
return runner(item) return runner(item)
def inline_runsource(self, source, *cmdlineargs): def inline_runsource(self, source, *cmdlineargs) -> HookRecorder:
"""Run a test module in process using ``pytest.main()``. """Run a test module in process using ``pytest.main()``.
This run writes "source" into a temporary file and runs This run writes "source" into a temporary file and runs
@ -896,7 +954,7 @@ class Testdir:
values = list(cmdlineargs) + [p] values = list(cmdlineargs) + [p]
return self.inline_run(*values) return self.inline_run(*values)
def inline_genitems(self, *args): def inline_genitems(self, *args) -> Tuple[List[Item], HookRecorder]:
"""Run ``pytest.main(['--collectonly'])`` in-process. """Run ``pytest.main(['--collectonly'])`` in-process.
Runs the :py:func:`pytest.main` function to run all of pytest inside Runs the :py:func:`pytest.main` function to run all of pytest inside
@ -907,7 +965,9 @@ class Testdir:
items = [x.item for x in rec.getcalls("pytest_itemcollected")] items = [x.item for x in rec.getcalls("pytest_itemcollected")]
return items, rec return items, rec
def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False): def inline_run(
self, *args, plugins=(), no_reraise_ctrlc: bool = False
) -> HookRecorder:
"""Run ``pytest.main()`` in-process, returning a HookRecorder. """Run ``pytest.main()`` in-process, returning a HookRecorder.
Runs the :py:func:`pytest.main` function to run all of pytest inside Runs the :py:func:`pytest.main` function to run all of pytest inside
@ -962,7 +1022,7 @@ class Testdir:
class reprec: # type: ignore class reprec: # type: ignore
pass pass
reprec.ret = ret # type: ignore[attr-defined] reprec.ret = ret
# Typically we reraise keyboard interrupts from the child run # Typically we reraise keyboard interrupts from the child run
# because it's our user requesting interruption of the testing. # because it's our user requesting interruption of the testing.
@ -1010,6 +1070,7 @@ class Testdir:
sys.stdout.write(out) sys.stdout.write(out)
sys.stderr.write(err) sys.stderr.write(err)
assert reprec.ret is not None
res = RunResult( res = RunResult(
reprec.ret, out.splitlines(), err.splitlines(), timing.time() - now reprec.ret, out.splitlines(), err.splitlines(), timing.time() - now
) )