From 265a9eb6a2916a8224d0561c269dc52d59d2ac48 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 6 Nov 2019 21:35:39 +0200 Subject: [PATCH] Add type annotations to some of _pytest.pytester --- src/_pytest/pytester.py | 118 +++++++++++++++++++++++++++------------- 1 file changed, 81 insertions(+), 37 deletions(-) diff --git a/src/_pytest/pytester.py b/src/_pytest/pytester.py index 14db8409e..6b45e077b 100644 --- a/src/_pytest/pytester.py +++ b/src/_pytest/pytester.py @@ -1,4 +1,5 @@ """(disabled by default) support for testing pytest and pytest plugins.""" +import collections.abc import gc import importlib import os @@ -8,9 +9,15 @@ import subprocess import sys import time import traceback -from collections.abc import Sequence from fnmatch import fnmatch from io import StringIO +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple from typing import Union from weakref import WeakKeyDictionary @@ -21,10 +28,16 @@ from _pytest._code import Source from _pytest._io.saferepr import saferepr from _pytest.capture import MultiCapture from _pytest.capture import SysCapture +from _pytest.fixtures import FixtureRequest from _pytest.main import ExitCode from _pytest.main import Session from _pytest.monkeypatch import MonkeyPatch from _pytest.pathlib import Path +from _pytest.reports import TestReport + +if False: # TYPE_CHECKING + from typing import Type + IGNORE_PAM = [ # filenames added when obtaining details about the current user "/var/lib/sss/mc/passwd" @@ -142,7 +155,7 @@ class LsofFdLeakChecker: @pytest.fixture -def _pytest(request): +def _pytest(request: FixtureRequest) -> "PytestArg": """Return a helper which offers a gethookrecorder(hook) method which returns a HookRecorder instance which helps to make assertions about called hooks. @@ -152,10 +165,10 @@ def _pytest(request): class PytestArg: - def __init__(self, request): + def __init__(self, request: FixtureRequest) -> None: self.request = request - def gethookrecorder(self, hook): + def gethookrecorder(self, hook) -> "HookRecorder": hookrecorder = HookRecorder(hook._pm) self.request.addfinalizer(hookrecorder.finish_recording) return hookrecorder @@ -176,6 +189,11 @@ class ParsedCall: del d["_name"] return "".format(self._name, d) + if False: # TYPE_CHECKING + # The class has undetermined attributes, this tells mypy about it. + def __getattr__(self, key): + raise NotImplementedError() + class HookRecorder: """Record all hooks called in a plugin manager. @@ -185,27 +203,27 @@ class HookRecorder: """ - def __init__(self, pluginmanager): + def __init__(self, pluginmanager) -> None: self._pluginmanager = pluginmanager - self.calls = [] + self.calls = [] # type: List[ParsedCall] - def before(hook_name, hook_impls, kwargs): + def before(hook_name: str, hook_impls, kwargs) -> None: self.calls.append(ParsedCall(hook_name, kwargs)) - def after(outcome, hook_name, hook_impls, kwargs): + def after(outcome, hook_name: str, hook_impls, kwargs) -> None: pass self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after) - def finish_recording(self): + def finish_recording(self) -> None: self._undo_wrapping() - def getcalls(self, names): + def getcalls(self, names: Union[str, Iterable[str]]) -> List[ParsedCall]: if isinstance(names, str): names = names.split() return [call for call in self.calls if call._name in names] - def assert_contains(self, entries): + def assert_contains(self, entries) -> None: __tracebackhide__ = True i = 0 entries = list(entries) @@ -226,7 +244,7 @@ class HookRecorder: else: pytest.fail("could not find {!r} check {!r}".format(name, check)) - def popcall(self, name): + def popcall(self, name: str) -> ParsedCall: __tracebackhide__ = True for i, call in enumerate(self.calls): if call._name == name: @@ -236,20 +254,27 @@ class HookRecorder: lines.extend([" %s" % x for x in self.calls]) pytest.fail("\n".join(lines)) - def getcall(self, name): + def getcall(self, name: str) -> ParsedCall: values = self.getcalls(name) assert len(values) == 1, (name, values) return values[0] # functionality for test reports - def getreports(self, names="pytest_runtest_logreport pytest_collectreport"): + def getreports( + self, + names: Union[ + str, Iterable[str] + ] = "pytest_runtest_logreport pytest_collectreport", + ) -> List[TestReport]: return [x.report for x in self.getcalls(names)] def matchreport( self, - inamepart="", - names="pytest_runtest_logreport pytest_collectreport", + inamepart: str = "", + names: Union[ + str, Iterable[str] + ] = "pytest_runtest_logreport pytest_collectreport", when=None, ): """return a testreport whose dotted import path matches""" @@ -275,13 +300,20 @@ class HookRecorder: ) return values[0] - def getfailures(self, names="pytest_runtest_logreport pytest_collectreport"): + def getfailures( + self, + names: Union[ + str, Iterable[str] + ] = "pytest_runtest_logreport pytest_collectreport", + ) -> List[TestReport]: return [rep for rep in self.getreports(names) if rep.failed] - def getfailedcollections(self): + def getfailedcollections(self) -> List[TestReport]: return self.getfailures("pytest_collectreport") - def listoutcomes(self): + def listoutcomes( + self + ) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]: passed = [] skipped = [] failed = [] @@ -296,31 +328,31 @@ class HookRecorder: failed.append(rep) return passed, skipped, failed - def countoutcomes(self): + def countoutcomes(self) -> List[int]: return [len(x) for x in self.listoutcomes()] - def assertoutcome(self, passed=0, skipped=0, failed=0): + def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None: realpassed, realskipped, realfailed = self.listoutcomes() assert passed == len(realpassed) assert skipped == len(realskipped) assert failed == len(realfailed) - def clear(self): + def clear(self) -> None: self.calls[:] = [] @pytest.fixture -def linecomp(request): +def linecomp(request: FixtureRequest) -> "LineComp": return LineComp() @pytest.fixture(name="LineMatcher") -def LineMatcher_fixture(request): +def LineMatcher_fixture(request: FixtureRequest) -> "Type[LineMatcher]": return LineMatcher @pytest.fixture -def testdir(request, tmpdir_factory): +def testdir(request: FixtureRequest, tmpdir_factory) -> "Testdir": return Testdir(request, tmpdir_factory) @@ -363,7 +395,13 @@ class RunResult: :ivar duration: duration in seconds """ - def __init__(self, ret: Union[int, ExitCode], outlines, errlines, duration) -> None: + def __init__( + self, + ret: Union[int, ExitCode], + outlines: Sequence[str], + errlines: Sequence[str], + duration: float, + ) -> None: try: self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode] except ValueError: @@ -374,13 +412,13 @@ class RunResult: self.stderr = LineMatcher(errlines) self.duration = duration - def __repr__(self): + def __repr__(self) -> str: return ( "" % (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration) ) - def parseoutcomes(self): + def parseoutcomes(self) -> Dict[str, int]: """Return a dictionary of outcomestring->num from parsing the terminal output that the test process produced. @@ -393,8 +431,14 @@ class RunResult: raise ValueError("Pytest terminal summary report not found") def assert_outcomes( - self, passed=0, skipped=0, failed=0, error=0, xpassed=0, xfailed=0 - ): + self, + passed: int = 0, + skipped: int = 0, + failed: int = 0, + error: int = 0, + xpassed: int = 0, + xfailed: int = 0, + ) -> None: """Assert that the specified outcomes appear with the respective numbers (0 means it didn't occur) in the text output from a test run. @@ -420,19 +464,19 @@ class RunResult: class CwdSnapshot: - def __init__(self): + def __init__(self) -> None: self.__saved = os.getcwd() - def restore(self): + def restore(self) -> None: os.chdir(self.__saved) class SysModulesSnapshot: - def __init__(self, preserve=None): + def __init__(self, preserve: Optional[Callable[[str], bool]] = None): self.__preserve = preserve self.__saved = dict(sys.modules) - def restore(self): + def restore(self) -> None: if self.__preserve: self.__saved.update( (k, m) for k, m in sys.modules.items() if self.__preserve(k) @@ -442,10 +486,10 @@ class SysModulesSnapshot: class SysPathsSnapshot: - def __init__(self): + def __init__(self) -> None: self.__saved = list(sys.path), list(sys.meta_path) - def restore(self): + def restore(self) -> None: sys.path[:], sys.meta_path[:] = self.__saved @@ -1357,7 +1401,7 @@ class LineMatcher: :param str match_nickname: the nickname for the match function that will be logged to stdout when a match occurs """ - assert isinstance(lines2, Sequence) + assert isinstance(lines2, collections.abc.Sequence) lines2 = self._getlines(lines2) lines1 = self.lines[:] nextline = None