Add type annotations to some of _pytest.pytester

This commit is contained in:
Ran Benita 2019-11-06 21:35:39 +02:00
parent 58f2849bf6
commit 265a9eb6a2
1 changed files with 81 additions and 37 deletions

View File

@ -1,4 +1,5 @@
"""(disabled by default) support for testing pytest and pytest plugins.""" """(disabled by default) support for testing pytest and pytest plugins."""
import collections.abc
import gc import gc
import importlib import importlib
import os import os
@ -8,9 +9,15 @@ import subprocess
import sys import sys
import time import time
import traceback import traceback
from collections.abc import Sequence
from fnmatch import fnmatch from fnmatch import fnmatch
from io import StringIO from io import StringIO
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union from typing import Union
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
@ -21,10 +28,16 @@ from _pytest._code import Source
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest.capture import MultiCapture from _pytest.capture import MultiCapture
from _pytest.capture import SysCapture from _pytest.capture import SysCapture
from _pytest.fixtures import FixtureRequest
from _pytest.main import ExitCode from _pytest.main import ExitCode
from _pytest.main import Session from _pytest.main import Session
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from _pytest.pathlib import Path from _pytest.pathlib import Path
from _pytest.reports import TestReport
if False: # TYPE_CHECKING
from typing import Type
IGNORE_PAM = [ # filenames added when obtaining details about the current user IGNORE_PAM = [ # filenames added when obtaining details about the current user
"/var/lib/sss/mc/passwd" "/var/lib/sss/mc/passwd"
@ -142,7 +155,7 @@ class LsofFdLeakChecker:
@pytest.fixture @pytest.fixture
def _pytest(request): def _pytest(request: FixtureRequest) -> "PytestArg":
"""Return a helper which offers a gethookrecorder(hook) method which """Return a helper which offers a gethookrecorder(hook) method which
returns a HookRecorder instance which helps to make assertions about called returns a HookRecorder instance which helps to make assertions about called
hooks. hooks.
@ -152,10 +165,10 @@ def _pytest(request):
class PytestArg: class PytestArg:
def __init__(self, request): def __init__(self, request: FixtureRequest) -> None:
self.request = request self.request = request
def gethookrecorder(self, hook): def gethookrecorder(self, hook) -> "HookRecorder":
hookrecorder = HookRecorder(hook._pm) hookrecorder = HookRecorder(hook._pm)
self.request.addfinalizer(hookrecorder.finish_recording) self.request.addfinalizer(hookrecorder.finish_recording)
return hookrecorder return hookrecorder
@ -176,6 +189,11 @@ class ParsedCall:
del d["_name"] del d["_name"]
return "<ParsedCall {!r}(**{!r})>".format(self._name, d) return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
if False: # TYPE_CHECKING
# The class has undetermined attributes, this tells mypy about it.
def __getattr__(self, key):
raise NotImplementedError()
class HookRecorder: class HookRecorder:
"""Record all hooks called in a plugin manager. """Record all hooks called in a plugin manager.
@ -185,27 +203,27 @@ class HookRecorder:
""" """
def __init__(self, pluginmanager): def __init__(self, pluginmanager) -> None:
self._pluginmanager = pluginmanager self._pluginmanager = pluginmanager
self.calls = [] self.calls = [] # type: List[ParsedCall]
def before(hook_name, hook_impls, kwargs): def before(hook_name: str, hook_impls, kwargs) -> None:
self.calls.append(ParsedCall(hook_name, kwargs)) self.calls.append(ParsedCall(hook_name, kwargs))
def after(outcome, hook_name, hook_impls, kwargs): def after(outcome, hook_name: str, hook_impls, kwargs) -> None:
pass pass
self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after) self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after)
def finish_recording(self): def finish_recording(self) -> None:
self._undo_wrapping() self._undo_wrapping()
def getcalls(self, names): def getcalls(self, names: Union[str, Iterable[str]]) -> List[ParsedCall]:
if isinstance(names, str): if isinstance(names, str):
names = names.split() names = names.split()
return [call for call in self.calls if call._name in names] return [call for call in self.calls if call._name in names]
def assert_contains(self, entries): def assert_contains(self, entries) -> None:
__tracebackhide__ = True __tracebackhide__ = True
i = 0 i = 0
entries = list(entries) entries = list(entries)
@ -226,7 +244,7 @@ class HookRecorder:
else: else:
pytest.fail("could not find {!r} check {!r}".format(name, check)) pytest.fail("could not find {!r} check {!r}".format(name, check))
def popcall(self, name): def popcall(self, name: str) -> ParsedCall:
__tracebackhide__ = True __tracebackhide__ = True
for i, call in enumerate(self.calls): for i, call in enumerate(self.calls):
if call._name == name: if call._name == name:
@ -236,20 +254,27 @@ class HookRecorder:
lines.extend([" %s" % x for x in self.calls]) lines.extend([" %s" % x for x in self.calls])
pytest.fail("\n".join(lines)) pytest.fail("\n".join(lines))
def getcall(self, name): def getcall(self, name: str) -> ParsedCall:
values = self.getcalls(name) values = self.getcalls(name)
assert len(values) == 1, (name, values) assert len(values) == 1, (name, values)
return values[0] return values[0]
# functionality for test reports # functionality for test reports
def getreports(self, names="pytest_runtest_logreport pytest_collectreport"): def getreports(
self,
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
) -> List[TestReport]:
return [x.report for x in self.getcalls(names)] return [x.report for x in self.getcalls(names)]
def matchreport( def matchreport(
self, self,
inamepart="", inamepart: str = "",
names="pytest_runtest_logreport pytest_collectreport", names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
when=None, when=None,
): ):
"""return a testreport whose dotted import path matches""" """return a testreport whose dotted import path matches"""
@ -275,13 +300,20 @@ class HookRecorder:
) )
return values[0] return values[0]
def getfailures(self, names="pytest_runtest_logreport pytest_collectreport"): def getfailures(
self,
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
) -> List[TestReport]:
return [rep for rep in self.getreports(names) if rep.failed] return [rep for rep in self.getreports(names) if rep.failed]
def getfailedcollections(self): def getfailedcollections(self) -> List[TestReport]:
return self.getfailures("pytest_collectreport") return self.getfailures("pytest_collectreport")
def listoutcomes(self): def listoutcomes(
self
) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]:
passed = [] passed = []
skipped = [] skipped = []
failed = [] failed = []
@ -296,31 +328,31 @@ class HookRecorder:
failed.append(rep) failed.append(rep)
return passed, skipped, failed return passed, skipped, failed
def countoutcomes(self): def countoutcomes(self) -> List[int]:
return [len(x) for x in self.listoutcomes()] return [len(x) for x in self.listoutcomes()]
def assertoutcome(self, passed=0, skipped=0, failed=0): def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None:
realpassed, realskipped, realfailed = self.listoutcomes() realpassed, realskipped, realfailed = self.listoutcomes()
assert passed == len(realpassed) assert passed == len(realpassed)
assert skipped == len(realskipped) assert skipped == len(realskipped)
assert failed == len(realfailed) assert failed == len(realfailed)
def clear(self): def clear(self) -> None:
self.calls[:] = [] self.calls[:] = []
@pytest.fixture @pytest.fixture
def linecomp(request): def linecomp(request: FixtureRequest) -> "LineComp":
return LineComp() return LineComp()
@pytest.fixture(name="LineMatcher") @pytest.fixture(name="LineMatcher")
def LineMatcher_fixture(request): def LineMatcher_fixture(request: FixtureRequest) -> "Type[LineMatcher]":
return LineMatcher return LineMatcher
@pytest.fixture @pytest.fixture
def testdir(request, tmpdir_factory): def testdir(request: FixtureRequest, tmpdir_factory) -> "Testdir":
return Testdir(request, tmpdir_factory) return Testdir(request, tmpdir_factory)
@ -363,7 +395,13 @@ class RunResult:
:ivar duration: duration in seconds :ivar duration: duration in seconds
""" """
def __init__(self, ret: Union[int, ExitCode], outlines, errlines, duration) -> None: def __init__(
self,
ret: Union[int, ExitCode],
outlines: Sequence[str],
errlines: Sequence[str],
duration: float,
) -> None:
try: try:
self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode] self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode]
except ValueError: except ValueError:
@ -374,13 +412,13 @@ class RunResult:
self.stderr = LineMatcher(errlines) self.stderr = LineMatcher(errlines)
self.duration = duration self.duration = duration
def __repr__(self): def __repr__(self) -> str:
return ( return (
"<RunResult ret=%s len(stdout.lines)=%d len(stderr.lines)=%d duration=%.2fs>" "<RunResult ret=%s len(stdout.lines)=%d len(stderr.lines)=%d duration=%.2fs>"
% (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration) % (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration)
) )
def parseoutcomes(self): def parseoutcomes(self) -> Dict[str, int]:
"""Return a dictionary of outcomestring->num from parsing the terminal """Return a dictionary of outcomestring->num from parsing the terminal
output that the test process produced. output that the test process produced.
@ -393,8 +431,14 @@ class RunResult:
raise ValueError("Pytest terminal summary report not found") raise ValueError("Pytest terminal summary report not found")
def assert_outcomes( def assert_outcomes(
self, passed=0, skipped=0, failed=0, error=0, xpassed=0, xfailed=0 self,
): passed: int = 0,
skipped: int = 0,
failed: int = 0,
error: int = 0,
xpassed: int = 0,
xfailed: int = 0,
) -> None:
"""Assert that the specified outcomes appear with the respective """Assert that the specified outcomes appear with the respective
numbers (0 means it didn't occur) in the text output from a test run. numbers (0 means it didn't occur) in the text output from a test run.
@ -420,19 +464,19 @@ class RunResult:
class CwdSnapshot: class CwdSnapshot:
def __init__(self): def __init__(self) -> None:
self.__saved = os.getcwd() self.__saved = os.getcwd()
def restore(self): def restore(self) -> None:
os.chdir(self.__saved) os.chdir(self.__saved)
class SysModulesSnapshot: class SysModulesSnapshot:
def __init__(self, preserve=None): def __init__(self, preserve: Optional[Callable[[str], bool]] = None):
self.__preserve = preserve self.__preserve = preserve
self.__saved = dict(sys.modules) self.__saved = dict(sys.modules)
def restore(self): def restore(self) -> None:
if self.__preserve: if self.__preserve:
self.__saved.update( self.__saved.update(
(k, m) for k, m in sys.modules.items() if self.__preserve(k) (k, m) for k, m in sys.modules.items() if self.__preserve(k)
@ -442,10 +486,10 @@ class SysModulesSnapshot:
class SysPathsSnapshot: class SysPathsSnapshot:
def __init__(self): def __init__(self) -> None:
self.__saved = list(sys.path), list(sys.meta_path) self.__saved = list(sys.path), list(sys.meta_path)
def restore(self): def restore(self) -> None:
sys.path[:], sys.meta_path[:] = self.__saved sys.path[:], sys.meta_path[:] = self.__saved
@ -1357,7 +1401,7 @@ class LineMatcher:
:param str match_nickname: the nickname for the match function that :param str match_nickname: the nickname for the match function that
will be logged to stdout when a match occurs will be logged to stdout when a match occurs
""" """
assert isinstance(lines2, Sequence) assert isinstance(lines2, collections.abc.Sequence)
lines2 = self._getlines(lines2) lines2 = self._getlines(lines2)
lines1 = self.lines[:] lines1 = self.lines[:]
nextline = None nextline = None