Add typing to `from_parent` return values (#11916)

Up to now the return values of `from_parent` were untyped, this is an
attempt to make it work with `typing.Self`.
This commit is contained in:
Ran Benita 2024-02-23 09:35:57 +02:00 committed by GitHub
parent 1640f2e454
commit 010ce2ab0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 41 additions and 33 deletions

View File

@ -47,6 +47,7 @@ from _pytest.warning_types import PytestWarning
if TYPE_CHECKING: if TYPE_CHECKING:
import doctest import doctest
from typing import Self
DOCTEST_REPORT_CHOICE_NONE = "none" DOCTEST_REPORT_CHOICE_NONE = "none"
DOCTEST_REPORT_CHOICE_CDIFF = "cdiff" DOCTEST_REPORT_CHOICE_CDIFF = "cdiff"
@ -133,11 +134,9 @@ def pytest_collect_file(
if config.option.doctestmodules and not any( if config.option.doctestmodules and not any(
(_is_setup_py(file_path), _is_main_py(file_path)) (_is_setup_py(file_path), _is_main_py(file_path))
): ):
mod: DoctestModule = DoctestModule.from_parent(parent, path=file_path) return DoctestModule.from_parent(parent, path=file_path)
return mod
elif _is_doctest(config, file_path, parent): elif _is_doctest(config, file_path, parent):
txt: DoctestTextfile = DoctestTextfile.from_parent(parent, path=file_path) return DoctestTextfile.from_parent(parent, path=file_path)
return txt
return None return None
@ -272,14 +271,14 @@ class DoctestItem(Item):
self._initrequest() self._initrequest()
@classmethod @classmethod
def from_parent( # type: ignore def from_parent( # type: ignore[override]
cls, cls,
parent: "Union[DoctestTextfile, DoctestModule]", parent: "Union[DoctestTextfile, DoctestModule]",
*, *,
name: str, name: str,
runner: "doctest.DocTestRunner", runner: "doctest.DocTestRunner",
dtest: "doctest.DocTest", dtest: "doctest.DocTest",
): ) -> "Self":
# incompatible signature due to imposed limits on subclass # incompatible signature due to imposed limits on subclass
"""The public named constructor.""" """The public named constructor."""
return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest) return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest)

View File

@ -21,6 +21,7 @@ from typing import Optional
from typing import overload from typing import overload
from typing import Sequence from typing import Sequence
from typing import Tuple from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union from typing import Union
import warnings import warnings
@ -49,6 +50,10 @@ from _pytest.runner import SetupState
from _pytest.warning_types import PytestWarning from _pytest.warning_types import PytestWarning
if TYPE_CHECKING:
from typing import Self
def pytest_addoption(parser: Parser) -> None: def pytest_addoption(parser: Parser) -> None:
parser.addini( parser.addini(
"norecursedirs", "norecursedirs",
@ -491,16 +496,16 @@ class Dir(nodes.Directory):
@classmethod @classmethod
def from_parent( # type: ignore[override] def from_parent( # type: ignore[override]
cls, cls,
parent: nodes.Collector, # type: ignore[override] parent: nodes.Collector,
*, *,
path: Path, path: Path,
) -> "Dir": ) -> "Self":
"""The public constructor. """The public constructor.
:param parent: The parent collector of this Dir. :param parent: The parent collector of this Dir.
:param path: The directory's path. :param path: The directory's path.
""" """
return super().from_parent(parent=parent, path=path) # type: ignore[no-any-return] return super().from_parent(parent=parent, path=path)
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]: def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
config = self.config config = self.config

View File

