diff --git a/_pytest/junitxml.py b/_pytest/junitxml.py index b88050712..de7c368ee 100644 --- a/_pytest/junitxml.py +++ b/_pytest/junitxml.py @@ -58,22 +58,48 @@ def bin_xml_escape(arg): return py.xml.raw(illegal_xml_re.sub(repl, py.xml.escape(arg))) +class _NodeReporter(object): + def __init__(self, nodeid, xml): + + self.id = nodeid + self.duration = 0 + self.properties = {} + self.testcase = None + + def add_property(self, name, value): + self.properties[str(name)] = bin_xml_escape(value) + + + def make_properties_node(self): + """Return a Junit node containing custom properties set for + the current test, if any, and reset the current custom properties. + """ + if self.properties: + return Junit.properties([ + Junit.property(name=name, value=value) + for name, value in self.properties.items() + ]) + + @pytest.fixture def record_xml_property(request): """Fixture that adds extra xml properties to the tag for the calling test. The fixture is callable with (name, value), with value being automatically xml-encoded. """ + request.config.warn( + code='C3', + message='record_xml_property is an experimental feature', + fslocation=request.node.location[:2]) + xml = getattr(request.config, "_xml", None) + if xml is not None: + nodereporter = xml.nodereporter(request.node.nodeid) + return nodereporter.add_property + else: + def add_property_noop(name, value): + pass - def inner(name, value): - if hasattr(request.config, "_xml"): - request.config._xml.add_custom_property(name, value) - msg = 'record_xml_property is an experimental feature' - request.config.warn(code='C3', - message=msg, - fslocation=request.node.location[:2]) - - return inner + return add_property_noop def pytest_addoption(parser): @@ -126,12 +152,16 @@ class LogXML(object): 'failure', 'skipped', ], 0) - self.tests_by_nodeid = {} # nodeid -> Junit.testcase - self.durations = {} # nodeid -> total duration (setup+call+teardown) - self.custom_properties = {} + self.nodereporters = {} # nodeid -> Junit.testcase + self.tests_by_nodeid = {} + + def nodereporter(self, nodeid): + if nodeid in self.nodereporters: + return self.nodereporters[nodeid] + reporter = _NodeReporter(nodeid, self) + self.nodereporters[nodeid] = reporter + return reporter - def add_custom_property(self, name, value): - self.custom_properties[str(name)] = bin_xml_escape(str(value)) def _addtestcase(self, attrs=None, **kw): testcase = Junit.testcase(**(attrs or kw)) @@ -148,6 +178,7 @@ class LogXML(object): self.append(node) def _opentestcase(self, report): + reporter = self.nodereporter(report.nodeid) names = mangle_testnames(report.nodeid.split("::")) classnames = names[:-1] if self.prefix: @@ -156,12 +187,13 @@ class LogXML(object): "classname": ".".join(classnames), "name": bin_xml_escape(names[-1]), "file": report.location[0], - "time": self.durations.get(report.nodeid, 0), + "time": reporter.duration, } if report.location[1] is not None: attrs["line"] = report.location[1] testcase = self._addtestcase(attrs) - custom_properties = self.pop_custom_properties() + reporter = self.nodereporter(report.nodeid) + custom_properties = reporter.make_properties_node() if custom_properties: testcase.append(custom_properties) self.tests_by_nodeid[report.nodeid] = testcase @@ -180,20 +212,7 @@ class LogXML(object): self._add_stats(type(obj).__name__) self.tests[-1].append(obj) - def pop_custom_properties(self): - """Return a Junit node containing custom properties set for - the current test, if any, and reset the current custom properties. - """ - if self.custom_properties: - result = Junit.properties( - [ - Junit.property(name=name, - value=value) - for name, value in self.custom_properties.items() - ]) - self.custom_properties.clear() - return result - return None + def append_pass(self, report): self._add_stats('passed') @@ -288,13 +307,12 @@ class LogXML(object): """accumulates total duration for nodeid from given report and updates the Junit.testcase with the new total if already created. """ - total = self.durations.get(report.nodeid, 0.0) - total += getattr(report, 'duration', 0.0) - self.durations[report.nodeid] = total + reporter = self.nodereporter(report.nodeid) + reporter.duration += getattr(report, 'duration', 0.0) testcase = self.tests_by_nodeid.get(report.nodeid) if testcase is not None: - testcase.attr.time = total + testcase.attr.time = reporter.duration def pytest_collectreport(self, report): if not report.passed: