Type annotate main.py and some parts related to collection

This commit is contained in:
Ran Benita 2020-05-01 14:40:15 +03:00
parent f8de424241
commit be00e12d47
9 changed files with 175 additions and 75 deletions

View File

@ -840,7 +840,7 @@ class Config:
self.cache = None # type: Optional[Cache]
@property
def invocation_dir(self):
def invocation_dir(self) -> py.path.local:
"""Backward compatibility"""
return py.path.local(str(self.invocation_params.dir))

View File

@ -7,6 +7,7 @@ import traceback
import warnings
from contextlib import contextmanager
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
@ -109,13 +110,18 @@ def pytest_unconfigure() -> None:
RUNNER_CLASS = None
def pytest_collect_file(path: py.path.local, parent):
def pytest_collect_file(
path: py.path.local, parent
) -> Optional[Union["DoctestModule", "DoctestTextfile"]]:
config = parent.config
if path.ext == ".py":
if config.option.doctestmodules and not _is_setup_py(path):
return DoctestModule.from_parent(parent, fspath=path)
mod = DoctestModule.from_parent(parent, fspath=path) # type: DoctestModule
return mod
elif _is_doctest(config, path, parent):
return DoctestTextfile.from_parent(parent, fspath=path)
txt = DoctestTextfile.from_parent(parent, fspath=path) # type: DoctestTextfile
return txt
return None
def _is_setup_py(path: py.path.local) -> bool:
@ -365,7 +371,7 @@ def _get_continue_on_failure(config):
class DoctestTextfile(pytest.Module):
obj = None
def collect(self):
def collect(self) -> Iterable[DoctestItem]:
import doctest
# inspired by doctest.testfile; ideally we would use it directly,
@ -444,7 +450,7 @@ def _patch_unwrap_mock_aware():
class DoctestModule(pytest.Module):
def collect(self):
def collect(self) -> Iterable[DoctestItem]:
import doctest
class MockAwareDocTestFinder(doctest.DocTestFinder):

View File

