Type annotate some parts related to runner & reports

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent 709bcbf3c4
commit 90e58f8961
11 changed files with 132 additions and 84 deletions

View File

@ -278,7 +278,7 @@ class LFPlugin:
elif report.failed:
self.lastfailed[report.nodeid] = True
def pytest_collectreport(self, report):
def pytest_collectreport(self, report: CollectReport) -> None:
passed = report.outcome in ("passed", "skipped")
if passed:
if report.nodeid in self.lastfailed:

View File

@ -2,6 +2,7 @@
import os
import sys
from argparse import Action
from typing import List
from typing import Optional
from typing import Union
@ -235,7 +236,7 @@ def getpluginversioninfo(config):
return lines
def pytest_report_header(config):
def pytest_report_header(config: Config) -> List[str]:
lines = []
if config.option.debug or config.option.traceconfig:
lines.append(

View File

@ -3,6 +3,7 @@ from typing import Any
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
@ -284,7 +285,7 @@ def pytest_itemcollected(item):
""" we just collected a test item. """
def pytest_collectreport(report):
def pytest_collectreport(report: "CollectReport") -> None:
""" collector finished collecting. """
@ -430,7 +431,7 @@ def pytest_runtest_teardown(item: "Item", nextitem: "Optional[Item]") -> None:
@hookspec(firstresult=True)
def pytest_runtest_makereport(item: "Item", call: "CallInfo") -> Optional[object]:
def pytest_runtest_makereport(item: "Item", call: "CallInfo[None]") -> Optional[object]:
""" return a :py:class:`_pytest.runner.TestReport` object
for the given :py:class:`pytest.Item <_pytest.main.Item>` and
:py:class:`_pytest.runner.CallInfo`.
@ -444,7 +445,7 @@ def pytest_runtest_logreport(report: "TestReport") -> None:
@hookspec(firstresult=True)
def pytest_report_to_serializable(config: "Config", report):
def pytest_report_to_serializable(config: "Config", report: "BaseReport"):
"""
Serializes the given report object into a data structure suitable for sending
over the wire, e.g. converted to JSON.
@ -580,7 +581,9 @@ def pytest_assertion_pass(item, lineno: int, orig: str, expl: str) -> None:
# -------------------------------------------------------------------------
def pytest_report_header(config: "Config", startdir):
def pytest_report_header(
config: "Config", startdir: py.path.local
) -> Union[str, List[str]]:
""" return a string or list of strings to be displayed as header info for terminal reporting.
:param _pytest.config.Config config: pytest config object
@ -601,7 +604,9 @@ def pytest_report_header(config: "Config", startdir):
"""
def pytest_report_collectionfinish(config: "Config", startdir, items):
def pytest_report_collectionfinish(
config: "Config", startdir: py.path.local, items: "Sequence[Item]"
) -> Union[str, List[str]]:
"""
.. versionadded:: 3.2
@ -758,7 +763,7 @@ def pytest_keyboard_interrupt(excinfo):
def pytest_exception_interact(
node: "Node", call: "CallInfo", report: "BaseReport"
node: "Node", call: "CallInfo[object]", report: "Union[CollectReport, TestReport]"
) -> None:
"""called when an exception was raised which can potentially be
interactively handled.

View File

@ -442,7 +442,9 @@ class Session(nodes.FSCollector):
raise self.Interrupted(self.shouldstop)
@hookimpl(tryfirst=True)
def pytest_runtest_logreport(self, report: TestReport) -> None:
def pytest_runtest_logreport(
self, report: Union[TestReport, CollectReport]
) -> None:
if report.failed and not hasattr(report, "wasxfail"):
self.testsfailed += 1
maxfail = self.config.getvalue("maxfail")

View File

@ -1,9 +1,12 @@
from io import StringIO
from pprint import pprint
from typing import Any
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
import attr
@ -21,12 +24,17 @@ from _pytest._code.code import ReprTraceback
from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import skip
from _pytest.pathlib import Path
if TYPE_CHECKING:
from typing import NoReturn
from typing_extensions import Type
from typing_extensions import Literal
from _pytest.runner import CallInfo
@ -42,6 +50,9 @@ def getslaveinfoline(node):
return s
_R = TypeVar("_R", bound="BaseReport")
class BaseReport:
when = None # type: Optional[str]
location = None # type: Optional[Tuple[str, Optional[int], str]]
@ -74,13 +85,13 @@ class BaseReport:
except UnicodeEncodeError:
out.line("<unprintable longrepr>")
def get_sections(self, prefix):
def get_sections(self, prefix: str) -> Iterator[Tuple[str, str]]:
for name, content in self.sections:
if name.startswith(prefix):
yield prefix, content
@property
def longreprtext(self):
def longreprtext(self) -> str:
"""
Read-only property that returns the full string representation
of ``longrepr``.
@ -95,7 +106,7 @@ class BaseReport:
return exc.strip()
@property
def caplog(self):
def caplog(self) -> str:
"""Return captured log lines, if log capturing is enabled
.. versionadded:: 3.5
@ -105,7 +116,7 @@ class BaseReport:
)
@property
def capstdout(self):
def capstdout(self) -> str:
"""Return captured text from stdout, if capturing is enabled
.. versionadded:: 3.0
@ -115,7 +126,7 @@ class BaseReport:
)
@property
def capstderr(self):
def capstderr(self) -> str:
"""Return captured text from stderr, if capturing is enabled
.. versionadded:: 3.0
@ -133,7 +144,7 @@ class BaseReport:
return self.nodeid.split("::")[0]
@property
def count_towards_summary(self):
def count_towards_summary(self) -> bool:
"""
**Experimental**
@ -148,7 +159,7 @@ class BaseReport:
return True
@property
def head_line(self):
def head_line(self) -> Optional[str]:
"""
**Experimental**
@ -168,8 +179,9 @@ class BaseReport:
if self.location is not None:
fspath, lineno, domain = self.location
return domain
return None
def _get_verbose_word(self, config):
def _get_verbose_word(self, config: Config):
_category, _short, verbose = config.hook.pytest_report_teststatus(
report=self, config=config
)
@ -187,7 +199,7 @@ class BaseReport:
return _report_to_json(self)
@classmethod
def _from_json(cls, reportdict):
def _from_json(cls: "Type[_R]", reportdict) -> _R:
"""
This was originally the serialize_report() function from xdist (ca03269).
@ -200,7 +212,9 @@ class BaseReport:
return cls(**kwargs)
def _report_unserialization_failure(type_name, report_class, reportdict):
def _report_unserialization_failure(
type_name: str, report_class: "Type[BaseReport]", reportdict
) -> "NoReturn":
url = "https://github.com/pytest-dev/pytest/issues"
stream = StringIO()
pprint("-" * 100, stream=stream)
@ -221,15 +235,15 @@ class TestReport(BaseReport):
def __init__(
self,
nodeid,
nodeid: str,
location: Tuple[str, Optional[int], str],
keywords,
outcome,
outcome: "Literal['passed', 'failed', 'skipped']",
longrepr,
when,
sections=(),
duration=0,
user_properties=None,
when: "Literal['setup', 'call', 'teardown']",
sections: Iterable[Tuple[str, str]] = (),
duration: float = 0,
user_properties: Optional[Iterable[Tuple[str, object]]] = None,
**extra
) -> None:
#: normalized collection node id
@ -268,23 +282,25 @@ class TestReport(BaseReport):
self.__dict__.update(extra)
def __repr__(self):
def __repr__(self) -> str:
return "<{} {!r} when={!r} outcome={!r}>".format(
self.__class__.__name__, self.nodeid, self.when, self.outcome
)
@classmethod
def from_item_and_call(cls, item: Item, call: "CallInfo") -> "TestReport":
def from_item_and_call(cls, item: Item, call: "CallInfo[None]") -> "TestReport":
"""
Factory method to create and fill a TestReport with standard item and call info.
"""
when = call.when
# Remove "collect" from the Literal type -- only for collection calls.
assert when != "collect"
duration = call.duration
keywords = {x: 1 for x in item.keywords}
excinfo = call.excinfo
sections = []
if not call.excinfo:
outcome = "passed"
outcome = "passed" # type: Literal["passed", "failed", "skipped"]
# TODO: Improve this Any.
longrepr = None # type: Optional[Any]
else:
@ -324,10 +340,10 @@ class CollectReport(BaseReport):
def __init__(
self,
nodeid: str,
outcome,
outcome: "Literal['passed', 'skipped', 'failed']",
longrepr,
result: Optional[List[Union[Item, Collector]]],
sections=(),
sections: Iterable[Tuple[str, str]] = (),
**extra
) -> None:
self.nodeid = nodeid
@ -341,28 +357,29 @@ class CollectReport(BaseReport):
def location(self):
return (self.fspath, None, self.fspath)
def __repr__(self):
def __repr__(self) -> str:
return "<CollectReport {!r} lenresult={} outcome={!r}>".format(
self.nodeid, len(self.result), self.outcome
)
class CollectErrorRepr(TerminalRepr):
def __init__(self, msg):
def __init__(self, msg) -> None:
self.longrepr = msg
def toterminal(self, out) -> None:
out.line(self.longrepr, red=True)
def pytest_report_to_serializable(report):
def pytest_report_to_serializable(report: BaseReport):
if isinstance(report, (TestReport, CollectReport)):
data = report._to_json()
data["$report_type"] = report.__class__.__name__
return data
return None
def pytest_report_from_serializable(data):
def pytest_report_from_serializable(data) -> Optional[BaseReport]:
if "$report_type" in data:
if data["$report_type"] == "TestReport":
return TestReport._from_json(data)
@ -371,9 +388,10 @@ def pytest_report_from_serializable(data):
assert False, "Unknown report_type unserialize data: {}".format(
data["$report_type"]
)
return None
def _report_to_json(report):
def _report_to_json(report: BaseReport):
"""
This was originally the serialize_report() function from xdist (ca03269).
@ -381,11 +399,12 @@ def _report_to_json(report):
serialization.
"""
def serialize_repr_entry(entry):
entry_data = {"type": type(entry).__name__, "data": attr.asdict(entry)}
for key, value in entry_data["data"].items():
def serialize_repr_entry(entry: Union[ReprEntry, ReprEntryNative]):
data = attr.asdict(entry)
for key, value in data.items():
if hasattr(value, "__dict__"):
entry_data["data"][key] = attr.asdict(value)
data[key] = attr.asdict(value)
entry_data = {"type": type(entry).__name__, "data": data}
return entry_data
def serialize_repr_traceback(reprtraceback: ReprTraceback):

View File

@ -7,6 +7,7 @@ import py
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
from _pytest.store import StoreKey
@ -87,7 +88,7 @@ class ResultLog:
longrepr = str(report.longrepr)
self.log_outcome(report, code, longrepr)
def pytest_collectreport(self, report):
def pytest_collectreport(self, report: CollectReport) -> None:
if not report.passed:
if report.failed:
code = "F"
@ -95,7 +96,7 @@ class ResultLog:
else:
assert report.skipped
code = "S"
longrepr = "%s:%d: %s" % report.longrepr
longrepr = "%s:%d: %s" % report.longrepr # type: ignore
self.log_outcome(report, code, longrepr)
def pytest_internalerror(self, excrepr):

View File

@ -3,10 +3,14 @@ import bdb
import os
import sys
from typing import Callable
from typing import cast
from typing import Dict
from typing import Generic
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
import attr
@ -179,7 +183,7 @@ def _update_current_test_var(
os.environ.pop(var_name)
def pytest_report_teststatus(report):
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]:
if report.when in ("setup", "teardown"):
if report.failed:
# category, shortletter, verbose-word
@ -188,6 +192,7 @@ def pytest_report_teststatus(report):
return "skipped", "s", "SKIPPED"
else:
return "", "", ""
return None
#
@ -217,9 +222,9 @@ def check_interactive_exception(call: "CallInfo", report: BaseReport) -> bool:
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo":
) -> "CallInfo[None]":
if when == "setup":
ihook = item.ihook.pytest_runtest_setup
ihook = item.ihook.pytest_runtest_setup # type: Callable[..., None]
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
@ -234,11 +239,14 @@ def call_runtest_hook(
)
_T = TypeVar("_T")
@attr.s(repr=False)
class CallInfo:
class CallInfo(Generic[_T]):
""" Result/Exception info a function invocation.
:param result: The return value of the call, if it didn't raise. Can only be accessed
:param T result: The return value of the call, if it didn't raise. Can only be accessed
if excinfo is None.
:param Optional[ExceptionInfo] excinfo: The captured exception of the call, if it raised.
:param float start: The system time when the call started, in seconds since the epoch.
@ -247,28 +255,34 @@ class CallInfo:
:param str when: The context of invocation: "setup", "call", "teardown", ...
"""
_result = attr.ib()
_result = attr.ib(type="Optional[_T]")
excinfo = attr.ib(type=Optional[ExceptionInfo])
start = attr.ib(type=float)
stop = attr.ib(type=float)
duration = attr.ib(type=float)
when = attr.ib(type=str)
when = attr.ib(type="Literal['collect', 'setup', 'call', 'teardown']")
@property
def result(self):
def result(self) -> _T:
if self.excinfo is not None:
raise AttributeError("{!r} has no valid result".format(self))
return self._result
# The cast is safe because an exception wasn't raised, hence
# _result has the expected function return type (which may be
# None, that's why a cast and not an assert).
return cast(_T, self._result)
@classmethod
def from_call(cls, func, when, reraise=None) -> "CallInfo":
#: context of invocation: one of "setup", "call",
#: "teardown", "memocollect"
def from_call(
cls,
func: "Callable[[], _T]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: "Optional[Union[Type[BaseException], Tuple[Type[BaseException], ...]]]" = None,
) -> "CallInfo[_T]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
result = func()
result = func() # type: Optional[_T]
except BaseException:
excinfo = ExceptionInfo.from_current()
if reraise is not None and excinfo.errisinstance(reraise):
@ -293,7 +307,7 @@ class CallInfo:
return "<CallInfo when={!r} excinfo={!r}>".format(self.when, self.excinfo)
def pytest_runtest_makereport(item: Item, call: CallInfo) -> TestReport:
def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport:
return TestReport.from_item_and_call(item, call)
@ -301,7 +315,7 @@ def pytest_make_collect_report(collector: Collector) -> CollectReport:
call = CallInfo.from_call(lambda: list(collector.collect()), "collect")
longrepr = None
if not call.excinfo:
outcome = "passed"
outcome = "passed" # type: Literal["passed", "skipped", "failed"]
else:
skip_exceptions = [Skipped]
unittest = sys.modules.get("unittest")
@ -321,9 +335,8 @@ def pytest_make_collect_report(collector: Collector) -> CollectReport:
if not hasattr(errorinfo, "toterminal"):
errorinfo = CollectErrorRepr(errorinfo)
longrepr = errorinfo
rep = CollectReport(
collector.nodeid, outcome, longrepr, getattr(call, "result", None)
)
result = call.result if not call.excinfo else None
rep = CollectReport(collector.nodeid, outcome, longrepr, result)
rep.call = call # type: ignore # see collect_one_node
return rep

View File

@ -1,4 +1,7 @@
""" support for skip/xfail functions and markers. """
from typing import Optional
from typing import Tuple
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
@ -8,6 +11,7 @@ from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.outcomes import xfail
from _pytest.python import Function
from _pytest.reports import BaseReport
from _pytest.runner import CallInfo
from _pytest.store import StoreKey
@ -129,7 +133,7 @@ def check_strict_xfail(pyfuncitem: Function) -> None:
@hookimpl(hookwrapper=True)
def pytest_runtest_makereport(item: Item, call: CallInfo):
def pytest_runtest_makereport(item: Item, call: CallInfo[None]):
outcome = yield
rep = outcome.get_result()
evalxfail = item._store.get(evalxfail_key, None)
@ -181,9 +185,10 @@ def pytest_runtest_makereport(item: Item, call: CallInfo):
# called by terminalreporter progress reporting
def pytest_report_teststatus(report):
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]:
if hasattr(report, "wasxfail"):
if report.skipped:
return "xfailed", "x", "XFAIL"
elif report.passed:
return "xpassed", "X", "XPASS"
return None

View File

@ -37,6 +37,7 @@ from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser
from _pytest.deprecated import TERMINALWRITER_WRITER
from _pytest.reports import BaseReport
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
@ -218,14 +219,14 @@ def getreportopt(config: Config) -> str:
@pytest.hookimpl(trylast=True) # after _pytest.runner
def pytest_report_teststatus(report: TestReport) -> Tuple[str, str, str]:
def pytest_report_teststatus(report: BaseReport) -> Tuple[str, str, str]:
letter = "F"
if report.passed:
letter = "."
elif report.skipped:
letter = "s"
outcome = report.outcome
outcome = report.outcome # type: str
if report.when in ("collect", "setup", "teardown") and outcome == "failed":
outcome = "error"
letter = "E"
@ -364,7 +365,7 @@ class TerminalReporter:
self._tw.write(extra, **kwargs)
self.currentfspath = -2
def ensure_newline(self):
def ensure_newline(self) -> None:
if self.currentfspath:
self._tw.line()
self.currentfspath = None
@ -375,7 +376,7 @@ class TerminalReporter:
def flush(self) -> None:
self._tw.flush()
def write_line(self, line, **markup):
def write_line(self, line: Union[str, bytes], **markup) -> None:
if not isinstance(line, str):
line = str(line, errors="replace")
self.ensure_newline()
@ -642,12 +643,12 @@ class TerminalReporter:
)
self._write_report_lines_from_hooks(lines)
def _write_report_lines_from_hooks(self, lines):
def _write_report_lines_from_hooks(self, lines) -> None:
lines.reverse()
for line in collapse(lines):
self.write_line(line)
def pytest_report_header(self, config):
def pytest_report_header(self, config: Config) -> List[str]:
line = "rootdir: %s" % config.rootdir
if config.inifile:
@ -664,7 +665,7 @@ class TerminalReporter:
result.append("plugins: %s" % ", ".join(_plugin_nameversions(plugininfo)))
return result
def pytest_collection_finish(self, session):
def pytest_collection_finish(self, session: "Session") -> None:
self.report_collect(True)
lines = self.config.hook.pytest_report_collectionfinish(

View File

@ -255,7 +255,7 @@ class TestCaseFunction(Function):
@hookimpl(tryfirst=True)
def pytest_runtest_makereport(item: Item, call: CallInfo) -> None:
def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> None:
if isinstance(item, TestCaseFunction):
if item._excinfo:
call.excinfo = item._excinfo.pop(0)
@ -272,9 +272,10 @@ def pytest_runtest_makereport(item: Item, call: CallInfo) -> None:
unittest.SkipTest # type: ignore[attr-defined] # noqa: F821
)
):
excinfo = call.excinfo
# let's substitute the excinfo with a pytest.skip one
call2 = CallInfo.from_call(
lambda: pytest.skip(str(call.excinfo.value)), call.when
call2 = CallInfo[None].from_call(
lambda: pytest.skip(str(excinfo.value)), call.when
)
call.excinfo = call2.excinfo

View File

@ -465,27 +465,27 @@ def test_report_extra_parameters(reporttype: "Type[reports.BaseReport]") -> None
def test_callinfo() -> None:
ci = runner.CallInfo.from_call(lambda: 0, "123")
assert ci.when == "123"
ci = runner.CallInfo.from_call(lambda: 0, "collect")
assert ci.when == "collect"
assert ci.result == 0
assert "result" in repr(ci)
assert repr(ci) == "<CallInfo when='123' result: 0>"
assert str(ci) == "<CallInfo when='123' result: 0>"
assert repr(ci) == "<CallInfo when='collect' result: 0>"
assert str(ci) == "<CallInfo when='collect' result: 0>"
ci = runner.CallInfo.from_call(lambda: 0 / 0, "123")
assert ci.when == "123"
assert not hasattr(ci, "result")
assert repr(ci) == "<CallInfo when='123' excinfo={!r}>".format(ci.excinfo)
assert str(ci) == repr(ci)
assert ci.excinfo
ci2 = runner.CallInfo.from_call(lambda: 0 / 0, "collect")
assert ci2.when == "collect"
assert not hasattr(ci2, "result")
assert repr(ci2) == "<CallInfo when='collect' excinfo={!r}>".format(ci2.excinfo)
assert str(ci2) == repr(ci2)
assert ci2.excinfo
# Newlines are escaped.
def raise_assertion():
assert 0, "assert_msg"
ci = runner.CallInfo.from_call(raise_assertion, "call")
assert repr(ci) == "<CallInfo when='call' excinfo={!r}>".format(ci.excinfo)
assert "\n" not in repr(ci)
ci3 = runner.CallInfo.from_call(raise_assertion, "call")
assert repr(ci3) == "<CallInfo when='call' excinfo={!r}>".format(ci3.excinfo)
assert "\n" not in repr(ci3)
# design question: do we want general hooks in python files?