Type annotate misc functions

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent d95132178c
commit e68a26199c
8 changed files with 73 additions and 44 deletions

View File

@ -8,6 +8,7 @@ import json
import os import os
from typing import Dict from typing import Dict
from typing import Generator from typing import Generator
from typing import Iterable
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Set from typing import Set
@ -27,10 +28,12 @@ from _pytest.compat import order_preserving_dict
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import ExitCode from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session from _pytest.main import Session
from _pytest.python import Module from _pytest.python import Module
from _pytest.reports import TestReport from _pytest.reports import TestReport
README_CONTENT = """\ README_CONTENT = """\
# pytest cache directory # # pytest cache directory #
@ -52,8 +55,8 @@ Signature: 8a477f597d28d172789f06886806bc55
@attr.s @attr.s
class Cache: class Cache:
_cachedir = attr.ib(repr=False) _cachedir = attr.ib(type=Path, repr=False)
_config = attr.ib(repr=False) _config = attr.ib(type=Config, repr=False)
# sub-directory under cache-dir for directories created by "makedir" # sub-directory under cache-dir for directories created by "makedir"
_CACHE_PREFIX_DIRS = "d" _CACHE_PREFIX_DIRS = "d"
@ -62,14 +65,14 @@ class Cache:
_CACHE_PREFIX_VALUES = "v" _CACHE_PREFIX_VALUES = "v"
@classmethod @classmethod
def for_config(cls, config): def for_config(cls, config: Config) -> "Cache":
cachedir = cls.cache_dir_from_config(config) cachedir = cls.cache_dir_from_config(config)
if config.getoption("cacheclear") and cachedir.is_dir(): if config.getoption("cacheclear") and cachedir.is_dir():
cls.clear_cache(cachedir) cls.clear_cache(cachedir)
return cls(cachedir, config) return cls(cachedir, config)
@classmethod @classmethod
def clear_cache(cls, cachedir: Path): def clear_cache(cls, cachedir: Path) -> None:
"""Clears the sub-directories used to hold cached directories and values.""" """Clears the sub-directories used to hold cached directories and values."""
for prefix in (cls._CACHE_PREFIX_DIRS, cls._CACHE_PREFIX_VALUES): for prefix in (cls._CACHE_PREFIX_DIRS, cls._CACHE_PREFIX_VALUES):
d = cachedir / prefix d = cachedir / prefix
@ -77,10 +80,10 @@ class Cache:
rm_rf(d) rm_rf(d)
@staticmethod @staticmethod
def cache_dir_from_config(config): def cache_dir_from_config(config: Config):
return resolve_from_str(config.getini("cache_dir"), config.rootdir) return resolve_from_str(config.getini("cache_dir"), config.rootdir)
def warn(self, fmt, **args): def warn(self, fmt: str, **args: object) -> None:
import warnings import warnings
from _pytest.warning_types import PytestCacheWarning from _pytest.warning_types import PytestCacheWarning
@ -90,7 +93,7 @@ class Cache:
stacklevel=3, stacklevel=3,
) )
def makedir(self, name): def makedir(self, name: str) -> py.path.local:
""" return a directory path object with the given name. If the """ return a directory path object with the given name. If the
directory does not yet exist, it will be created. You can use it directory does not yet exist, it will be created. You can use it
to manage files likes e. g. store/retrieve database to manage files likes e. g. store/retrieve database
@ -100,14 +103,14 @@ class Cache:
Make sure the name contains your plugin or application Make sure the name contains your plugin or application
identifiers to prevent clashes with other cache users. identifiers to prevent clashes with other cache users.
""" """
name = Path(name) path = Path(name)
if len(name.parts) > 1: if len(path.parts) > 1:
raise ValueError("name is not allowed to contain path separators") raise ValueError("name is not allowed to contain path separators")
res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, name) res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path)
res.mkdir(exist_ok=True, parents=True) res.mkdir(exist_ok=True, parents=True)
return py.path.local(res) return py.path.local(res)
def _getvaluepath(self, key): def _getvaluepath(self, key: str) -> Path:
return self._cachedir.joinpath(self._CACHE_PREFIX_VALUES, Path(key)) return self._cachedir.joinpath(self._CACHE_PREFIX_VALUES, Path(key))
def get(self, key, default): def get(self, key, default):
@ -128,7 +131,7 @@ class Cache:
except (ValueError, OSError): except (ValueError, OSError):
return default return default
def set(self, key, value): def set(self, key, value) -> None:
""" save value for the given key. """ save value for the given key.
:param key: must be a ``/`` separated value. Usually the first :param key: must be a ``/`` separated value. Usually the first
@ -158,7 +161,7 @@ class Cache:
with f: with f:
f.write(data) f.write(data)
def _ensure_supporting_files(self): def _ensure_supporting_files(self) -> None:
"""Create supporting files in the cache dir that are not really part of the cache.""" """Create supporting files in the cache dir that are not really part of the cache."""
readme_path = self._cachedir / "README.md" readme_path = self._cachedir / "README.md"
readme_path.write_text(README_CONTENT) readme_path.write_text(README_CONTENT)
@ -172,12 +175,12 @@ class Cache:
class LFPluginCollWrapper: class LFPluginCollWrapper:
def __init__(self, lfplugin: "LFPlugin"): def __init__(self, lfplugin: "LFPlugin") -> None:
self.lfplugin = lfplugin self.lfplugin = lfplugin
self._collected_at_least_one_failure = False self._collected_at_least_one_failure = False
@pytest.hookimpl(hookwrapper=True) @pytest.hookimpl(hookwrapper=True)
def pytest_make_collect_report(self, collector) -> Generator: def pytest_make_collect_report(self, collector: nodes.Collector) -> Generator:
if isinstance(collector, Session): if isinstance(collector, Session):
out = yield out = yield
res = out.get_result() # type: CollectReport res = out.get_result() # type: CollectReport
@ -220,11 +223,13 @@ class LFPluginCollWrapper:
class LFPluginCollSkipfiles: class LFPluginCollSkipfiles:
def __init__(self, lfplugin: "LFPlugin"): def __init__(self, lfplugin: "LFPlugin") -> None:
self.lfplugin = lfplugin self.lfplugin = lfplugin
@pytest.hookimpl @pytest.hookimpl
def pytest_make_collect_report(self, collector) -> Optional[CollectReport]: def pytest_make_collect_report(
self, collector: nodes.Collector
) -> Optional[CollectReport]:
if isinstance(collector, Module): if isinstance(collector, Module):
if Path(str(collector.fspath)) not in self.lfplugin._last_failed_paths: if Path(str(collector.fspath)) not in self.lfplugin._last_failed_paths:
self.lfplugin._skipped_files += 1 self.lfplugin._skipped_files += 1
@ -262,9 +267,10 @@ class LFPlugin:
result = {rootpath / nodeid.split("::")[0] for nodeid in self.lastfailed} result = {rootpath / nodeid.split("::")[0] for nodeid in self.lastfailed}
return {x for x in result if x.exists()} return {x for x in result if x.exists()}
def pytest_report_collectionfinish(self): def pytest_report_collectionfinish(self) -> Optional[str]:
if self.active and self.config.getoption("verbose") >= 0: if self.active and self.config.getoption("verbose") >= 0:
return "run-last-failure: %s" % self._report_status return "run-last-failure: %s" % self._report_status
return None
def pytest_runtest_logreport(self, report: TestReport) -> None: def pytest_runtest_logreport(self, report: TestReport) -> None:
if (report.when == "call" and report.passed) or report.skipped: if (report.when == "call" and report.passed) or report.skipped:
@ -347,9 +353,10 @@ class LFPlugin:
class NFPlugin: class NFPlugin:
""" Plugin which implements the --nf (run new-first) option """ """ Plugin which implements the --nf (run new-first) option """
def __init__(self, config): def __init__(self, config: Config) -> None:
self.config = config self.config = config
self.active = config.option.newfirst self.active = config.option.newfirst
assert config.cache is not None
self.cached_nodeids = set(config.cache.get("cache/nodeids", [])) self.cached_nodeids = set(config.cache.get("cache/nodeids", []))
@pytest.hookimpl(hookwrapper=True, tryfirst=True) @pytest.hookimpl(hookwrapper=True, tryfirst=True)
@ -374,7 +381,7 @@ class NFPlugin:
else: else:
self.cached_nodeids.update(item.nodeid for item in items) self.cached_nodeids.update(item.nodeid for item in items)
def _get_increasing_order(self, items): def _get_increasing_order(self, items: Iterable[nodes.Item]) -> List[nodes.Item]:
return sorted(items, key=lambda item: item.fspath.mtime(), reverse=True) return sorted(items, key=lambda item: item.fspath.mtime(), reverse=True)
def pytest_sessionfinish(self) -> None: def pytest_sessionfinish(self) -> None:
@ -384,6 +391,8 @@ class NFPlugin:
if config.getoption("collectonly"): if config.getoption("collectonly"):
return return
assert config.cache is not None
config.cache.set("cache/nodeids", sorted(self.cached_nodeids)) config.cache.set("cache/nodeids", sorted(self.cached_nodeids))
@ -462,7 +471,7 @@ def pytest_configure(config: Config) -> None:
@pytest.fixture @pytest.fixture
def cache(request): def cache(request: FixtureRequest) -> Cache:
""" """
Return a cache object that can persist state between testing sessions. Return a cache object that can persist state between testing sessions.
@ -474,12 +483,14 @@ def cache(request):
Values can be any object handled by the json stdlib module. Values can be any object handled by the json stdlib module.
""" """
assert request.config.cache is not None
return request.config.cache return request.config.cache
def pytest_report_header(config): def pytest_report_header(config: Config) -> Optional[str]:
"""Display cachedir with --cache-show and if non-default.""" """Display cachedir with --cache-show and if non-default."""
if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache": if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache":
assert config.cache is not None
cachedir = config.cache._cachedir cachedir = config.cache._cachedir
# TODO: evaluate generating upward relative paths # TODO: evaluate generating upward relative paths
# starting with .., ../.. if sensible # starting with .., ../.. if sensible
@ -489,11 +500,14 @@ def pytest_report_header(config):
except ValueError: except ValueError:
displaypath = cachedir displaypath = cachedir
return "cachedir: {}".format(displaypath) return "cachedir: {}".format(displaypath)
return None
def cacheshow(config, session): def cacheshow(config: Config, session: Session) -> int:
from pprint import pformat from pprint import pformat
assert config.cache is not None
tw = TerminalWriter() tw = TerminalWriter()
tw.line("cachedir: " + str(config.cache._cachedir)) tw.line("cachedir: " + str(config.cache._cachedir))
if not config.cache._cachedir.is_dir(): if not config.cache._cachedir.is_dir():

