Merge pull request #6141 from bluetech/type-annotations-7
Add type annotations to _pytest.{warning_types,_code.source,pytester}
This commit is contained in:
commit
e670ff76cb
|
@ -7,10 +7,17 @@ import tokenize
|
|||
import warnings
|
||||
from ast import PyCF_ONLY_AST as _AST_FLAG
|
||||
from bisect import bisect_right
|
||||
from types import FrameType
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import py
|
||||
|
||||
from _pytest.compat import overload
|
||||
|
||||
|
||||
class Source:
|
||||
""" an immutable object holding a source code fragment,
|
||||
|
@ -19,7 +26,7 @@ class Source:
|
|||
|
||||
_compilecounter = 0
|
||||
|
||||
def __init__(self, *parts, **kwargs):
|
||||
def __init__(self, *parts, **kwargs) -> None:
|
||||
self.lines = lines = [] # type: List[str]
|
||||
de = kwargs.get("deindent", True)
|
||||
for part in parts:
|
||||
|
@ -48,7 +55,15 @@ class Source:
|
|||
# Ignore type because of https://github.com/python/mypy/issues/4266.
|
||||
__hash__ = None # type: ignore
|
||||
|
||||
def __getitem__(self, key):
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@overload # noqa: F811
|
||||
def __getitem__(self, key: slice) -> "Source":
|
||||
raise NotImplementedError()
|
||||
|
||||
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811
|
||||
if isinstance(key, int):
|
||||
return self.lines[key]
|
||||
else:
|
||||
|
@ -58,10 +73,10 @@ class Source:
|
|||
newsource.lines = self.lines[key.start : key.stop]
|
||||
return newsource
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.lines)
|
||||
|
||||
def strip(self):
|
||||
def strip(self) -> "Source":
|
||||
""" return new source object with trailing
|
||||
and leading blank lines removed.
|
||||
"""
|
||||
|
@ -74,18 +89,20 @@ class Source:
|
|||
source.lines[:] = self.lines[start:end]
|
||||
return source
|
||||
|
||||
def putaround(self, before="", after="", indent=" " * 4):
|
||||
def putaround(
|
||||
self, before: str = "", after: str = "", indent: str = " " * 4
|
||||
) -> "Source":
|
||||
""" return a copy of the source object with
|
||||
'before' and 'after' wrapped around it.
|
||||
"""
|
||||
before = Source(before)
|
||||
after = Source(after)
|
||||
beforesource = Source(before)
|
||||
aftersource = Source(after)
|
||||
newsource = Source()
|
||||
lines = [(indent + line) for line in self.lines]
|
||||
newsource.lines = before.lines + lines + after.lines
|
||||
newsource.lines = beforesource.lines + lines + aftersource.lines
|
||||
return newsource
|
||||
|
||||
def indent(self, indent=" " * 4):
|
||||
def indent(self, indent: str = " " * 4) -> "Source":
|
||||
""" return a copy of the source object with
|
||||
all lines indented by the given indent-string.
|
||||
"""
|
||||
|
@ -93,14 +110,14 @@ class Source:
|
|||
newsource.lines = [(indent + line) for line in self.lines]
|
||||
return newsource
|
||||
|
||||
def getstatement(self, lineno):
|
||||
def getstatement(self, lineno: int) -> "Source":
|
||||
""" return Source statement which contains the
|
||||
given linenumber (counted from 0).
|
||||
"""
|
||||
start, end = self.getstatementrange(lineno)
|
||||
return self[start:end]
|
||||
|
||||
def getstatementrange(self, lineno):
|
||||
def getstatementrange(self, lineno: int):
|
||||
""" return (start, end) tuple which spans the minimal
|
||||
statement region which containing the given lineno.
|
||||
"""
|
||||
|
@ -109,13 +126,13 @@ class Source:
|
|||
ast, start, end = getstatementrange_ast(lineno, self)
|
||||
return start, end
|
||||
|
||||
def deindent(self):
|
||||
def deindent(self) -> "Source":
|
||||
"""return a new source object deindented."""
|
||||
newsource = Source()
|
||||
newsource.lines[:] = deindent(self.lines)
|
||||
return newsource
|
||||
|
||||
def isparseable(self, deindent=True):
|
||||
def isparseable(self, deindent: bool = True) -> bool:
|
||||
""" return True if source is parseable, heuristically
|
||||
deindenting it by default.
|
||||
"""
|
||||
|
@ -135,11 +152,16 @@ class Source:
|
|||
else:
|
||||
return True
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return "\n".join(self.lines)
|
||||
|
||||
def compile(
|
||||
self, filename=None, mode="exec", flag=0, dont_inherit=0, _genframe=None
|
||||
self,
|
||||
filename=None,
|
||||
mode="exec",
|
||||
flag: int = 0,
|
||||
dont_inherit: int = 0,
|
||||
_genframe: Optional[FrameType] = None,
|
||||
):
|
||||
""" return compiled code object. if filename is None
|
||||
invent an artificial filename which displays
|
||||
|
@ -183,7 +205,7 @@ class Source:
|
|||
#
|
||||
|
||||
|
||||
def compile_(source, filename=None, mode="exec", flags=0, dont_inherit=0):
|
||||
def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0):
|
||||
""" compile the given source to a raw code object,
|
||||
and maintain an internal cache which allows later
|
||||
retrieval of the source code for the code object
|
||||
|
@ -233,7 +255,7 @@ def getfslineno(obj):
|
|||
#
|
||||
|
||||
|
||||
def findsource(obj):
|
||||
def findsource(obj) -> Tuple[Optional[Source], int]:
|
||||
try:
|
||||
sourcelines, lineno = inspect.findsource(obj)
|
||||
except Exception:
|
||||
|
@ -243,7 +265,7 @@ def findsource(obj):
|
|||
return source, lineno
|
||||
|
||||
|
||||
def getsource(obj, **kwargs):
|
||||
def getsource(obj, **kwargs) -> Source:
|
||||
from .code import getrawcode
|
||||
|
||||
obj = getrawcode(obj)
|
||||
|
@ -255,21 +277,21 @@ def getsource(obj, **kwargs):
|
|||
return Source(strsrc, **kwargs)
|
||||
|
||||
|
||||
def deindent(lines):
|
||||
def deindent(lines: Sequence[str]) -> List[str]:
|
||||
return textwrap.dedent("\n".join(lines)).splitlines()
|
||||
|
||||
|
||||
def get_statement_startend2(lineno, node):
|
||||
def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
|
||||
import ast
|
||||
|
||||
# flatten all statements and except handlers into one lineno-list
|
||||
# AST's line numbers start indexing at 1
|
||||
values = []
|
||||
values = [] # type: List[int]
|
||||
for x in ast.walk(node):
|
||||
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
|
||||
values.append(x.lineno - 1)
|
||||
for name in ("finalbody", "orelse"):
|
||||
val = getattr(x, name, None)
|
||||
val = getattr(x, name, None) # type: Optional[List[ast.stmt]]
|
||||
if val:
|
||||
# treat the finally/orelse part as its own statement
|
||||
values.append(val[0].lineno - 1 - 1)
|
||||
|
@ -283,7 +305,12 @@ def get_statement_startend2(lineno, node):
|
|||
return start, end
|
||||
|
||||
|
||||
def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None):
|
||||
def getstatementrange_ast(
|
||||
lineno: int,
|
||||
source: Source,
|
||||
assertion: bool = False,
|
||||
astnode: Optional[ast.AST] = None,
|
||||
) -> Tuple[ast.AST, int, int]:
|
||||
if astnode is None:
|
||||
content = str(source)
|
||||
# See #4260:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""(disabled by default) support for testing pytest and pytest plugins."""
|
||||
import collections.abc
|
||||
import gc
|
||||
import importlib
|
||||
import os
|
||||
|
@ -8,9 +9,15 @@ import subprocess
|
|||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Sequence
|
||||
from fnmatch import fnmatch
|
||||
from io import StringIO
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
|
@ -21,10 +28,16 @@ from _pytest._code import Source
|
|||
from _pytest._io.saferepr import saferepr
|
||||
from _pytest.capture import MultiCapture
|
||||
from _pytest.capture import SysCapture
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
from _pytest.main import ExitCode
|
||||
from _pytest.main import Session
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from _pytest.pathlib import Path
|
||||
from _pytest.reports import TestReport
|
||||
|
||||
if False: # TYPE_CHECKING
|
||||
from typing import Type
|
||||
|
||||
|
||||
IGNORE_PAM = [ # filenames added when obtaining details about the current user
|
||||
"/var/lib/sss/mc/passwd"
|
||||
|
@ -142,7 +155,7 @@ class LsofFdLeakChecker:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def _pytest(request):
|
||||
def _pytest(request: FixtureRequest) -> "PytestArg":
|
||||
"""Return a helper which offers a gethookrecorder(hook) method which
|
||||
returns a HookRecorder instance which helps to make assertions about called
|
||||
hooks.
|
||||
|
@ -152,10 +165,10 @@ def _pytest(request):
|
|||
|
||||
|
||||
class PytestArg:
|
||||
def __init__(self, request):
|
||||
def __init__(self, request: FixtureRequest) -> None:
|
||||
self.request = request
|
||||
|
||||
def gethookrecorder(self, hook):
|
||||
def gethookrecorder(self, hook) -> "HookRecorder":
|
||||
hookrecorder = HookRecorder(hook._pm)
|
||||
self.request.addfinalizer(hookrecorder.finish_recording)
|
||||
return hookrecorder
|
||||
|
@ -176,6 +189,11 @@ class ParsedCall:
|
|||
del d["_name"]
|
||||
return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
|
||||
|
||||
if False: # TYPE_CHECKING
|
||||
# The class has undetermined attributes, this tells mypy about it.
|
||||
def __getattr__(self, key):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class HookRecorder:
|
||||
"""Record all hooks called in a plugin manager.
|
||||
|
@ -185,27 +203,27 @@ class HookRecorder:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, pluginmanager):
|
||||
def __init__(self, pluginmanager) -> None:
|
||||
self._pluginmanager = pluginmanager
|
||||
self.calls = []
|
||||
self.calls = [] # type: List[ParsedCall]
|
||||
|
||||
def before(hook_name, hook_impls, kwargs):
|
||||
def before(hook_name: str, hook_impls, kwargs) -> None:
|
||||
self.calls.append(ParsedCall(hook_name, kwargs))
|
||||
|
||||
def after(outcome, hook_name, hook_impls, kwargs):
|
||||
def after(outcome, hook_name: str, hook_impls, kwargs) -> None:
|
||||
pass
|
||||
|
||||
self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after)
|
||||
|
||||
def finish_recording(self):
|
||||
def finish_recording(self) -> None:
|
||||
self._undo_wrapping()
|
||||
|
||||
def getcalls(self, names):
|
||||
def getcalls(self, names: Union[str, Iterable[str]]) -> List[ParsedCall]:
|
||||
if isinstance(names, str):
|
||||
names = names.split()
|
||||
return [call for call in self.calls if call._name in names]
|
||||
|
||||
def assert_contains(self, entries):
|
||||
def assert_contains(self, entries) -> None:
|
||||
__tracebackhide__ = True
|
||||
i = 0
|
||||
entries = list(entries)
|
||||
|
@ -226,7 +244,7 @@ class HookRecorder:
|
|||
else:
|
||||
pytest.fail("could not find {!r} check {!r}".format(name, check))
|
||||
|
||||
def popcall(self, name):
|
||||
def popcall(self, name: str) -> ParsedCall:
|
||||
__tracebackhide__ = True
|
||||
for i, call in enumerate(self.calls):
|
||||
if call._name == name:
|
||||
|
@ -236,20 +254,27 @@ class HookRecorder:
|
|||
lines.extend([" %s" % x for x in self.calls])
|
||||
pytest.fail("\n".join(lines))
|
||||
|
||||
def getcall(self, name):
|
||||
def getcall(self, name: str) -> ParsedCall:
|
||||
values = self.getcalls(name)
|
||||
assert len(values) == 1, (name, values)
|
||||
return values[0]
|
||||
|
||||
# functionality for test reports
|
||||
|
||||
def getreports(self, names="pytest_runtest_logreport pytest_collectreport"):
|
||||
def getreports(
|
||||
self,
|
||||
names: Union[
|
||||
str, Iterable[str]
|
||||
] = "pytest_runtest_logreport pytest_collectreport",
|
||||
) -> List[TestReport]:
|
||||
return [x.report for x in self.getcalls(names)]
|
||||
|
||||
def matchreport(
|
||||
self,
|
||||
inamepart="",
|
||||
names="pytest_runtest_logreport pytest_collectreport",
|
||||
inamepart: str = "",
|
||||
names: Union[
|
||||
str, Iterable[str]
|
||||
] = "pytest_runtest_logreport pytest_collectreport",
|
||||
when=None,
|
||||
):
|
||||
"""return a testreport whose dotted import path matches"""
|
||||
|
@ -275,13 +300,20 @@ class HookRecorder:
|
|||
)
|
||||
return values[0]
|
||||
|
||||
def getfailures(self, names="pytest_runtest_logreport pytest_collectreport"):
|
||||
def getfailures(
|
||||
self,
|
||||
names: Union[
|
||||
str, Iterable[str]
|
||||
] = "pytest_runtest_logreport pytest_collectreport",
|
||||
) -> List[TestReport]:
|
||||
return [rep for rep in self.getreports(names) if rep.failed]
|
||||
|
||||
def getfailedcollections(self):
|
||||
def getfailedcollections(self) -> List[TestReport]:
|
||||
return self.getfailures("pytest_collectreport")
|
||||
|
||||
def listoutcomes(self):
|
||||
def listoutcomes(
|
||||
self
|
||||
) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]:
|
||||
passed = []
|
||||
skipped = []
|
||||
failed = []
|
||||
|
@ -296,31 +328,31 @@ class HookRecorder:
|
|||
failed.append(rep)
|
||||
return passed, skipped, failed
|
||||
|
||||
def countoutcomes(self):
|
||||
def countoutcomes(self) -> List[int]:
|
||||
return [len(x) for x in self.listoutcomes()]
|
||||
|
||||
def assertoutcome(self, passed=0, skipped=0, failed=0):
|
||||
def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None:
|
||||
realpassed, realskipped, realfailed = self.listoutcomes()
|
||||
assert passed == len(realpassed)
|
||||
assert skipped == len(realskipped)
|
||||
assert failed == len(realfailed)
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
self.calls[:] = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def linecomp(request):
|
||||
def linecomp(request: FixtureRequest) -> "LineComp":
|
||||
return LineComp()
|
||||
|
||||
|
||||
@pytest.fixture(name="LineMatcher")
|
||||
def LineMatcher_fixture(request):
|
||||
def LineMatcher_fixture(request: FixtureRequest) -> "Type[LineMatcher]":
|
||||
return LineMatcher
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def testdir(request, tmpdir_factory):
|
||||
def testdir(request: FixtureRequest, tmpdir_factory) -> "Testdir":
|
||||
return Testdir(request, tmpdir_factory)
|
||||
|
||||
|
||||
|
@ -363,7 +395,13 @@ class RunResult:
|
|||
:ivar duration: duration in seconds
|
||||
"""
|
||||
|
||||
def __init__(self, ret: Union[int, ExitCode], outlines, errlines, duration) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
ret: Union[int, ExitCode],
|
||||
outlines: Sequence[str],
|
||||
errlines: Sequence[str],
|
||||
duration: float,
|
||||
) -> None:
|
||||
try:
|
||||
self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode]
|
||||
except ValueError:
|
||||
|
@ -374,13 +412,13 @@ class RunResult:
|
|||
self.stderr = LineMatcher(errlines)
|
||||
self.duration = duration
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"<RunResult ret=%s len(stdout.lines)=%d len(stderr.lines)=%d duration=%.2fs>"
|
||||
% (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration)
|
||||
)
|
||||
|
||||
def parseoutcomes(self):
|
||||
def parseoutcomes(self) -> Dict[str, int]:
|
||||
"""Return a dictionary of outcomestring->num from parsing the terminal
|
||||
output that the test process produced.
|
||||
|
||||
|
@ -393,8 +431,14 @@ class RunResult:
|
|||
raise ValueError("Pytest terminal summary report not found")
|
||||
|
||||
def assert_outcomes(
|
||||
self, passed=0, skipped=0, failed=0, error=0, xpassed=0, xfailed=0
|
||||
):
|
||||
self,
|
||||
passed: int = 0,
|
||||
skipped: int = 0,
|
||||
failed: int = 0,
|
||||
error: int = 0,
|
||||
xpassed: int = 0,
|
||||
xfailed: int = 0,
|
||||
) -> None:
|
||||
"""Assert that the specified outcomes appear with the respective
|
||||
numbers (0 means it didn't occur) in the text output from a test run.
|
||||
|
||||
|
@ -420,19 +464,19 @@ class RunResult:
|
|||
|
||||
|
||||
class CwdSnapshot:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.__saved = os.getcwd()
|
||||
|
||||
def restore(self):
|
||||
def restore(self) -> None:
|
||||
os.chdir(self.__saved)
|
||||
|
||||
|
||||
class SysModulesSnapshot:
|
||||
def __init__(self, preserve=None):
|
||||
def __init__(self, preserve: Optional[Callable[[str], bool]] = None):
|
||||
self.__preserve = preserve
|
||||
self.__saved = dict(sys.modules)
|
||||
|
||||
def restore(self):
|
||||
def restore(self) -> None:
|
||||
if self.__preserve:
|
||||
self.__saved.update(
|
||||
(k, m) for k, m in sys.modules.items() if self.__preserve(k)
|
||||
|
@ -442,10 +486,10 @@ class SysModulesSnapshot:
|
|||
|
||||
|
||||
class SysPathsSnapshot:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.__saved = list(sys.path), list(sys.meta_path)
|
||||
|
||||
def restore(self):
|
||||
def restore(self) -> None:
|
||||
sys.path[:], sys.meta_path[:] = self.__saved
|
||||
|
||||
|
||||
|
@ -1357,7 +1401,7 @@ class LineMatcher:
|
|||
:param str match_nickname: the nickname for the match function that
|
||||
will be logged to stdout when a match occurs
|
||||
"""
|
||||
assert isinstance(lines2, Sequence)
|
||||
assert isinstance(lines2, collections.abc.Sequence)
|
||||
lines2 = self._getlines(lines2)
|
||||
lines1 = self.lines[:]
|
||||
nextline = None
|
||||
|
|
|
@ -1,6 +1,14 @@
|
|||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
import attr
|
||||
|
||||
|
||||
if False: # TYPE_CHECKING
|
||||
from typing import Type # noqa: F401 (used in type string)
|
||||
|
||||
|
||||
class PytestWarning(UserWarning):
|
||||
"""
|
||||
Bases: :class:`UserWarning`.
|
||||
|
@ -72,7 +80,7 @@ class PytestExperimentalApiWarning(PytestWarning, FutureWarning):
|
|||
__module__ = "pytest"
|
||||
|
||||
@classmethod
|
||||
def simple(cls, apiname):
|
||||
def simple(cls, apiname: str) -> "PytestExperimentalApiWarning":
|
||||
return cls(
|
||||
"{apiname} is an experimental api that may change over time".format(
|
||||
apiname=apiname
|
||||
|
@ -103,17 +111,20 @@ class PytestUnknownMarkWarning(PytestWarning):
|
|||
__module__ = "pytest"
|
||||
|
||||
|
||||
_W = TypeVar("_W", bound=PytestWarning)
|
||||
|
||||
|
||||
@attr.s
|
||||
class UnformattedWarning:
|
||||
class UnformattedWarning(Generic[_W]):
|
||||
"""Used to hold warnings that need to format their message at runtime, as opposed to a direct message.
|
||||
|
||||
Using this class avoids to keep all the warning types and messages in this module, avoiding misuse.
|
||||
"""
|
||||
|
||||
category = attr.ib()
|
||||
template = attr.ib()
|
||||
category = attr.ib(type="Type[_W]")
|
||||
template = attr.ib(type=str)
|
||||
|
||||
def format(self, **kwargs):
|
||||
def format(self, **kwargs: Any) -> _W:
|
||||
"""Returns an instance of the warning category, formatted with given kwargs"""
|
||||
return self.category(self.template.format(**kwargs))
|
||||
|
||||
|
|
Loading…
Reference in New Issue