Type annotate main.py and some parts related to collection

This commit is contained in:
Ran Benita 2020-05-01 14:40:15 +03:00
parent f8de424241
commit be00e12d47
9 changed files with 175 additions and 75 deletions

View File

@ -840,7 +840,7 @@ class Config:
self.cache = None # type: Optional[Cache] self.cache = None # type: Optional[Cache]
@property @property
def invocation_dir(self): def invocation_dir(self) -> py.path.local:
"""Backward compatibility""" """Backward compatibility"""
return py.path.local(str(self.invocation_params.dir)) return py.path.local(str(self.invocation_params.dir))

View File

@ -7,6 +7,7 @@ import traceback
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict from typing import Dict
from typing import Iterable
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Sequence from typing import Sequence
@ -109,13 +110,18 @@ def pytest_unconfigure() -> None:
RUNNER_CLASS = None RUNNER_CLASS = None
def pytest_collect_file(path: py.path.local, parent): def pytest_collect_file(
path: py.path.local, parent
) -> Optional[Union["DoctestModule", "DoctestTextfile"]]:
config = parent.config config = parent.config
if path.ext == ".py": if path.ext == ".py":
if config.option.doctestmodules and not _is_setup_py(path): if config.option.doctestmodules and not _is_setup_py(path):
return DoctestModule.from_parent(parent, fspath=path) mod = DoctestModule.from_parent(parent, fspath=path) # type: DoctestModule
return mod
elif _is_doctest(config, path, parent): elif _is_doctest(config, path, parent):
return DoctestTextfile.from_parent(parent, fspath=path) txt = DoctestTextfile.from_parent(parent, fspath=path) # type: DoctestTextfile
return txt
return None
def _is_setup_py(path: py.path.local) -> bool: def _is_setup_py(path: py.path.local) -> bool:
@ -365,7 +371,7 @@ def _get_continue_on_failure(config):
class DoctestTextfile(pytest.Module): class DoctestTextfile(pytest.Module):
obj = None obj = None
def collect(self): def collect(self) -> Iterable[DoctestItem]:
import doctest import doctest
# inspired by doctest.testfile; ideally we would use it directly, # inspired by doctest.testfile; ideally we would use it directly,
@ -444,7 +450,7 @@ def _patch_unwrap_mock_aware():
class DoctestModule(pytest.Module): class DoctestModule(pytest.Module):
def collect(self): def collect(self) -> Iterable[DoctestItem]:
import doctest import doctest
class MockAwareDocTestFinder(doctest.DocTestFinder): class MockAwareDocTestFinder(doctest.DocTestFinder):

View File

@ -6,6 +6,7 @@ from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
import py.path
from pluggy import HookspecMarker from pluggy import HookspecMarker
from .deprecated import COLLECT_DIRECTORY_HOOK from .deprecated import COLLECT_DIRECTORY_HOOK
@ -20,9 +21,14 @@ if TYPE_CHECKING:
from _pytest.config import _PluggyPlugin from _pytest.config import _PluggyPlugin
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.main import Session from _pytest.main import Session
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.python import Metafunc from _pytest.python import Metafunc
from _pytest.python import Module
from _pytest.python import PyCollector
from _pytest.reports import BaseReport from _pytest.reports import BaseReport
hookspec = HookspecMarker("pytest") hookspec = HookspecMarker("pytest")
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -249,7 +255,7 @@ def pytest_collect_directory(path, parent):
""" """
def pytest_collect_file(path, parent): def pytest_collect_file(path: py.path.local, parent) -> "Optional[Collector]":
""" return collection Node or None for the given path. Any new node """ return collection Node or None for the given path. Any new node
needs to have the specified ``parent`` as a parent. needs to have the specified ``parent`` as a parent.
@ -289,7 +295,7 @@ def pytest_make_collect_report(collector):
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_pycollect_makemodule(path, parent): def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Optional[Module]":
""" return a Module collector or None for the given path. """ return a Module collector or None for the given path.
This hook will be called for each matching test module path. This hook will be called for each matching test module path.
The pytest_collect_file hook needs to be used if you want to The pytest_collect_file hook needs to be used if you want to
@ -302,7 +308,9 @@ def pytest_pycollect_makemodule(path, parent):
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_pycollect_makeitem(collector, name, obj): def pytest_pycollect_makeitem(
collector: "PyCollector", name: str, obj
) -> "Union[None, Item, Collector, List[Union[Item, Collector]]]":
""" return custom item/collector for a python object in a module, or None. """ return custom item/collector for a python object in a module, or None.
Stops at first non-None result, see :ref:`firstresult` """ Stops at first non-None result, see :ref:`firstresult` """