View File

@ -57,6 +57,7 @@ if TYPE_CHECKING:
from _pytest import nodes from _pytest import nodes
from _pytest.main import Session from _pytest.main import Session
from _pytest.python import Metafunc from _pytest.python import Metafunc
from _pytest.python import CallSpec2
_Scope = Literal["session", "package", "module", "class", "function"] _Scope = Literal["session", "package", "module", "class", "function"]
@ -217,10 +218,11 @@ def get_parametrized_fixture_keys(item, scopenum):
the specified scope. """ the specified scope. """
assert scopenum < scopenum_function # function assert scopenum < scopenum_function # function
try: try:
cs = item.callspec callspec = item.callspec # type: ignore[attr-defined] # noqa: F821
except AttributeError: except AttributeError:
pass pass
else: else:
cs = callspec # type: CallSpec2
# cs.indices.items() is random order of argnames. Need to # cs.indices.items() is random order of argnames. Need to
# sort this so that different calls to # sort this so that different calls to
# get_parametrized_fixture_keys will be deterministic. # get_parametrized_fixture_keys will be deterministic.
@ -434,9 +436,9 @@ class FixtureRequest:
return fixturedefs[index] return fixturedefs[index]
@property @property
def config(self): def config(self) -> Config:
""" the pytest config object associated with this request. """ """ the pytest config object associated with this request. """
return self._pyfuncitem.config return self._pyfuncitem.config # type: ignore[no-any-return] # noqa: F723
@scopeproperty() @scopeproperty()
def function(self): def function(self):
@ -1464,7 +1466,7 @@ class FixtureManager:
else: else:
continue # will raise FixtureLookupError at setup time continue # will raise FixtureLookupError at setup time
def pytest_collection_modifyitems(self, items): def pytest_collection_modifyitems(self, items: "List[nodes.Item]") -> None:
# separate parametrized setups # separate parametrized setups
items[:] = reorder_items(items) items[:] = reorder_items(items)

