Merge pull request #6141 from bluetech/type-annotations-7

Add type annotations to _pytest.{warning_types,_code.source,pytester}
This commit is contained in:
Ran Benita 2019-11-07 17:11:01 +02:00 committed by GitHub
commit e670ff76cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 147 additions and 65 deletions

View File

@ -7,10 +7,17 @@ import tokenize
import warnings import warnings
from ast import PyCF_ONLY_AST as _AST_FLAG from ast import PyCF_ONLY_AST as _AST_FLAG
from bisect import bisect_right from bisect import bisect_right
from types import FrameType
from typing import List from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import py import py
from _pytest.compat import overload
class Source: class Source:
""" an immutable object holding a source code fragment, """ an immutable object holding a source code fragment,
@ -19,7 +26,7 @@ class Source:
_compilecounter = 0 _compilecounter = 0
def __init__(self, *parts, **kwargs): def __init__(self, *parts, **kwargs) -> None:
self.lines = lines = [] # type: List[str] self.lines = lines = [] # type: List[str]
de = kwargs.get("deindent", True) de = kwargs.get("deindent", True)
for part in parts: for part in parts:
@ -48,7 +55,15 @@ class Source:
# Ignore type because of https://github.com/python/mypy/issues/4266. # Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore __hash__ = None # type: ignore
def __getitem__(self, key): @overload
def __getitem__(self, key: int) -> str:
raise NotImplementedError()
@overload # noqa: F811
def __getitem__(self, key: slice) -> "Source":
raise NotImplementedError()
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811
if isinstance(key, int): if isinstance(key, int):
return self.lines[key] return self.lines[key]
else: else:
@ -58,10 +73,10 @@ class Source:
newsource.lines = self.lines[key.start : key.stop] newsource.lines = self.lines[key.start : key.stop]
return newsource return newsource
def __len__(self): def __len__(self) -> int:
return len(self.lines) return len(self.lines)
def strip(self): def strip(self) -> "Source":
""" return new source object with trailing """ return new source object with trailing
and leading blank lines removed. and leading blank lines removed.
""" """
@ -74,18 +89,20 @@ class Source:
source.lines[:] = self.lines[start:end] source.lines[:] = self.lines[start:end]
return source return source
def putaround(self, before="", after="", indent=" " * 4): def putaround(
self, before: str = "", after: str = "", indent: str = " " * 4
) -> "Source":
""" return a copy of the source object with """ return a copy of the source object with
'before' and 'after' wrapped around it. 'before' and 'after' wrapped around it.
""" """
before = Source(before) beforesource = Source(before)
after = Source(after) aftersource = Source(after)
newsource = Source() newsource = Source()
lines = [(indent + line) for line in self.lines] lines = [(indent + line) for line in self.lines]
newsource.lines = before.lines + lines + after.lines newsource.lines = beforesource.lines + lines + aftersource.lines
return newsource return newsource
def indent(self, indent=" " * 4): def indent(self, indent: str = " " * 4) -> "Source":
""" return a copy of the source object with """ return a copy of the source object with
all lines indented by the given indent-string. all lines indented by the given indent-string.
""" """
@ -93,14 +110,14 @@ class Source:
newsource.lines = [(indent + line) for line in self.lines] newsource.lines = [(indent + line) for line in self.lines]
return newsource return newsource
def getstatement(self, lineno): def getstatement(self, lineno: int) -> "Source":
""" return Source statement which contains the """ return Source statement which contains the
given linenumber (counted from 0). given linenumber (counted from 0).
""" """
start, end = self.getstatementrange(lineno) start, end = self.getstatementrange(lineno)
return self[start:end] return self[start:end]
def getstatementrange(self, lineno): def getstatementrange(self, lineno: int):
""" return (start, end) tuple which spans the minimal """ return (start, end) tuple which spans the minimal
statement region which containing the given lineno. statement region which containing the given lineno.
""" """
@ -109,13 +126,13 @@ class Source:
ast, start, end = getstatementrange_ast(lineno, self) ast, start, end = getstatementrange_ast(lineno, self)
return start, end return start, end
def deindent(self): def deindent(self) -> "Source":
"""return a new source object deindented.""" """return a new source object deindented."""
newsource = Source() newsource = Source()
newsource.lines[:] = deindent(self.lines) newsource.lines[:] = deindent(self.lines)
return newsource return newsource
def isparseable(self, deindent=True): def isparseable(self, deindent: bool = True) -> bool:
""" return True if source is parseable, heuristically """ return True if source is parseable, heuristically
deindenting it by default. deindenting it by default.
""" """
@ -135,11 +152,16 @@ class Source:
else: else:
return True return True
def __str__(self): def __str__(self) -> str:
return "\n".join(self.lines) return "\n".join(self.lines)
def compile( def compile(
self, filename=None, mode="exec", flag=0, dont_inherit=0, _genframe=None self,
filename=None,
mode="exec",
flag: int = 0,
dont_inherit: int = 0,
_genframe: Optional[FrameType] = None,
): ):
""" return compiled code object. if filename is None """ return compiled code object. if filename is None
invent an artificial filename which displays invent an artificial filename which displays
@ -183,7 +205,7 @@ class Source:
# #
def compile_(source, filename=None, mode="exec", flags=0, dont_inherit=0): def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0):
""" compile the given source to a raw code object, """ compile the given source to a raw code object,
and maintain an internal cache which allows later and maintain an internal cache which allows later
retrieval of the source code for the code object retrieval of the source code for the code object
@ -233,7 +255,7 @@ def getfslineno(obj):
# #
def findsource(obj): def findsource(obj) -> Tuple[Optional[Source], int]:
try: try:
sourcelines, lineno = inspect.findsource(obj) sourcelines, lineno = inspect.findsource(obj)
except Exception: except Exception:
@ -243,7 +265,7 @@ def findsource(obj):
return source, lineno return source, lineno
def getsource(obj, **kwargs): def getsource(obj, **kwargs) -> Source:
from .code import getrawcode from .code import getrawcode
obj = getrawcode(obj) obj = getrawcode(obj)
@ -255,21 +277,21 @@ def getsource(obj, **kwargs):
return Source(strsrc, **kwargs) return Source(strsrc, **kwargs)
def deindent(lines): def deindent(lines: Sequence[str]) -> List[str]:
return textwrap.dedent("\n".join(lines)).splitlines() return textwrap.dedent("\n".join(lines)).splitlines()
def get_statement_startend2(lineno, node): def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
import ast import ast
# flatten all statements and except handlers into one lineno-list # flatten all statements and except handlers into one lineno-list
# AST's line numbers start indexing at 1 # AST's line numbers start indexing at 1
values = [] values = [] # type: List[int]
for x in ast.walk(node): for x in ast.walk(node):
if isinstance(x, (ast.stmt, ast.ExceptHandler)): if isinstance(x, (ast.stmt, ast.ExceptHandler)):
values.append(x.lineno - 1) values.append(x.lineno - 1)
for name in ("finalbody", "orelse"): for name in ("finalbody", "orelse"):
val = getattr(x, name, None) val = getattr(x, name, None) # type: Optional[List[ast.stmt]]
if val: if val:
# treat the finally/orelse part as its own statement # treat the finally/orelse part as its own statement
values.append(val[0].lineno - 1 - 1) values.append(val[0].lineno - 1 - 1)
@ -283,7 +305,12 @@ def get_statement_startend2(lineno, node):
return start, end return start, end
def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None): def getstatementrange_ast(
lineno: int,
source: Source,
assertion: bool = False,
astnode: Optional[ast.AST] = None,
) -> Tuple[ast.AST, int, int]:
if astnode is None: if astnode is None:
content = str(source) content = str(source)
# See #4260: # See #4260:

View File

@ -1,4 +1,5 @@
"""(disabled by default) support for testing pytest and pytest plugins.""" """(disabled by default) support for testing pytest and pytest plugins."""
import collections.abc
import gc import gc
import importlib import importlib
import os import os
@ -8,9 +9,15 @@ import subprocess
import sys import sys
import time import time
import traceback import traceback
from collections.abc import Sequence
from fnmatch import fnmatch from fnmatch import fnmatch
from io import StringIO from io import StringIO
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union from typing import Union
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
@ -21,10 +28,16 @@ from _pytest._code import Source
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest.capture import MultiCapture from _pytest.capture import MultiCapture
from _pytest.capture import SysCapture from _pytest.capture import SysCapture
from _pytest.fixtures import FixtureRequest
from _pytest.main import ExitCode from _pytest.main import ExitCode
from _pytest.main import Session from _pytest.main import Session
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from _pytest.pathlib import Path from _pytest.pathlib import Path
from _pytest.reports import TestReport
if False: # TYPE_CHECKING
from typing import Type
IGNORE_PAM = [ # filenames added when obtaining details about the current user IGNORE_PAM = [ # filenames added when obtaining details about the current user
"/var/lib/sss/mc/passwd" "/var/lib/sss/mc/passwd"
@ -142,7 +155,7 @@ class LsofFdLeakChecker:
@pytest.fixture @pytest.fixture
def _pytest(request): def _pytest(request: FixtureRequest) -> "PytestArg":
"""Return a helper which offers a gethookrecorder(hook) method which """Return a helper which offers a gethookrecorder(hook) method which
returns a HookRecorder instance which helps to make assertions about called returns a HookRecorder instance which helps to make assertions about called
hooks. hooks.
@ -152,10 +165,10 @@ def _pytest(request):
class PytestArg: class PytestArg:
def __init__(self, request): def __init__(self, request: FixtureRequest) -> None:
self.request = request self.request = request
def gethookrecorder(self, hook): def gethookrecorder(self, hook) -> "HookRecorder":
hookrecorder = HookRecorder(hook._pm) hookrecorder = HookRecorder(hook._pm)
self.request.addfinalizer(hookrecorder.finish_recording) self.request.addfinalizer(hookrecorder.finish_recording)
return hookrecorder return hookrecorder
@ -176,6 +189,11 @@ class ParsedCall:
del d["_name"] del d["_name"]
return "<ParsedCall {!r}(**{!r})>".format(self._name, d) return "<ParsedCall {!r}(**{!r})>".format(self._name, d)
if False: # TYPE_CHECKING
# The class has undetermined attributes, this tells mypy about it.
def __getattr__(self, key):
raise NotImplementedError()
class HookRecorder: class HookRecorder:
"""Record all hooks called in a plugin manager. """Record all hooks called in a plugin manager.
@ -185,27 +203,27 @@ class HookRecorder:
""" """
def __init__(self, pluginmanager): def __init__(self, pluginmanager) -> None:
self._pluginmanager = pluginmanager self._pluginmanager = pluginmanager
self.calls = [] self.calls = [] # type: List[ParsedCall]
def before(hook_name, hook_impls, kwargs): def before(hook_name: str, hook_impls, kwargs) -> None:
self.calls.append(ParsedCall(hook_name, kwargs)) self.calls.append(ParsedCall(hook_name, kwargs))
def after(outcome, hook_name, hook_impls, kwargs): def after(outcome, hook_name: str, hook_impls, kwargs) -> None:
pass pass
self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after) self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after)
def finish_recording(self): def finish_recording(self) -> None:
self._undo_wrapping() self._undo_wrapping()
def getcalls(self, names): def getcalls(self, names: Union[str, Iterable[str]]) -> List[ParsedCall]:
if isinstance(names, str): if isinstance(names, str):
names = names.split() names = names.split()
return [call for call in self.calls if call._name in names] return [call for call in self.calls if call._name in names]
def assert_contains(self, entries): def assert_contains(self, entries) -> None:
__tracebackhide__ = True __tracebackhide__ = True
i = 0 i = 0
entries = list(entries) entries = list(entries)
@ -226,7 +244,7 @@ class HookRecorder:
else: else:
pytest.fail("could not find {!r} check {!r}".format(name, check)) pytest.fail("could not find {!r} check {!r}".format(name, check))
def popcall(self, name): def popcall(self, name: str) -> ParsedCall:
__tracebackhide__ = True __tracebackhide__ = True
for i, call in enumerate(self.calls): for i, call in enumerate(self.calls):
if call._name == name: if call._name == name:
@ -236,20 +254,27 @@ class HookRecorder:
lines.extend([" %s" % x for x in self.calls]) lines.extend([" %s" % x for x in self.calls])
pytest.fail("\n".join(lines)) pytest.fail("\n".join(lines))
def getcall(self, name): def getcall(self, name: str) -> ParsedCall:
values = self.getcalls(name) values = self.getcalls(name)
assert len(values) == 1, (name, values) assert len(values) == 1, (name, values)
return values[0] return values[0]
# functionality for test reports # functionality for test reports
def getreports(self, names="pytest_runtest_logreport pytest_collectreport"): def getreports(
self,
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
) -> List[TestReport]:
return [x.report for x in self.getcalls(names)] return [x.report for x in self.getcalls(names)]
def matchreport( def matchreport(
self, self,
inamepart="", inamepart: str = "",
names="pytest_runtest_logreport pytest_collectreport", names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
when=None, when=None,
): ):
"""return a testreport whose dotted import path matches""" """return a testreport whose dotted import path matches"""
@ -275,13 +300,20 @@ class HookRecorder:
) )
return values[0] return values[0]
def getfailures(self, names="pytest_runtest_logreport pytest_collectreport"): def getfailures(
self,
names: Union[
str, Iterable[str]
] = "pytest_runtest_logreport pytest_collectreport",
) -> List[TestReport]:
return [rep for rep in self.getreports(names) if rep.failed] return [rep for rep in self.getreports(names) if rep.failed]
def getfailedcollections(self): def getfailedcollections(self) -> List[TestReport]:
return self.getfailures("pytest_collectreport") return self.getfailures("pytest_collectreport")
def listoutcomes(self): def listoutcomes(
self
) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]:
passed = [] passed = []
skipped = [] skipped = []
failed = [] failed = []
@ -296,31 +328,31 @@ class HookRecorder:
failed.append(rep) failed.append(rep)
return passed, skipped, failed return passed, skipped, failed
def countoutcomes(self): def countoutcomes(self) -> List[int]:
return [len(x) for x in self.listoutcomes()] return [len(x) for x in self.listoutcomes()]
def assertoutcome(self, passed=0, skipped=0, failed=0): def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None:
realpassed, realskipped, realfailed = self.listoutcomes() realpassed, realskipped, realfailed = self.listoutcomes()
assert passed == len(realpassed) assert passed == len(realpassed)
assert skipped == len(realskipped) assert skipped == len(realskipped)
assert failed == len(realfailed) assert failed == len(realfailed)
def clear(self): def clear(self) -> None:
self.calls[:] = [] self.calls[:] = []
@pytest.fixture @pytest.fixture
def linecomp(request): def linecomp(request: FixtureRequest) -> "LineComp":
return LineComp() return LineComp()
@pytest.fixture(name="LineMatcher") @pytest.fixture(name="LineMatcher")
def LineMatcher_fixture(request): def LineMatcher_fixture(request: FixtureRequest) -> "Type[LineMatcher]":
return LineMatcher return LineMatcher
@pytest.fixture @pytest.fixture
def testdir(request, tmpdir_factory): def testdir(request: FixtureRequest, tmpdir_factory) -> "Testdir":
return Testdir(request, tmpdir_factory) return Testdir(request, tmpdir_factory)
@ -363,7 +395,13 @@ class RunResult:
:ivar duration: duration in seconds :ivar duration: duration in seconds
""" """
def __init__(self, ret: Union[int, ExitCode], outlines, errlines, duration) -> None: def __init__(
self,
ret: Union[int, ExitCode],
outlines: Sequence[str],
errlines: Sequence[str],
duration: float,
) -> None:
try: try:
self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode] self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode]
except ValueError: except ValueError:
@ -374,13 +412,13 @@ class RunResult:
self.stderr = LineMatcher(errlines) self.stderr = LineMatcher(errlines)
self.duration = duration self.duration = duration
def __repr__(self): def __repr__(self) -> str:
return ( return (
"<RunResult ret=%s len(stdout.lines)=%d len(stderr.lines)=%d duration=%.2fs>" "<RunResult ret=%s len(stdout.lines)=%d len(stderr.lines)=%d duration=%.2fs>"
% (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration) % (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration)
) )
def parseoutcomes(self): def parseoutcomes(self) -> Dict[str, int]:
"""Return a dictionary of outcomestring->num from parsing the terminal """Return a dictionary of outcomestring->num from parsing the terminal
output that the test process produced. output that the test process produced.
@ -393,8 +431,14 @@ class RunResult:
raise ValueError("Pytest terminal summary report not found") raise ValueError("Pytest terminal summary report not found")
def assert_outcomes( def assert_outcomes(
self, passed=0, skipped=0, failed=0, error=0, xpassed=0, xfailed=0 self,
): passed: int = 0,
skipped: int = 0,
failed: int = 0,
error: int = 0,
xpassed: int = 0,
xfailed: int = 0,
) -> None:
"""Assert that the specified outcomes appear with the respective """Assert that the specified outcomes appear with the respective
numbers (0 means it didn't occur) in the text output from a test run. numbers (0 means it didn't occur) in the text output from a test run.
@ -420,19 +464,19 @@ class RunResult:
class CwdSnapshot: class CwdSnapshot:
def __init__(self): def __init__(self) -> None:
self.__saved = os.getcwd() self.__saved = os.getcwd()
def restore(self): def restore(self) -> None:
os.chdir(self.__saved) os.chdir(self.__saved)
class SysModulesSnapshot: class SysModulesSnapshot:
def __init__(self, preserve=None): def __init__(self, preserve: Optional[Callable[[str], bool]] = None):
self.__preserve = preserve self.__preserve = preserve
self.__saved = dict(sys.modules) self.__saved = dict(sys.modules)
def restore(self): def restore(self) -> None:
if self.__preserve: if self.__preserve:
self.__saved.update( self.__saved.update(
(k, m) for k, m in sys.modules.items() if self.__preserve(k) (k, m) for k, m in sys.modules.items() if self.__preserve(k)
@ -442,10 +486,10 @@ class SysModulesSnapshot:
class SysPathsSnapshot: class SysPathsSnapshot:
def __init__(self): def __init__(self) -> None:
self.__saved = list(sys.path), list(sys.meta_path) self.__saved = list(sys.path), list(sys.meta_path)
def restore(self): def restore(self) -> None:
sys.path[:], sys.meta_path[:] = self.__saved sys.path[:], sys.meta_path[:] = self.__saved
@ -1357,7 +1401,7 @@ class LineMatcher:
:param str match_nickname: the nickname for the match function that :param str match_nickname: the nickname for the match function that
will be logged to stdout when a match occurs will be logged to stdout when a match occurs
""" """
assert isinstance(lines2, Sequence) assert isinstance(lines2, collections.abc.Sequence)
lines2 = self._getlines(lines2) lines2 = self._getlines(lines2)
lines1 = self.lines[:] lines1 = self.lines[:]
nextline = None nextline = None

View File

@ -1,6 +1,14 @@
from typing import Any
from typing import Generic
from typing import TypeVar
import attr import attr
if False: # TYPE_CHECKING
from typing import Type # noqa: F401 (used in type string)
class PytestWarning(UserWarning): class PytestWarning(UserWarning):
""" """
Bases: :class:`UserWarning`. Bases: :class:`UserWarning`.
@ -72,7 +80,7 @@ class PytestExperimentalApiWarning(PytestWarning, FutureWarning):
__module__ = "pytest" __module__ = "pytest"
@classmethod @classmethod
def simple(cls, apiname): def simple(cls, apiname: str) -> "PytestExperimentalApiWarning":
return cls( return cls(
"{apiname} is an experimental api that may change over time".format( "{apiname} is an experimental api that may change over time".format(
apiname=apiname apiname=apiname
@ -103,17 +111,20 @@ class PytestUnknownMarkWarning(PytestWarning):
__module__ = "pytest" __module__ = "pytest"
_W = TypeVar("_W", bound=PytestWarning)
@attr.s @attr.s
class UnformattedWarning: class UnformattedWarning(Generic[_W]):
"""Used to hold warnings that need to format their message at runtime, as opposed to a direct message. """Used to hold warnings that need to format their message at runtime, as opposed to a direct message.
Using this class avoids to keep all the warning types and messages in this module, avoiding misuse. Using this class avoids to keep all the warning types and messages in this module, avoiding misuse.
""" """
category = attr.ib() category = attr.ib(type="Type[_W]")
template = attr.ib() template = attr.ib(type=str)
def format(self, **kwargs): def format(self, **kwargs: Any) -> _W:
"""Returns an instance of the warning category, formatted with given kwargs""" """Returns an instance of the warning category, formatted with given kwargs"""
return self.category(self.template.format(**kwargs)) return self.category(self.template.format(**kwargs))