Type annotate _pytest.junitxml

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent 3e351afeb3
commit 216a010ab7
2 changed files with 84 additions and 66 deletions

View File

@ -14,6 +14,11 @@ import platform
import re
import sys
from datetime import datetime
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import py
@ -21,14 +26,19 @@ import pytest
from _pytest import deprecated
from _pytest import nodes
from _pytest import timing
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import filename_arg
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.reports import TestReport
from _pytest.store import StoreKey
from _pytest.terminal import TerminalReporter
from _pytest.warnings import _issue_warning_captured
if TYPE_CHECKING:
from typing import Type
xml_key = StoreKey["LogXML"]()
@ -58,8 +68,8 @@ del _legal_xml_re
_py_ext_re = re.compile(r"\.py$")
def bin_xml_escape(arg):
def repl(matchobj):
def bin_xml_escape(arg: str) -> py.xml.raw:
def repl(matchobj: "re.Match[str]") -> str:
i = ord(matchobj.group())
if i <= 0xFF:
return "#x%02X" % i
@ -69,7 +79,7 @@ def bin_xml_escape(arg):
return py.xml.raw(illegal_xml_re.sub(repl, py.xml.escape(arg)))
def merge_family(left, right):
def merge_family(left, right) -> None:
result = {}
for kl, vl in left.items():
for kr, vr in right.items():
@ -92,28 +102,27 @@ families["xunit2"] = families["_base"]
class _NodeReporter:
def __init__(self, nodeid, xml):
def __init__(self, nodeid: Union[str, TestReport], xml: "LogXML") -> None:
self.id = nodeid
self.xml = xml
self.add_stats = self.xml.add_stats
self.family = self.xml.family
self.duration = 0
self.properties = []
self.nodes = []
self.testcase = None
self.attrs = {}
self.properties = [] # type: List[Tuple[str, py.xml.raw]]
self.nodes = [] # type: List[py.xml.Tag]
self.attrs = {} # type: Dict[str, Union[str, py.xml.raw]]
def append(self, node):
def append(self, node: py.xml.Tag) -> None:
self.xml.add_stats(type(node).__name__)
self.nodes.append(node)
def add_property(self, name, value):
def add_property(self, name: str, value: str) -> None:
self.properties.append((str(name), bin_xml_escape(value)))
def add_attribute(self, name, value):
def add_attribute(self, name: str, value: str) -> None:
self.attrs[str(name)] = bin_xml_escape(value)
def make_properties_node(self):
def make_properties_node(self) -> Union[py.xml.Tag, str]:
"""Return a Junit node containing custom properties, if any.
"""
if self.properties:
@ -125,8 +134,7 @@ class _NodeReporter:
)
return ""
def record_testreport(self, testreport):
assert not self.testcase
def record_testreport(self, testreport: TestReport) -> None:
names = mangle_test_address(testreport.nodeid)
existing_attrs = self.attrs
classnames = names[:-1]
@ -136,9 +144,9 @@ class _NodeReporter:
"classname": ".".join(classnames),
"name": bin_xml_escape(names[-1]),
"file": testreport.location[0],
}
} # type: Dict[str, Union[str, py.xml.raw]]
if testreport.location[1] is not None:
attrs["line"] = testreport.location[1]
attrs["line"] = str(testreport.location[1])
if hasattr(testreport, "url"):
attrs["url"] = testreport.url
self.attrs = attrs
@ -156,19 +164,19 @@ class _NodeReporter:
temp_attrs[key] = self.attrs[key]
self.attrs = temp_attrs
def to_xml(self):
def to_xml(self) -> py.xml.Tag:
testcase = Junit.testcase(time="%.3f" % self.duration, **self.attrs)
testcase.append(self.make_properties_node())
for node in self.nodes:
testcase.append(node)
return testcase
def _add_simple(self, kind, message, data=None):
def _add_simple(self, kind: "Type[py.xml.Tag]", message: str, data=None) -> None:
data = bin_xml_escape(data)
node = kind(data, message=message)
self.append(node)
def write_captured_output(self, report):
def write_captured_output(self, report: TestReport) -> None:
if not self.xml.log_passing_tests and report.passed:
return
@ -191,21 +199,22 @@ class _NodeReporter:
if content_all:
self._write_content(report, content_all, "system-out")
def _prepare_content(self, content, header):
def _prepare_content(self, content: str, header: str) -> str:
return "\n".join([header.center(80, "-"), content, ""])
def _write_content(self, report, content, jheader):
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
tag = getattr(Junit, jheader)
self.append(tag(bin_xml_escape(content)))
def append_pass(self, report):
def append_pass(self, report: TestReport) -> None:
self.add_stats("passed")
def append_failure(self, report):
def append_failure(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
if hasattr(report, "wasxfail"):
self._add_simple(Junit.skipped, "xfail-marked test passes unexpectedly")
else:
assert report.longrepr is not None
if getattr(report.longrepr, "reprcrash", None) is not None:
message = report.longrepr.reprcrash.message
else:
@ -215,23 +224,24 @@ class _NodeReporter:
fail.append(bin_xml_escape(report.longrepr))
self.append(fail)
def append_collect_error(self, report):
def append_collect_error(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
assert report.longrepr is not None
self.append(
Junit.error(bin_xml_escape(report.longrepr), message="collection failure")
)
def append_collect_skipped(self, report):
def append_collect_skipped(self, report: TestReport) -> None:
self._add_simple(Junit.skipped, "collection skipped", report.longrepr)
def append_error(self, report):
def append_error(self, report: TestReport) -> None:
if report.when == "teardown":
msg = "test teardown failure"
else:
msg = "test setup failure"
self._add_simple(Junit.error, msg, report.longrepr)
def append_skipped(self, report):
def append_skipped(self, report: TestReport) -> None:
if hasattr(report, "wasxfail"):
xfailreason = report.wasxfail
if xfailreason.startswith("reason: "):
@ -242,6 +252,7 @@ class _NodeReporter:
)
)
else:
assert report.longrepr is not None
filename, lineno, skipreason = report.longrepr
if skipreason.startswith("Skipped: "):
skipreason = skipreason[9:]
@ -256,13 +267,17 @@ class _NodeReporter:
)
self.write_captured_output(report)
def finalize(self):
def finalize(self) -> None:
data = self.to_xml().unicode(indent=0)
self.__dict__.clear()
self.to_xml = lambda: py.xml.raw(data)
# Type ignored becuase mypy doesn't like overriding a method.
# Also the return value doesn't match...
self.to_xml = lambda: py.xml.raw(data) # type: ignore # noqa: F821
def _warn_incompatibility_with_xunit2(request, fixture_name):
def _warn_incompatibility_with_xunit2(
request: FixtureRequest, fixture_name: str
) -> None:
"""Emits a PytestWarning about the given fixture being incompatible with newer xunit revisions"""
from _pytest.warning_types import PytestWarning
@ -278,7 +293,7 @@ def _warn_incompatibility_with_xunit2(request, fixture_name):
@pytest.fixture
def record_property(request):
def record_property(request: FixtureRequest):
"""Add an extra properties the calling test.
User properties become part of the test report and are available to the
configured reporters, like JUnit XML.
@ -292,14 +307,14 @@ def record_property(request):
"""
_warn_incompatibility_with_xunit2(request, "record_property")
def append_property(name, value):
def append_property(name: str, value: object) -> None:
request.node.user_properties.append((name, value))
return append_property
@pytest.fixture
def record_xml_attribute(request):
def record_xml_attribute(request: FixtureRequest):
"""Add extra xml attributes to the tag for the calling test.
The fixture is callable with ``(name, value)``, with value being
automatically xml-encoded
@ -313,7 +328,7 @@ def record_xml_attribute(request):
_warn_incompatibility_with_xunit2(request, "record_xml_attribute")
# Declare noop
def add_attr_noop(name, value):
def add_attr_noop(name: str, value: str) -> None:
pass
attr_func = add_attr_noop
@ -326,7 +341,7 @@ def record_xml_attribute(request):
return attr_func
def _check_record_param_type(param, v):
def _check_record_param_type(param: str, v: str) -> None:
"""Used by record_testsuite_property to check that the given parameter name is of the proper
type"""
__tracebackhide__ = True
@ -336,7 +351,7 @@ def _check_record_param_type(param, v):
@pytest.fixture(scope="session")
def record_testsuite_property(request):
def record_testsuite_property(request: FixtureRequest):
"""
Records a new ``<property>`` tag as child of the root ``<testsuite>``. This is suitable to
writing global information regarding the entire test suite, and is compatible with ``xunit2`` JUnit family.
@ -354,7 +369,7 @@ def record_testsuite_property(request):
__tracebackhide__ = True
def record_func(name, value):
def record_func(name: str, value: str):
"""noop function in case --junitxml was not passed in the command-line"""
__tracebackhide__ = True
_check_record_param_type("name", name)
@ -437,7 +452,7 @@ def pytest_unconfigure(config: Config) -> None:
config.pluginmanager.unregister(xml)
def mangle_test_address(address):
def mangle_test_address(address: str) -> List[str]:
path, possible_open_bracket, params = address.partition("[")
names = path.split("::")
try:
@ -456,13 +471,13 @@ class LogXML:
def __init__(
self,
logfile,
prefix,
suite_name="pytest",
logging="no",
report_duration="total",
prefix: Optional[str],
suite_name: str = "pytest",
logging: str = "no",
report_duration: str = "total",
family="xunit1",
log_passing_tests=True,
):
log_passing_tests: bool = True,
) -> None:
logfile = os.path.expanduser(os.path.expandvars(logfile))
self.logfile = os.path.normpath(os.path.abspath(logfile))
self.prefix = prefix
@ -471,20 +486,24 @@ class LogXML:
self.log_passing_tests = log_passing_tests
self.report_duration = report_duration
self.family = family
self.stats = dict.fromkeys(["error", "passed", "failure", "skipped"], 0)
self.node_reporters = {} # nodeid -> _NodeReporter
self.node_reporters_ordered = []
self.global_properties = []
self.stats = dict.fromkeys(
["error", "passed", "failure", "skipped"], 0
) # type: Dict[str, int]
self.node_reporters = (
{}
) # type: Dict[Tuple[Union[str, TestReport], object], _NodeReporter]
self.node_reporters_ordered = [] # type: List[_NodeReporter]
self.global_properties = [] # type: List[Tuple[str, py.xml.raw]]
# List of reports that failed on call but teardown is pending.
self.open_reports = []
self.open_reports = [] # type: List[TestReport]
self.cnt_double_fail_tests = 0
# Replaces convenience family with real family
if self.family == "legacy":
self.family = "xunit1"
def finalize(self, report):
def finalize(self, report: TestReport) -> None:
nodeid = getattr(report, "nodeid", report)
# local hack to handle xdist report order
slavenode = getattr(report, "node", None)
@ -492,8 +511,8 @@ class LogXML:
if reporter is not None:
reporter.finalize()
def node_reporter(self, report):
nodeid = getattr(report, "nodeid", report)
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporter:
nodeid = getattr(report, "nodeid", report) # type: Union[str, TestReport]
# local hack to handle xdist report order
slavenode = getattr(report, "node", None)
@ -510,11 +529,11 @@ class LogXML:
return reporter
def add_stats(self, key):
def add_stats(self, key: str) -> None:
if key in self.stats:
self.stats[key] += 1
def _opentestcase(self, report):
def _opentestcase(self, report: TestReport) -> _NodeReporter:
reporter = self.node_reporter(report)
reporter.record_testreport(report)
return reporter
@ -587,7 +606,7 @@ class LogXML:
reporter.write_captured_output(report)
for propname, propvalue in report.user_properties:
reporter.add_property(propname, propvalue)
reporter.add_property(propname, str(propvalue))
self.finalize(report)
report_wid = getattr(report, "worker_id", None)
@ -607,7 +626,7 @@ class LogXML:
if close_report:
self.open_reports.remove(close_report)
def update_testcase_duration(self, report):
def update_testcase_duration(self, report: TestReport) -> None:
"""accumulates total duration for nodeid from given report and updates
the Junit.testcase with the new total if already created.
"""
@ -615,7 +634,7 @@ class LogXML:
reporter = self.node_reporter(report)
reporter.duration += getattr(report, "duration", 0.0)
def pytest_collectreport(self, report):
def pytest_collectreport(self, report: TestReport) -> None:
if not report.passed:
reporter = self._opentestcase(report)
if report.failed:
@ -623,7 +642,7 @@ class LogXML:
else:
reporter.append_collect_skipped(report)
def pytest_internalerror(self, excrepr):
def pytest_internalerror(self, excrepr) -> None:
reporter = self.node_reporter("internal")
reporter.attrs.update(classname="pytest", name="internal")
reporter._add_simple(Junit.error, "internal error", excrepr)
@ -652,10 +671,10 @@ class LogXML:
self._get_global_properties_node(),
[x.to_xml() for x in self.node_reporters_ordered],
name=self.suite_name,
errors=self.stats["error"],
failures=self.stats["failure"],
skipped=self.stats["skipped"],
tests=numtests,
errors=str(self.stats["error"]),
failures=str(self.stats["failure"]),
skipped=str(self.stats["skipped"]),
tests=str(numtests),
time="%.3f" % suite_time_delta,
timestamp=datetime.fromtimestamp(self.suite_start_time).isoformat(),
hostname=platform.node(),
@ -666,12 +685,12 @@ class LogXML:
def pytest_terminal_summary(self, terminalreporter: TerminalReporter) -> None:
terminalreporter.write_sep("-", "generated xml file: {}".format(self.logfile))
def add_global_property(self, name, value):
def add_global_property(self, name: str, value: str) -> None:
__tracebackhide__ = True
_check_record_param_type("name", name)
self.global_properties.append((name, bin_xml_escape(value)))
def _get_global_properties_node(self):
def _get_global_properties_node(self) -> Union[py.xml.Tag, str]:
"""Return a Junit node containing custom properties, if any.
"""
if self.global_properties:

View File

@ -1,7 +1,6 @@
import os
import warnings
from functools import lru_cache
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
@ -618,7 +617,7 @@ class Item(Node):
#: user properties is a list of tuples (name, value) that holds user
#: defined properties for this test.
self.user_properties = [] # type: List[Tuple[str, Any]]
self.user_properties = [] # type: List[Tuple[str, object]]
def runtest(self) -> None:
raise NotImplementedError("runtest must be implemented by Item subclass")