View File

@ -223,7 +223,9 @@ def pytest_collection(session: "Session") -> Optional[Any]:
""" """
def pytest_collection_modifyitems(session: "Session", config: "Config", items): def pytest_collection_modifyitems(
session: "Session", config: "Config", items: List["Item"]
) -> None:
""" called after collection has been performed, may filter or re-order """ called after collection has been performed, may filter or re-order
the items in-place. the items in-place.

View File

@ -333,7 +333,7 @@ def pytest_ignore_collect(
return None return None
def pytest_collection_modifyitems(items, config: Config) -> None: def pytest_collection_modifyitems(items: List[nodes.Item], config: Config) -> None:
deselect_prefixes = tuple(config.getoption("deselect") or []) deselect_prefixes = tuple(config.getoption("deselect") or [])
if not deselect_prefixes: if not deselect_prefixes:
return return
@ -487,18 +487,18 @@ class Session(nodes.FSCollector):
@overload @overload
def _perform_collect( def _perform_collect(
self, args: Optional[Sequence[str]], genitems: "Literal[True]" self, args: Optional[Sequence[str]], genitems: "Literal[True]"
) -> Sequence[nodes.Item]: ) -> List[nodes.Item]:
raise NotImplementedError() raise NotImplementedError()
@overload # noqa: F811 @overload # noqa: F811
def _perform_collect( # noqa: F811 def _perform_collect( # noqa: F811
self, args: Optional[Sequence[str]], genitems: bool self, args: Optional[Sequence[str]], genitems: bool
) -> Sequence[Union[nodes.Item, nodes.Collector]]: ) -> Union[List[Union[nodes.Item]], List[Union[nodes.Item, nodes.Collector]]]:
raise NotImplementedError() raise NotImplementedError()
def _perform_collect( # noqa: F811 def _perform_collect( # noqa: F811
self, args: Optional[Sequence[str]], genitems: bool self, args: Optional[Sequence[str]], genitems: bool
) -> Sequence[Union[nodes.Item, nodes.Collector]]: ) -> Union[List[Union[nodes.Item]], List[Union[nodes.Item, nodes.Collector]]]:
if args is None: if args is None:
args = self.config.args args = self.config.args
self.trace("perform_collect", self, args) self.trace("perform_collect", self, args)

View File

@ -2,6 +2,7 @@
import typing import typing
import warnings import warnings
from typing import AbstractSet from typing import AbstractSet
from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -173,7 +174,7 @@ class KeywordMatcher:
return False return False
def deselect_by_keyword(items, config: Config) -> None: def deselect_by_keyword(items: "List[Item]", config: Config) -> None:
keywordexpr = config.option.keyword.lstrip() keywordexpr = config.option.keyword.lstrip()
if not keywordexpr: if not keywordexpr:
return return
@ -229,7 +230,7 @@ class MarkMatcher:
return name in self.own_mark_names return name in self.own_mark_names
def deselect_by_mark(items, config: Config) -> None: def deselect_by_mark(items: "List[Item]", config: Config) -> None:
matchexpr = config.option.markexpr matchexpr = config.option.markexpr
if not matchexpr: if not matchexpr:
return return
@ -254,7 +255,7 @@ def deselect_by_mark(items, config: Config) -> None:
items[:] = remaining items[:] = remaining
def pytest_collection_modifyitems(items, config: Config) -> None: def pytest_collection_modifyitems(items: "List[Item]", config: Config) -> None:
deselect_by_keyword(items, config) deselect_by_keyword(items, config)
deselect_by_mark(items, config) deselect_by_mark(items, config)

View File