@ -6,6 +6,7 @@ from typing import Optional
from typing import Tuple
from typing import Union
import py.path
from pluggy import HookspecMarker
from .deprecated import COLLECT_DIRECTORY_HOOK
@ -20,9 +21,14 @@ if TYPE_CHECKING:
from _pytest.config import _PluggyPlugin
from _pytest.config.argparsing import Parser
from _pytest.main import Session
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.python import Metafunc
from _pytest.python import Module
from _pytest.python import PyCollector
from _pytest.reports import BaseReport
hookspec = HookspecMarker("pytest")
# -------------------------------------------------------------------------
@ -249,7 +255,7 @@ def pytest_collect_directory(path, parent):
"""
def pytest_collect_file(path, parent):
def pytest_collect_file(path: py.path.local, parent) -> "Optional[Collector]":
""" return collection Node or None for the given path. Any new node
needs to have the specified ``parent`` as a parent.
@ -289,7 +295,7 @@ def pytest_make_collect_report(collector):
@hookspec(firstresult=True)
def pytest_pycollect_makemodule(path, parent):
def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Optional[Module]":
""" return a Module collector or None for the given path.
This hook will be called for each matching test module path.
The pytest_collect_file hook needs to be used if you want to
@ -302,7 +308,9 @@ def pytest_pycollect_makemodule(path, parent):
@hookspec(firstresult=True)
def pytest_pycollect_makeitem(collector, name, obj):
def pytest_pycollect_makeitem(
collector: "PyCollector", name: str, obj
) -> "Union[None, Item, Collector, List[Union[Item, Collector]]]":
""" return custom item/collector for a python object in a module, or None.
Stops at first non-None result, see :ref:`firstresult` """

View File

@ -7,9 +7,11 @@ import sys
from typing import Callable
from typing import Dict
from typing import FrozenSet
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union
@ -18,12 +20,14 @@ import py
import _pytest._code
from _pytest import nodes
from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import directory_arg
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureManager
from _pytest.outcomes import exit
from _pytest.reports import CollectReport
@ -38,7 +42,7 @@ if TYPE_CHECKING:
from _pytest.python import Package
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"norecursedirs",
"directory patterns to avoid for recursion",
@ -241,7 +245,7 @@ def wrap_session(
return session.exitstatus
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
return wrap_session(config, _main)
@ -258,11 +262,11 @@ def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
return None
def pytest_collection(session):
def pytest_collection(session: "Session") -> Sequence[nodes.Item]:
return session.perform_collect()
def pytest_runtestloop(session):
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
@ -282,7 +286,7 @@ def pytest_runtestloop(session):
return True
def _in_venv(path):
def _in_venv(path: py.path.local) -> bool:
"""Attempts to detect if ``path`` is the root of a Virtual Environment by
checking for the existence of the appropriate activate script"""
bindir = path.join("Scripts" if sys.platform.startswith("win") else "bin")
@ -328,7 +332,7 @@ def pytest_ignore_collect(
return None
def pytest_collection_modifyitems(items, config):
def pytest_collection_modifyitems(items, config: Config) -> None:
deselect_prefixes = tuple(config.getoption("deselect") or [])
if not deselect_prefixes:
return
@ -385,8 +389,8 @@ class Session(nodes.FSCollector):
)
self.testsfailed = 0
self.testscollected = 0
self.shouldstop = False
self.shouldfail = False
self.shouldstop = False # type: Union[bool, str]
self.shouldfail = False # type: Union[bool, str]
self.trace = config.trace.root.get("collection")
self.startdir = config.invocation_dir
self._initialpaths = frozenset() # type: FrozenSet[py.path.local]
@ -412,10 +416,11 @@ class Session(nodes.FSCollector):
self.config.pluginmanager.register(self, name="session")
@classmethod
def from_config(cls, config):
return cls._create(config)
def from_config(cls, config: Config) -> "Session":
session = cls._create(config) # type: Session
return session
def __repr__(self):
def __repr__(self) -> str:
return "<%s %s exitstatus=%r testsfailed=%d testscollected=%d>" % (
self.__class__.__name__,
self.name,
@ -429,14 +434,14 @@ class Session(nodes.FSCollector):
return self._bestrelpathcache[node_path]
@hookimpl(tryfirst=True)
def pytest_collectstart(self):
def pytest_collectstart(self) -> None:
if self.shouldfail:
raise self.Failed(self.shouldfail)
if self.shouldstop:
raise self.Interrupted(self.shouldstop)
@hookimpl(tryfirst=True)
def pytest_runtest_logreport(self, report):
def pytest_runtest_logreport(self, report) -> None:
if report.failed and not hasattr(report, "wasxfail"):
self.testsfailed += 1
maxfail = self.config.getvalue("maxfail")
@ -445,13 +450,27 @@ class Session(nodes.FSCollector):
pytest_collectreport = pytest_runtest_logreport
def isinitpath(self, path):
def isinitpath(self, path: py.path.local) -> bool:
return path in self._initialpaths
def gethookproxy(self, fspath: py.path.local):
return super()._gethookproxy(fspath)
def perform_collect(self, args=None, genitems=True):
@overload
def perform_collect(
self, args: Optional[Sequence[str]] = ..., genitems: "Literal[True]" = ...
) -> Sequence[nodes.Item]:
raise NotImplementedError()
@overload # noqa: F811
def perform_collect( # noqa: F811
self, args: Optional[Sequence[str]] = ..., genitems: bool = ...
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
raise NotImplementedError()
def perform_collect( # noqa: F811
self, args: Optional[Sequence[str]] = None, genitems: bool = True
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
hook = self.config.hook
try:
items = self._perform_collect(args, genitems)
@ -464,15 +483,29 @@ class Session(nodes.FSCollector):
self.testscollected = len(items)
return items
def _perform_collect(self, args, genitems):
@overload
def _perform_collect(
self, args: Optional[Sequence[str]], genitems: "Literal[True]"
) -> Sequence[nodes.Item]:
raise NotImplementedError()
@overload # noqa: F811
def _perform_collect( # noqa: F811
self, args: Optional[Sequence[str]], genitems: bool
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
raise NotImplementedError()
def _perform_collect( # noqa: F811
self, args: Optional[Sequence[str]], genitems: bool
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
if args is None:
args = self.config.args
self.trace("perform_collect", self, args)
self.trace.root.indent += 1
self._notfound = []
self._notfound = [] # type: List[Tuple[str, NoMatch]]
initialpaths = [] # type: List[py.path.local]
self._initial_parts = [] # type: List[Tuple[py.path.local, List[str]]]
self.items = items = []
self.items = items = [] # type: List[nodes.Item]
for arg in args:
fspath, parts = self._parsearg(arg)
self._initial_parts.append((fspath, parts))
@ -495,7 +528,7 @@ class Session(nodes.FSCollector):
self.items.extend(self.genitems(node))
return items
def collect(self):
def collect(self) -> Iterator[Union[nodes.Item, nodes.Collector]]:
for fspath, parts in self._initial_parts:
self.trace("processing argument", (fspath, parts))
self.trace.root.indent += 1
@ -513,7 +546,9 @@ class Session(nodes.FSCollector):
self._collection_node_cache3.clear()
self._collection_pkg_roots.clear()
def _collect(self, argpath, names):
def _collect(
self, argpath: py.path.local, names: List[str]
) -> Iterator[Union[nodes.Item, nodes.Collector]]:
from _pytest.python import Package
# Start with a Session root, and delve to argpath item (dir or file)
@ -541,7 +576,7 @@ class Session(nodes.FSCollector):
if argpath.check(dir=1):
assert not names, "invalid arg {!r}".format((argpath, names))
seen_dirs = set()
seen_dirs = set() # type: Set[py.path.local]
for path in argpath.visit(
fil=self._visit_filter, rec=self._recurse, bf=True, sort=True
):
@ -582,8 +617,9 @@ class Session(nodes.FSCollector):
# Module itself, so just use that. If this special case isn't taken, then all
# the files in the package will be yielded.
if argpath.basename == "__init__.py":
assert isinstance(m[0], nodes.Collector)
try:
yield next(m[0].collect())
yield next(iter(m[0].collect()))
except StopIteration:
# The package collects nothing with only an __init__.py
# file in it, which gets ignored by the default
@ -593,10 +629,11 @@ class Session(nodes.FSCollector):
yield from m
@staticmethod
def _visit_filter(f):
return f.check(file=1)
def _visit_filter(f: py.path.local) -> bool:
# TODO: Remove type: ignore once `py` is typed.
return f.check(file=1) # type: ignore
def _tryconvertpyarg(self, x):
def _tryconvertpyarg(self, x: str) -> str:
"""Convert a dotted module name to path."""
try:
spec = importlib.util.find_spec(x)
@ -605,14 +642,14 @@ class Session(nodes.FSCollector):
# ValueError: not a module name
except (AttributeError, ImportError, ValueError):
return x
if spec is None or spec.origin in {None, "namespace"}:
if spec is None or spec.origin is None or spec.origin == "namespace":
return x
elif spec.submodule_search_locations:
return os.path.dirname(spec.origin)
else:
return spec.origin
def _parsearg(self, arg):
def _parsearg(self, arg: str) -> Tuple[py.path.local, List[str]]:
""" return (fspath, names) tuple after checking the file exists. """
strpath, *parts = str(arg).split("::")
if self.config.option.pyargs:
@ -628,7 +665,9 @@ class Session(nodes.FSCollector):
fspath = fspath.realpath()
return (fspath, parts)
def matchnodes(self, matching, names):
def matchnodes(
self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str],
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
self.trace("matchnodes", matching, names)
self.trace.root.indent += 1
nodes = self._matchnodes(matching, names)
@ -639,13 +678,15 @@ class Session(nodes.FSCollector):
raise NoMatch(matching, names[:1])
return nodes
def _matchnodes(self, matching, names):
def _matchnodes(
self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str],
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
if not matching or not names:
return matching
name = names[0]
assert name
nextnames = names[1:]
resultnodes = []
resultnodes = [] # type: List[Union[nodes.Item, nodes.Collector]]
for node in matching:
if isinstance(node, nodes.Item):
if not names:
@ -676,7 +717,9 @@ class Session(nodes.FSCollector):
node.ihook.pytest_collectreport(report=rep)
return resultnodes
def genitems(self, node):
def genitems(
self, node: Union[nodes.Item, nodes.Collector]
) -> Iterator[nodes.Item]:
self.trace("genitems", node)
if isinstance(node, nodes.Item):
node.ihook.pytest_itemcollected(item=node)

View File

@ -4,8 +4,10 @@ from functools import lru_cache
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union
@ -226,7 +228,7 @@ class Node(metaclass=NodeMeta):
# methods for ordering nodes
@property
def nodeid(self):
def nodeid(self) -> str:
""" a ::-separated string denoting its collection tree address. """
return self._nodeid
@ -423,7 +425,7 @@ class Collector(Node):
class CollectError(Exception):
""" an error during collection, contains a custom message. """
def collect(self):
def collect(self) -> Iterable[Union["Item", "Collector"]]:
""" returns a list of children (items and collectors)
for this collection node.
"""
@ -522,6 +524,9 @@ class FSCollector(Collector):
proxy = self.config.hook
return proxy
def gethookproxy(self, fspath: py.path.local):
raise NotImplementedError()
def _recurse(self, dirpath: py.path.local) -> bool:
if dirpath.basename == "__pycache__":
return False
@ -535,7 +540,12 @@ class FSCollector(Collector):
ihook.pytest_collect_directory(path=dirpath, parent=self)
return True
def _collectfile(self, path, handle_dupes=True):
def isinitpath(self, path: py.path.local) -> bool:
raise NotImplementedError()
def _collectfile(
self, path: py.path.local, handle_dupes: bool = True
) -> Sequence[Collector]:
assert (
path.isfile()
), "{!r} is not a file (isdir={!r}, exists={!r}, islink={!r})".format(
@ -555,7 +565,7 @@ class FSCollector(Collector):
else:
duplicate_paths.add(path)
return ihook.pytest_collect_file(path=path, parent=self)
return ihook.pytest_collect_file(path=path, parent=self) # type: ignore[no-any-return] # noqa: F723
class File(FSCollector):

View File

@ -43,9 +43,9 @@ from _pytest.compat import REGEX_TYPE
from _pytest.compat import safe_getattr
from _pytest.compat import safe_isclass
from _pytest.compat import STRING_TYPES
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.compat import TYPE_CHECKING
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.deprecated import FUNCARGNAMES
@ -184,16 +184,20 @@ def pytest_pyfunc_call(pyfuncitem: "Function"):
return True
def pytest_collect_file(path, parent):
def pytest_collect_file(path: py.path.local, parent) -> Optional["Module"]:
ext = path.ext
if ext == ".py":
if not parent.session.isinitpath(path):
if not path_matches_patterns(
path, parent.config.getini("python_files") + ["__init__.py"]
):
return
return None
ihook = parent.session.gethookproxy(path)
return ihook.pytest_pycollect_makemodule(path=path, parent=parent)
module = ihook.pytest_pycollect_makemodule(
path=path, parent=parent
) # type: Module
return module
return None
def path_matches_patterns(path, patterns):
@ -201,14 +205,16 @@ def path_matches_patterns(path, patterns):
return any(path.fnmatch(pattern) for pattern in patterns)
def pytest_pycollect_makemodule(path, parent):
def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Module":
if path.basename == "__init__.py":
return Package.from_parent(parent, fspath=path)
return Module.from_parent(parent, fspath=path)
pkg = Package.from_parent(parent, fspath=path) # type: Package
return pkg
mod = Module.from_parent(parent, fspath=path) # type: Module
return mod
@hookimpl(hookwrapper=True)
def pytest_pycollect_makeitem(collector, name, obj):
def pytest_pycollect_makeitem(collector: "PyCollector", name: str, obj):
outcome = yield
res = outcome.get_result()
if res is not None:
@ -372,7 +378,7 @@ class PyCollector(PyobjMixin, nodes.Collector):
return True
return False
def collect(self):
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
if not getattr(self.obj, "__test__", True):
return []
@ -381,8 +387,8 @@ class PyCollector(PyobjMixin, nodes.Collector):
dicts = [getattr(self.obj, "__dict__", {})]
for basecls in self.obj.__class__.__mro__:
dicts.append(basecls.__dict__)
seen = set()
values = []
seen = set() # type: Set[str]
values = [] # type: List[Union[nodes.Item, nodes.Collector]]
for dic in dicts:
# Note: seems like the dict can change during iteration -
# be careful not to remove the list() without consideration.
@ -404,9 +410,16 @@ class PyCollector(PyobjMixin, nodes.Collector):
values.sort(key=sort_key)
return values
def _makeitem(self, name, obj):
def _makeitem(
self, name: str, obj
) -> Union[
None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]]
]:
# assert self.ihook.fspath == self.fspath, self
return self.ihook.pytest_pycollect_makeitem(collector=self, name=name, obj=obj)
item = self.ihook.pytest_pycollect_makeitem(
collector=self, name=name, obj=obj
) # type: Union[None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]]]
return item
def _genfunctions(self, name, funcobj):
module = self.getparent(Module).obj
@ -458,7 +471,7 @@ class Module(nodes.File, PyCollector):
def _getobj(self):
return self._importtestmodule()
def collect(self):
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
self._inject_setup_module_fixture()
self._inject_setup_function_fixture()
self.session._fixturemanager.parsefactories(self)
@ -603,17 +616,17 @@ class Package(Module):
def gethookproxy(self, fspath: py.path.local):
return super()._gethookproxy(fspath)
def isinitpath(self, path):
def isinitpath(self, path: py.path.local) -> bool:
return path in self.session._initialpaths
def collect(self):
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
this_path = self.fspath.dirpath()
init_module = this_path.join("__init__.py")
if init_module.check(file=1) and path_matches_patterns(
init_module, self.config.getini("python_files")
):
yield Module.from_parent(self, fspath=init_module)
pkg_prefixes = set()
pkg_prefixes = set() # type: Set[py.path.local]
for path in this_path.visit(rec=self._recurse, bf=True, sort=True):
# We will visit our own __init__.py file, in which case we skip it.
is_file = path.isfile()
@ -670,10 +683,11 @@ class Class(PyCollector):
"""
return super().from_parent(name=name, parent=parent)
def collect(self):
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
if not safe_getattr(self.obj, "__test__", True):
return []
if hasinit(self.obj):
assert self.parent is not None
self.warn(
PytestCollectionWarning(
"cannot collect test class %r because it has a "
@ -683,6 +697,7 @@ class Class(PyCollector):
)
return []
elif hasnew(self.obj):
assert self.parent is not None
self.warn(
PytestCollectionWarning(
"cannot collect test class %r because it has a "
@ -756,7 +771,7 @@ class Instance(PyCollector):
def _getobj(self):
return self.parent.obj()
def collect(self):
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
self.session._fixturemanager.parsefactories(self)
return super().collect()

View File

@ -21,7 +21,8 @@ from _pytest._code.code import ReprTraceback
from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.compat import TYPE_CHECKING
from _pytest.nodes import Node
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import skip
from _pytest.pathlib import Path
@ -316,7 +317,13 @@ class CollectReport(BaseReport):
when = "collect"
def __init__(
self, nodeid: str, outcome, longrepr, result: List[Node], sections=(), **extra
self,
nodeid: str,
outcome,
longrepr,
result: Optional[List[Union[Item, Collector]]],
sections=(),
**extra
) -> None:
self.nodeid = nodeid
self.outcome = outcome

View File

@ -404,10 +404,10 @@ class SetupState:
raise e
def collect_one_node(collector):
def collect_one_node(collector: Collector) -> CollectReport:
ihook = collector.ihook
ihook.pytest_collectstart(collector=collector)
rep = ihook.pytest_make_collect_report(collector=collector)
rep = ihook.pytest_make_collect_report(collector=collector) # type: CollectReport
call = rep.__dict__.pop("call", None)
if call and check_interactive_exception(call, rep):
ihook.pytest_exception_interact(node=collector, call=call, report=rep)

View File

@ -1,32 +1,43 @@
""" discovery and running of std-library "unittest" style tests. """
import sys
import traceback
from typing import Iterable
from typing import Optional
from typing import Union
import _pytest._code
import pytest
from _pytest.compat import getimfunc
from _pytest.compat import is_async_function
from _pytest.config import hookimpl
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import exit
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.outcomes import xfail
from _pytest.python import Class
from _pytest.python import Function
from _pytest.python import PyCollector
from _pytest.runner import CallInfo
from _pytest.skipping import skipped_by_mark_key
from _pytest.skipping import unexpectedsuccess_key
def pytest_pycollect_makeitem(collector, name, obj):
def pytest_pycollect_makeitem(
collector: PyCollector, name: str, obj
) -> Optional["UnitTestCase"]:
# has unittest been imported and is obj a subclass of its TestCase?
try:
if not issubclass(obj, sys.modules["unittest"].TestCase):
return
ut = sys.modules["unittest"]
# Type ignored because `ut` is an opaque module.
if not issubclass(obj, ut.TestCase): # type: ignore
return None
except Exception:
return
return None
# yes, so let's collect it
return UnitTestCase.from_parent(collector, name=name, obj=obj)
item = UnitTestCase.from_parent(collector, name=name, obj=obj) # type: UnitTestCase
return item
class UnitTestCase(Class):
@ -34,7 +45,7 @@ class UnitTestCase(Class):
# to declare that our children do not support funcargs
nofuncargs = True
def collect(self):
def collect(self) -> Iterable[Union[Item, Collector]]:
from unittest import TestLoader
cls = self.obj
@ -61,8 +72,8 @@ class UnitTestCase(Class):
runtest = getattr(self.obj, "runTest", None)
if runtest is not None:
ut = sys.modules.get("twisted.trial.unittest", None)
if ut is None or runtest != ut.TestCase.runTest:
# TODO: callobj consistency
# Type ignored because `ut` is an opaque module.
if ut is None or runtest != ut.TestCase.runTest: # type: ignore
yield TestCaseFunction.from_parent(self, name="runTest")
def _inject_setup_teardown_fixtures(self, cls):