View File

@ -7,9 +7,11 @@ import sys
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import FrozenSet from typing import FrozenSet
from typing import Iterator
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Set
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
@ -18,12 +20,14 @@ import py
import _pytest._code import _pytest._code
from _pytest import nodes from _pytest import nodes
from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import directory_arg from _pytest.config import directory_arg
from _pytest.config import ExitCode from _pytest.config import ExitCode
from _pytest.config import hookimpl from _pytest.config import hookimpl
from _pytest.config import UsageError from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureManager from _pytest.fixtures import FixtureManager
from _pytest.outcomes import exit from _pytest.outcomes import exit
from _pytest.reports import CollectReport from _pytest.reports import CollectReport
@ -38,7 +42,7 @@ if TYPE_CHECKING:
from _pytest.python import Package from _pytest.python import Package
def pytest_addoption(parser): def pytest_addoption(parser: Parser) -> None:
parser.addini( parser.addini(
"norecursedirs", "norecursedirs",
"directory patterns to avoid for recursion", "directory patterns to avoid for recursion",
@ -241,7 +245,7 @@ def wrap_session(
return session.exitstatus return session.exitstatus
def pytest_cmdline_main(config): def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
return wrap_session(config, _main) return wrap_session(config, _main)
@ -258,11 +262,11 @@ def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
return None return None
def pytest_collection(session): def pytest_collection(session: "Session") -> Sequence[nodes.Item]:
return session.perform_collect() return session.perform_collect()
def pytest_runtestloop(session): def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors: if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted( raise session.Interrupted(
"%d error%s during collection" "%d error%s during collection"
@ -282,7 +286,7 @@ def pytest_runtestloop(session):
return True return True
def _in_venv(path): def _in_venv(path: py.path.local) -> bool:
"""Attempts to detect if ``path`` is the root of a Virtual Environment by """Attempts to detect if ``path`` is the root of a Virtual Environment by
checking for the existence of the appropriate activate script""" checking for the existence of the appropriate activate script"""
bindir = path.join("Scripts" if sys.platform.startswith("win") else "bin") bindir = path.join("Scripts" if sys.platform.startswith("win") else "bin")
@ -328,7 +332,7 @@ def pytest_ignore_collect(
return None return None
def pytest_collection_modifyitems(items, config): def pytest_collection_modifyitems(items, config: Config) -> None:
deselect_prefixes = tuple(config.getoption("deselect") or []) deselect_prefixes = tuple(config.getoption("deselect") or [])
if not deselect_prefixes: if not deselect_prefixes:
return return
@ -385,8 +389,8 @@ class Session(nodes.FSCollector):
) )
self.testsfailed = 0 self.testsfailed = 0
self.testscollected = 0 self.testscollected = 0
self.shouldstop = False self.shouldstop = False # type: Union[bool, str]
self.shouldfail = False self.shouldfail = False # type: Union[bool, str]
self.trace = config.trace.root.get("collection") self.trace = config.trace.root.get("collection")
self.startdir = config.invocation_dir self.startdir = config.invocation_dir
self._initialpaths = frozenset() # type: FrozenSet[py.path.local] self._initialpaths = frozenset() # type: FrozenSet[py.path.local]
@ -412,10 +416,11 @@ class Session(nodes.FSCollector):
self.config.pluginmanager.register(self, name="session") self.config.pluginmanager.register(self, name="session")
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config: Config) -> "Session":
return cls._create(config) session = cls._create(config) # type: Session
return session
def __repr__(self): def __repr__(self) -> str:
return "<%s %s exitstatus=%r testsfailed=%d testscollected=%d>" % ( return "<%s %s exitstatus=%r testsfailed=%d testscollected=%d>" % (
self.__class__.__name__, self.__class__.__name__,
self.name, self.name,
@ -429,14 +434,14 @@ class Session(nodes.FSCollector):
return self._bestrelpathcache[node_path] return self._bestrelpathcache[node_path]
@hookimpl(tryfirst=True) @hookimpl(tryfirst=True)
def pytest_collectstart(self): def pytest_collectstart(self) -> None:
if self.shouldfail: if self.shouldfail:
raise self.Failed(self.shouldfail) raise self.Failed(self.shouldfail)
if self.shouldstop: if self.shouldstop:
raise self.Interrupted(self.shouldstop) raise self.Interrupted(self.shouldstop)
@hookimpl(tryfirst=True) @hookimpl(tryfirst=True)
def pytest_runtest_logreport(self, report): def pytest_runtest_logreport(self, report) -> None:
if report.failed and not hasattr(report, "wasxfail"): if report.failed and not hasattr(report, "wasxfail"):
self.testsfailed += 1 self.testsfailed += 1
maxfail = self.config.getvalue("maxfail") maxfail = self.config.getvalue("maxfail")
@ -445,13 +450,27 @@ class Session(nodes.FSCollector):
pytest_collectreport = pytest_runtest_logreport pytest_collectreport = pytest_runtest_logreport
def isinitpath(self, path): def isinitpath(self, path: py.path.local) -> bool:
return path in self._initialpaths return path in self._initialpaths
def gethookproxy(self, fspath: py.path.local): def gethookproxy(self, fspath: py.path.local):
return super()._gethookproxy(fspath) return super()._gethookproxy(fspath)
def perform_collect(self, args=None, genitems=True): @overload
def perform_collect(
self, args: Optional[Sequence[str]] = ..., genitems: "Literal[True]" = ...
) -> Sequence[nodes.Item]:
raise NotImplementedError()
@overload # noqa: F811
def perform_collect( # noqa: F811
self, args: Optional[Sequence[str]] = ..., genitems: bool = ...
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
raise NotImplementedError()
def perform_collect( # noqa: F811
self, args: Optional[Sequence[str]] = None, genitems: bool = True
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
hook = self.config.hook hook = self.config.hook
try: try:
items = self._perform_collect(args, genitems) items = self._perform_collect(args, genitems)
@ -464,15 +483,29 @@ class Session(nodes.FSCollector):
self.testscollected = len(items) self.testscollected = len(items)
return items return items
def _perform_collect(self, args, genitems): @overload
def _perform_collect(
self, args: Optional[Sequence[str]], genitems: "Literal[True]"
) -> Sequence[nodes.Item]:
raise NotImplementedError()
@overload # noqa: F811
def _perform_collect( # noqa: F811
self, args: Optional[Sequence[str]], genitems: bool
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
raise NotImplementedError()
def _perform_collect( # noqa: F811
self, args: Optional[Sequence[str]], genitems: bool
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
if args is None: if args is None:
args = self.config.args args = self.config.args
self.trace("perform_collect", self, args) self.trace("perform_collect", self, args)
self.trace.root.indent += 1 self.trace.root.indent += 1
self._notfound = [] self._notfound = [] # type: List[Tuple[str, NoMatch]]
initialpaths = [] # type: List[py.path.local] initialpaths = [] # type: List[py.path.local]
self._initial_parts = [] # type: List[Tuple[py.path.local, List[str]]] self._initial_parts = [] # type: List[Tuple[py.path.local, List[str]]]
self.items = items = [] self.items = items = [] # type: List[nodes.Item]
for arg in args: for arg in args:
fspath, parts = self._parsearg(arg) fspath, parts = self._parsearg(arg)
self._initial_parts.append((fspath, parts)) self._initial_parts.append((fspath, parts))
@ -495,7 +528,7 @@ class Session(nodes.FSCollector):
self.items.extend(self.genitems(node)) self.items.extend(self.genitems(node))
return items return items
def collect(self): def collect(self) -> Iterator[Union[nodes.Item, nodes.Collector]]:
for fspath, parts in self._initial_parts: for fspath, parts in self._initial_parts:
self.trace("processing argument", (fspath, parts)) self.trace("processing argument", (fspath, parts))
self.trace.root.indent += 1 self.trace.root.indent += 1
@ -513,7 +546,9 @@ class Session(nodes.FSCollector):
self._collection_node_cache3.clear() self._collection_node_cache3.clear()
self._collection_pkg_roots.clear() self._collection_pkg_roots.clear()
def _collect(self, argpath, names): def _collect(
self, argpath: py.path.local, names: List[str]
) -> Iterator[Union[nodes.Item, nodes.Collector]]:
from _pytest.python import Package from _pytest.python import Package
# Start with a Session root, and delve to argpath item (dir or file) # Start with a Session root, and delve to argpath item (dir or file)
@ -541,7 +576,7 @@ class Session(nodes.FSCollector):
if argpath.check(dir=1): if argpath.check(dir=1):
assert not names, "invalid arg {!r}".format((argpath, names)) assert not names, "invalid arg {!r}".format((argpath, names))
seen_dirs = set() seen_dirs = set() # type: Set[py.path.local]
for path in argpath.visit( for path in argpath.visit(
fil=self._visit_filter, rec=self._recurse, bf=True, sort=True fil=self._visit_filter, rec=self._recurse, bf=True, sort=True
): ):
@ -582,8 +617,9 @@ class Session(nodes.FSCollector):
# Module itself, so just use that. If this special case isn't taken, then all # Module itself, so just use that. If this special case isn't taken, then all
# the files in the package will be yielded. # the files in the package will be yielded.
if argpath.basename == "__init__.py": if argpath.basename == "__init__.py":
assert isinstance(m[0], nodes.Collector)
try: try:
yield next(m[0].collect()) yield next(iter(m[0].collect()))
except StopIteration: except StopIteration:
# The package collects nothing with only an __init__.py # The package collects nothing with only an __init__.py
# file in it, which gets ignored by the default # file in it, which gets ignored by the default
@ -593,10 +629,11 @@ class Session(nodes.FSCollector):
yield from m yield from m
@staticmethod @staticmethod
def _visit_filter(f): def _visit_filter(f: py.path.local) -> bool:
return f.check(file=1) # TODO: Remove type: ignore once `py` is typed.
return f.check(file=1) # type: ignore
def _tryconvertpyarg(self, x): def _tryconvertpyarg(self, x: str) -> str:
"""Convert a dotted module name to path.""" """Convert a dotted module name to path."""
try: try:
spec = importlib.util.find_spec(x) spec = importlib.util.find_spec(x)
@ -605,14 +642,14 @@ class Session(nodes.FSCollector):
# ValueError: not a module name # ValueError: not a module name
except (AttributeError, ImportError, ValueError): except (AttributeError, ImportError, ValueError):
return x return x
if spec is None or spec.origin in {None, "namespace"}: if spec is None or spec.origin is None or spec.origin == "namespace":
return x return x
elif spec.submodule_search_locations: elif spec.submodule_search_locations:
return os.path.dirname(spec.origin) return os.path.dirname(spec.origin)
else: else:
return spec.origin return spec.origin
def _parsearg(self, arg): def _parsearg(self, arg: str) -> Tuple[py.path.local, List[str]]:
""" return (fspath, names) tuple after checking the file exists. """ """ return (fspath, names) tuple after checking the file exists. """
strpath, *parts = str(arg).split("::") strpath, *parts = str(arg).split("::")
if self.config.option.pyargs: if self.config.option.pyargs:
@ -628,7 +665,9 @@ class Session(nodes.FSCollector):
fspath = fspath.realpath() fspath = fspath.realpath()
return (fspath, parts) return (fspath, parts)
def matchnodes(self, matching, names): def matchnodes(
self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str],
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
self.trace("matchnodes", matching, names) self.trace("matchnodes", matching, names)
self.trace.root.indent += 1 self.trace.root.indent += 1
nodes = self._matchnodes(matching, names) nodes = self._matchnodes(matching, names)
@ -639,13 +678,15 @@ class Session(nodes.FSCollector):
raise NoMatch(matching, names[:1]) raise NoMatch(matching, names[:1])
return nodes return nodes
def _matchnodes(self, matching, names): def _matchnodes(
self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str],
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
if not matching or not names: if not matching or not names:
return matching return matching
name = names[0] name = names[0]
assert name assert name
nextnames = names[1:] nextnames = names[1:]
resultnodes = [] resultnodes = [] # type: List[Union[nodes.Item, nodes.Collector]]
for node in matching: for node in matching:
if isinstance(node, nodes.Item): if isinstance(node, nodes.Item):
if not names: if not names:
@ -676,7 +717,9 @@ class Session(nodes.FSCollector):
node.ihook.pytest_collectreport(report=rep) node.ihook.pytest_collectreport(report=rep)
return resultnodes return resultnodes
def genitems(self, node): def genitems(
self, node: Union[nodes.Item, nodes.Collector]
) -> Iterator[nodes.Item]:
self.trace("genitems", node) self.trace("genitems", node)
if isinstance(node, nodes.Item): if isinstance(node, nodes.Item):
node.ihook.pytest_itemcollected(item=node) node.ihook.pytest_itemcollected(item=node)

View File

@ -4,8 +4,10 @@ from functools import lru_cache
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import Iterable
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Sequence
from typing import Set from typing import Set
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
@ -226,7 +228,7 @@ class Node(metaclass=NodeMeta):
# methods for ordering nodes # methods for ordering nodes
@property @property
def nodeid(self): def nodeid(self) -> str:
""" a ::-separated string denoting its collection tree address. """ """ a ::-separated string denoting its collection tree address. """
return self._nodeid return self._nodeid
@ -423,7 +425,7 @@ class Collector(Node):
class CollectError(Exception): class CollectError(Exception):
""" an error during collection, contains a custom message. """ """ an error during collection, contains a custom message. """
def collect(self): def collect(self) -> Iterable[Union["Item", "Collector"]]:
""" returns a list of children (items and collectors) """ returns a list of children (items and collectors)
for this collection node. for this collection node.
""" """
@ -522,6 +524,9 @@ class FSCollector(Collector):
proxy = self.config.hook proxy = self.config.hook
return proxy return proxy
def gethookproxy(self, fspath: py.path.local):
raise NotImplementedError()
def _recurse(self, dirpath: py.path.local) -> bool: def _recurse(self, dirpath: py.path.local) -> bool:
if dirpath.basename == "__pycache__": if dirpath.basename == "__pycache__":
return False return False
@ -535,7 +540,12 @@ class FSCollector(Collector):
ihook.pytest_collect_directory(path=dirpath, parent=self) ihook.pytest_collect_directory(path=dirpath, parent=self)
return True return True
def _collectfile(self, path, handle_dupes=True): def isinitpath(self, path: py.path.local) -> bool:
raise NotImplementedError()
def _collectfile(
self, path: py.path.local, handle_dupes: bool = True
) -> Sequence[Collector]:
assert ( assert (
path.isfile() path.isfile()
), "{!r} is not a file (isdir={!r}, exists={!r}, islink={!r})".format( ), "{!r} is not a file (isdir={!r}, exists={!r}, islink={!r})".format(
@ -555,7 +565,7 @@ class FSCollector(Collector):
else: else:
duplicate_paths.add(path) duplicate_paths.add(path)
return ihook.pytest_collect_file(path=path, parent=self) return ihook.pytest_collect_file(path=path, parent=self) # type: ignore[no-any-return] # noqa: F723
class File(FSCollector): class File(FSCollector):

View File

@ -43,9 +43,9 @@ from _pytest.compat import REGEX_TYPE
from _pytest.compat import safe_getattr from _pytest.compat import safe_getattr
from _pytest.compat import safe_isclass from _pytest.compat import safe_isclass
from _pytest.compat import STRING_TYPES from _pytest.compat import STRING_TYPES
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import ExitCode from _pytest.config import ExitCode
from _pytest.compat import TYPE_CHECKING
from _pytest.config import hookimpl from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.deprecated import FUNCARGNAMES from _pytest.deprecated import FUNCARGNAMES
@ -184,16 +184,20 @@ def pytest_pyfunc_call(pyfuncitem: "Function"):
return True return True
def pytest_collect_file(path, parent): def pytest_collect_file(path: py.path.local, parent) -> Optional["Module"]:
ext = path.ext ext = path.ext
if ext == ".py": if ext == ".py":
if not parent.session.isinitpath(path): if not parent.session.isinitpath(path):
if not path_matches_patterns( if not path_matches_patterns(
path, parent.config.getini("python_files") + ["__init__.py"] path, parent.config.getini("python_files") + ["__init__.py"]
): ):
return return None
ihook = parent.session.gethookproxy(path) ihook = parent.session.gethookproxy(path)
return ihook.pytest_pycollect_makemodule(path=path, parent=parent) module = ihook.pytest_pycollect_makemodule(
path=path, parent=parent
) # type: Module
return module
return None
def path_matches_patterns(path, patterns): def path_matches_patterns(path, patterns):
@ -201,14 +205,16 @@ def path_matches_patterns(path, patterns):
return any(path.fnmatch(pattern) for pattern in patterns) return any(path.fnmatch(pattern) for pattern in patterns)
def pytest_pycollect_makemodule(path, parent): def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Module":
if path.basename == "__init__.py": if path.basename == "__init__.py":
return Package.from_parent(parent, fspath=path) pkg = Package.from_parent(parent, fspath=path) # type: Package
return Module.from_parent(parent, fspath=path) return pkg
mod = Module.from_parent(parent, fspath=path) # type: Module
return mod
@hookimpl(hookwrapper=True) @hookimpl(hookwrapper=True)
def pytest_pycollect_makeitem(collector, name, obj): def pytest_pycollect_makeitem(collector: "PyCollector", name: str, obj):
outcome = yield outcome = yield
res = outcome.get_result() res = outcome.get_result()
if res is not None: if res is not None:
@ -372,7 +378,7 @@ class PyCollector(PyobjMixin, nodes.Collector):
return True return True
return False return False
def collect(self): def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
if not getattr(self.obj, "__test__", True): if not getattr(self.obj, "__test__", True):
return [] return []
@ -381,8 +387,8 @@ class PyCollector(PyobjMixin, nodes.Collector):
dicts = [getattr(self.obj, "__dict__", {})] dicts = [getattr(self.obj, "__dict__", {})]
for basecls in self.obj.__class__.__mro__: for basecls in self.obj.__class__.__mro__:
dicts.append(basecls.__dict__) dicts.append(basecls.__dict__)
seen = set() seen = set() # type: Set[str]
values = [] values = [] # type: List[Union[nodes.Item, nodes.Collector]]
for dic in dicts: for dic in dicts:
# Note: seems like the dict can change during iteration - # Note: seems like the dict can change during iteration -
# be careful not to remove the list() without consideration. # be careful not to remove the list() without consideration.
@ -404,9 +410,16 @@ class PyCollector(PyobjMixin, nodes.Collector):
values.sort(key=sort_key) values.sort(key=sort_key)
return values return values
def _makeitem(self, name, obj): def _makeitem(
self, name: str, obj
) -> Union[
None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]]
]:
# assert self.ihook.fspath == self.fspath, self # assert self.ihook.fspath == self.fspath, self
return self.ihook.pytest_pycollect_makeitem(collector=self, name=name, obj=obj) item = self.ihook.pytest_pycollect_makeitem(
collector=self, name=name, obj=obj
) # type: Union[None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]]]
return item
def _genfunctions(self, name, funcobj): def _genfunctions(self, name, funcobj):
module = self.getparent(Module).obj module = self.getparent(Module).obj
@ -458,7 +471,7 @@ class Module(nodes.File, PyCollector):
def _getobj(self): def _getobj(self):
return self._importtestmodule() return self._importtestmodule()
def collect(self): def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
self._inject_setup_module_fixture() self._inject_setup_module_fixture()
self._inject_setup_function_fixture() self._inject_setup_function_fixture()
self.session._fixturemanager.parsefactories(self) self.session._fixturemanager.parsefactories(self)
@ -603,17 +616,17 @@ class Package(Module):
def gethookproxy(self, fspath: py.path.local): def gethookproxy(self, fspath: py.path.local):
return super()._gethookproxy(fspath) return super()._gethookproxy(fspath)
def isinitpath(self, path): def isinitpath(self, path: py.path.local) -> bool:
return path in self.session._initialpaths return path in self.session._initialpaths
def collect(self): def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
this_path = self.fspath.dirpath() this_path = self.fspath.dirpath()
init_module = this_path.join("__init__.py") init_module = this_path.join("__init__.py")
if init_module.check(file=1) and path_matches_patterns( if init_module.check(file=1) and path_matches_patterns(
init_module, self.config.getini("python_files") init_module, self.config.getini("python_files")
): ):
yield Module.from_parent(self, fspath=init_module) yield Module.from_parent(self, fspath=init_module)
pkg_prefixes = set() pkg_prefixes = set() # type: Set[py.path.local]
for path in this_path.visit(rec=self._recurse, bf=True, sort=True): for path in this_path.visit(rec=self._recurse, bf=True, sort=True):
# We will visit our own __init__.py file, in which case we skip it. # We will visit our own __init__.py file, in which case we skip it.
is_file = path.isfile() is_file = path.isfile()
@ -670,10 +683,11 @@ class Class(PyCollector):
""" """
return super().from_parent(name=name, parent=parent) return super().from_parent(name=name, parent=parent)
def collect(self): def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
if not safe_getattr(self.obj, "__test__", True): if not safe_getattr(self.obj, "__test__", True):
return [] return []
if hasinit(self.obj): if hasinit(self.obj):
assert self.parent is not None
self.warn( self.warn(
PytestCollectionWarning( PytestCollectionWarning(
"cannot collect test class %r because it has a " "cannot collect test class %r because it has a "
@ -683,6 +697,7 @@ class Class(PyCollector):
) )
return [] return []
elif hasnew(self.obj): elif hasnew(self.obj):
assert self.parent is not None
self.warn( self.warn(
PytestCollectionWarning( PytestCollectionWarning(
"cannot collect test class %r because it has a " "cannot collect test class %r because it has a "
@ -756,7 +771,7 @@ class Instance(PyCollector):
def _getobj(self): def _getobj(self):
return self.parent.obj() return self.parent.obj()
def collect(self): def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
self.session._fixturemanager.parsefactories(self) self.session._fixturemanager.parsefactories(self)
return super().collect() return super().collect()

View File

@ -21,7 +21,8 @@ from _pytest._code.code import ReprTraceback
from _pytest._code.code import TerminalRepr from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter from _pytest._io import TerminalWriter
from _pytest.compat import TYPE_CHECKING from _pytest.compat import TYPE_CHECKING
from _pytest.nodes import Node from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import skip from _pytest.outcomes import skip
from _pytest.pathlib import Path from _pytest.pathlib import Path
@ -316,7 +317,13 @@ class CollectReport(BaseReport):
when = "collect" when = "collect"
def __init__( def __init__(
self, nodeid: str, outcome, longrepr, result: List[Node], sections=(), **extra self,
nodeid: str,
outcome,
longrepr,
result: Optional[List[Union[Item, Collector]]],
sections=(),
**extra
) -> None: ) -> None:
self.nodeid = nodeid self.nodeid = nodeid
self.outcome = outcome self.outcome = outcome

View File

@ -404,10 +404,10 @@ class SetupState:
raise e raise e
def collect_one_node(collector): def collect_one_node(collector: Collector) -> CollectReport:
ihook = collector.ihook ihook = collector.ihook
ihook.pytest_collectstart(collector=collector) ihook.pytest_collectstart(collector=collector)
rep = ihook.pytest_make_collect_report(collector=collector) rep = ihook.pytest_make_collect_report(collector=collector) # type: CollectReport
call = rep.__dict__.pop("call", None) call = rep.__dict__.pop("call", None)
if call and check_interactive_exception(call, rep): if call and check_interactive_exception(call, rep):
ihook.pytest_exception_interact(node=collector, call=call, report=rep) ihook.pytest_exception_interact(node=collector, call=call, report=rep)

View File

@ -1,32 +1,43 @@
""" discovery and running of std-library "unittest" style tests. """ """ discovery and running of std-library "unittest" style tests. """
import sys import sys
import traceback import traceback
from typing import Iterable
from typing import Optional
from typing import Union
import _pytest._code import _pytest._code
import pytest import pytest
from _pytest.compat import getimfunc from _pytest.compat import getimfunc
from _pytest.compat import is_async_function from _pytest.compat import is_async_function
from _pytest.config import hookimpl from _pytest.config import hookimpl
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import exit from _pytest.outcomes import exit
from _pytest.outcomes import fail from _pytest.outcomes import fail
from _pytest.outcomes import skip from _pytest.outcomes import skip
from _pytest.outcomes import xfail from _pytest.outcomes import xfail
from _pytest.python import Class from _pytest.python import Class
from _pytest.python import Function from _pytest.python import Function
from _pytest.python import PyCollector
from _pytest.runner import CallInfo from _pytest.runner import CallInfo
from _pytest.skipping import skipped_by_mark_key from _pytest.skipping import skipped_by_mark_key
from _pytest.skipping import unexpectedsuccess_key from _pytest.skipping import unexpectedsuccess_key
def pytest_pycollect_makeitem(collector, name, obj): def pytest_pycollect_makeitem(
collector: PyCollector, name: str, obj
) -> Optional["UnitTestCase"]:
# has unittest been imported and is obj a subclass of its TestCase? # has unittest been imported and is obj a subclass of its TestCase?
try: try:
if not issubclass(obj, sys.modules["unittest"].TestCase): ut = sys.modules["unittest"]
return # Type ignored because `ut` is an opaque module.
if not issubclass(obj, ut.TestCase): # type: ignore
return None
except Exception: except Exception:
return return None
# yes, so let's collect it # yes, so let's collect it
return UnitTestCase.from_parent(collector, name=name, obj=obj) item = UnitTestCase.from_parent(collector, name=name, obj=obj) # type: UnitTestCase
return item
class UnitTestCase(Class): class UnitTestCase(Class):
@ -34,7 +45,7 @@ class UnitTestCase(Class):
# to declare that our children do not support funcargs # to declare that our children do not support funcargs
nofuncargs = True nofuncargs = True
def collect(self): def collect(self) -> Iterable[Union[Item, Collector]]:
from unittest import TestLoader from unittest import TestLoader
cls = self.obj cls = self.obj
@ -61,8 +72,8 @@ class UnitTestCase(Class):
runtest = getattr(self.obj, "runTest", None) runtest = getattr(self.obj, "runTest", None)
if runtest is not None: if runtest is not None:
ut = sys.modules.get("twisted.trial.unittest", None) ut = sys.modules.get("twisted.trial.unittest", None)
if ut is None or runtest != ut.TestCase.runTest: # Type ignored because `ut` is an opaque module.
# TODO: callobj consistency if ut is None or runtest != ut.TestCase.runTest: # type: ignore
yield TestCaseFunction.from_parent(self, name="runTest") yield TestCaseFunction.from_parent(self, name="runTest")
def _inject_setup_teardown_fixtures(self, cls): def _inject_setup_teardown_fixtures(self, cls):