From 90e58f89615327d78a0c25d148321edb296ca982 Mon Sep 17 00:00:00 2001 From: Ran Benita <ran@unusedvar.com> Date: Fri, 1 May 2020 14:40:16 +0300 Subject: [PATCH] Type annotate some parts related to runner & reports --- src/_pytest/cacheprovider.py | 2 +- src/_pytest/helpconfig.py | 3 +- src/_pytest/hookspec.py | 17 +++++--- src/_pytest/main.py | 4 +- src/_pytest/reports.py | 79 ++++++++++++++++++++++-------------- src/_pytest/resultlog.py | 5 ++- src/_pytest/runner.py | 49 ++++++++++++++-------- src/_pytest/skipping.py | 9 +++- src/_pytest/terminal.py | 15 +++---- src/_pytest/unittest.py | 7 ++-- testing/test_runner.py | 26 ++++++------ 11 files changed, 132 insertions(+), 84 deletions(-) diff --git a/src/_pytest/cacheprovider.py b/src/_pytest/cacheprovider.py index bb08c5a6e..af7d57a24 100755 --- a/src/_pytest/cacheprovider.py +++ b/src/_pytest/cacheprovider.py @@ -278,7 +278,7 @@ class LFPlugin: elif report.failed: self.lastfailed[report.nodeid] = True - def pytest_collectreport(self, report): + def pytest_collectreport(self, report: CollectReport) -> None: passed = report.outcome in ("passed", "skipped") if passed: if report.nodeid in self.lastfailed: diff --git a/src/_pytest/helpconfig.py b/src/_pytest/helpconfig.py index c2519c8af..06e0954cf 100644 --- a/src/_pytest/helpconfig.py +++ b/src/_pytest/helpconfig.py @@ -2,6 +2,7 @@ import os import sys from argparse import Action +from typing import List from typing import Optional from typing import Union @@ -235,7 +236,7 @@ def getpluginversioninfo(config): return lines -def pytest_report_header(config): +def pytest_report_header(config: Config) -> List[str]: lines = [] if config.option.debug or config.option.traceconfig: lines.append( diff --git a/src/_pytest/hookspec.py b/src/_pytest/hookspec.py index c5d5bdedd..99f646bd6 100644 --- a/src/_pytest/hookspec.py +++ b/src/_pytest/hookspec.py @@ -3,6 +3,7 @@ from typing import Any from typing import List from typing import Mapping from typing import Optional +from typing import Sequence from typing import Tuple from typing import Union @@ -284,7 +285,7 @@ def pytest_itemcollected(item): """ we just collected a test item. """ -def pytest_collectreport(report): +def pytest_collectreport(report: "CollectReport") -> None: """ collector finished collecting. """ @@ -430,7 +431,7 @@ def pytest_runtest_teardown(item: "Item", nextitem: "Optional[Item]") -> None: @hookspec(firstresult=True) -def pytest_runtest_makereport(item: "Item", call: "CallInfo") -> Optional[object]: +def pytest_runtest_makereport(item: "Item", call: "CallInfo[None]") -> Optional[object]: """ return a :py:class:`_pytest.runner.TestReport` object for the given :py:class:`pytest.Item <_pytest.main.Item>` and :py:class:`_pytest.runner.CallInfo`. @@ -444,7 +445,7 @@ def pytest_runtest_logreport(report: "TestReport") -> None: @hookspec(firstresult=True) -def pytest_report_to_serializable(config: "Config", report): +def pytest_report_to_serializable(config: "Config", report: "BaseReport"): """ Serializes the given report object into a data structure suitable for sending over the wire, e.g. converted to JSON. @@ -580,7 +581,9 @@ def pytest_assertion_pass(item, lineno: int, orig: str, expl: str) -> None: # ------------------------------------------------------------------------- -def pytest_report_header(config: "Config", startdir): +def pytest_report_header( + config: "Config", startdir: py.path.local +) -> Union[str, List[str]]: """ return a string or list of strings to be displayed as header info for terminal reporting. :param _pytest.config.Config config: pytest config object @@ -601,7 +604,9 @@ def pytest_report_header(config: "Config", startdir): """ -def pytest_report_collectionfinish(config: "Config", startdir, items): +def pytest_report_collectionfinish( + config: "Config", startdir: py.path.local, items: "Sequence[Item]" +) -> Union[str, List[str]]: """ .. versionadded:: 3.2 @@ -758,7 +763,7 @@ def pytest_keyboard_interrupt(excinfo): def pytest_exception_interact( - node: "Node", call: "CallInfo", report: "BaseReport" + node: "Node", call: "CallInfo[object]", report: "Union[CollectReport, TestReport]" ) -> None: """called when an exception was raised which can potentially be interactively handled. diff --git a/src/_pytest/main.py b/src/_pytest/main.py index a80097f5a..1c1cda18b 100644 --- a/src/_pytest/main.py +++ b/src/_pytest/main.py @@ -442,7 +442,9 @@ class Session(nodes.FSCollector): raise self.Interrupted(self.shouldstop) @hookimpl(tryfirst=True) - def pytest_runtest_logreport(self, report: TestReport) -> None: + def pytest_runtest_logreport( + self, report: Union[TestReport, CollectReport] + ) -> None: if report.failed and not hasattr(report, "wasxfail"): self.testsfailed += 1 maxfail = self.config.getvalue("maxfail") diff --git a/src/_pytest/reports.py b/src/_pytest/reports.py index 9763cb4ad..7462cea0b 100644 --- a/src/_pytest/reports.py +++ b/src/_pytest/reports.py @@ -1,9 +1,12 @@ from io import StringIO from pprint import pprint from typing import Any +from typing import Iterable +from typing import Iterator from typing import List from typing import Optional from typing import Tuple +from typing import TypeVar from typing import Union import attr @@ -21,12 +24,17 @@ from _pytest._code.code import ReprTraceback from _pytest._code.code import TerminalRepr from _pytest._io import TerminalWriter from _pytest.compat import TYPE_CHECKING +from _pytest.config import Config from _pytest.nodes import Collector from _pytest.nodes import Item from _pytest.outcomes import skip from _pytest.pathlib import Path if TYPE_CHECKING: + from typing import NoReturn + from typing_extensions import Type + from typing_extensions import Literal + from _pytest.runner import CallInfo @@ -42,6 +50,9 @@ def getslaveinfoline(node): return s +_R = TypeVar("_R", bound="BaseReport") + + class BaseReport: when = None # type: Optional[str] location = None # type: Optional[Tuple[str, Optional[int], str]] @@ -74,13 +85,13 @@ class BaseReport: except UnicodeEncodeError: out.line("<unprintable longrepr>") - def get_sections(self, prefix): + def get_sections(self, prefix: str) -> Iterator[Tuple[str, str]]: for name, content in self.sections: if name.startswith(prefix): yield prefix, content @property - def longreprtext(self): + def longreprtext(self) -> str: """ Read-only property that returns the full string representation of ``longrepr``. @@ -95,7 +106,7 @@ class BaseReport: return exc.strip() @property - def caplog(self): + def caplog(self) -> str: """Return captured log lines, if log capturing is enabled .. versionadded:: 3.5 @@ -105,7 +116,7 @@ class BaseReport: ) @property - def capstdout(self): + def capstdout(self) -> str: """Return captured text from stdout, if capturing is enabled .. versionadded:: 3.0 @@ -115,7 +126,7 @@ class BaseReport: ) @property - def capstderr(self): + def capstderr(self) -> str: """Return captured text from stderr, if capturing is enabled .. versionadded:: 3.0 @@ -133,7 +144,7 @@ class BaseReport: return self.nodeid.split("::")[0] @property - def count_towards_summary(self): + def count_towards_summary(self) -> bool: """ **Experimental** @@ -148,7 +159,7 @@ class BaseReport: return True @property - def head_line(self): + def head_line(self) -> Optional[str]: """ **Experimental** @@ -168,8 +179,9 @@ class BaseReport: if self.location is not None: fspath, lineno, domain = self.location return domain + return None - def _get_verbose_word(self, config): + def _get_verbose_word(self, config: Config): _category, _short, verbose = config.hook.pytest_report_teststatus( report=self, config=config ) @@ -187,7 +199,7 @@ class BaseReport: return _report_to_json(self) @classmethod - def _from_json(cls, reportdict): + def _from_json(cls: "Type[_R]", reportdict) -> _R: """ This was originally the serialize_report() function from xdist (ca03269). @@ -200,7 +212,9 @@ class BaseReport: return cls(**kwargs) -def _report_unserialization_failure(type_name, report_class, reportdict): +def _report_unserialization_failure( + type_name: str, report_class: "Type[BaseReport]", reportdict +) -> "NoReturn": url = "https://github.com/pytest-dev/pytest/issues" stream = StringIO() pprint("-" * 100, stream=stream) @@ -221,15 +235,15 @@ class TestReport(BaseReport): def __init__( self, - nodeid, + nodeid: str, location: Tuple[str, Optional[int], str], keywords, - outcome, + outcome: "Literal['passed', 'failed', 'skipped']", longrepr, - when, - sections=(), - duration=0, - user_properties=None, + when: "Literal['setup', 'call', 'teardown']", + sections: Iterable[Tuple[str, str]] = (), + duration: float = 0, + user_properties: Optional[Iterable[Tuple[str, object]]] = None, **extra ) -> None: #: normalized collection node id @@ -268,23 +282,25 @@ class TestReport(BaseReport): self.__dict__.update(extra) - def __repr__(self): + def __repr__(self) -> str: return "<{} {!r} when={!r} outcome={!r}>".format( self.__class__.__name__, self.nodeid, self.when, self.outcome ) @classmethod - def from_item_and_call(cls, item: Item, call: "CallInfo") -> "TestReport": + def from_item_and_call(cls, item: Item, call: "CallInfo[None]") -> "TestReport": """ Factory method to create and fill a TestReport with standard item and call info. """ when = call.when + # Remove "collect" from the Literal type -- only for collection calls. + assert when != "collect" duration = call.duration keywords = {x: 1 for x in item.keywords} excinfo = call.excinfo sections = [] if not call.excinfo: - outcome = "passed" + outcome = "passed" # type: Literal["passed", "failed", "skipped"] # TODO: Improve this Any. longrepr = None # type: Optional[Any] else: @@ -324,10 +340,10 @@ class CollectReport(BaseReport): def __init__( self, nodeid: str, - outcome, + outcome: "Literal['passed', 'skipped', 'failed']", longrepr, result: Optional[List[Union[Item, Collector]]], - sections=(), + sections: Iterable[Tuple[str, str]] = (), **extra ) -> None: self.nodeid = nodeid @@ -341,28 +357,29 @@ class CollectReport(BaseReport): def location(self): return (self.fspath, None, self.fspath) - def __repr__(self): + def __repr__(self) -> str: return "<CollectReport {!r} lenresult={} outcome={!r}>".format( self.nodeid, len(self.result), self.outcome ) class CollectErrorRepr(TerminalRepr): - def __init__(self, msg): + def __init__(self, msg) -> None: self.longrepr = msg def toterminal(self, out) -> None: out.line(self.longrepr, red=True) -def pytest_report_to_serializable(report): +def pytest_report_to_serializable(report: BaseReport): if isinstance(report, (TestReport, CollectReport)): data = report._to_json() data["$report_type"] = report.__class__.__name__ return data + return None -def pytest_report_from_serializable(data): +def pytest_report_from_serializable(data) -> Optional[BaseReport]: if "$report_type" in data: if data["$report_type"] == "TestReport": return TestReport._from_json(data) @@ -371,9 +388,10 @@ def pytest_report_from_serializable(data): assert False, "Unknown report_type unserialize data: {}".format( data["$report_type"] ) + return None -def _report_to_json(report): +def _report_to_json(report: BaseReport): """ This was originally the serialize_report() function from xdist (ca03269). @@ -381,11 +399,12 @@ def _report_to_json(report): serialization. """ - def serialize_repr_entry(entry): - entry_data = {"type": type(entry).__name__, "data": attr.asdict(entry)} - for key, value in entry_data["data"].items(): + def serialize_repr_entry(entry: Union[ReprEntry, ReprEntryNative]): + data = attr.asdict(entry) + for key, value in data.items(): if hasattr(value, "__dict__"): - entry_data["data"][key] = attr.asdict(value) + data[key] = attr.asdict(value) + entry_data = {"type": type(entry).__name__, "data": data} return entry_data def serialize_repr_traceback(reprtraceback: ReprTraceback): diff --git a/src/_pytest/resultlog.py b/src/_pytest/resultlog.py index 720ea9f49..c2b0cf556 100644 --- a/src/_pytest/resultlog.py +++ b/src/_pytest/resultlog.py @@ -7,6 +7,7 @@ import py from _pytest.config import Config from _pytest.config.argparsing import Parser +from _pytest.reports import CollectReport from _pytest.reports import TestReport from _pytest.store import StoreKey @@ -87,7 +88,7 @@ class ResultLog: longrepr = str(report.longrepr) self.log_outcome(report, code, longrepr) - def pytest_collectreport(self, report): + def pytest_collectreport(self, report: CollectReport) -> None: if not report.passed: if report.failed: code = "F" @@ -95,7 +96,7 @@ class ResultLog: else: assert report.skipped code = "S" - longrepr = "%s:%d: %s" % report.longrepr + longrepr = "%s:%d: %s" % report.longrepr # type: ignore self.log_outcome(report, code, longrepr) def pytest_internalerror(self, excrepr): diff --git a/src/_pytest/runner.py b/src/_pytest/runner.py index 568065d94..f89b67399 100644 --- a/src/_pytest/runner.py +++ b/src/_pytest/runner.py @@ -3,10 +3,14 @@ import bdb import os import sys from typing import Callable +from typing import cast from typing import Dict +from typing import Generic from typing import List from typing import Optional from typing import Tuple +from typing import TypeVar +from typing import Union import attr @@ -179,7 +183,7 @@ def _update_current_test_var( os.environ.pop(var_name) -def pytest_report_teststatus(report): +def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]: if report.when in ("setup", "teardown"): if report.failed: # category, shortletter, verbose-word @@ -188,6 +192,7 @@ def pytest_report_teststatus(report): return "skipped", "s", "SKIPPED" else: return "", "", "" + return None # @@ -217,9 +222,9 @@ def check_interactive_exception(call: "CallInfo", report: BaseReport) -> bool: def call_runtest_hook( item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds -) -> "CallInfo": +) -> "CallInfo[None]": if when == "setup": - ihook = item.ihook.pytest_runtest_setup + ihook = item.ihook.pytest_runtest_setup # type: Callable[..., None] elif when == "call": ihook = item.ihook.pytest_runtest_call elif when == "teardown": @@ -234,11 +239,14 @@ def call_runtest_hook( ) +_T = TypeVar("_T") + + @attr.s(repr=False) -class CallInfo: +class CallInfo(Generic[_T]): """ Result/Exception info a function invocation. - :param result: The return value of the call, if it didn't raise. Can only be accessed + :param T result: The return value of the call, if it didn't raise. Can only be accessed if excinfo is None. :param Optional[ExceptionInfo] excinfo: The captured exception of the call, if it raised. :param float start: The system time when the call started, in seconds since the epoch. @@ -247,28 +255,34 @@ class CallInfo: :param str when: The context of invocation: "setup", "call", "teardown", ... """ - _result = attr.ib() + _result = attr.ib(type="Optional[_T]") excinfo = attr.ib(type=Optional[ExceptionInfo]) start = attr.ib(type=float) stop = attr.ib(type=float) duration = attr.ib(type=float) - when = attr.ib(type=str) + when = attr.ib(type="Literal['collect', 'setup', 'call', 'teardown']") @property - def result(self): + def result(self) -> _T: if self.excinfo is not None: raise AttributeError("{!r} has no valid result".format(self)) - return self._result + # The cast is safe because an exception wasn't raised, hence + # _result has the expected function return type (which may be + # None, that's why a cast and not an assert). + return cast(_T, self._result) @classmethod - def from_call(cls, func, when, reraise=None) -> "CallInfo": - #: context of invocation: one of "setup", "call", - #: "teardown", "memocollect" + def from_call( + cls, + func: "Callable[[], _T]", + when: "Literal['collect', 'setup', 'call', 'teardown']", + reraise: "Optional[Union[Type[BaseException], Tuple[Type[BaseException], ...]]]" = None, + ) -> "CallInfo[_T]": excinfo = None start = timing.time() precise_start = timing.perf_counter() try: - result = func() + result = func() # type: Optional[_T] except BaseException: excinfo = ExceptionInfo.from_current() if reraise is not None and excinfo.errisinstance(reraise): @@ -293,7 +307,7 @@ class CallInfo: return "<CallInfo when={!r} excinfo={!r}>".format(self.when, self.excinfo) -def pytest_runtest_makereport(item: Item, call: CallInfo) -> TestReport: +def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport: return TestReport.from_item_and_call(item, call) @@ -301,7 +315,7 @@ def pytest_make_collect_report(collector: Collector) -> CollectReport: call = CallInfo.from_call(lambda: list(collector.collect()), "collect") longrepr = None if not call.excinfo: - outcome = "passed" + outcome = "passed" # type: Literal["passed", "skipped", "failed"] else: skip_exceptions = [Skipped] unittest = sys.modules.get("unittest") @@ -321,9 +335,8 @@ def pytest_make_collect_report(collector: Collector) -> CollectReport: if not hasattr(errorinfo, "toterminal"): errorinfo = CollectErrorRepr(errorinfo) longrepr = errorinfo - rep = CollectReport( - collector.nodeid, outcome, longrepr, getattr(call, "result", None) - ) + result = call.result if not call.excinfo else None + rep = CollectReport(collector.nodeid, outcome, longrepr, result) rep.call = call # type: ignore # see collect_one_node return rep diff --git a/src/_pytest/skipping.py b/src/_pytest/skipping.py index 5994b5b2f..54621f111 100644 --- a/src/_pytest/skipping.py +++ b/src/_pytest/skipping.py @@ -1,4 +1,7 @@ """ support for skip/xfail functions and markers. """ +from typing import Optional +from typing import Tuple + from _pytest.config import Config from _pytest.config import hookimpl from _pytest.config.argparsing import Parser @@ -8,6 +11,7 @@ from _pytest.outcomes import fail from _pytest.outcomes import skip from _pytest.outcomes import xfail from _pytest.python import Function +from _pytest.reports import BaseReport from _pytest.runner import CallInfo from _pytest.store import StoreKey @@ -129,7 +133,7 @@ def check_strict_xfail(pyfuncitem: Function) -> None: @hookimpl(hookwrapper=True) -def pytest_runtest_makereport(item: Item, call: CallInfo): +def pytest_runtest_makereport(item: Item, call: CallInfo[None]): outcome = yield rep = outcome.get_result() evalxfail = item._store.get(evalxfail_key, None) @@ -181,9 +185,10 @@ def pytest_runtest_makereport(item: Item, call: CallInfo): # called by terminalreporter progress reporting -def pytest_report_teststatus(report): +def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]: if hasattr(report, "wasxfail"): if report.skipped: return "xfailed", "x", "XFAIL" elif report.passed: return "xpassed", "X", "XPASS" + return None diff --git a/src/_pytest/terminal.py b/src/_pytest/terminal.py index bc2b5bf23..1b9601a22 100644 --- a/src/_pytest/terminal.py +++ b/src/_pytest/terminal.py @@ -37,6 +37,7 @@ from _pytest.config import Config from _pytest.config import ExitCode from _pytest.config.argparsing import Parser from _pytest.deprecated import TERMINALWRITER_WRITER +from _pytest.reports import BaseReport from _pytest.reports import CollectReport from _pytest.reports import TestReport @@ -218,14 +219,14 @@ def getreportopt(config: Config) -> str: @pytest.hookimpl(trylast=True) # after _pytest.runner -def pytest_report_teststatus(report: TestReport) -> Tuple[str, str, str]: +def pytest_report_teststatus(report: BaseReport) -> Tuple[str, str, str]: letter = "F" if report.passed: letter = "." elif report.skipped: letter = "s" - outcome = report.outcome + outcome = report.outcome # type: str if report.when in ("collect", "setup", "teardown") and outcome == "failed": outcome = "error" letter = "E" @@ -364,7 +365,7 @@ class TerminalReporter: self._tw.write(extra, **kwargs) self.currentfspath = -2 - def ensure_newline(self): + def ensure_newline(self) -> None: if self.currentfspath: self._tw.line() self.currentfspath = None @@ -375,7 +376,7 @@ class TerminalReporter: def flush(self) -> None: self._tw.flush() - def write_line(self, line, **markup): + def write_line(self, line: Union[str, bytes], **markup) -> None: if not isinstance(line, str): line = str(line, errors="replace") self.ensure_newline() @@ -642,12 +643,12 @@ class TerminalReporter: ) self._write_report_lines_from_hooks(lines) - def _write_report_lines_from_hooks(self, lines): + def _write_report_lines_from_hooks(self, lines) -> None: lines.reverse() for line in collapse(lines): self.write_line(line) - def pytest_report_header(self, config): + def pytest_report_header(self, config: Config) -> List[str]: line = "rootdir: %s" % config.rootdir if config.inifile: @@ -664,7 +665,7 @@ class TerminalReporter: result.append("plugins: %s" % ", ".join(_plugin_nameversions(plugininfo))) return result - def pytest_collection_finish(self, session): + def pytest_collection_finish(self, session: "Session") -> None: self.report_collect(True) lines = self.config.hook.pytest_report_collectionfinish( diff --git a/src/_pytest/unittest.py b/src/_pytest/unittest.py index 3fbf7c88d..f9eb6e719 100644 --- a/src/_pytest/unittest.py +++ b/src/_pytest/unittest.py @@ -255,7 +255,7 @@ class TestCaseFunction(Function): @hookimpl(tryfirst=True) -def pytest_runtest_makereport(item: Item, call: CallInfo) -> None: +def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> None: if isinstance(item, TestCaseFunction): if item._excinfo: call.excinfo = item._excinfo.pop(0) @@ -272,9 +272,10 @@ def pytest_runtest_makereport(item: Item, call: CallInfo) -> None: unittest.SkipTest # type: ignore[attr-defined] # noqa: F821 ) ): + excinfo = call.excinfo # let's substitute the excinfo with a pytest.skip one - call2 = CallInfo.from_call( - lambda: pytest.skip(str(call.excinfo.value)), call.when + call2 = CallInfo[None].from_call( + lambda: pytest.skip(str(excinfo.value)), call.when ) call.excinfo = call2.excinfo diff --git a/testing/test_runner.py b/testing/test_runner.py index be79b14fd..9c19ded0e 100644 --- a/testing/test_runner.py +++ b/testing/test_runner.py @@ -465,27 +465,27 @@ def test_report_extra_parameters(reporttype: "Type[reports.BaseReport]") -> None def test_callinfo() -> None: - ci = runner.CallInfo.from_call(lambda: 0, "123") - assert ci.when == "123" + ci = runner.CallInfo.from_call(lambda: 0, "collect") + assert ci.when == "collect" assert ci.result == 0 assert "result" in repr(ci) - assert repr(ci) == "<CallInfo when='123' result: 0>" - assert str(ci) == "<CallInfo when='123' result: 0>" + assert repr(ci) == "<CallInfo when='collect' result: 0>" + assert str(ci) == "<CallInfo when='collect' result: 0>" - ci = runner.CallInfo.from_call(lambda: 0 / 0, "123") - assert ci.when == "123" - assert not hasattr(ci, "result") - assert repr(ci) == "<CallInfo when='123' excinfo={!r}>".format(ci.excinfo) - assert str(ci) == repr(ci) - assert ci.excinfo + ci2 = runner.CallInfo.from_call(lambda: 0 / 0, "collect") + assert ci2.when == "collect" + assert not hasattr(ci2, "result") + assert repr(ci2) == "<CallInfo when='collect' excinfo={!r}>".format(ci2.excinfo) + assert str(ci2) == repr(ci2) + assert ci2.excinfo # Newlines are escaped. def raise_assertion(): assert 0, "assert_msg" - ci = runner.CallInfo.from_call(raise_assertion, "call") - assert repr(ci) == "<CallInfo when='call' excinfo={!r}>".format(ci.excinfo) - assert "\n" not in repr(ci) + ci3 = runner.CallInfo.from_call(raise_assertion, "call") + assert repr(ci3) == "<CallInfo when='call' excinfo={!r}>".format(ci3.excinfo) + assert "\n" not in repr(ci3) # design question: do we want general hooks in python files?