pytester: improve type annotations
This commit is contained in:
parent
62ddf7a0e5
commit
f0eb82f7d4
|
@ -28,6 +28,7 @@ import pytest
|
|||
from _pytest import timing
|
||||
from _pytest._code import Source
|
||||
from _pytest.capture import _get_multicapture
|
||||
from _pytest.compat import overload
|
||||
from _pytest.compat import TYPE_CHECKING
|
||||
from _pytest.config import _PluggyPlugin
|
||||
from _pytest.config import Config
|
||||
|
@ -42,11 +43,13 @@ from _pytest.nodes import Item
|
|||
from _pytest.pathlib import make_numbered_dir
|
||||
from _pytest.pathlib import Path
|
||||
from _pytest.python import Module
|
||||
from _pytest.reports import CollectReport
|
||||
from _pytest.reports import TestReport
|
||||
from _pytest.tmpdir import TempdirFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Type
|
||||
from typing_extensions import Literal
|
||||
|
||||
import pexpect
|
||||
|
||||
|
@ -180,24 +183,24 @@ class PytestArg:
|
|||
return hookrecorder
|
||||
|
||||
|
||||
def get_public_names(values):
|
||||
def get_public_names(values: Iterable[str]) -> List[str]:
|
||||
"""Only return names from iterator values without a leading underscore."""
|
||||
return [x for x in values if x[0] != "_"]
|
||||
|
||||
|
||||
class ParsedCall:
|
||||
def __init__(self, name, kwargs):
|
||||
def __init__(self, name: str, kwargs) -> None:
|
||||
self.__dict__.update(kwargs)
|
||||
self._name = name
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
d = self.__dict__.copy()
|
||||
del d["_name"]
|
||||
return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# The class has undetermined attributes, this tells mypy about it.
|
||||
def __getattr__(self, key):
|
||||
def __getattr__(self, key: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
@ -211,6 +214,7 @@ class HookRecorder:
|
|||
def __init__(self, pluginmanager: PytestPluginManager) -> None:
|
||||
self._pluginmanager = pluginmanager
|
||||
self.calls = [] # type: List[ParsedCall]
|
||||
self.ret = None # type: Optional[Union[int, ExitCode]]
|
||||
|
||||
def before(hook_name: str, hook_impls, kwargs) -> None:
|
||||
self.calls.append(ParsedCall(hook_name, kwargs))
|
||||
|
@ -228,7 +232,7 @@ class HookRecorder:
|
|||
names = names.split()
|
||||
return [call for call in self.calls if call._name in names]
|
||||
|
||||
def assert_contains(self, entries) -> None:
|
||||
def assert_contains(self, entries: Sequence[Tuple[str, str]]) -> None:
|
||||
__tracebackhide__ = True
|
||||
i = 0
|
||||
entries = list(entries)
|
||||
|
@ -266,22 +270,46 @@ class HookRecorder:
|
|||
|
||||
# functionality for test reports
|
||||
|
||||
@overload
|
||||
def getreports(
|
||||
self, names: "Literal['pytest_collectreport']",
|
||||
) -> Sequence[CollectReport]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@overload # noqa: F811
|
||||
def getreports( # noqa: F811
|
||||
self, names: "Literal['pytest_runtest_logreport']",
|
||||
) -> Sequence[TestReport]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@overload # noqa: F811
|
||||
def getreports( # noqa: F811
|
||||
self,
|
||||
names: Union[
|
||||
str, Iterable[str]
|
||||
] = "pytest_runtest_logreport pytest_collectreport",
|
||||
) -> List[TestReport]:
|
||||
names: Union[str, Iterable[str]] = (
|
||||
"pytest_collectreport",
|
||||
"pytest_runtest_logreport",
|
||||
),
|
||||
) -> Sequence[Union[CollectReport, TestReport]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def getreports( # noqa: F811
|
||||
self,
|
||||
names: Union[str, Iterable[str]] = (
|
||||
"pytest_collectreport",
|
||||
"pytest_runtest_logreport",
|
||||
),
|
||||
) -> Sequence[Union[CollectReport, TestReport]]:
|
||||
return [x.report for x in self.getcalls(names)]
|
||||
|
||||
def matchreport(
|
||||
self,
|
||||
inamepart: str = "",
|
||||
names: Union[
|
||||
str, Iterable[str]
|
||||
] = "pytest_runtest_logreport pytest_collectreport",
|
||||
when=None,
|
||||
):
|
||||
names: Union[str, Iterable[str]] = (
|
||||
"pytest_runtest_logreport",
|
||||
"pytest_collectreport",
|
||||
),
|
||||
when: Optional[str] = None,
|
||||
) -> Union[CollectReport, TestReport]:
|
||||
"""Return a testreport whose dotted import path matches."""
|
||||
values = []
|
||||
for rep in self.getreports(names=names):
|
||||
|
@ -305,26 +333,56 @@ class HookRecorder:
|
|||
)
|
||||
return values[0]
|
||||
|
||||
@overload
|
||||
def getfailures(
|
||||
self, names: "Literal['pytest_collectreport']",
|
||||
) -> Sequence[CollectReport]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@overload # noqa: F811
|
||||
def getfailures( # noqa: F811
|
||||
self, names: "Literal['pytest_runtest_logreport']",
|
||||
) -> Sequence[TestReport]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@overload # noqa: F811
|
||||
def getfailures( # noqa: F811
|
||||
self,
|
||||
names: Union[
|
||||
str, Iterable[str]
|
||||
] = "pytest_runtest_logreport pytest_collectreport",
|
||||
) -> List[TestReport]:
|
||||
names: Union[str, Iterable[str]] = (
|
||||
"pytest_collectreport",
|
||||
"pytest_runtest_logreport",
|
||||
),
|
||||
) -> Sequence[Union[CollectReport, TestReport]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def getfailures( # noqa: F811
|
||||
self,
|
||||
names: Union[str, Iterable[str]] = (
|
||||
"pytest_collectreport",
|
||||
"pytest_runtest_logreport",
|
||||
),
|
||||
) -> Sequence[Union[CollectReport, TestReport]]:
|
||||
return [rep for rep in self.getreports(names) if rep.failed]
|
||||
|
||||
def getfailedcollections(self) -> List[TestReport]:
|
||||
def getfailedcollections(self) -> Sequence[CollectReport]:
|
||||
return self.getfailures("pytest_collectreport")
|
||||
|
||||
def listoutcomes(
|
||||
self,
|
||||
) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]:
|
||||
) -> Tuple[
|
||||
Sequence[TestReport],
|
||||
Sequence[Union[CollectReport, TestReport]],
|
||||
Sequence[Union[CollectReport, TestReport]],
|
||||
]:
|
||||
passed = []
|
||||
skipped = []
|
||||
failed = []
|
||||
for rep in self.getreports("pytest_collectreport pytest_runtest_logreport"):
|
||||
for rep in self.getreports(
|
||||
("pytest_collectreport", "pytest_runtest_logreport")
|
||||
):
|
||||
if rep.passed:
|
||||
if rep.when == "call":
|
||||
assert isinstance(rep, TestReport)
|
||||
passed.append(rep)
|
||||
elif rep.skipped:
|
||||
skipped.append(rep)
|
||||
|
@ -879,7 +937,7 @@ class Testdir:
|
|||
runner = testclassinstance.getrunner()
|
||||
return runner(item)
|
||||
|
||||
def inline_runsource(self, source, *cmdlineargs):
|
||||
def inline_runsource(self, source, *cmdlineargs) -> HookRecorder:
|
||||
"""Run a test module in process using ``pytest.main()``.
|
||||
|
||||
This run writes "source" into a temporary file and runs
|
||||
|
@ -896,7 +954,7 @@ class Testdir:
|
|||
values = list(cmdlineargs) + [p]
|
||||
return self.inline_run(*values)
|
||||
|
||||
def inline_genitems(self, *args):
|
||||
def inline_genitems(self, *args) -> Tuple[List[Item], HookRecorder]:
|
||||
"""Run ``pytest.main(['--collectonly'])`` in-process.
|
||||
|
||||
Runs the :py:func:`pytest.main` function to run all of pytest inside
|
||||
|
@ -907,7 +965,9 @@ class Testdir:
|
|||
items = [x.item for x in rec.getcalls("pytest_itemcollected")]
|
||||
return items, rec
|
||||
|
||||
def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False):
|
||||
def inline_run(
|
||||
self, *args, plugins=(), no_reraise_ctrlc: bool = False
|
||||
) -> HookRecorder:
|
||||
"""Run ``pytest.main()`` in-process, returning a HookRecorder.
|
||||
|
||||
Runs the :py:func:`pytest.main` function to run all of pytest inside
|
||||
|
@ -962,7 +1022,7 @@ class Testdir:
|
|||
class reprec: # type: ignore
|
||||
pass
|
||||
|
||||
reprec.ret = ret # type: ignore[attr-defined]
|
||||
reprec.ret = ret
|
||||
|
||||
# Typically we reraise keyboard interrupts from the child run
|
||||
# because it's our user requesting interruption of the testing.
|
||||
|
@ -1010,6 +1070,7 @@ class Testdir:
|
|||
sys.stdout.write(out)
|
||||
sys.stderr.write(err)
|
||||
|
||||
assert reprec.ret is not None
|
||||
res = RunResult(
|
||||
reprec.ret, out.splitlines(), err.splitlines(), timing.time() - now
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue