Merge pull request #6141 from bluetech/type-annotations-7

Add type annotations to _pytest.{warning_types,_code.source,pytester}
This commit is contained in:
Ran Benita 2019-11-07 17:11:01 +02:00 committed by GitHub
commit e670ff76cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 147 additions and 65 deletions

View File

@ -7,10 +7,17 @@ import tokenize
import warnings
from ast import PyCF_ONLY_AST as _AST_FLAG
from bisect import bisect_right
from types import FrameType
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import py
from _pytest.compat import overload
class Source:
""" an immutable object holding a source code fragment,
@ -19,7 +26,7 @@ class Source:
_compilecounter = 0
def __init__(self, *parts, **kwargs):
def __init__(self, *parts, **kwargs) -> None:
self.lines = lines = [] # type: List[str]
de = kwargs.get("deindent", True)
for part in parts:
@ -48,7 +55,15 @@ class Source:
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
def __getitem__(self, key):
@overload
def __getitem__(self, key: int) -> str:
raise NotImplementedError()
@overload # noqa: F811
def __getitem__(self, key: slice) -> "Source":
raise NotImplementedError()
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811
if isinstance(key, int):
return self.lines[key]
else:
@ -58,10 +73,10 @@ class Source:
newsource.lines = self.lines[key.start : key.stop]
return newsource
def __len__(self):
def __len__(self) -> int:
return len(self.lines)
def strip(self):
def strip(self) -> "Source":
""" return new source object with trailing
and leading blank lines removed.
"""
@ -74,18 +89,20 @@ class Source:
source.lines[:] = self.lines[start:end]
return source
def putaround(self, before="", after="", indent=" " * 4):
def putaround(
self, before: str = "", after: str = "", indent: str = " " * 4
) -> "Source":
""" return a copy of the source object with
'before' and 'after' wrapped around it.
"""
before = Source(before)
after = Source(after)
beforesource = Source(before)
aftersource = Source(after)
newsource = Source()
lines = [(indent + line) for line in self.lines]
newsource.lines = before.lines + lines + after.lines
newsource.lines = beforesource.lines + lines + aftersource.lines
return newsource
def indent(self, indent=" " * 4):
def indent(self, indent: str = " " * 4) -> "Source":
""" return a copy of the source object with
all lines indented by the given indent-string.
"""
@ -93,14 +110,14 @@ class Source:
newsource.lines = [(indent + line) for line in self.lines]
return newsource
def getstatement(self, lineno):
def getstatement(self, lineno: int) -> "Source":
""" return Source statement which contains the
given linenumber (counted from 0).
"""
start, end = self.getstatementrange(lineno)
return self[start:end]
def getstatementrange(self, lineno):
def getstatementrange(self, lineno: int):
""" return (start, end) tuple which spans the minimal
statement region which containing the given lineno.
"""
@ -109,13 +126,13 @@ class Source:
ast, start, end = getstatementrange_ast(lineno, self)
return start, end
def deindent(self):
def deindent(self) -> "Source":
"""return a new source object deindented."""
newsource = Source()
newsource.lines[:] = deindent(self.lines)
return newsource
def isparseable(self, deindent=True):
def isparseable(self, deindent: bool = True) -> bool:
""" return True if source is parseable, heuristically
deindenting it by default.
"""
@ -135,11 +152,16 @@ class Source:
else:
return True
def __str__(self):
def __str__(self) -> str:
return "\n".join(self.lines)
def compile(
self, filename=None, mode="exec", flag=0, dont_inherit=0, _genframe=None
self,
filename=None,
mode="exec",
flag: int = 0,
dont_inherit: int = 0,
_genframe: Optional[FrameType] = None,
):
""" return compiled code object. if filename is None
invent an artificial filename which displays
@ -183,7 +205,7 @@ class Source:
#
def compile_(source, filename=None, mode="exec", flags=0, dont_inherit=0):
def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0):
""" compile the given source to a raw code object,
and maintain an internal cache which allows later
retrieval of the source code for the code object
@ -233,7 +255,7 @@ def getfslineno(obj):
#
def findsource(obj):
def findsource(obj) -> Tuple[Optional[Source], int]:
try:
sourcelines, lineno = inspect.findsource(obj)
except Exception:
@ -243,7 +265,7 @@ def findsource(obj):
return source, lineno
def getsource(obj, **kwargs):
def getsource(obj, **kwargs) -> Source:
from .code import getrawcode
obj = getrawcode(obj)
@ -255,21 +277,21 @@ def getsource(obj, **kwargs):
return Source(strsrc, **kwargs)
def deindent(lines):
def deindent(lines: Sequence[str]) -> List[str]:
return textwrap.dedent("\n".join(lines)).splitlines()
def get_statement_startend2(lineno, node):
def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
import ast
# flatten all statements and except handlers into one lineno-list
# AST's line numbers start indexing at 1
values = []
values = [] # type: List[int]
for x in ast.walk(node):
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
values.append(x.lineno - 1)
for name in ("finalbody", "orelse"):
val = getattr(x, name, None)
val = getattr(x, name, None) # type: Optional[List[ast.stmt]]
if val:
# treat the finally/orelse part as its own statement
values.append(val[0].lineno - 1 - 1)
@ -283,7 +305,12 @@ def get_statement_startend2(lineno, node):
return start, end
def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None):
def getstatementrange_ast(
lineno: int,
source: Source,
assertion: bool = False,
astnode: Optional[ast.AST] = None,
) -> Tuple[ast.AST, int, int]:
if astnode is None:
content = str(source)
# See #4260:

View File

@ -1,4 +1,5 @@
"""(disabled by default) support for testing pytest and pytest plugins."""
import collections.abc
import gc
import importlib
import os
@ -8,9 +9,15 @@ import subprocess
import sys
import time
import traceback
from collections.abc import Sequence
from fnmatch import fnmatch
from io import StringIO
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from weakref import WeakKeyDictionary
@ -21,10 +28,16 @@ from _pytest._code import Source
from _pytest._io.saferepr import saferepr
from _pytest.capture import MultiCapture
from _pytest.capture import SysCapture
from _pytest.fixtures import FixtureRequest
from _pytest.main import ExitCode
from _pytest.main import Session
from _pytest.monkeypatch import MonkeyPatch
from _pytest.pathlib import Path
from _pytest.reports import TestReport
if False: # TYPE_CHECKING
from typing import Type
IGNORE_PAM = [ # filenames added when obtaining details about the current user
"/var/lib/sss/mc/passwd"
@ -142,7 +155,7 @@ class LsofFdLeakChecker:
@pytest.fixture
def _pytest(request):
def _pytest(request: FixtureRequest) -> "PytestArg":
"""Return a helper which offers a gethookrecorder(hook) method which
returns a HookRecorder instance which helps to make assertions about called
hooks.
@ -152,10 +165,10 @@ def _pytest(request):
class PytestArg:
def __init__(self, request):
def __init__(self, request: FixtureRequest) -> None:
self.request = request
def gethookrecorder(self, hook):
def gethookrecorder(self, hook) -> "HookRecorder":
hookrecorder = HookRecorder(hook._pm)
self.request.addfinalizer(hookrecorder.finish_recording)
return hookrecorder
@ -176,6 +189,11 @@ class ParsedCall:
del d["_name"]
return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
if False: # TYPE_CHECKING
# The class has undetermined attributes, this tells mypy about it.
def __getattr__(self, key):
raise NotImplementedError()
class HookRecorder:
"""Record all hooks called in a plugin manager.
@ -185,27 +203,27 @@ class HookRecorder:
"""
def __init__(self, pluginmanager):
def __init__(self, pluginmanager) -> None:
self._pluginmanager = pluginmanager
self.calls = []
self.calls = [] # type: List[ParsedCall]
def before(hook_name, hook_impls, kwargs):
def before(hook_name: str, hook_impls, kwargs) -> None:
self.calls.append(ParsedCall(hook_name, kwargs))
def after(outcome, hook_name, hook_impls, kwargs):
def after(outcome, hook_name: str, hook_impls, kwargs) -> None:
pass
self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after)
def finish_recording(self):
def finish_recording(self) -> None:
self._undo_wrapping()
def getcalls(self, names):
def getcalls(self, names: Union[str, Iterable[str]]) -> List[ParsedCall]:
if isinstance(names, str):
names = names.split()
return [call for call in self.calls if call._name in names]
def assert_contains(self, entries):
def assert_contains(self, entries) -> None:
__tracebackhide__ = True
i = 0
entries = list(entries)
@ -226,7 +244,7 @@ class HookRecorder:
else:
pytest.fail("could not find {!r} check {!r}".format(name, check))
def popcall(self, name):
def popcall(self, name: str) -> ParsedCall:
__tracebackhide__ = True
for i, call in enumerate(self.calls):
if call._name == name:
@ -236,20 +254,27 @@ class HookRecorder:
lines.extend([" %s" % x for x in self.calls])
pytest.fail("\n".join(lines))
def getcall(self, name):
def getcall(self, name: str) -> ParsedCall:
values = self.getcalls(name)
assert len(values) == 1, (name, values)
return values[0]
# functionality for test reports
def getreports(self, names="pytest_runtest_logreport pytest_collectreport"):
def getreports(
self,
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
) -> List[TestReport]:
return [x.report for x in self.getcalls(names)]
def matchreport(
self,
inamepart="",
names="pytest_runtest_logreport pytest_collectreport",
inamepart: str = "",
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
when=None,
):
"""return a testreport whose dotted import path matches"""
@ -275,13 +300,20 @@ class HookRecorder:
)
return values[0]
def getfailures(self, names="pytest_runtest_logreport pytest_collectreport"):
def getfailures(
self,
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
) -> List[TestReport]:
return [rep for rep in self.getreports(names) if rep.failed]
def getfailedcollections(self):
def getfailedcollections(self) -> List[TestReport]:
return self.getfailures("pytest_collectreport")
def listoutcomes(self):
def listoutcomes(
self
) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]:
passed = []
skipped = []
failed = []
@ -296,31 +328,31 @@ class HookRecorder:
failed.append(rep)
return passed, skipped, failed
def countoutcomes(self):
def countoutcomes(self) -> List[int]:
return [len(x) for x in self.listoutcomes()]
def assertoutcome(self, passed=0, skipped=0, failed=0):
def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None:
realpassed, realskipped, realfailed = self.listoutcomes()
assert passed == len(realpassed)
assert skipped == len(realskipped)
assert failed == len(realfailed)
def clear(self):
def clear(self) -> None:
self.calls[:] = []
@pytest.fixture
def linecomp(request):
def linecomp(request: FixtureRequest) -> "LineComp":
return LineComp()
@pytest.fixture(name="LineMatcher")
def LineMatcher_fixture(request):
def LineMatcher_fixture(request: FixtureRequest) -> "Type[LineMatcher]":
return LineMatcher
@pytest.fixture
def testdir(request, tmpdir_factory):
def testdir(request: FixtureRequest, tmpdir_factory) -> "Testdir":
return Testdir(request, tmpdir_factory)
@ -363,7 +395,13 @@ class RunResult:
:ivar duration: duration in seconds
"""
def __init__(self, ret: Union[int, ExitCode], outlines, errlines, duration) -> None:
def __init__(
self,
ret: Union[int, ExitCode],
outlines: Sequence[str],
errlines: Sequence[str],
duration: float,
) -> None:
try:
self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode]
except ValueError:
@ -374,13 +412,13 @@ class RunResult:
self.stderr = LineMatcher(errlines)
self.duration = duration
def __repr__(self):
def __repr__(self) -> str:
return (
"<RunResult ret=%s len(stdout.lines)=%d len(stderr.lines)=%d duration=%.2fs>"
% (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration)
)
def parseoutcomes(self):
def parseoutcomes(self) -> Dict[str, int]:
"""Return a dictionary of outcomestring->num from parsing the terminal
output that the test process produced.
@ -393,8 +431,14 @@ class RunResult:
raise ValueError("Pytest terminal summary report not found")
def assert_outcomes(
self, passed=0, skipped=0, failed=0, error=0, xpassed=0, xfailed=0
):
self,
passed: int = 0,
skipped: int = 0,
failed: int = 0,
error: int = 0,
xpassed: int = 0,
xfailed: int = 0,
) -> None:
"""Assert that the specified outcomes appear with the respective
numbers (0 means it didn't occur) in the text output from a test run.
@ -420,19 +464,19 @@ class RunResult:
class CwdSnapshot:
def __init__(self):
def __init__(self) -> None:
self.__saved = os.getcwd()
def restore(self):
def restore(self) -> None:
os.chdir(self.__saved)
class SysModulesSnapshot:
def __init__(self, preserve=None):
def __init__(self, preserve: Optional[Callable[[str], bool]] = None):
self.__preserve = preserve
self.__saved = dict(sys.modules)
def restore(self):
def restore(self) -> None:
if self.__preserve:
self.__saved.update(
(k, m) for k, m in sys.modules.items() if self.__preserve(k)
@ -442,10 +486,10 @@ class SysModulesSnapshot:
class SysPathsSnapshot:
def __init__(self):
def __init__(self) -> None:
self.__saved = list(sys.path), list(sys.meta_path)
def restore(self):
def restore(self) -> None:
sys.path[:], sys.meta_path[:] = self.__saved
@ -1357,7 +1401,7 @@ class LineMatcher:
:param str match_nickname: the nickname for the match function that
will be logged to stdout when a match occurs
"""
assert isinstance(lines2, Sequence)
assert isinstance(lines2, collections.abc.Sequence)
lines2 = self._getlines(lines2)
lines1 = self.lines[:]
nextline = None

View File

@ -1,6 +1,14 @@
from typing import Any
from typing import Generic
from typing import TypeVar
import attr
if False: # TYPE_CHECKING
from typing import Type # noqa: F401 (used in type string)
class PytestWarning(UserWarning):
"""
Bases: :class:`UserWarning`.
@ -72,7 +80,7 @@ class PytestExperimentalApiWarning(PytestWarning, FutureWarning):
__module__ = "pytest"
@classmethod
def simple(cls, apiname):
def simple(cls, apiname: str) -> "PytestExperimentalApiWarning":
return cls(
"{apiname} is an experimental api that may change over time".format(
apiname=apiname
@ -103,17 +111,20 @@ class PytestUnknownMarkWarning(PytestWarning):
__module__ = "pytest"
_W = TypeVar("_W", bound=PytestWarning)
@attr.s
class UnformattedWarning:
class UnformattedWarning(Generic[_W]):
"""Used to hold warnings that need to format their message at runtime, as opposed to a direct message.
Using this class avoids to keep all the warning types and messages in this module, avoiding misuse.
"""
category = attr.ib()
template = attr.ib()
category = attr.ib(type="Type[_W]")
template = attr.ib(type=str)
def format(self, **kwargs):
def format(self, **kwargs: Any) -> _W:
"""Returns an instance of the warning category, formatted with given kwargs"""
return self.category(self.template.format(**kwargs))