Add type annotations to some of _pytest.pytester

This commit is contained in:
Ran Benita 2019-11-06 21:35:39 +02:00
parent 58f2849bf6
commit 265a9eb6a2
1 changed files with 81 additions and 37 deletions

View File

@ -1,4 +1,5 @@
"""(disabled by default) support for testing pytest and pytest plugins."""
import collections.abc
import gc
import importlib
import os
@ -8,9 +9,15 @@ import subprocess
import sys
import time
import traceback
from collections.abc import Sequence
from fnmatch import fnmatch
from io import StringIO
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from weakref import WeakKeyDictionary
@ -21,10 +28,16 @@ from _pytest._code import Source
from _pytest._io.saferepr import saferepr
from _pytest.capture import MultiCapture
from _pytest.capture import SysCapture
from _pytest.fixtures import FixtureRequest
from _pytest.main import ExitCode
from _pytest.main import Session
from _pytest.monkeypatch import MonkeyPatch
from _pytest.pathlib import Path
from _pytest.reports import TestReport
if False: # TYPE_CHECKING
from typing import Type
IGNORE_PAM = [ # filenames added when obtaining details about the current user
"/var/lib/sss/mc/passwd"
@ -142,7 +155,7 @@ class LsofFdLeakChecker:
@pytest.fixture
def _pytest(request):
def _pytest(request: FixtureRequest) -> "PytestArg":
"""Return a helper which offers a gethookrecorder(hook) method which
returns a HookRecorder instance which helps to make assertions about called
hooks.
@ -152,10 +165,10 @@ def _pytest(request):
class PytestArg:
def __init__(self, request):
def __init__(self, request: FixtureRequest) -> None:
self.request = request
def gethookrecorder(self, hook):
def gethookrecorder(self, hook) -> "HookRecorder":
hookrecorder = HookRecorder(hook._pm)
self.request.addfinalizer(hookrecorder.finish_recording)
return hookrecorder
@ -176,6 +189,11 @@ class ParsedCall:
del d["_name"]
return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
if False: # TYPE_CHECKING
# The class has undetermined attributes, this tells mypy about it.
def __getattr__(self, key):
raise NotImplementedError()
class HookRecorder:
"""Record all hooks called in a plugin manager.
@ -185,27 +203,27 @@ class HookRecorder:
"""
def __init__(self, pluginmanager):
def __init__(self, pluginmanager) -> None:
self._pluginmanager = pluginmanager
self.calls = []
self.calls = [] # type: List[ParsedCall]
def before(hook_name, hook_impls, kwargs):
def before(hook_name: str, hook_impls, kwargs) -> None:
self.calls.append(ParsedCall(hook_name, kwargs))
def after(outcome, hook_name, hook_impls, kwargs):
def after(outcome, hook_name: str, hook_impls, kwargs) -> None:
pass
self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after)
def finish_recording(self):
def finish_recording(self) -> None:
self._undo_wrapping()
def getcalls(self, names):
def getcalls(self, names: Union[str, Iterable[str]]) -> List[ParsedCall]:
if isinstance(names, str):
names = names.split()
return [call for call in self.calls if call._name in names]
def assert_contains(self, entries):
def assert_contains(self, entries) -> None:
__tracebackhide__ = True
i = 0
entries = list(entries)
@ -226,7 +244,7 @@ class HookRecorder:
else:
pytest.fail("could not find {!r} check {!r}".format(name, check))
def popcall(self, name):
def popcall(self, name: str) -> ParsedCall:
__tracebackhide__ = True
for i, call in enumerate(self.calls):
if call._name == name:
@ -236,20 +254,27 @@ class HookRecorder:
lines.extend([" %s" % x for x in self.calls])
pytest.fail("\n".join(lines))
def getcall(self, name):
def getcall(self, name: str) -> ParsedCall:
values = self.getcalls(name)
assert len(values) == 1, (name, values)
return values[0]
# functionality for test reports
def getreports(self, names="pytest_runtest_logreport pytest_collectreport"):
def getreports(
self,
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
) -> List[TestReport]:
return [x.report for x in self.getcalls(names)]
def matchreport(
self,
inamepart="",
names="pytest_runtest_logreport pytest_collectreport",
inamepart: str = "",
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
when=None,
):
"""return a testreport whose dotted import path matches"""
@ -275,13 +300,20 @@ class HookRecorder:
)
return values[0]
def getfailures(self, names="pytest_runtest_logreport pytest_collectreport"):
def getfailures(
self,
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
) -> List[TestReport]:
return [rep for rep in self.getreports(names) if rep.failed]
def getfailedcollections(self):
def getfailedcollections(self) -> List[TestReport]:
return self.getfailures("pytest_collectreport")
def listoutcomes(self):
def listoutcomes(
self
) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]:
passed = []
skipped = []
failed = []
@ -296,31 +328,31 @@ class HookRecorder:
failed.append(rep)
return passed, skipped, failed
def countoutcomes(self):
def countoutcomes(self) -> List[int]:
return [len(x) for x in self.listoutcomes()]
def assertoutcome(self, passed=0, skipped=0, failed=0):
def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None:
realpassed, realskipped, realfailed = self.listoutcomes()
assert passed == len(realpassed)
assert skipped == len(realskipped)
assert failed == len(realfailed)
def clear(self):
def clear(self) -> None:
self.calls[:] = []
@pytest.fixture
def linecomp(request):
def linecomp(request: FixtureRequest) -> "LineComp":
return LineComp()
@pytest.fixture(name="LineMatcher")
def LineMatcher_fixture(request):
def LineMatcher_fixture(request: FixtureRequest) -> "Type[LineMatcher]":
return LineMatcher
@pytest.fixture
def testdir(request, tmpdir_factory):
def testdir(request: FixtureRequest, tmpdir_factory) -> "Testdir":
return Testdir(request, tmpdir_factory)
@ -363,7 +395,13 @@ class RunResult:
:ivar duration: duration in seconds
"""
def __init__(self, ret: Union[int, ExitCode], outlines, errlines, duration) -> None:
def __init__(
self,
ret: Union[int, ExitCode],
outlines: Sequence[str],
errlines: Sequence[str],
duration: float,
) -> None:
try:
self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode]
except ValueError:
@ -374,13 +412,13 @@ class RunResult:
self.stderr = LineMatcher(errlines)
self.duration = duration
def __repr__(self):
def __repr__(self) -> str:
return (
"<RunResult ret=%s len(stdout.lines)=%d len(stderr.lines)=%d duration=%.2fs>"
% (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration)
)
def parseoutcomes(self):
def parseoutcomes(self) -> Dict[str, int]:
"""Return a dictionary of outcomestring->num from parsing the terminal
output that the test process produced.
@ -393,8 +431,14 @@ class RunResult:
raise ValueError("Pytest terminal summary report not found")
def assert_outcomes(
self, passed=0, skipped=0, failed=0, error=0, xpassed=0, xfailed=0
):
self,
passed: int = 0,
skipped: int = 0,
failed: int = 0,
error: int = 0,
xpassed: int = 0,
xfailed: int = 0,
) -> None:
"""Assert that the specified outcomes appear with the respective
numbers (0 means it didn't occur) in the text output from a test run.
@ -420,19 +464,19 @@ class RunResult:
class CwdSnapshot:
def __init__(self):
def __init__(self) -> None:
self.__saved = os.getcwd()
def restore(self):
def restore(self) -> None:
os.chdir(self.__saved)
class SysModulesSnapshot:
def __init__(self, preserve=None):
def __init__(self, preserve: Optional[Callable[[str], bool]] = None):
self.__preserve = preserve
self.__saved = dict(sys.modules)
def restore(self):
def restore(self) -> None:
if self.__preserve:
self.__saved.update(
(k, m) for k, m in sys.modules.items() if self.__preserve(k)
@ -442,10 +486,10 @@ class SysModulesSnapshot:
class SysPathsSnapshot:
def __init__(self):
def __init__(self) -> None:
self.__saved = list(sys.path), list(sys.meta_path)
def restore(self):
def restore(self) -> None:
sys.path[:], sys.meta_path[:] = self.__saved
@ -1357,7 +1401,7 @@ class LineMatcher:
:param str match_nickname: the nickname for the match function that
will be logged to stdout when a match occurs
"""
assert isinstance(lines2, Sequence)
assert isinstance(lines2, collections.abc.Sequence)
lines2 = self._getlines(lines2)
lines1 = self.lines[:]
nextline = None