Type annotate _pytest.junitxml

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent 3e351afeb3
commit 216a010ab7
2 changed files with 84 additions and 66 deletions

View File

@ -14,6 +14,11 @@ import platform
import re import re
import sys import sys
from datetime import datetime from datetime import datetime
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import py import py
@ -21,14 +26,19 @@ import pytest
from _pytest import deprecated from _pytest import deprecated
from _pytest import nodes from _pytest import nodes
from _pytest import timing from _pytest import timing
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import filename_arg from _pytest.config import filename_arg
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.reports import TestReport from _pytest.reports import TestReport
from _pytest.store import StoreKey from _pytest.store import StoreKey
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
from _pytest.warnings import _issue_warning_captured from _pytest.warnings import _issue_warning_captured
if TYPE_CHECKING:
from typing import Type
xml_key = StoreKey["LogXML"]() xml_key = StoreKey["LogXML"]()
@ -58,8 +68,8 @@ del _legal_xml_re
_py_ext_re = re.compile(r"\.py$") _py_ext_re = re.compile(r"\.py$")
def bin_xml_escape(arg): def bin_xml_escape(arg: str) -> py.xml.raw:
def repl(matchobj): def repl(matchobj: "re.Match[str]") -> str:
i = ord(matchobj.group()) i = ord(matchobj.group())
if i <= 0xFF: if i <= 0xFF:
return "#x%02X" % i return "#x%02X" % i
@ -69,7 +79,7 @@ def bin_xml_escape(arg):
return py.xml.raw(illegal_xml_re.sub(repl, py.xml.escape(arg))) return py.xml.raw(illegal_xml_re.sub(repl, py.xml.escape(arg)))
def merge_family(left, right): def merge_family(left, right) -> None:
result = {} result = {}
for kl, vl in left.items(): for kl, vl in left.items():
for kr, vr in right.items(): for kr, vr in right.items():
@ -92,28 +102,27 @@ families["xunit2"] = families["_base"]
class _NodeReporter: class _NodeReporter:
def __init__(self, nodeid, xml): def __init__(self, nodeid: Union[str, TestReport], xml: "LogXML") -> None:
self.id = nodeid self.id = nodeid
self.xml = xml self.xml = xml
self.add_stats = self.xml.add_stats self.add_stats = self.xml.add_stats
self.family = self.xml.family self.family = self.xml.family
self.duration = 0 self.duration = 0
self.properties = [] self.properties = [] # type: List[Tuple[str, py.xml.raw]]
self.nodes = [] self.nodes = [] # type: List[py.xml.Tag]
self.testcase = None self.attrs = {} # type: Dict[str, Union[str, py.xml.raw]]
self.attrs = {}
def append(self, node): def append(self, node: py.xml.Tag) -> None:
self.xml.add_stats(type(node).__name__) self.xml.add_stats(type(node).__name__)
self.nodes.append(node) self.nodes.append(node)
def add_property(self, name, value): def add_property(self, name: str, value: str) -> None:
self.properties.append((str(name), bin_xml_escape(value))) self.properties.append((str(name), bin_xml_escape(value)))
def add_attribute(self, name, value): def add_attribute(self, name: str, value: str) -> None:
self.attrs[str(name)] = bin_xml_escape(value) self.attrs[str(name)] = bin_xml_escape(value)
def make_properties_node(self): def make_properties_node(self) -> Union[py.xml.Tag, str]:
"""Return a Junit node containing custom properties, if any. """Return a Junit node containing custom properties, if any.
""" """
if self.properties: if self.properties:
@ -125,8 +134,7 @@ class _NodeReporter:
) )
return "" return ""
def record_testreport(self, testreport): def record_testreport(self, testreport: TestReport) -> None:
assert not self.testcase
names = mangle_test_address(testreport.nodeid) names = mangle_test_address(testreport.nodeid)
existing_attrs = self.attrs existing_attrs = self.attrs
classnames = names[:-1] classnames = names[:-1]
@ -136,9 +144,9 @@ class _NodeReporter:
"classname": ".".join(classnames), "classname": ".".join(classnames),
"name": bin_xml_escape(names[-1]), "name": bin_xml_escape(names[-1]),
"file": testreport.location[0], "file": testreport.location[0],
} } # type: Dict[str, Union[str, py.xml.raw]]
if testreport.location[1] is not None: if testreport.location[1] is not None:
attrs["line"] = testreport.location[1] attrs["line"] = str(testreport.location[1])
if hasattr(testreport, "url"): if hasattr(testreport, "url"):
attrs["url"] = testreport.url attrs["url"] = testreport.url
self.attrs = attrs self.attrs = attrs
@ -156,19 +164,19 @@ class _NodeReporter:
temp_attrs[key] = self.attrs[key] temp_attrs[key] = self.attrs[key]
self.attrs = temp_attrs self.attrs = temp_attrs
def to_xml(self): def to_xml(self) -> py.xml.Tag:
testcase = Junit.testcase(time="%.3f" % self.duration, **self.attrs) testcase = Junit.testcase(time="%.3f" % self.duration, **self.attrs)
testcase.append(self.make_properties_node()) testcase.append(self.make_properties_node())
for node in self.nodes: for node in self.nodes:
testcase.append(node) testcase.append(node)
return testcase return testcase
def _add_simple(self, kind, message, data=None): def _add_simple(self, kind: "Type[py.xml.Tag]", message: str, data=None) -> None:
data = bin_xml_escape(data) data = bin_xml_escape(data)
node = kind(data, message=message) node = kind(data, message=message)
self.append(node) self.append(node)
def write_captured_output(self, report): def write_captured_output(self, report: TestReport) -> None:
if not self.xml.log_passing_tests and report.passed: if not self.xml.log_passing_tests and report.passed:
return return
@ -191,21 +199,22 @@ class _NodeReporter:
if content_all: if content_all:
self._write_content(report, content_all, "system-out") self._write_content(report, content_all, "system-out")
def _prepare_content(self, content, header): def _prepare_content(self, content: str, header: str) -> str:
return "\n".join([header.center(80, "-"), content, ""]) return "\n".join([header.center(80, "-"), content, ""])
def _write_content(self, report, content, jheader): def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
tag = getattr(Junit, jheader) tag = getattr(Junit, jheader)
self.append(tag(bin_xml_escape(content))) self.append(tag(bin_xml_escape(content)))
def append_pass(self, report): def append_pass(self, report: TestReport) -> None:
self.add_stats("passed") self.add_stats("passed")
def append_failure(self, report): def append_failure(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline) # msg = str(report.longrepr.reprtraceback.extraline)
if hasattr(report, "wasxfail"): if hasattr(report, "wasxfail"):
self._add_simple(Junit.skipped, "xfail-marked test passes unexpectedly") self._add_simple(Junit.skipped, "xfail-marked test passes unexpectedly")
else: else:
assert report.longrepr is not None
if getattr(report.longrepr, "reprcrash", None) is not None: if getattr(report.longrepr, "reprcrash", None) is not None:
message = report.longrepr.reprcrash.message message = report.longrepr.reprcrash.message
else: else:
@ -215,23 +224,24 @@ class _NodeReporter:
fail.append(bin_xml_escape(report.longrepr)) fail.append(bin_xml_escape(report.longrepr))
self.append(fail) self.append(fail)
def append_collect_error(self, report): def append_collect_error(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline) # msg = str(report.longrepr.reprtraceback.extraline)
assert report.longrepr is not None
self.append( self.append(
Junit.error(bin_xml_escape(report.longrepr), message="collection failure") Junit.error(bin_xml_escape(report.longrepr), message="collection failure")
) )
def append_collect_skipped(self, report): def append_collect_skipped(self, report: TestReport) -> None:
self._add_simple(Junit.skipped, "collection skipped", report.longrepr) self._add_simple(Junit.skipped, "collection skipped", report.longrepr)
def append_error(self, report): def append_error(self, report: TestReport) -> None:
if report.when == "teardown": if report.when == "teardown":
msg = "test teardown failure" msg = "test teardown failure"
else: else:
msg = "test setup failure" msg = "test setup failure"
self._add_simple(Junit.error, msg, report.longrepr) self._add_simple(Junit.error, msg, report.longrepr)
def append_skipped(self, report): def append_skipped(self, report: TestReport) -> None:
if hasattr(report, "wasxfail"): if hasattr(report, "wasxfail"):
xfailreason = report.wasxfail xfailreason = report.wasxfail
if xfailreason.startswith("reason: "): if xfailreason.startswith("reason: "):
@ -242,6 +252,7 @@ class _NodeReporter:
) )
) )
else: else:
assert report.longrepr is not None
filename, lineno, skipreason = report.longrepr filename, lineno, skipreason = report.longrepr
if skipreason.startswith("Skipped: "): if skipreason.startswith("Skipped: "):
skipreason = skipreason[9:] skipreason = skipreason[9:]
@ -256,13 +267,17 @@ class _NodeReporter:
) )
self.write_captured_output(report) self.write_captured_output(report)
def finalize(self): def finalize(self) -> None:
data = self.to_xml().unicode(indent=0) data = self.to_xml().unicode(indent=0)
self.__dict__.clear() self.__dict__.clear()
self.to_xml = lambda: py.xml.raw(data) # Type ignored becuase mypy doesn't like overriding a method.
# Also the return value doesn't match...
self.to_xml = lambda: py.xml.raw(data) # type: ignore # noqa: F821
def _warn_incompatibility_with_xunit2(request, fixture_name): def _warn_incompatibility_with_xunit2(
request: FixtureRequest, fixture_name: str
) -> None:
"""Emits a PytestWarning about the given fixture being incompatible with newer xunit revisions""" """Emits a PytestWarning about the given fixture being incompatible with newer xunit revisions"""
from _pytest.warning_types import PytestWarning from _pytest.warning_types import PytestWarning
@ -278,7 +293,7 @@ def _warn_incompatibility_with_xunit2(request, fixture_name):
@pytest.fixture @pytest.fixture
def record_property(request): def record_property(request: FixtureRequest):
"""Add an extra properties the calling test. """Add an extra properties the calling test.
User properties become part of the test report and are available to the User properties become part of the test report and are available to the
configured reporters, like JUnit XML. configured reporters, like JUnit XML.
@ -292,14 +307,14 @@ def record_property(request):
""" """
_warn_incompatibility_with_xunit2(request, "record_property") _warn_incompatibility_with_xunit2(request, "record_property")
def append_property(name, value): def append_property(name: str, value: object) -> None:
request.node.user_properties.append((name, value)) request.node.user_properties.append((name, value))
return append_property return append_property
@pytest.fixture @pytest.fixture
def record_xml_attribute(request): def record_xml_attribute(request: FixtureRequest):
"""Add extra xml attributes to the tag for the calling test. """Add extra xml attributes to the tag for the calling test.
The fixture is callable with ``(name, value)``, with value being The fixture is callable with ``(name, value)``, with value being
automatically xml-encoded automatically xml-encoded
@ -313,7 +328,7 @@ def record_xml_attribute(request):
_warn_incompatibility_with_xunit2(request, "record_xml_attribute") _warn_incompatibility_with_xunit2(request, "record_xml_attribute")
# Declare noop # Declare noop
def add_attr_noop(name, value): def add_attr_noop(name: str, value: str) -> None:
pass pass
attr_func = add_attr_noop attr_func = add_attr_noop
@ -326,7 +341,7 @@ def record_xml_attribute(request):
return attr_func return attr_func
def _check_record_param_type(param, v): def _check_record_param_type(param: str, v: str) -> None:
"""Used by record_testsuite_property to check that the given parameter name is of the proper """Used by record_testsuite_property to check that the given parameter name is of the proper
type""" type"""
__tracebackhide__ = True __tracebackhide__ = True
@ -336,7 +351,7 @@ def _check_record_param_type(param, v):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def record_testsuite_property(request): def record_testsuite_property(request: FixtureRequest):
""" """
Records a new ``<property>`` tag as child of the root ``<testsuite>``. This is suitable to Records a new ``<property>`` tag as child of the root ``<testsuite>``. This is suitable to
writing global information regarding the entire test suite, and is compatible with ``xunit2`` JUnit family. writing global information regarding the entire test suite, and is compatible with ``xunit2`` JUnit family.
@ -354,7 +369,7 @@ def record_testsuite_property(request):
__tracebackhide__ = True __tracebackhide__ = True
def record_func(name, value): def record_func(name: str, value: str):
"""noop function in case --junitxml was not passed in the command-line""" """noop function in case --junitxml was not passed in the command-line"""
__tracebackhide__ = True __tracebackhide__ = True
_check_record_param_type("name", name) _check_record_param_type("name", name)
@ -437,7 +452,7 @@ def pytest_unconfigure(config: Config) -> None:
config.pluginmanager.unregister(xml) config.pluginmanager.unregister(xml)
def mangle_test_address(address): def mangle_test_address(address: str) -> List[str]:
path, possible_open_bracket, params = address.partition("[") path, possible_open_bracket, params = address.partition("[")
names = path.split("::") names = path.split("::")
try: try:
@ -456,13 +471,13 @@ class LogXML:
def __init__( def __init__(
self, self,
logfile, logfile,
prefix, prefix: Optional[str],
suite_name="pytest", suite_name: str = "pytest",
logging="no", logging: str = "no",
report_duration="total", report_duration: str = "total",
family="xunit1", family="xunit1",
log_passing_tests=True, log_passing_tests: bool = True,
): ) -> None:
logfile = os.path.expanduser(os.path.expandvars(logfile)) logfile = os.path.expanduser(os.path.expandvars(logfile))
self.logfile = os.path.normpath(os.path.abspath(logfile)) self.logfile = os.path.normpath(os.path.abspath(logfile))
self.prefix = prefix self.prefix = prefix
@ -471,20 +486,24 @@ class LogXML:
self.log_passing_tests = log_passing_tests self.log_passing_tests = log_passing_tests
self.report_duration = report_duration self.report_duration = report_duration
self.family = family self.family = family
self.stats = dict.fromkeys(["error", "passed", "failure", "skipped"], 0) self.stats = dict.fromkeys(
self.node_reporters = {} # nodeid -> _NodeReporter ["error", "passed", "failure", "skipped"], 0
self.node_reporters_ordered = [] ) # type: Dict[str, int]
self.global_properties = [] self.node_reporters = (
{}
) # type: Dict[Tuple[Union[str, TestReport], object], _NodeReporter]
self.node_reporters_ordered = [] # type: List[_NodeReporter]
self.global_properties = [] # type: List[Tuple[str, py.xml.raw]]
# List of reports that failed on call but teardown is pending. # List of reports that failed on call but teardown is pending.
self.open_reports = [] self.open_reports = [] # type: List[TestReport]
self.cnt_double_fail_tests = 0 self.cnt_double_fail_tests = 0
# Replaces convenience family with real family # Replaces convenience family with real family
if self.family == "legacy": if self.family == "legacy":
self.family = "xunit1" self.family = "xunit1"
def finalize(self, report): def finalize(self, report: TestReport) -> None:
nodeid = getattr(report, "nodeid", report) nodeid = getattr(report, "nodeid", report)
# local hack to handle xdist report order # local hack to handle xdist report order
slavenode = getattr(report, "node", None) slavenode = getattr(report, "node", None)
@ -492,8 +511,8 @@ class LogXML:
if reporter is not None: if reporter is not None:
reporter.finalize() reporter.finalize()
def node_reporter(self, report): def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporter:
nodeid = getattr(report, "nodeid", report) nodeid = getattr(report, "nodeid", report) # type: Union[str, TestReport]
# local hack to handle xdist report order # local hack to handle xdist report order
slavenode = getattr(report, "node", None) slavenode = getattr(report, "node", None)
@ -510,11 +529,11 @@ class LogXML:
return reporter return reporter
def add_stats(self, key): def add_stats(self, key: str) -> None:
if key in self.stats: if key in self.stats:
self.stats[key] += 1 self.stats[key] += 1
def _opentestcase(self, report): def _opentestcase(self, report: TestReport) -> _NodeReporter:
reporter = self.node_reporter(report) reporter = self.node_reporter(report)
reporter.record_testreport(report) reporter.record_testreport(report)
return reporter return reporter
@ -587,7 +606,7 @@ class LogXML:
reporter.write_captured_output(report) reporter.write_captured_output(report)
for propname, propvalue in report.user_properties: for propname, propvalue in report.user_properties:
reporter.add_property(propname, propvalue) reporter.add_property(propname, str(propvalue))
self.finalize(report) self.finalize(report)
report_wid = getattr(report, "worker_id", None) report_wid = getattr(report, "worker_id", None)
@ -607,7 +626,7 @@ class LogXML:
if close_report: if close_report:
self.open_reports.remove(close_report) self.open_reports.remove(close_report)
def update_testcase_duration(self, report): def update_testcase_duration(self, report: TestReport) -> None:
"""accumulates total duration for nodeid from given report and updates """accumulates total duration for nodeid from given report and updates
the Junit.testcase with the new total if already created. the Junit.testcase with the new total if already created.
""" """
@ -615,7 +634,7 @@ class LogXML:
reporter = self.node_reporter(report) reporter = self.node_reporter(report)
reporter.duration += getattr(report, "duration", 0.0) reporter.duration += getattr(report, "duration", 0.0)
def pytest_collectreport(self, report): def pytest_collectreport(self, report: TestReport) -> None:
if not report.passed: if not report.passed:
reporter = self._opentestcase(report) reporter = self._opentestcase(report)
if report.failed: if report.failed:
@ -623,7 +642,7 @@ class LogXML:
else: else:
reporter.append_collect_skipped(report) reporter.append_collect_skipped(report)
def pytest_internalerror(self, excrepr): def pytest_internalerror(self, excrepr) -> None:
reporter = self.node_reporter("internal") reporter = self.node_reporter("internal")
reporter.attrs.update(classname="pytest", name="internal") reporter.attrs.update(classname="pytest", name="internal")
reporter._add_simple(Junit.error, "internal error", excrepr) reporter._add_simple(Junit.error, "internal error", excrepr)
@ -652,10 +671,10 @@ class LogXML:
self._get_global_properties_node(), self._get_global_properties_node(),
[x.to_xml() for x in self.node_reporters_ordered], [x.to_xml() for x in self.node_reporters_ordered],
name=self.suite_name, name=self.suite_name,
errors=self.stats["error"], errors=str(self.stats["error"]),
failures=self.stats["failure"], failures=str(self.stats["failure"]),
skipped=self.stats["skipped"], skipped=str(self.stats["skipped"]),
tests=numtests, tests=str(numtests),
time="%.3f" % suite_time_delta, time="%.3f" % suite_time_delta,
timestamp=datetime.fromtimestamp(self.suite_start_time).isoformat(), timestamp=datetime.fromtimestamp(self.suite_start_time).isoformat(),
hostname=platform.node(), hostname=platform.node(),
@ -666,12 +685,12 @@ class LogXML:
def pytest_terminal_summary(self, terminalreporter: TerminalReporter) -> None: def pytest_terminal_summary(self, terminalreporter: TerminalReporter) -> None:
terminalreporter.write_sep("-", "generated xml file: {}".format(self.logfile)) terminalreporter.write_sep("-", "generated xml file: {}".format(self.logfile))
def add_global_property(self, name, value): def add_global_property(self, name: str, value: str) -> None:
__tracebackhide__ = True __tracebackhide__ = True
_check_record_param_type("name", name) _check_record_param_type("name", name)
self.global_properties.append((name, bin_xml_escape(value))) self.global_properties.append((name, bin_xml_escape(value)))
def _get_global_properties_node(self): def _get_global_properties_node(self) -> Union[py.xml.Tag, str]:
"""Return a Junit node containing custom properties, if any. """Return a Junit node containing custom properties, if any.
""" """
if self.global_properties: if self.global_properties:

View File

@ -1,7 +1,6 @@
import os import os
import warnings import warnings
from functools import lru_cache from functools import lru_cache
from typing import Any
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import Iterable from typing import Iterable
@ -618,7 +617,7 @@ class Item(Node):
#: user properties is a list of tuples (name, value) that holds user #: user properties is a list of tuples (name, value) that holds user
#: defined properties for this test. #: defined properties for this test.
self.user_properties = [] # type: List[Tuple[str, Any]] self.user_properties = [] # type: List[Tuple[str, object]]
def runtest(self) -> None: def runtest(self) -> None:
raise NotImplementedError("runtest must be implemented by Item subclass") raise NotImplementedError("runtest must be implemented by Item subclass")