@ -11,6 +11,7 @@ from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List from typing import List
from typing import MutableMapping from typing import MutableMapping
from typing import NoReturn
from typing import Optional from typing import Optional
from typing import overload from typing import overload
from typing import Set from typing import Set
@ -41,6 +42,8 @@ from _pytest.warning_types import PytestWarning
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Self
# Imported here due to circular import. # Imported here due to circular import.
from _pytest._code.code import _TracebackStyle from _pytest._code.code import _TracebackStyle
from _pytest.main import Session from _pytest.main import Session
@ -51,6 +54,7 @@ SEP = "/"
tracebackcutdir = Path(_pytest.__file__).parent tracebackcutdir = Path(_pytest.__file__).parent
_T = TypeVar("_T")
_NodeType = TypeVar("_NodeType", bound="Node") _NodeType = TypeVar("_NodeType", bound="Node")
@ -69,33 +73,33 @@ class NodeMeta(abc.ABCMeta):
progress on detangling the :class:`Node` classes. progress on detangling the :class:`Node` classes.
""" """
def __call__(self, *k, **kw): def __call__(cls, *k, **kw) -> NoReturn:
msg = ( msg = (
"Direct construction of {name} has been deprecated, please use {name}.from_parent.\n" "Direct construction of {name} has been deprecated, please use {name}.from_parent.\n"
"See " "See "
"https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent" "https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent"
" for more details." " for more details."
).format(name=f"{self.__module__}.{self.__name__}") ).format(name=f"{cls.__module__}.{cls.__name__}")
fail(msg, pytrace=False) fail(msg, pytrace=False)
def _create(self, *k, **kw): def _create(cls: Type[_T], *k, **kw) -> _T:
try: try:
return super().__call__(*k, **kw) return super().__call__(*k, **kw) # type: ignore[no-any-return,misc]
except TypeError: except TypeError:
sig = signature(getattr(self, "__init__")) sig = signature(getattr(cls, "__init__"))
known_kw = {k: v for k, v in kw.items() if k in sig.parameters} known_kw = {k: v for k, v in kw.items() if k in sig.parameters}
from .warning_types import PytestDeprecationWarning from .warning_types import PytestDeprecationWarning
warnings.warn( warnings.warn(
PytestDeprecationWarning( PytestDeprecationWarning(
f"{self} is not using a cooperative constructor and only takes {set(known_kw)}.\n" f"{cls} is not using a cooperative constructor and only takes {set(known_kw)}.\n"
"See https://docs.pytest.org/en/stable/deprecations.html" "See https://docs.pytest.org/en/stable/deprecations.html"
"#constructors-of-custom-pytest-node-subclasses-should-take-kwargs " "#constructors-of-custom-pytest-node-subclasses-should-take-kwargs "
"for more details." "for more details."
) )
) )
return super().__call__(*k, **known_kw) return super().__call__(*k, **known_kw) # type: ignore[no-any-return,misc]
class Node(abc.ABC, metaclass=NodeMeta): class Node(abc.ABC, metaclass=NodeMeta):
@ -181,7 +185,7 @@ class Node(abc.ABC, metaclass=NodeMeta):
self._store = self.stash self._store = self.stash
@classmethod @classmethod
def from_parent(cls, parent: "Node", **kw): def from_parent(cls, parent: "Node", **kw) -> "Self":
"""Public constructor for Nodes. """Public constructor for Nodes.
This indirection got introduced in order to enable removing This indirection got introduced in order to enable removing
@ -583,7 +587,7 @@ class FSCollector(Collector, abc.ABC):
*, *,
path: Optional[Path] = None, path: Optional[Path] = None,
**kw, **kw,
): ) -> "Self":
"""The public constructor.""" """The public constructor."""
return super().from_parent(parent=parent, path=path, **kw) return super().from_parent(parent=parent, path=path, **kw)

View File

