Merge pull request #6580 from blueyed/typing-testdir-init

typing: Testdir.__init__
This commit is contained in:
Daniel Hahler 2020-01-28 00:58:11 +01:00 committed by GitHub
commit d0cb16010b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 9 deletions

View File

@ -48,6 +48,13 @@ if TYPE_CHECKING:
from typing import Type from typing import Type
_PluggyPlugin = object
"""A type to represent plugin objects.
Plugins can be any namespace, so we can't narrow it down much, but we use an
alias to make the intent clear.
Ideally this type would be provided by pluggy itself."""
hookimpl = HookimplMarker("pytest") hookimpl = HookimplMarker("pytest")
hookspec = HookspecMarker("pytest") hookspec = HookspecMarker("pytest")

View File

@ -29,12 +29,17 @@ from _pytest._io.saferepr import saferepr
from _pytest.capture import MultiCapture from _pytest.capture import MultiCapture
from _pytest.capture import SysCapture from _pytest.capture import SysCapture
from _pytest.compat import TYPE_CHECKING from _pytest.compat import TYPE_CHECKING
from _pytest.config import _PluggyPlugin
from _pytest.fixtures import FixtureRequest from _pytest.fixtures import FixtureRequest
from _pytest.main import ExitCode from _pytest.main import ExitCode
from _pytest.main import Session from _pytest.main import Session
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.pathlib import Path from _pytest.pathlib import Path
from _pytest.python import Module
from _pytest.reports import TestReport from _pytest.reports import TestReport
from _pytest.tmpdir import TempdirFactory
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Type from typing import Type
@ -534,13 +539,15 @@ class Testdir:
class TimeoutExpired(Exception): class TimeoutExpired(Exception):
pass pass
def __init__(self, request, tmpdir_factory): def __init__(self, request: FixtureRequest, tmpdir_factory: TempdirFactory) -> None:
self.request = request self.request = request
self._mod_collections = WeakKeyDictionary() self._mod_collections = (
WeakKeyDictionary()
) # type: WeakKeyDictionary[Module, List[Union[Item, Collector]]]
name = request.function.__name__ name = request.function.__name__
self.tmpdir = tmpdir_factory.mktemp(name, numbered=True) self.tmpdir = tmpdir_factory.mktemp(name, numbered=True)
self.test_tmproot = tmpdir_factory.mktemp("tmp-" + name, numbered=True) self.test_tmproot = tmpdir_factory.mktemp("tmp-" + name, numbered=True)
self.plugins = [] self.plugins = [] # type: List[Union[str, _PluggyPlugin]]
self._cwd_snapshot = CwdSnapshot() self._cwd_snapshot = CwdSnapshot()
self._sys_path_snapshot = SysPathsSnapshot() self._sys_path_snapshot = SysPathsSnapshot()
self._sys_modules_snapshot = self.__take_sys_modules_snapshot() self._sys_modules_snapshot = self.__take_sys_modules_snapshot()
@ -1064,7 +1071,9 @@ class Testdir:
self.config = config = self.parseconfigure(path, *configargs) self.config = config = self.parseconfigure(path, *configargs)
return self.getnode(config, path) return self.getnode(config, path)
def collect_by_name(self, modcol, name): def collect_by_name(
self, modcol: Module, name: str
) -> Optional[Union[Item, Collector]]:
"""Return the collection node for name from the module collection. """Return the collection node for name from the module collection.
This will search a module collection node for a collection node This will search a module collection node for a collection node
@ -1073,13 +1082,13 @@ class Testdir:
:param modcol: a module collection node; see :py:meth:`getmodulecol` :param modcol: a module collection node; see :py:meth:`getmodulecol`
:param name: the name of the node to return :param name: the name of the node to return
""" """
if modcol not in self._mod_collections: if modcol not in self._mod_collections:
self._mod_collections[modcol] = list(modcol.collect()) self._mod_collections[modcol] = list(modcol.collect())
for colitem in self._mod_collections[modcol]: for colitem in self._mod_collections[modcol]:
if colitem.name == name: if colitem.name == name:
return colitem return colitem
return None
def popen( def popen(
self, self,

View File

@ -9,6 +9,7 @@ import pytest
from _pytest.main import _in_venv from _pytest.main import _in_venv
from _pytest.main import ExitCode from _pytest.main import ExitCode
from _pytest.main import Session from _pytest.main import Session
from _pytest.pytester import Testdir
class TestCollector: class TestCollector:
@ -18,7 +19,7 @@ class TestCollector:
assert not issubclass(Collector, Item) assert not issubclass(Collector, Item)
assert not issubclass(Item, Collector) assert not issubclass(Item, Collector)
def test_check_equality(self, testdir): def test_check_equality(self, testdir: Testdir) -> None:
modcol = testdir.getmodulecol( modcol = testdir.getmodulecol(
""" """
def test_pass(): pass def test_pass(): pass
@ -40,12 +41,15 @@ class TestCollector:
assert fn1 != fn3 assert fn1 != fn3
for fn in fn1, fn2, fn3: for fn in fn1, fn2, fn3:
assert fn != 3 assert isinstance(fn, pytest.Function)
assert fn != 3 # type: ignore[comparison-overlap] # noqa: F821
assert fn != modcol assert fn != modcol
assert fn != [1, 2, 3] assert fn != [1, 2, 3] # type: ignore[comparison-overlap] # noqa: F821
assert [1, 2, 3] != fn assert [1, 2, 3] != fn # type: ignore[comparison-overlap] # noqa: F821
assert modcol != fn assert modcol != fn
assert testdir.collect_by_name(modcol, "doesnotexist") is None
def test_getparent(self, testdir): def test_getparent(self, testdir):
modcol = testdir.getmodulecol( modcol = testdir.getmodulecol(
""" """