Type annotate more of _pytest.nodes

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent 32dd0e87cb
commit fc325bc0c3
3 changed files with 55 additions and 20 deletions

View File

@ -147,9 +147,9 @@ class KeywordMatcher:
# Add the names of the current item and any parent items
import pytest
for item in item.listchain():
if not isinstance(item, (pytest.Instance, pytest.Session)):
mapped_names.add(item.name)
for node in item.listchain():
if not isinstance(node, (pytest.Instance, pytest.Session)):
mapped_names.add(node.name)
# Add the names added as extra keywords to current or parent items
mapped_names.update(item.listextrakeywords())

View File

@ -5,11 +5,13 @@ from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TypeVar
from typing import Union
import py
@ -20,6 +22,7 @@ from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import ReprExceptionInfo
from _pytest.compat import cached_property
from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import ConftestImportFailure
@ -36,6 +39,8 @@ from _pytest.pathlib import Path
from _pytest.store import Store
if TYPE_CHECKING:
from typing import Type
# Imported here due to circular import.
from _pytest.main import Session
@ -45,7 +50,7 @@ tracebackcutdir = py.path.local(_pytest.__file__).dirpath()
@lru_cache(maxsize=None)
def _splitnode(nodeid):
def _splitnode(nodeid: str) -> Tuple[str, ...]:
"""Split a nodeid into constituent 'parts'.
Node IDs are strings, and can be things like:
@ -70,7 +75,7 @@ def _splitnode(nodeid):
return tuple(parts)
def ischildnode(baseid, nodeid):
def ischildnode(baseid: str, nodeid: str) -> bool:
"""Return True if the nodeid is a child node of the baseid.
E.g. 'foo/bar::Baz' is a child of 'foo', 'foo/bar' and 'foo/bar::Baz', but not of 'foo/blorp'
@ -82,6 +87,9 @@ def ischildnode(baseid, nodeid):
return node_parts[: len(base_parts)] == base_parts
_NodeType = TypeVar("_NodeType", bound="Node")
class NodeMeta(type):
def __call__(self, *k, **kw):
warnings.warn(NODE_USE_FROM_PARENT.format(name=self.__name__), stacklevel=2)
@ -191,7 +199,7 @@ class Node(metaclass=NodeMeta):
""" fspath sensitive hook proxy used to call pytest hooks"""
return self.session.gethookproxy(self.fspath)
def __repr__(self):
def __repr__(self) -> str:
return "<{} {}>".format(self.__class__.__name__, getattr(self, "name", None))
def warn(self, warning):
@ -232,16 +240,16 @@ class Node(metaclass=NodeMeta):
""" a ::-separated string denoting its collection tree address. """
return self._nodeid
def __hash__(self):
def __hash__(self) -> int:
return hash(self._nodeid)
def setup(self):
def setup(self) -> None:
pass
def teardown(self):
def teardown(self) -> None:
pass
def listchain(self):
def listchain(self) -> List["Node"]:
""" return list of all parent collectors up to self,
starting from root of collection tree. """
chain = []
@ -276,7 +284,7 @@ class Node(metaclass=NodeMeta):
else:
self.own_markers.insert(0, marker_.mark)
def iter_markers(self, name=None):
def iter_markers(self, name: Optional[str] = None) -> Iterator[Mark]:
"""
:param name: if given, filter the results by the name attribute
@ -284,7 +292,9 @@ class Node(metaclass=NodeMeta):
"""
return (x[1] for x in self.iter_markers_with_node(name=name))
def iter_markers_with_node(self, name=None):
def iter_markers_with_node(
self, name: Optional[str] = None
) -> Iterator[Tuple["Node", Mark]]:
"""
:param name: if given, filter the results by the name attribute
@ -296,7 +306,17 @@ class Node(metaclass=NodeMeta):
if name is None or getattr(mark, "name", None) == name:
yield node, mark
def get_closest_marker(self, name, default=None):
@overload
def get_closest_marker(self, name: str) -> Optional[Mark]:
raise NotImplementedError()
@overload # noqa: F811
def get_closest_marker(self, name: str, default: Mark) -> Mark: # noqa: F811
raise NotImplementedError()
def get_closest_marker( # noqa: F811
self, name: str, default: Optional[Mark] = None
) -> Optional[Mark]:
"""return the first marker matching the name, from closest (for example function) to farther level (for example
module level).
@ -305,14 +325,14 @@ class Node(metaclass=NodeMeta):
"""
return next(self.iter_markers(name=name), default)
def listextrakeywords(self):
def listextrakeywords(self) -> Set[str]:
""" Return a set of all extra keywords in self and any parents."""
extra_keywords = set() # type: Set[str]
for item in self.listchain():
extra_keywords.update(item.extra_keyword_matches)
return extra_keywords
def listnames(self):
def listnames(self) -> List[str]:
return [x.name for x in self.listchain()]
def addfinalizer(self, fin: Callable[[], object]) -> None:
@ -323,12 +343,13 @@ class Node(metaclass=NodeMeta):
"""
self.session._setupstate.addfinalizer(fin, self)
def getparent(self, cls):
def getparent(self, cls: "Type[_NodeType]") -> Optional[_NodeType]:
""" get the next parent node (including ourself)
which is an instance of the given class"""
current = self # type: Optional[Node]
while current and not isinstance(current, cls):
current = current.parent
assert current is None or isinstance(current, cls)
return current
def _prunetraceback(self, excinfo):
@ -479,7 +500,12 @@ class FSHookProxy:
class FSCollector(Collector):
def __init__(
self, fspath: py.path.local, parent=None, config=None, session=None, nodeid=None
self,
fspath: py.path.local,
parent=None,
config: Optional[Config] = None,
session: Optional["Session"] = None,
nodeid: Optional[str] = None,
) -> None:
name = fspath.basename
if parent is not None:
@ -579,7 +605,14 @@ class Item(Node):
nextitem = None
def __init__(self, name, parent=None, config=None, session=None, nodeid=None):
def __init__(
self,
name,
parent=None,
config: Optional[Config] = None,
session: Optional["Session"] = None,
nodeid: Optional[str] = None,
) -> None:
super().__init__(name, parent, config, session, nodeid=nodeid)
self._report_sections = [] # type: List[Tuple[str, str, str]]

View File

@ -423,7 +423,9 @@ class PyCollector(PyobjMixin, nodes.Collector):
return item
def _genfunctions(self, name, funcobj):
module = self.getparent(Module).obj
modulecol = self.getparent(Module)
assert modulecol is not None
module = modulecol.obj
clscol = self.getparent(Class)
cls = clscol and clscol.obj or None
fm = self.session._fixturemanager
@ -437,7 +439,7 @@ class PyCollector(PyobjMixin, nodes.Collector):
methods = []
if hasattr(module, "pytest_generate_tests"):
methods.append(module.pytest_generate_tests)
if hasattr(cls, "pytest_generate_tests"):
if cls is not None and hasattr(cls, "pytest_generate_tests"):
methods.append(cls().pytest_generate_tests)
self.ihook.pytest_generate_tests.call_extra(methods, dict(metafunc=metafunc))