Type annotate more of _pytest.nodes

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent 32dd0e87cb
commit fc325bc0c3
3 changed files with 55 additions and 20 deletions

View File

@ -147,9 +147,9 @@ class KeywordMatcher:
# Add the names of the current item and any parent items # Add the names of the current item and any parent items
import pytest import pytest
for item in item.listchain(): for node in item.listchain():
if not isinstance(item, (pytest.Instance, pytest.Session)): if not isinstance(node, (pytest.Instance, pytest.Session)):
mapped_names.add(item.name) mapped_names.add(node.name)
# Add the names added as extra keywords to current or parent items # Add the names added as extra keywords to current or parent items
mapped_names.update(item.listextrakeywords()) mapped_names.update(item.listextrakeywords())

View File

@ -5,11 +5,13 @@ from typing import Any
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import Iterable from typing import Iterable
from typing import Iterator
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Set from typing import Set
from typing import Tuple from typing import Tuple
from typing import TypeVar
from typing import Union from typing import Union
import py import py
@ -20,6 +22,7 @@ from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo from _pytest._code.code import ExceptionInfo
from _pytest._code.code import ReprExceptionInfo from _pytest._code.code import ReprExceptionInfo
from _pytest.compat import cached_property from _pytest.compat import cached_property
from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import ConftestImportFailure from _pytest.config import ConftestImportFailure
@ -36,6 +39,8 @@ from _pytest.pathlib import Path
from _pytest.store import Store from _pytest.store import Store
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Type
# Imported here due to circular import. # Imported here due to circular import.
from _pytest.main import Session from _pytest.main import Session
@ -45,7 +50,7 @@ tracebackcutdir = py.path.local(_pytest.__file__).dirpath()
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _splitnode(nodeid): def _splitnode(nodeid: str) -> Tuple[str, ...]:
"""Split a nodeid into constituent 'parts'. """Split a nodeid into constituent 'parts'.
Node IDs are strings, and can be things like: Node IDs are strings, and can be things like:
@ -70,7 +75,7 @@ def _splitnode(nodeid):
return tuple(parts) return tuple(parts)
def ischildnode(baseid, nodeid): def ischildnode(baseid: str, nodeid: str) -> bool:
"""Return True if the nodeid is a child node of the baseid. """Return True if the nodeid is a child node of the baseid.
E.g. 'foo/bar::Baz' is a child of 'foo', 'foo/bar' and 'foo/bar::Baz', but not of 'foo/blorp' E.g. 'foo/bar::Baz' is a child of 'foo', 'foo/bar' and 'foo/bar::Baz', but not of 'foo/blorp'
@ -82,6 +87,9 @@ def ischildnode(baseid, nodeid):
return node_parts[: len(base_parts)] == base_parts return node_parts[: len(base_parts)] == base_parts
_NodeType = TypeVar("_NodeType", bound="Node")
class NodeMeta(type): class NodeMeta(type):
def __call__(self, *k, **kw): def __call__(self, *k, **kw):
warnings.warn(NODE_USE_FROM_PARENT.format(name=self.__name__), stacklevel=2) warnings.warn(NODE_USE_FROM_PARENT.format(name=self.__name__), stacklevel=2)
@ -191,7 +199,7 @@ class Node(metaclass=NodeMeta):
""" fspath sensitive hook proxy used to call pytest hooks""" """ fspath sensitive hook proxy used to call pytest hooks"""
return self.session.gethookproxy(self.fspath) return self.session.gethookproxy(self.fspath)
def __repr__(self): def __repr__(self) -> str:
return "<{} {}>".format(self.__class__.__name__, getattr(self, "name", None)) return "<{} {}>".format(self.__class__.__name__, getattr(self, "name", None))
def warn(self, warning): def warn(self, warning):
@ -232,16 +240,16 @@ class Node(metaclass=NodeMeta):
""" a ::-separated string denoting its collection tree address. """ """ a ::-separated string denoting its collection tree address. """
return self._nodeid return self._nodeid
def __hash__(self): def __hash__(self) -> int:
return hash(self._nodeid) return hash(self._nodeid)
def setup(self): def setup(self) -> None:
pass pass
def teardown(self): def teardown(self) -> None:
pass pass
def listchain(self): def listchain(self) -> List["Node"]:
""" return list of all parent collectors up to self, """ return list of all parent collectors up to self,
starting from root of collection tree. """ starting from root of collection tree. """
chain = [] chain = []
@ -276,7 +284,7 @@ class Node(metaclass=NodeMeta):
else: else:
self.own_markers.insert(0, marker_.mark) self.own_markers.insert(0, marker_.mark)
def iter_markers(self, name=None): def iter_markers(self, name: Optional[str] = None) -> Iterator[Mark]:
""" """
:param name: if given, filter the results by the name attribute :param name: if given, filter the results by the name attribute
@ -284,7 +292,9 @@ class Node(metaclass=NodeMeta):
""" """
return (x[1] for x in self.iter_markers_with_node(name=name)) return (x[1] for x in self.iter_markers_with_node(name=name))
def iter_markers_with_node(self, name=None): def iter_markers_with_node(
self, name: Optional[str] = None
) -> Iterator[Tuple["Node", Mark]]:
""" """
:param name: if given, filter the results by the name attribute :param name: if given, filter the results by the name attribute
@ -296,7 +306,17 @@ class Node(metaclass=NodeMeta):
if name is None or getattr(mark, "name", None) == name: if name is None or getattr(mark, "name", None) == name:
yield node, mark yield node, mark
def get_closest_marker(self, name, default=None): @overload
def get_closest_marker(self, name: str) -> Optional[Mark]:
raise NotImplementedError()
@overload # noqa: F811
def get_closest_marker(self, name: str, default: Mark) -> Mark: # noqa: F811
raise NotImplementedError()
def get_closest_marker( # noqa: F811
self, name: str, default: Optional[Mark] = None
) -> Optional[Mark]:
"""return the first marker matching the name, from closest (for example function) to farther level (for example """return the first marker matching the name, from closest (for example function) to farther level (for example
module level). module level).
@ -305,14 +325,14 @@ class Node(metaclass=NodeMeta):
""" """
return next(self.iter_markers(name=name), default) return next(self.iter_markers(name=name), default)
def listextrakeywords(self): def listextrakeywords(self) -> Set[str]:
""" Return a set of all extra keywords in self and any parents.""" """ Return a set of all extra keywords in self and any parents."""
extra_keywords = set() # type: Set[str] extra_keywords = set() # type: Set[str]
for item in self.listchain(): for item in self.listchain():
extra_keywords.update(item.extra_keyword_matches) extra_keywords.update(item.extra_keyword_matches)
return extra_keywords return extra_keywords
def listnames(self): def listnames(self) -> List[str]:
return [x.name for x in self.listchain()] return [x.name for x in self.listchain()]
def addfinalizer(self, fin: Callable[[], object]) -> None: def addfinalizer(self, fin: Callable[[], object]) -> None:
@ -323,12 +343,13 @@ class Node(metaclass=NodeMeta):
""" """
self.session._setupstate.addfinalizer(fin, self) self.session._setupstate.addfinalizer(fin, self)
def getparent(self, cls): def getparent(self, cls: "Type[_NodeType]") -> Optional[_NodeType]:
""" get the next parent node (including ourself) """ get the next parent node (including ourself)
which is an instance of the given class""" which is an instance of the given class"""
current = self # type: Optional[Node] current = self # type: Optional[Node]
while current and not isinstance(current, cls): while current and not isinstance(current, cls):
current = current.parent current = current.parent
assert current is None or isinstance(current, cls)
return current return current
def _prunetraceback(self, excinfo): def _prunetraceback(self, excinfo):
@ -479,7 +500,12 @@ class FSHookProxy:
class FSCollector(Collector): class FSCollector(Collector):
def __init__( def __init__(
self, fspath: py.path.local, parent=None, config=None, session=None, nodeid=None self,
fspath: py.path.local,
parent=None,
config: Optional[Config] = None,
session: Optional["Session"] = None,
nodeid: Optional[str] = None,
) -> None: ) -> None:
name = fspath.basename name = fspath.basename
if parent is not None: if parent is not None:
@ -579,7 +605,14 @@ class Item(Node):
nextitem = None nextitem = None
def __init__(self, name, parent=None, config=None, session=None, nodeid=None): def __init__(
self,
name,
parent=None,
config: Optional[Config] = None,
session: Optional["Session"] = None,
nodeid: Optional[str] = None,
) -> None:
super().__init__(name, parent, config, session, nodeid=nodeid) super().__init__(name, parent, config, session, nodeid=nodeid)
self._report_sections = [] # type: List[Tuple[str, str, str]] self._report_sections = [] # type: List[Tuple[str, str, str]]

View File

@ -423,7 +423,9 @@ class PyCollector(PyobjMixin, nodes.Collector):
return item return item
def _genfunctions(self, name, funcobj): def _genfunctions(self, name, funcobj):
module = self.getparent(Module).obj modulecol = self.getparent(Module)
assert modulecol is not None
module = modulecol.obj
clscol = self.getparent(Class) clscol = self.getparent(Class)
cls = clscol and clscol.obj or None cls = clscol and clscol.obj or None
fm = self.session._fixturemanager fm = self.session._fixturemanager
@ -437,7 +439,7 @@ class PyCollector(PyobjMixin, nodes.Collector):
methods = [] methods = []
if hasattr(module, "pytest_generate_tests"): if hasattr(module, "pytest_generate_tests"):
methods.append(module.pytest_generate_tests) methods.append(module.pytest_generate_tests)
if hasattr(cls, "pytest_generate_tests"): if cls is not None and hasattr(cls, "pytest_generate_tests"):
methods.append(cls().pytest_generate_tests) methods.append(cls().pytest_generate_tests)
self.ihook.pytest_generate_tests.call_extra(methods, dict(metafunc=metafunc)) self.ihook.pytest_generate_tests.call_extra(methods, dict(metafunc=metafunc))