Merge pull request #11801 from bluetech/node-iterchain

nodes: add `Node.iterchain()` function
This commit is contained in:
Ran Benita 2024-01-12 11:01:48 +02:00 committed by GitHub
commit 5645fa45fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 38 deletions

View File

@ -0,0 +1,2 @@
Added the :func:`iterparents() <_pytest.nodes.Node.iterparents>` helper method on nodes.
It is similar to :func:`listchain <_pytest.nodes.Node.listchain>`, but goes from bottom to top, and returns an iterator, not a list.

View File

@ -116,22 +116,16 @@ def pytest_sessionstart(session: "Session") -> None:
def get_scope_package( def get_scope_package(
node: nodes.Item, node: nodes.Item,
fixturedef: "FixtureDef[object]", fixturedef: "FixtureDef[object]",
) -> Optional[Union[nodes.Item, nodes.Collector]]: ) -> Optional[nodes.Node]:
from _pytest.python import Package from _pytest.python import Package
current: Optional[Union[nodes.Item, nodes.Collector]] = node for parent in node.iterparents():
while current and ( if isinstance(parent, Package) and parent.nodeid == fixturedef.baseid:
not isinstance(current, Package) or current.nodeid != fixturedef.baseid return parent
): return node.session
current = current.parent # type: ignore[assignment]
if current is None:
return node.session
return current
def get_scope_node( def get_scope_node(node: nodes.Node, scope: Scope) -> Optional[nodes.Node]:
node: nodes.Node, scope: Scope
) -> Optional[Union[nodes.Item, nodes.Collector]]:
import _pytest.python import _pytest.python
if scope is Scope.Function: if scope is Scope.Function:
@ -738,7 +732,7 @@ class SubRequest(FixtureRequest):
scope = self._scope scope = self._scope
if scope is Scope.Function: if scope is Scope.Function:
# This might also be a non-function Item despite its attribute name. # This might also be a non-function Item despite its attribute name.
node: Optional[Union[nodes.Item, nodes.Collector]] = self._pyfuncitem node: Optional[nodes.Node] = self._pyfuncitem
elif scope is Scope.Package: elif scope is Scope.Package:
node = get_scope_package(self._pyfuncitem, self._fixturedef) node = get_scope_package(self._pyfuncitem, self._fixturedef)
else: else:
@ -1513,7 +1507,7 @@ class FixtureManager:
def _getautousenames(self, node: nodes.Node) -> Iterator[str]: def _getautousenames(self, node: nodes.Node) -> Iterator[str]:
"""Return the names of autouse fixtures applicable to node.""" """Return the names of autouse fixtures applicable to node."""
for parentnode in reversed(list(nodes.iterparentnodes(node))): for parentnode in node.listchain():
basenames = self._nodeid_autousenames.get(parentnode.nodeid) basenames = self._nodeid_autousenames.get(parentnode.nodeid)
if basenames: if basenames:
yield from basenames yield from basenames
@ -1781,7 +1775,7 @@ class FixtureManager:
def _matchfactories( def _matchfactories(
self, fixturedefs: Iterable[FixtureDef[Any]], node: nodes.Node self, fixturedefs: Iterable[FixtureDef[Any]], node: nodes.Node
) -> Iterator[FixtureDef[Any]]: ) -> Iterator[FixtureDef[Any]]:
parentnodeids = {n.nodeid for n in nodes.iterparentnodes(node)} parentnodeids = {n.nodeid for n in node.iterparents()}
for fixturedef in fixturedefs: for fixturedef in fixturedefs:
if fixturedef.baseid in parentnodeids: if fixturedef.baseid in parentnodeids:
yield fixturedef yield fixturedef

View File

@ -49,15 +49,6 @@ SEP = "/"
tracebackcutdir = Path(_pytest.__file__).parent tracebackcutdir = Path(_pytest.__file__).parent
def iterparentnodes(node: "Node") -> Iterator["Node"]:
"""Return the parent nodes, including the node itself, from the node
upwards."""
parent: Optional[Node] = node
while parent is not None:
yield parent
parent = parent.parent
_NodeType = TypeVar("_NodeType", bound="Node") _NodeType = TypeVar("_NodeType", bound="Node")
@ -265,12 +256,20 @@ class Node(abc.ABC, metaclass=NodeMeta):
def teardown(self) -> None: def teardown(self) -> None:
pass pass
def listchain(self) -> List["Node"]: def iterparents(self) -> Iterator["Node"]:
"""Return list of all parent collectors up to self, starting from """Iterate over all parent collectors starting from and including self
the root of collection tree. up to the root of the collection tree.
:returns: The nodes. .. versionadded:: 8.1
""" """
parent: Optional[Node] = self
while parent is not None:
yield parent
parent = parent.parent
def listchain(self) -> List["Node"]:
"""Return a list of all parent collectors starting from the root of the
collection tree down to and including self."""
chain = [] chain = []
item: Optional[Node] = self item: Optional[Node] = self
while item is not None: while item is not None:
@ -319,7 +318,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
:param name: If given, filter the results by the name attribute. :param name: If given, filter the results by the name attribute.
:returns: An iterator of (node, mark) tuples. :returns: An iterator of (node, mark) tuples.
""" """
for node in reversed(self.listchain()): for node in self.iterparents():
for mark in node.own_markers: for mark in node.own_markers:
if name is None or getattr(mark, "name", None) == name: if name is None or getattr(mark, "name", None) == name:
yield node, mark yield node, mark
@ -363,17 +362,16 @@ class Node(abc.ABC, metaclass=NodeMeta):
self.session._setupstate.addfinalizer(fin, self) self.session._setupstate.addfinalizer(fin, self)
def getparent(self, cls: Type[_NodeType]) -> Optional[_NodeType]: def getparent(self, cls: Type[_NodeType]) -> Optional[_NodeType]:
"""Get the next parent node (including self) which is an instance of """Get the closest parent node (including self) which is an instance of
the given class. the given class.
:param cls: The node class to search for. :param cls: The node class to search for.
:returns: The node, if found. :returns: The node, if found.
""" """
current: Optional[Node] = self for node in self.iterparents():
while current and not isinstance(current, cls): if isinstance(node, cls):
current = current.parent return node
assert current is None or isinstance(current, cls) return None
return current
def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback: def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback:
return excinfo.traceback return excinfo.traceback

View File

@ -332,10 +332,8 @@ class PyobjMixin(nodes.Node):
def getmodpath(self, stopatmodule: bool = True, includemodule: bool = False) -> str: def getmodpath(self, stopatmodule: bool = True, includemodule: bool = False) -> str:
"""Return Python path relative to the containing module.""" """Return Python path relative to the containing module."""
chain = self.listchain()
chain.reverse()
parts = [] parts = []
for node in chain: for node in self.iterparents():
name = node.name name = node.name
if isinstance(node, Module): if isinstance(node, Module):
name = os.path.splitext(name)[0] name = os.path.splitext(name)[0]