@ -27,6 +27,7 @@ from typing import Pattern
from typing import Sequence from typing import Sequence
from typing import Set from typing import Set
from typing import Tuple from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union from typing import Union
import warnings import warnings
@ -81,6 +82,10 @@ from _pytest.warning_types import PytestReturnNotNoneWarning
from _pytest.warning_types import PytestUnhandledCoroutineWarning from _pytest.warning_types import PytestUnhandledCoroutineWarning
if TYPE_CHECKING:
from typing import Self
_PYTEST_DIR = Path(_pytest.__file__).parent _PYTEST_DIR = Path(_pytest.__file__).parent
@ -204,8 +209,7 @@ def pytest_collect_directory(
) -> Optional[nodes.Collector]: ) -> Optional[nodes.Collector]:
pkginit = path / "__init__.py" pkginit = path / "__init__.py"
if pkginit.is_file(): if pkginit.is_file():
pkg: Package = Package.from_parent(parent, path=path) return Package.from_parent(parent, path=path)
return pkg
return None return None
@ -230,8 +234,7 @@ def path_matches_patterns(path: Path, patterns: Iterable[str]) -> bool:
def pytest_pycollect_makemodule(module_path: Path, parent) -> "Module": def pytest_pycollect_makemodule(module_path: Path, parent) -> "Module":
mod: Module = Module.from_parent(parent, path=module_path) return Module.from_parent(parent, path=module_path)
return mod
@hookimpl(trylast=True) @hookimpl(trylast=True)
@ -242,8 +245,7 @@ def pytest_pycollect_makeitem(
# Nothing was collected elsewhere, let's do it here. # Nothing was collected elsewhere, let's do it here.
if safe_isclass(obj): if safe_isclass(obj):
if collector.istestclass(obj, name): if collector.istestclass(obj, name):
klass: Class = Class.from_parent(collector, name=name, obj=obj) return Class.from_parent(collector, name=name, obj=obj)
return klass
elif collector.istestfunction(obj, name): elif collector.istestfunction(obj, name):
# mock seems to store unbound methods (issue473), normalize it. # mock seems to store unbound methods (issue473), normalize it.
obj = getattr(obj, "__func__", obj) obj = getattr(obj, "__func__", obj)
@ -262,7 +264,7 @@ def pytest_pycollect_makeitem(
) )
elif getattr(obj, "__test__", True): elif getattr(obj, "__test__", True):
if is_generator(obj): if is_generator(obj):
res: Function = Function.from_parent(collector, name=name) res = Function.from_parent(collector, name=name)
reason = ( reason = (
f"yield tests were removed in pytest 4.0 - {name} will be ignored" f"yield tests were removed in pytest 4.0 - {name} will be ignored"
) )
@ -465,9 +467,7 @@ class PyCollector(PyobjMixin, nodes.Collector, abc.ABC):
clscol = self.getparent(Class) clscol = self.getparent(Class)
cls = clscol and clscol.obj or None cls = clscol and clscol.obj or None
definition: FunctionDefinition = FunctionDefinition.from_parent( definition = FunctionDefinition.from_parent(self, name=name, callobj=funcobj)
self, name=name, callobj=funcobj
)
fixtureinfo = definition._fixtureinfo fixtureinfo = definition._fixtureinfo
# pytest_generate_tests impls call metafunc.parametrize() which fills # pytest_generate_tests impls call metafunc.parametrize() which fills
@ -751,7 +751,7 @@ class Class(PyCollector):
"""Collector for test methods (and nested classes) in a Python class.""" """Collector for test methods (and nested classes) in a Python class."""
@classmethod @classmethod
def from_parent(cls, parent, *, name, obj=None, **kw): def from_parent(cls, parent, *, name, obj=None, **kw) -> "Self": # type: ignore[override]
"""The public constructor.""" """The public constructor."""
return super().from_parent(name=name, parent=parent, **kw) return super().from_parent(name=name, parent=parent, **kw)
@ -1730,8 +1730,9 @@ class Function(PyobjMixin, nodes.Item):
self.fixturenames = fixtureinfo.names_closure self.fixturenames = fixtureinfo.names_closure
self._initrequest() self._initrequest()
# todo: determine sound type limitations
@classmethod @classmethod
def from_parent(cls, parent, **kw): # todo: determine sound type limitations def from_parent(cls, parent, **kw) -> "Self":
"""The public constructor.""" """The public constructor."""
return super().from_parent(parent=parent, **kw) return super().from_parent(parent=parent, **kw)

View File

@ -55,8 +55,7 @@ def pytest_pycollect_makeitem(
except Exception: except Exception:
return None return None
# Yes, so let's collect it. # Yes, so let's collect it.
item: UnitTestCase = UnitTestCase.from_parent(collector, name=name, obj=obj) return UnitTestCase.from_parent(collector, name=name, obj=obj)
return item
class UnitTestCase(Class): class UnitTestCase(Class):

View File

@ -1613,7 +1613,7 @@ def test_fscollector_from_parent(pytester: Pytester, request: FixtureRequest) ->
assert collector.x == 10 assert collector.x == 10
def test_class_from_parent(pytester: Pytester, request: FixtureRequest) -> None: def test_class_from_parent(request: FixtureRequest) -> None:
"""Ensure Class.from_parent can forward custom arguments to the constructor.""" """Ensure Class.from_parent can forward custom arguments to the constructor."""
class MyCollector(pytest.Class): class MyCollector(pytest.Class):