@ -348,7 +348,7 @@ def make_numbered_dir_with_cleanup(
raise e raise e
def resolve_from_str(input, root): def resolve_from_str(input: str, root):
assert not isinstance(input, Path), "would break on py2" assert not isinstance(input, Path), "would break on py2"
root = Path(root) root = Path(root)
input = expanduser(input) input = expanduser(input)

View File

@ -647,8 +647,8 @@ class Testdir:
for basename, value in items: for basename, value in items:
p = self.tmpdir.join(basename).new(ext=ext) p = self.tmpdir.join(basename).new(ext=ext)
p.dirpath().ensure_dir() p.dirpath().ensure_dir()
source = Source(value) source_ = Source(value)
source = "\n".join(to_text(line) for line in source.lines) source = "\n".join(to_text(line) for line in source_.lines)
p.write(source.strip().encode(encoding), "wb") p.write(source.strip().encode(encoding), "wb")
if ret is None: if ret is None:
ret = p ret = p
@ -839,7 +839,7 @@ class Testdir:
config.hook.pytest_sessionfinish(session=session, exitstatus=ExitCode.OK) config.hook.pytest_sessionfinish(session=session, exitstatus=ExitCode.OK)
return res return res
def genitems(self, colitems): def genitems(self, colitems: List[Union[Item, Collector]]) -> List[Item]:
"""Generate all test items from a collection node. """Generate all test items from a collection node.
This recurses into the collection node and returns a list of all the This recurses into the collection node and returns a list of all the
@ -847,7 +847,7 @@ class Testdir:
""" """
session = colitems[0].session session = colitems[0].session
result = [] result = [] # type: List[Item]
for colitem in colitems: for colitem in colitems:
result.extend(session.genitems(colitem)) result.extend(session.genitems(colitem))
return result return result

View File

@ -1,4 +1,8 @@
from typing import List
from typing import Optional
import pytest import pytest
from _pytest import nodes
from _pytest.config import Config from _pytest.config import Config
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.main import Session from _pytest.main import Session
@ -28,20 +32,23 @@ def pytest_configure(config: Config) -> None:
class StepwisePlugin: class StepwisePlugin:
def __init__(self, config): def __init__(self, config: Config) -> None:
self.config = config self.config = config
self.active = config.getvalue("stepwise") self.active = config.getvalue("stepwise")
self.session = None self.session = None # type: Optional[Session]
self.report_status = "" self.report_status = ""
if self.active: if self.active:
assert config.cache is not None
self.lastfailed = config.cache.get("cache/stepwise", None) self.lastfailed = config.cache.get("cache/stepwise", None)
self.skip = config.getvalue("stepwise_skip") self.skip = config.getvalue("stepwise_skip")
def pytest_sessionstart(self, session: Session) -> None: def pytest_sessionstart(self, session: Session) -> None:
self.session = session self.session = session
def pytest_collection_modifyitems(self, session, config, items): def pytest_collection_modifyitems(
self, session: Session, config: Config, items: List[nodes.Item]
) -> None:
if not self.active: if not self.active:
return return
if not self.lastfailed: if not self.lastfailed:
@ -89,6 +96,7 @@ class StepwisePlugin:
else: else:
# Mark test as the last failing and interrupt the test session. # Mark test as the last failing and interrupt the test session.
self.lastfailed = report.nodeid self.lastfailed = report.nodeid
assert self.session is not None
self.session.shouldstop = ( self.session.shouldstop = (
"Test failed, continuing from this test next run." "Test failed, continuing from this test next run."
) )
@ -100,11 +108,13 @@ class StepwisePlugin:
if report.nodeid == self.lastfailed: if report.nodeid == self.lastfailed:
self.lastfailed = None self.lastfailed = None
def pytest_report_collectionfinish(self): def pytest_report_collectionfinish(self) -> Optional[str]:
if self.active and self.config.getoption("verbose") >= 0 and self.report_status: if self.active and self.config.getoption("verbose") >= 0 and self.report_status:
return "stepwise: %s" % self.report_status return "stepwise: %s" % self.report_status
return None
def pytest_sessionfinish(self, session: Session) -> None: def pytest_sessionfinish(self, session: Session) -> None:
assert self.config.cache is not None
if self.active: if self.active:
self.config.cache.set("cache/stepwise", self.lastfailed) self.config.cache.set("cache/stepwise", self.lastfailed)
else: else: