Add type annotations to some of _pytest.pytester
This commit is contained in:
parent
58f2849bf6
commit
265a9eb6a2
|
@ -1,4 +1,5 @@
|
||||||
"""(disabled by default) support for testing pytest and pytest plugins."""
|
"""(disabled by default) support for testing pytest and pytest plugins."""
|
||||||
|
import collections.abc
|
||||||
import gc
|
import gc
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
@ -8,9 +9,15 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from collections.abc import Sequence
|
|
||||||
from fnmatch import fnmatch
|
from fnmatch import fnmatch
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from typing import Callable
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Iterable
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Sequence
|
||||||
|
from typing import Tuple
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from weakref import WeakKeyDictionary
|
from weakref import WeakKeyDictionary
|
||||||
|
|
||||||
|
@ -21,10 +28,16 @@ from _pytest._code import Source
|
||||||
from _pytest._io.saferepr import saferepr
|
from _pytest._io.saferepr import saferepr
|
||||||
from _pytest.capture import MultiCapture
|
from _pytest.capture import MultiCapture
|
||||||
from _pytest.capture import SysCapture
|
from _pytest.capture import SysCapture
|
||||||
|
from _pytest.fixtures import FixtureRequest
|
||||||
from _pytest.main import ExitCode
|
from _pytest.main import ExitCode
|
||||||
from _pytest.main import Session
|
from _pytest.main import Session
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from _pytest.pathlib import Path
|
from _pytest.pathlib import Path
|
||||||
|
from _pytest.reports import TestReport
|
||||||
|
|
||||||
|
if False: # TYPE_CHECKING
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
|
||||||
IGNORE_PAM = [ # filenames added when obtaining details about the current user
|
IGNORE_PAM = [ # filenames added when obtaining details about the current user
|
||||||
"/var/lib/sss/mc/passwd"
|
"/var/lib/sss/mc/passwd"
|
||||||
|
@ -142,7 +155,7 @@ class LsofFdLeakChecker:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def _pytest(request):
|
def _pytest(request: FixtureRequest) -> "PytestArg":
|
||||||
"""Return a helper which offers a gethookrecorder(hook) method which
|
"""Return a helper which offers a gethookrecorder(hook) method which
|
||||||
returns a HookRecorder instance which helps to make assertions about called
|
returns a HookRecorder instance which helps to make assertions about called
|
||||||
hooks.
|
hooks.
|
||||||
|
@ -152,10 +165,10 @@ def _pytest(request):
|
||||||
|
|
||||||
|
|
||||||
class PytestArg:
|
class PytestArg:
|
||||||
def __init__(self, request):
|
def __init__(self, request: FixtureRequest) -> None:
|
||||||
self.request = request
|
self.request = request
|
||||||
|
|
||||||
def gethookrecorder(self, hook):
|
def gethookrecorder(self, hook) -> "HookRecorder":
|
||||||
hookrecorder = HookRecorder(hook._pm)
|
hookrecorder = HookRecorder(hook._pm)
|
||||||
self.request.addfinalizer(hookrecorder.finish_recording)
|
self.request.addfinalizer(hookrecorder.finish_recording)
|
||||||
return hookrecorder
|
return hookrecorder
|
||||||
|
@ -176,6 +189,11 @@ class ParsedCall:
|
||||||
del d["_name"]
|
del d["_name"]
|
||||||
return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
|
return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
|
||||||
|
|
||||||
|
if False: # TYPE_CHECKING
|
||||||
|
# The class has undetermined attributes, this tells mypy about it.
|
||||||
|
def __getattr__(self, key):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class HookRecorder:
|
class HookRecorder:
|
||||||
"""Record all hooks called in a plugin manager.
|
"""Record all hooks called in a plugin manager.
|
||||||
|
@ -185,27 +203,27 @@ class HookRecorder:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pluginmanager):
|
def __init__(self, pluginmanager) -> None:
|
||||||
self._pluginmanager = pluginmanager
|
self._pluginmanager = pluginmanager
|
||||||
self.calls = []
|
self.calls = [] # type: List[ParsedCall]
|
||||||
|
|
||||||
def before(hook_name, hook_impls, kwargs):
|
def before(hook_name: str, hook_impls, kwargs) -> None:
|
||||||
self.calls.append(ParsedCall(hook_name, kwargs))
|
self.calls.append(ParsedCall(hook_name, kwargs))
|
||||||
|
|
||||||
def after(outcome, hook_name, hook_impls, kwargs):
|
def after(outcome, hook_name: str, hook_impls, kwargs) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after)
|
self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after)
|
||||||
|
|
||||||
def finish_recording(self):
|
def finish_recording(self) -> None:
|
||||||
self._undo_wrapping()
|
self._undo_wrapping()
|
||||||
|
|
||||||
def getcalls(self, names):
|
def getcalls(self, names: Union[str, Iterable[str]]) -> List[ParsedCall]:
|
||||||
if isinstance(names, str):
|
if isinstance(names, str):
|
||||||
names = names.split()
|
names = names.split()
|
||||||
return [call for call in self.calls if call._name in names]
|
return [call for call in self.calls if call._name in names]
|
||||||
|
|
||||||
def assert_contains(self, entries):
|
def assert_contains(self, entries) -> None:
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
i = 0
|
i = 0
|
||||||
entries = list(entries)
|
entries = list(entries)
|
||||||
|
@ -226,7 +244,7 @@ class HookRecorder:
|
||||||
else:
|
else:
|
||||||
pytest.fail("could not find {!r} check {!r}".format(name, check))
|
pytest.fail("could not find {!r} check {!r}".format(name, check))
|
||||||
|
|
||||||
def popcall(self, name):
|
def popcall(self, name: str) -> ParsedCall:
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
for i, call in enumerate(self.calls):
|
for i, call in enumerate(self.calls):
|
||||||
if call._name == name:
|
if call._name == name:
|
||||||
|
@ -236,20 +254,27 @@ class HookRecorder:
|
||||||
lines.extend([" %s" % x for x in self.calls])
|
lines.extend([" %s" % x for x in self.calls])
|
||||||
pytest.fail("\n".join(lines))
|
pytest.fail("\n".join(lines))
|
||||||
|
|
||||||
def getcall(self, name):
|
def getcall(self, name: str) -> ParsedCall:
|
||||||
values = self.getcalls(name)
|
values = self.getcalls(name)
|
||||||
assert len(values) == 1, (name, values)
|
assert len(values) == 1, (name, values)
|
||||||
return values[0]
|
return values[0]
|
||||||
|
|
||||||
# functionality for test reports
|
# functionality for test reports
|
||||||
|
|
||||||
def getreports(self, names="pytest_runtest_logreport pytest_collectreport"):
|
def getreports(
|
||||||
|
self,
|
||||||
|
names: Union[
|
||||||
|
str, Iterable[str]
|
||||||
|
] = "pytest_runtest_logreport pytest_collectreport",
|
||||||
|
) -> List[TestReport]:
|
||||||
return [x.report for x in self.getcalls(names)]
|
return [x.report for x in self.getcalls(names)]
|
||||||
|
|
||||||
def matchreport(
|
def matchreport(
|
||||||
self,
|
self,
|
||||||
inamepart="",
|
inamepart: str = "",
|
||||||
names="pytest_runtest_logreport pytest_collectreport",
|
names: Union[
|
||||||
|
str, Iterable[str]
|
||||||
|
] = "pytest_runtest_logreport pytest_collectreport",
|
||||||
when=None,
|
when=None,
|
||||||
):
|
):
|
||||||
"""return a testreport whose dotted import path matches"""
|
"""return a testreport whose dotted import path matches"""
|
||||||
|
@ -275,13 +300,20 @@ class HookRecorder:
|
||||||
)
|
)
|
||||||
return values[0]
|
return values[0]
|
||||||
|
|
||||||
def getfailures(self, names="pytest_runtest_logreport pytest_collectreport"):
|
def getfailures(
|
||||||
|
self,
|
||||||
|
names: Union[
|
||||||
|
str, Iterable[str]
|
||||||
|
] = "pytest_runtest_logreport pytest_collectreport",
|
||||||
|
) -> List[TestReport]:
|
||||||
return [rep for rep in self.getreports(names) if rep.failed]
|
return [rep for rep in self.getreports(names) if rep.failed]
|
||||||
|
|
||||||
def getfailedcollections(self):
|
def getfailedcollections(self) -> List[TestReport]:
|
||||||
return self.getfailures("pytest_collectreport")
|
return self.getfailures("pytest_collectreport")
|
||||||
|
|
||||||
def listoutcomes(self):
|
def listoutcomes(
|
||||||
|
self
|
||||||
|
) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]:
|
||||||
passed = []
|
passed = []
|
||||||
skipped = []
|
skipped = []
|
||||||
failed = []
|
failed = []
|
||||||
|
@ -296,31 +328,31 @@ class HookRecorder:
|
||||||
failed.append(rep)
|
failed.append(rep)
|
||||||
return passed, skipped, failed
|
return passed, skipped, failed
|
||||||
|
|
||||||
def countoutcomes(self):
|
def countoutcomes(self) -> List[int]:
|
||||||
return [len(x) for x in self.listoutcomes()]
|
return [len(x) for x in self.listoutcomes()]
|
||||||
|
|
||||||
def assertoutcome(self, passed=0, skipped=0, failed=0):
|
def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None:
|
||||||
realpassed, realskipped, realfailed = self.listoutcomes()
|
realpassed, realskipped, realfailed = self.listoutcomes()
|
||||||
assert passed == len(realpassed)
|
assert passed == len(realpassed)
|
||||||
assert skipped == len(realskipped)
|
assert skipped == len(realskipped)
|
||||||
assert failed == len(realfailed)
|
assert failed == len(realfailed)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self) -> None:
|
||||||
self.calls[:] = []
|
self.calls[:] = []
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def linecomp(request):
|
def linecomp(request: FixtureRequest) -> "LineComp":
|
||||||
return LineComp()
|
return LineComp()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="LineMatcher")
|
@pytest.fixture(name="LineMatcher")
|
||||||
def LineMatcher_fixture(request):
|
def LineMatcher_fixture(request: FixtureRequest) -> "Type[LineMatcher]":
|
||||||
return LineMatcher
|
return LineMatcher
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def testdir(request, tmpdir_factory):
|
def testdir(request: FixtureRequest, tmpdir_factory) -> "Testdir":
|
||||||
return Testdir(request, tmpdir_factory)
|
return Testdir(request, tmpdir_factory)
|
||||||
|
|
||||||
|
|
||||||
|
@ -363,7 +395,13 @@ class RunResult:
|
||||||
:ivar duration: duration in seconds
|
:ivar duration: duration in seconds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ret: Union[int, ExitCode], outlines, errlines, duration) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
ret: Union[int, ExitCode],
|
||||||
|
outlines: Sequence[str],
|
||||||
|
errlines: Sequence[str],
|
||||||
|
duration: float,
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode]
|
self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -374,13 +412,13 @@ class RunResult:
|
||||||
self.stderr = LineMatcher(errlines)
|
self.stderr = LineMatcher(errlines)
|
||||||
self.duration = duration
|
self.duration = duration
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
"<RunResult ret=%s len(stdout.lines)=%d len(stderr.lines)=%d duration=%.2fs>"
|
"<RunResult ret=%s len(stdout.lines)=%d len(stderr.lines)=%d duration=%.2fs>"
|
||||||
% (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration)
|
% (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration)
|
||||||
)
|
)
|
||||||
|
|
||||||
def parseoutcomes(self):
|
def parseoutcomes(self) -> Dict[str, int]:
|
||||||
"""Return a dictionary of outcomestring->num from parsing the terminal
|
"""Return a dictionary of outcomestring->num from parsing the terminal
|
||||||
output that the test process produced.
|
output that the test process produced.
|
||||||
|
|
||||||
|
@ -393,8 +431,14 @@ class RunResult:
|
||||||
raise ValueError("Pytest terminal summary report not found")
|
raise ValueError("Pytest terminal summary report not found")
|
||||||
|
|
||||||
def assert_outcomes(
|
def assert_outcomes(
|
||||||
self, passed=0, skipped=0, failed=0, error=0, xpassed=0, xfailed=0
|
self,
|
||||||
):
|
passed: int = 0,
|
||||||
|
skipped: int = 0,
|
||||||
|
failed: int = 0,
|
||||||
|
error: int = 0,
|
||||||
|
xpassed: int = 0,
|
||||||
|
xfailed: int = 0,
|
||||||
|
) -> None:
|
||||||
"""Assert that the specified outcomes appear with the respective
|
"""Assert that the specified outcomes appear with the respective
|
||||||
numbers (0 means it didn't occur) in the text output from a test run.
|
numbers (0 means it didn't occur) in the text output from a test run.
|
||||||
|
|
||||||
|
@ -420,19 +464,19 @@ class RunResult:
|
||||||
|
|
||||||
|
|
||||||
class CwdSnapshot:
|
class CwdSnapshot:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.__saved = os.getcwd()
|
self.__saved = os.getcwd()
|
||||||
|
|
||||||
def restore(self):
|
def restore(self) -> None:
|
||||||
os.chdir(self.__saved)
|
os.chdir(self.__saved)
|
||||||
|
|
||||||
|
|
||||||
class SysModulesSnapshot:
|
class SysModulesSnapshot:
|
||||||
def __init__(self, preserve=None):
|
def __init__(self, preserve: Optional[Callable[[str], bool]] = None):
|
||||||
self.__preserve = preserve
|
self.__preserve = preserve
|
||||||
self.__saved = dict(sys.modules)
|
self.__saved = dict(sys.modules)
|
||||||
|
|
||||||
def restore(self):
|
def restore(self) -> None:
|
||||||
if self.__preserve:
|
if self.__preserve:
|
||||||
self.__saved.update(
|
self.__saved.update(
|
||||||
(k, m) for k, m in sys.modules.items() if self.__preserve(k)
|
(k, m) for k, m in sys.modules.items() if self.__preserve(k)
|
||||||
|
@ -442,10 +486,10 @@ class SysModulesSnapshot:
|
||||||
|
|
||||||
|
|
||||||
class SysPathsSnapshot:
|
class SysPathsSnapshot:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.__saved = list(sys.path), list(sys.meta_path)
|
self.__saved = list(sys.path), list(sys.meta_path)
|
||||||
|
|
||||||
def restore(self):
|
def restore(self) -> None:
|
||||||
sys.path[:], sys.meta_path[:] = self.__saved
|
sys.path[:], sys.meta_path[:] = self.__saved
|
||||||
|
|
||||||
|
|
||||||
|
@ -1357,7 +1401,7 @@ class LineMatcher:
|
||||||
:param str match_nickname: the nickname for the match function that
|
:param str match_nickname: the nickname for the match function that
|
||||||
will be logged to stdout when a match occurs
|
will be logged to stdout when a match occurs
|
||||||
"""
|
"""
|
||||||
assert isinstance(lines2, Sequence)
|
assert isinstance(lines2, collections.abc.Sequence)
|
||||||
lines2 = self._getlines(lines2)
|
lines2 = self._getlines(lines2)
|
||||||
lines1 = self.lines[:]
|
lines1 = self.lines[:]
|
||||||
nextline = None
|
nextline = None
|
||||||
|
|
Loading…
Reference in New Issue