diff --git a/src/_pytest/config/__init__.py b/src/_pytest/config/__init__.py index d9abc17b4..ff6aee744 100644 --- a/src/_pytest/config/__init__.py +++ b/src/_pytest/config/__init__.py @@ -840,7 +840,7 @@ class Config: self.cache = None # type: Optional[Cache] @property - def invocation_dir(self): + def invocation_dir(self) -> py.path.local: """Backward compatibility""" return py.path.local(str(self.invocation_params.dir)) diff --git a/src/_pytest/doctest.py b/src/_pytest/doctest.py index 50f115cd1..026476b8a 100644 --- a/src/_pytest/doctest.py +++ b/src/_pytest/doctest.py @@ -7,6 +7,7 @@ import traceback import warnings from contextlib import contextmanager from typing import Dict +from typing import Iterable from typing import List from typing import Optional from typing import Sequence @@ -109,13 +110,18 @@ def pytest_unconfigure() -> None: RUNNER_CLASS = None -def pytest_collect_file(path: py.path.local, parent): +def pytest_collect_file( + path: py.path.local, parent +) -> Optional[Union["DoctestModule", "DoctestTextfile"]]: config = parent.config if path.ext == ".py": if config.option.doctestmodules and not _is_setup_py(path): - return DoctestModule.from_parent(parent, fspath=path) + mod = DoctestModule.from_parent(parent, fspath=path) # type: DoctestModule + return mod elif _is_doctest(config, path, parent): - return DoctestTextfile.from_parent(parent, fspath=path) + txt = DoctestTextfile.from_parent(parent, fspath=path) # type: DoctestTextfile + return txt + return None def _is_setup_py(path: py.path.local) -> bool: @@ -365,7 +371,7 @@ def _get_continue_on_failure(config): class DoctestTextfile(pytest.Module): obj = None - def collect(self): + def collect(self) -> Iterable[DoctestItem]: import doctest # inspired by doctest.testfile; ideally we would use it directly, @@ -444,7 +450,7 @@ def _patch_unwrap_mock_aware(): class DoctestModule(pytest.Module): - def collect(self): + def collect(self) -> Iterable[DoctestItem]: import doctest class MockAwareDocTestFinder(doctest.DocTestFinder): diff --git a/src/_pytest/hookspec.py b/src/_pytest/hookspec.py index 8b4505691..1321eff54 100644 --- a/src/_pytest/hookspec.py +++ b/src/_pytest/hookspec.py @@ -6,6 +6,7 @@ from typing import Optional from typing import Tuple from typing import Union +import py.path from pluggy import HookspecMarker from .deprecated import COLLECT_DIRECTORY_HOOK @@ -20,9 +21,14 @@ if TYPE_CHECKING: from _pytest.config import _PluggyPlugin from _pytest.config.argparsing import Parser from _pytest.main import Session + from _pytest.nodes import Collector + from _pytest.nodes import Item from _pytest.python import Metafunc + from _pytest.python import Module + from _pytest.python import PyCollector from _pytest.reports import BaseReport + hookspec = HookspecMarker("pytest") # ------------------------------------------------------------------------- @@ -249,7 +255,7 @@ def pytest_collect_directory(path, parent): """ -def pytest_collect_file(path, parent): +def pytest_collect_file(path: py.path.local, parent) -> "Optional[Collector]": """ return collection Node or None for the given path. Any new node needs to have the specified ``parent`` as a parent. @@ -289,7 +295,7 @@ def pytest_make_collect_report(collector): @hookspec(firstresult=True) -def pytest_pycollect_makemodule(path, parent): +def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Optional[Module]": """ return a Module collector or None for the given path. This hook will be called for each matching test module path. The pytest_collect_file hook needs to be used if you want to @@ -302,7 +308,9 @@ def pytest_pycollect_makemodule(path, parent): @hookspec(firstresult=True) -def pytest_pycollect_makeitem(collector, name, obj): +def pytest_pycollect_makeitem( + collector: "PyCollector", name: str, obj +) -> "Union[None, Item, Collector, List[Union[Item, Collector]]]": """ return custom item/collector for a python object in a module, or None. Stops at first non-None result, see :ref:`firstresult` """ diff --git a/src/_pytest/main.py b/src/_pytest/main.py index 4eb47be2c..a0007d226 100644 --- a/src/_pytest/main.py +++ b/src/_pytest/main.py @@ -7,9 +7,11 @@ import sys from typing import Callable from typing import Dict from typing import FrozenSet +from typing import Iterator from typing import List from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple from typing import Union @@ -18,12 +20,14 @@ import py import _pytest._code from _pytest import nodes +from _pytest.compat import overload from _pytest.compat import TYPE_CHECKING from _pytest.config import Config from _pytest.config import directory_arg from _pytest.config import ExitCode from _pytest.config import hookimpl from _pytest.config import UsageError +from _pytest.config.argparsing import Parser from _pytest.fixtures import FixtureManager from _pytest.outcomes import exit from _pytest.reports import CollectReport @@ -38,7 +42,7 @@ if TYPE_CHECKING: from _pytest.python import Package -def pytest_addoption(parser): +def pytest_addoption(parser: Parser) -> None: parser.addini( "norecursedirs", "directory patterns to avoid for recursion", @@ -241,7 +245,7 @@ def wrap_session( return session.exitstatus -def pytest_cmdline_main(config): +def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]: return wrap_session(config, _main) @@ -258,11 +262,11 @@ def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]: return None -def pytest_collection(session): +def pytest_collection(session: "Session") -> Sequence[nodes.Item]: return session.perform_collect() -def pytest_runtestloop(session): +def pytest_runtestloop(session: "Session") -> bool: if session.testsfailed and not session.config.option.continue_on_collection_errors: raise session.Interrupted( "%d error%s during collection" @@ -282,7 +286,7 @@ def pytest_runtestloop(session): return True -def _in_venv(path): +def _in_venv(path: py.path.local) -> bool: """Attempts to detect if ``path`` is the root of a Virtual Environment by checking for the existence of the appropriate activate script""" bindir = path.join("Scripts" if sys.platform.startswith("win") else "bin") @@ -328,7 +332,7 @@ def pytest_ignore_collect( return None -def pytest_collection_modifyitems(items, config): +def pytest_collection_modifyitems(items, config: Config) -> None: deselect_prefixes = tuple(config.getoption("deselect") or []) if not deselect_prefixes: return @@ -385,8 +389,8 @@ class Session(nodes.FSCollector): ) self.testsfailed = 0 self.testscollected = 0 - self.shouldstop = False - self.shouldfail = False + self.shouldstop = False # type: Union[bool, str] + self.shouldfail = False # type: Union[bool, str] self.trace = config.trace.root.get("collection") self.startdir = config.invocation_dir self._initialpaths = frozenset() # type: FrozenSet[py.path.local] @@ -412,10 +416,11 @@ class Session(nodes.FSCollector): self.config.pluginmanager.register(self, name="session") @classmethod - def from_config(cls, config): - return cls._create(config) + def from_config(cls, config: Config) -> "Session": + session = cls._create(config) # type: Session + return session - def __repr__(self): + def __repr__(self) -> str: return "<%s %s exitstatus=%r testsfailed=%d testscollected=%d>" % ( self.__class__.__name__, self.name, @@ -429,14 +434,14 @@ class Session(nodes.FSCollector): return self._bestrelpathcache[node_path] @hookimpl(tryfirst=True) - def pytest_collectstart(self): + def pytest_collectstart(self) -> None: if self.shouldfail: raise self.Failed(self.shouldfail) if self.shouldstop: raise self.Interrupted(self.shouldstop) @hookimpl(tryfirst=True) - def pytest_runtest_logreport(self, report): + def pytest_runtest_logreport(self, report) -> None: if report.failed and not hasattr(report, "wasxfail"): self.testsfailed += 1 maxfail = self.config.getvalue("maxfail") @@ -445,13 +450,27 @@ class Session(nodes.FSCollector): pytest_collectreport = pytest_runtest_logreport - def isinitpath(self, path): + def isinitpath(self, path: py.path.local) -> bool: return path in self._initialpaths def gethookproxy(self, fspath: py.path.local): return super()._gethookproxy(fspath) - def perform_collect(self, args=None, genitems=True): + @overload + def perform_collect( + self, args: Optional[Sequence[str]] = ..., genitems: "Literal[True]" = ... + ) -> Sequence[nodes.Item]: + raise NotImplementedError() + + @overload # noqa: F811 + def perform_collect( # noqa: F811 + self, args: Optional[Sequence[str]] = ..., genitems: bool = ... + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: + raise NotImplementedError() + + def perform_collect( # noqa: F811 + self, args: Optional[Sequence[str]] = None, genitems: bool = True + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: hook = self.config.hook try: items = self._perform_collect(args, genitems) @@ -464,15 +483,29 @@ class Session(nodes.FSCollector): self.testscollected = len(items) return items - def _perform_collect(self, args, genitems): + @overload + def _perform_collect( + self, args: Optional[Sequence[str]], genitems: "Literal[True]" + ) -> Sequence[nodes.Item]: + raise NotImplementedError() + + @overload # noqa: F811 + def _perform_collect( # noqa: F811 + self, args: Optional[Sequence[str]], genitems: bool + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: + raise NotImplementedError() + + def _perform_collect( # noqa: F811 + self, args: Optional[Sequence[str]], genitems: bool + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: if args is None: args = self.config.args self.trace("perform_collect", self, args) self.trace.root.indent += 1 - self._notfound = [] + self._notfound = [] # type: List[Tuple[str, NoMatch]] initialpaths = [] # type: List[py.path.local] self._initial_parts = [] # type: List[Tuple[py.path.local, List[str]]] - self.items = items = [] + self.items = items = [] # type: List[nodes.Item] for arg in args: fspath, parts = self._parsearg(arg) self._initial_parts.append((fspath, parts)) @@ -495,7 +528,7 @@ class Session(nodes.FSCollector): self.items.extend(self.genitems(node)) return items - def collect(self): + def collect(self) -> Iterator[Union[nodes.Item, nodes.Collector]]: for fspath, parts in self._initial_parts: self.trace("processing argument", (fspath, parts)) self.trace.root.indent += 1 @@ -513,7 +546,9 @@ class Session(nodes.FSCollector): self._collection_node_cache3.clear() self._collection_pkg_roots.clear() - def _collect(self, argpath, names): + def _collect( + self, argpath: py.path.local, names: List[str] + ) -> Iterator[Union[nodes.Item, nodes.Collector]]: from _pytest.python import Package # Start with a Session root, and delve to argpath item (dir or file) @@ -541,7 +576,7 @@ class Session(nodes.FSCollector): if argpath.check(dir=1): assert not names, "invalid arg {!r}".format((argpath, names)) - seen_dirs = set() + seen_dirs = set() # type: Set[py.path.local] for path in argpath.visit( fil=self._visit_filter, rec=self._recurse, bf=True, sort=True ): @@ -582,8 +617,9 @@ class Session(nodes.FSCollector): # Module itself, so just use that. If this special case isn't taken, then all # the files in the package will be yielded. if argpath.basename == "__init__.py": + assert isinstance(m[0], nodes.Collector) try: - yield next(m[0].collect()) + yield next(iter(m[0].collect())) except StopIteration: # The package collects nothing with only an __init__.py # file in it, which gets ignored by the default @@ -593,10 +629,11 @@ class Session(nodes.FSCollector): yield from m @staticmethod - def _visit_filter(f): - return f.check(file=1) + def _visit_filter(f: py.path.local) -> bool: + # TODO: Remove type: ignore once `py` is typed. + return f.check(file=1) # type: ignore - def _tryconvertpyarg(self, x): + def _tryconvertpyarg(self, x: str) -> str: """Convert a dotted module name to path.""" try: spec = importlib.util.find_spec(x) @@ -605,14 +642,14 @@ class Session(nodes.FSCollector): # ValueError: not a module name except (AttributeError, ImportError, ValueError): return x - if spec is None or spec.origin in {None, "namespace"}: + if spec is None or spec.origin is None or spec.origin == "namespace": return x elif spec.submodule_search_locations: return os.path.dirname(spec.origin) else: return spec.origin - def _parsearg(self, arg): + def _parsearg(self, arg: str) -> Tuple[py.path.local, List[str]]: """ return (fspath, names) tuple after checking the file exists. """ strpath, *parts = str(arg).split("::") if self.config.option.pyargs: @@ -628,7 +665,9 @@ class Session(nodes.FSCollector): fspath = fspath.realpath() return (fspath, parts) - def matchnodes(self, matching, names): + def matchnodes( + self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str], + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: self.trace("matchnodes", matching, names) self.trace.root.indent += 1 nodes = self._matchnodes(matching, names) @@ -639,13 +678,15 @@ class Session(nodes.FSCollector): raise NoMatch(matching, names[:1]) return nodes - def _matchnodes(self, matching, names): + def _matchnodes( + self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str], + ) -> Sequence[Union[nodes.Item, nodes.Collector]]: if not matching or not names: return matching name = names[0] assert name nextnames = names[1:] - resultnodes = [] + resultnodes = [] # type: List[Union[nodes.Item, nodes.Collector]] for node in matching: if isinstance(node, nodes.Item): if not names: @@ -676,7 +717,9 @@ class Session(nodes.FSCollector): node.ihook.pytest_collectreport(report=rep) return resultnodes - def genitems(self, node): + def genitems( + self, node: Union[nodes.Item, nodes.Collector] + ) -> Iterator[nodes.Item]: self.trace("genitems", node) if isinstance(node, nodes.Item): node.ihook.pytest_itemcollected(item=node) diff --git a/src/_pytest/nodes.py b/src/_pytest/nodes.py index c9b633579..010dce925 100644 --- a/src/_pytest/nodes.py +++ b/src/_pytest/nodes.py @@ -4,8 +4,10 @@ from functools import lru_cache from typing import Any from typing import Callable from typing import Dict +from typing import Iterable from typing import List from typing import Optional +from typing import Sequence from typing import Set from typing import Tuple from typing import Union @@ -226,7 +228,7 @@ class Node(metaclass=NodeMeta): # methods for ordering nodes @property - def nodeid(self): + def nodeid(self) -> str: """ a ::-separated string denoting its collection tree address. """ return self._nodeid @@ -423,7 +425,7 @@ class Collector(Node): class CollectError(Exception): """ an error during collection, contains a custom message. """ - def collect(self): + def collect(self) -> Iterable[Union["Item", "Collector"]]: """ returns a list of children (items and collectors) for this collection node. """ @@ -522,6 +524,9 @@ class FSCollector(Collector): proxy = self.config.hook return proxy + def gethookproxy(self, fspath: py.path.local): + raise NotImplementedError() + def _recurse(self, dirpath: py.path.local) -> bool: if dirpath.basename == "__pycache__": return False @@ -535,7 +540,12 @@ class FSCollector(Collector): ihook.pytest_collect_directory(path=dirpath, parent=self) return True - def _collectfile(self, path, handle_dupes=True): + def isinitpath(self, path: py.path.local) -> bool: + raise NotImplementedError() + + def _collectfile( + self, path: py.path.local, handle_dupes: bool = True + ) -> Sequence[Collector]: assert ( path.isfile() ), "{!r} is not a file (isdir={!r}, exists={!r}, islink={!r})".format( @@ -555,7 +565,7 @@ class FSCollector(Collector): else: duplicate_paths.add(path) - return ihook.pytest_collect_file(path=path, parent=self) + return ihook.pytest_collect_file(path=path, parent=self) # type: ignore[no-any-return] # noqa: F723 class File(FSCollector): diff --git a/src/_pytest/python.py b/src/_pytest/python.py index e46d498ab..e05aa398d 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -43,9 +43,9 @@ from _pytest.compat import REGEX_TYPE from _pytest.compat import safe_getattr from _pytest.compat import safe_isclass from _pytest.compat import STRING_TYPES +from _pytest.compat import TYPE_CHECKING from _pytest.config import Config from _pytest.config import ExitCode -from _pytest.compat import TYPE_CHECKING from _pytest.config import hookimpl from _pytest.config.argparsing import Parser from _pytest.deprecated import FUNCARGNAMES @@ -184,16 +184,20 @@ def pytest_pyfunc_call(pyfuncitem: "Function"): return True -def pytest_collect_file(path, parent): +def pytest_collect_file(path: py.path.local, parent) -> Optional["Module"]: ext = path.ext if ext == ".py": if not parent.session.isinitpath(path): if not path_matches_patterns( path, parent.config.getini("python_files") + ["__init__.py"] ): - return + return None ihook = parent.session.gethookproxy(path) - return ihook.pytest_pycollect_makemodule(path=path, parent=parent) + module = ihook.pytest_pycollect_makemodule( + path=path, parent=parent + ) # type: Module + return module + return None def path_matches_patterns(path, patterns): @@ -201,14 +205,16 @@ def path_matches_patterns(path, patterns): return any(path.fnmatch(pattern) for pattern in patterns) -def pytest_pycollect_makemodule(path, parent): +def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Module": if path.basename == "__init__.py": - return Package.from_parent(parent, fspath=path) - return Module.from_parent(parent, fspath=path) + pkg = Package.from_parent(parent, fspath=path) # type: Package + return pkg + mod = Module.from_parent(parent, fspath=path) # type: Module + return mod @hookimpl(hookwrapper=True) -def pytest_pycollect_makeitem(collector, name, obj): +def pytest_pycollect_makeitem(collector: "PyCollector", name: str, obj): outcome = yield res = outcome.get_result() if res is not None: @@ -372,7 +378,7 @@ class PyCollector(PyobjMixin, nodes.Collector): return True return False - def collect(self): + def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: if not getattr(self.obj, "__test__", True): return [] @@ -381,8 +387,8 @@ class PyCollector(PyobjMixin, nodes.Collector): dicts = [getattr(self.obj, "__dict__", {})] for basecls in self.obj.__class__.__mro__: dicts.append(basecls.__dict__) - seen = set() - values = [] + seen = set() # type: Set[str] + values = [] # type: List[Union[nodes.Item, nodes.Collector]] for dic in dicts: # Note: seems like the dict can change during iteration - # be careful not to remove the list() without consideration. @@ -404,9 +410,16 @@ class PyCollector(PyobjMixin, nodes.Collector): values.sort(key=sort_key) return values - def _makeitem(self, name, obj): + def _makeitem( + self, name: str, obj + ) -> Union[ + None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]] + ]: # assert self.ihook.fspath == self.fspath, self - return self.ihook.pytest_pycollect_makeitem(collector=self, name=name, obj=obj) + item = self.ihook.pytest_pycollect_makeitem( + collector=self, name=name, obj=obj + ) # type: Union[None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]]] + return item def _genfunctions(self, name, funcobj): module = self.getparent(Module).obj @@ -458,7 +471,7 @@ class Module(nodes.File, PyCollector): def _getobj(self): return self._importtestmodule() - def collect(self): + def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: self._inject_setup_module_fixture() self._inject_setup_function_fixture() self.session._fixturemanager.parsefactories(self) @@ -603,17 +616,17 @@ class Package(Module): def gethookproxy(self, fspath: py.path.local): return super()._gethookproxy(fspath) - def isinitpath(self, path): + def isinitpath(self, path: py.path.local) -> bool: return path in self.session._initialpaths - def collect(self): + def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: this_path = self.fspath.dirpath() init_module = this_path.join("__init__.py") if init_module.check(file=1) and path_matches_patterns( init_module, self.config.getini("python_files") ): yield Module.from_parent(self, fspath=init_module) - pkg_prefixes = set() + pkg_prefixes = set() # type: Set[py.path.local] for path in this_path.visit(rec=self._recurse, bf=True, sort=True): # We will visit our own __init__.py file, in which case we skip it. is_file = path.isfile() @@ -670,10 +683,11 @@ class Class(PyCollector): """ return super().from_parent(name=name, parent=parent) - def collect(self): + def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: if not safe_getattr(self.obj, "__test__", True): return [] if hasinit(self.obj): + assert self.parent is not None self.warn( PytestCollectionWarning( "cannot collect test class %r because it has a " @@ -683,6 +697,7 @@ class Class(PyCollector): ) return [] elif hasnew(self.obj): + assert self.parent is not None self.warn( PytestCollectionWarning( "cannot collect test class %r because it has a " @@ -756,7 +771,7 @@ class Instance(PyCollector): def _getobj(self): return self.parent.obj() - def collect(self): + def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: self.session._fixturemanager.parsefactories(self) return super().collect() diff --git a/src/_pytest/reports.py b/src/_pytest/reports.py index 178df6004..908ba7d3b 100644 --- a/src/_pytest/reports.py +++ b/src/_pytest/reports.py @@ -21,7 +21,8 @@ from _pytest._code.code import ReprTraceback from _pytest._code.code import TerminalRepr from _pytest._io import TerminalWriter from _pytest.compat import TYPE_CHECKING -from _pytest.nodes import Node +from _pytest.nodes import Collector +from _pytest.nodes import Item from _pytest.outcomes import skip from _pytest.pathlib import Path @@ -316,7 +317,13 @@ class CollectReport(BaseReport): when = "collect" def __init__( - self, nodeid: str, outcome, longrepr, result: List[Node], sections=(), **extra + self, + nodeid: str, + outcome, + longrepr, + result: Optional[List[Union[Item, Collector]]], + sections=(), + **extra ) -> None: self.nodeid = nodeid self.outcome = outcome diff --git a/src/_pytest/runner.py b/src/_pytest/runner.py index c7f6d8811..a2b9ee207 100644 --- a/src/_pytest/runner.py +++ b/src/_pytest/runner.py @@ -404,10 +404,10 @@ class SetupState: raise e -def collect_one_node(collector): +def collect_one_node(collector: Collector) -> CollectReport: ihook = collector.ihook ihook.pytest_collectstart(collector=collector) - rep = ihook.pytest_make_collect_report(collector=collector) + rep = ihook.pytest_make_collect_report(collector=collector) # type: CollectReport call = rep.__dict__.pop("call", None) if call and check_interactive_exception(call, rep): ihook.pytest_exception_interact(node=collector, call=call, report=rep) diff --git a/src/_pytest/unittest.py b/src/_pytest/unittest.py index 0d9133f60..b2e6ab89d 100644 --- a/src/_pytest/unittest.py +++ b/src/_pytest/unittest.py @@ -1,32 +1,43 @@ """ discovery and running of std-library "unittest" style tests. """ import sys import traceback +from typing import Iterable +from typing import Optional +from typing import Union import _pytest._code import pytest from _pytest.compat import getimfunc from _pytest.compat import is_async_function from _pytest.config import hookimpl +from _pytest.nodes import Collector +from _pytest.nodes import Item from _pytest.outcomes import exit from _pytest.outcomes import fail from _pytest.outcomes import skip from _pytest.outcomes import xfail from _pytest.python import Class from _pytest.python import Function +from _pytest.python import PyCollector from _pytest.runner import CallInfo from _pytest.skipping import skipped_by_mark_key from _pytest.skipping import unexpectedsuccess_key -def pytest_pycollect_makeitem(collector, name, obj): +def pytest_pycollect_makeitem( + collector: PyCollector, name: str, obj +) -> Optional["UnitTestCase"]: # has unittest been imported and is obj a subclass of its TestCase? try: - if not issubclass(obj, sys.modules["unittest"].TestCase): - return + ut = sys.modules["unittest"] + # Type ignored because `ut` is an opaque module. + if not issubclass(obj, ut.TestCase): # type: ignore + return None except Exception: - return + return None # yes, so let's collect it - return UnitTestCase.from_parent(collector, name=name, obj=obj) + item = UnitTestCase.from_parent(collector, name=name, obj=obj) # type: UnitTestCase + return item class UnitTestCase(Class): @@ -34,7 +45,7 @@ class UnitTestCase(Class): # to declare that our children do not support funcargs nofuncargs = True - def collect(self): + def collect(self) -> Iterable[Union[Item, Collector]]: from unittest import TestLoader cls = self.obj @@ -61,8 +72,8 @@ class UnitTestCase(Class): runtest = getattr(self.obj, "runTest", None) if runtest is not None: ut = sys.modules.get("twisted.trial.unittest", None) - if ut is None or runtest != ut.TestCase.runTest: - # TODO: callobj consistency + # Type ignored because `ut` is an opaque module. + if ut is None or runtest != ut.TestCase.runTest: # type: ignore yield TestCaseFunction.from_parent(self, name="runTest") def _inject_setup_teardown_fixtures(self, cls):