Fix some check_untyped_defs = True mypy warnings

This commit is contained in:
Ran Benita 2019-07-14 18:45:40 +03:00 committed by Ran Benita
parent 28761c8da1
commit 7259c453d6
13 changed files with 196 additions and 106 deletions

View File

@ -5,10 +5,15 @@ import traceback
from inspect import CO_VARARGS
from inspect import CO_VARKEYWORDS
from traceback import format_exception_only
from types import CodeType
from types import TracebackType
from typing import Any
from typing import Dict
from typing import Generic
from typing import List
from typing import Optional
from typing import Pattern
from typing import Set
from typing import Tuple
from typing import TypeVar
from typing import Union
@ -29,7 +34,7 @@ if False: # TYPE_CHECKING
class Code:
""" wrapper around Python code objects """
def __init__(self, rawcode):
def __init__(self, rawcode) -> None:
if not hasattr(rawcode, "co_filename"):
rawcode = getrawcode(rawcode)
try:
@ -38,7 +43,7 @@ class Code:
self.name = rawcode.co_name
except AttributeError:
raise TypeError("not a code object: {!r}".format(rawcode))
self.raw = rawcode
self.raw = rawcode # type: CodeType
def __eq__(self, other):
return self.raw == other.raw
@ -351,7 +356,7 @@ class Traceback(list):
""" return the index of the frame/TracebackEntry where recursion
originates if appropriate, None if no recursion occurred
"""
cache = {}
cache = {} # type: Dict[Tuple[Any, int, int], List[Dict[str, Any]]]
for i, entry in enumerate(self):
# id for the code.raw is needed to work around
# the strange metaprogramming in the decorator lib from pypi
@ -650,7 +655,7 @@ class FormattedExcinfo:
args.append((argname, saferepr(argvalue)))
return ReprFuncArgs(args)
def get_source(self, source, line_index=-1, excinfo=None, short=False):
def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]:
""" return formatted and marked up source lines. """
import _pytest._code
@ -722,7 +727,7 @@ class FormattedExcinfo:
else:
line_index = entry.lineno - entry.getfirstlinesource()
lines = []
lines = [] # type: List[str]
style = entry._repr_style
if style is None:
style = self.style
@ -799,7 +804,7 @@ class FormattedExcinfo:
exc_msg=str(e),
max_frames=max_frames,
total=len(traceback),
)
) # type: Optional[str]
traceback = traceback[:max_frames] + traceback[-max_frames:]
else:
if recursionindex is not None:
@ -812,10 +817,12 @@ class FormattedExcinfo:
def repr_excinfo(self, excinfo):
repr_chain = []
repr_chain = (
[]
) # type: List[Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]]
e = excinfo.value
descr = None
seen = set()
seen = set() # type: Set[int]
while e is not None and id(e) not in seen:
seen.add(id(e))
if excinfo:
@ -868,8 +875,8 @@ class TerminalRepr:
class ExceptionRepr(TerminalRepr):
def __init__(self):
self.sections = []
def __init__(self) -> None:
self.sections = [] # type: List[Tuple[str, str, str]]
def addsection(self, name, content, sep="-"):
self.sections.append((name, content, sep))

View File

@ -7,6 +7,7 @@ import tokenize
import warnings
from ast import PyCF_ONLY_AST as _AST_FLAG
from bisect import bisect_right
from typing import List
import py
@ -19,11 +20,11 @@ class Source:
_compilecounter = 0
def __init__(self, *parts, **kwargs):
self.lines = lines = []
self.lines = lines = [] # type: List[str]
de = kwargs.get("deindent", True)
for part in parts:
if not part:
partlines = []
partlines = [] # type: List[str]
elif isinstance(part, Source):
partlines = part.lines
elif isinstance(part, (tuple, list)):
@ -157,8 +158,7 @@ class Source:
source = "\n".join(self.lines) + "\n"
try:
co = compile(source, filename, mode, flag)
except SyntaxError:
ex = sys.exc_info()[1]
except SyntaxError as ex:
# re-represent syntax errors from parsing python strings
msglines = self.lines[: ex.lineno]
if ex.offset:
@ -173,7 +173,8 @@ class Source:
if flag & _AST_FLAG:
return co
lines = [(x + "\n") for x in self.lines]
linecache.cache[filename] = (1, None, lines, filename)
# Type ignored because linecache.cache is private.
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
return co
@ -282,7 +283,7 @@ def get_statement_startend2(lineno, node):
return start, end
def getstatementrange_ast(lineno, source, assertion=False, astnode=None):
def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None):
if astnode is None:
content = str(source)
# See #4260:

View File

@ -2,6 +2,7 @@
support for presenting detailed information in failing assertions.
"""
import sys
from typing import Optional
from _pytest.assertion import rewrite
from _pytest.assertion import truncate
@ -52,7 +53,9 @@ def register_assert_rewrite(*names):
importhook = hook
break
else:
importhook = DummyRewriteHook()
# TODO(typing): Add a protocol for mark_rewrite() and use it
# for importhook and for PytestPluginManager.rewrite_hook.
importhook = DummyRewriteHook() # type: ignore
importhook.mark_rewrite(*names)
@ -69,7 +72,7 @@ class AssertionState:
def __init__(self, config, mode):
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.hook = None
self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
def install_importhook(config):
@ -108,6 +111,7 @@ def pytest_runtest_setup(item):
"""
def callbinrepr(op, left, right):
# type: (str, object, object) -> Optional[str]
"""Call the pytest_assertrepr_compare hook and prepare the result
This uses the first result from the hook and then ensures the
@ -133,12 +137,13 @@ def pytest_runtest_setup(item):
if item.config.getvalue("assertmode") == "rewrite":
res = res.replace("%", "%%")
return res
return None
util._reprcompare = callbinrepr
if item.ihook.pytest_assertion_pass.get_hookimpls():
def call_assertion_pass_hook(lineno, expl, orig):
def call_assertion_pass_hook(lineno, orig, expl):
item.ihook.pytest_assertion_pass(
item=item, lineno=lineno, orig=orig, expl=expl
)

View File

@ -17,6 +17,7 @@ from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
import atomicwrites
@ -48,13 +49,13 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder):
except ValueError:
self.fnpats = ["test_*.py", "*_test.py"]
self.session = None
self._rewritten_names = set()
self._must_rewrite = set()
self._rewritten_names = set() # type: Set[str]
self._must_rewrite = set() # type: Set[str]
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506)
self._writing_pyc = False
self._basenames_to_check_rewrite = {"conftest"}
self._marked_for_rewrite_cache = {}
self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
self._session_paths_checked = False
def set_session(self, session):
@ -203,7 +204,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder):
return self._is_marked_for_rewrite(name, state)
def _is_marked_for_rewrite(self, name, state):
def _is_marked_for_rewrite(self, name: str, state):
try:
return self._marked_for_rewrite_cache[name]
except KeyError:
@ -218,7 +219,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder):
self._marked_for_rewrite_cache[name] = False
return False
def mark_rewrite(self, *names):
def mark_rewrite(self, *names: str) -> None:
"""Mark import names as needing to be rewritten.
The named module or package as well as any nested modules will
@ -385,6 +386,7 @@ def _format_boolop(explanations, is_or):
def _call_reprcompare(ops, results, expls, each_obj):
# type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
for i, res, expl in zip(range(len(ops)), results, expls):
try:
done = not res
@ -400,11 +402,13 @@ def _call_reprcompare(ops, results, expls, each_obj):
def _call_assertion_pass(lineno, orig, expl):
# type: (int, str, str) -> None
if util._assertion_pass is not None:
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)
util._assertion_pass(lineno, orig, expl)
def _check_if_assertion_pass_impl():
# type: () -> bool
"""Checks if any plugins implement the pytest_assertion_pass hook
in order not to generate explanation unecessarily (might be expensive)"""
return True if util._assertion_pass else False
@ -578,7 +582,7 @@ class AssertionRewriter(ast.NodeVisitor):
def _assert_expr_to_lineno(self):
return _get_assertion_exprs(self.source)
def run(self, mod):
def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
if not mod.body:
# Nothing to do.
@ -620,12 +624,12 @@ class AssertionRewriter(ast.NodeVisitor):
]
mod.body[pos:pos] = imports
# Collect asserts.
nodes = [mod]
nodes = [mod] # type: List[ast.AST]
while nodes:
node = nodes.pop()
for name, field in ast.iter_fields(node):
if isinstance(field, list):
new = []
new = [] # type: List
for i, child in enumerate(field):
if isinstance(child, ast.Assert):
# Transform assert.
@ -699,7 +703,7 @@ class AssertionRewriter(ast.NodeVisitor):
.explanation_param().
"""
self.explanation_specifiers = {}
self.explanation_specifiers = {} # type: Dict[str, ast.expr]
self.stack.append(self.explanation_specifiers)
def pop_format_context(self, expl_expr):
@ -742,7 +746,8 @@ class AssertionRewriter(ast.NodeVisitor):
from _pytest.warning_types import PytestAssertRewriteWarning
import warnings
warnings.warn_explicit(
# Ignore type: typeshed bug https://github.com/python/typeshed/pull/3121
warnings.warn_explicit( # type: ignore
PytestAssertRewriteWarning(
"assertion is always true, perhaps remove parentheses?"
),
@ -751,15 +756,15 @@ class AssertionRewriter(ast.NodeVisitor):
lineno=assert_.lineno,
)
self.statements = []
self.variables = []
self.statements = [] # type: List[ast.stmt]
self.variables = [] # type: List[str]
self.variable_counter = itertools.count()
if self.enable_assertion_pass_hook:
self.format_variables = []
self.format_variables = [] # type: List[str]
self.stack = []
self.expl_stmts = []
self.stack = [] # type: List[Dict[str, ast.expr]]
self.expl_stmts = [] # type: List[ast.stmt]
self.push_format_context()
# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)
@ -897,7 +902,7 @@ warn_explicit(
# Process each operand, short-circuiting if needed.
for i, v in enumerate(boolop.values):
if i:
fail_inner = []
fail_inner = [] # type: List[ast.stmt]
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
@ -908,10 +913,10 @@ warn_explicit(
call = ast.Call(app, [expl_format], [])
self.expl_stmts.append(ast.Expr(call))
if i < levels:
cond = res
cond = res # type: ast.expr
if is_or:
cond = ast.UnaryOp(ast.Not(), cond)
inner = []
inner = [] # type: List[ast.stmt]
self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner
self.statements = save
@ -977,7 +982,7 @@ warn_explicit(
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
return res, expl
def visit_Compare(self, comp):
def visit_Compare(self, comp: ast.Compare):
self.push_format_context()
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
@ -1010,7 +1015,7 @@ warn_explicit(
ast.Tuple(results, ast.Load()),
)
if len(comp.ops) > 1:
res = ast.BoolOp(ast.And(), load_names)
res = ast.BoolOp(ast.And(), load_names) # type: ast.expr
else:
res = load_names[0]
return res, self.explanation_param(self.pop_format_context(expl_call))

View File

@ -1,6 +1,9 @@
"""Utilities for assertion debugging"""
import pprint
from collections.abc import Sequence
from typing import Callable
from typing import List
from typing import Optional
import _pytest._code
from _pytest import outcomes
@ -10,11 +13,11 @@ from _pytest._io.saferepr import saferepr
# interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the
# DebugInterpreter.
_reprcompare = None
_reprcompare = None # type: Optional[Callable[[str, object, object], Optional[str]]]
# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass = None
_assertion_pass = None # type: Optional[Callable[[int, str, str], None]]
def format_explanation(explanation):
@ -177,7 +180,7 @@ def _diff_text(left, right, verbose=0):
"""
from difflib import ndiff
explanation = []
explanation = [] # type: List[str]
def escape_for_readable_diff(binary_text):
"""
@ -235,7 +238,7 @@ def _compare_eq_verbose(left, right):
left_lines = repr(left).splitlines(keepends)
right_lines = repr(right).splitlines(keepends)
explanation = []
explanation = [] # type: List[str]
explanation += ["-" + line for line in left_lines]
explanation += ["+" + line for line in right_lines]
@ -259,7 +262,7 @@ def _compare_eq_iterable(left, right, verbose=0):
def _compare_eq_sequence(left, right, verbose=0):
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
explanation = []
explanation = [] # type: List[str]
len_left = len(left)
len_right = len(right)
for i in range(min(len_left, len_right)):
@ -327,7 +330,7 @@ def _compare_eq_set(left, right, verbose=0):
def _compare_eq_dict(left, right, verbose=0):
explanation = []
explanation = [] # type: List[str]
set_left = set(left)
set_right = set(right)
common = set_left.intersection(set_right)

View File

@ -9,6 +9,15 @@ import types
import warnings
from functools import lru_cache
from pathlib import Path
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
import attr
import py
@ -32,6 +41,10 @@ from _pytest.outcomes import fail
from _pytest.outcomes import Skipped
from _pytest.warning_types import PytestConfigWarning
if False: # TYPE_CHECKING
from typing import Type
hookimpl = HookimplMarker("pytest")
hookspec = HookspecMarker("pytest")
@ -40,7 +53,7 @@ class ConftestImportFailure(Exception):
def __init__(self, path, excinfo):
Exception.__init__(self, path, excinfo)
self.path = path
self.excinfo = excinfo
self.excinfo = excinfo # type: Tuple[Type[Exception], Exception, TracebackType]
def main(args=None, plugins=None):
@ -237,14 +250,18 @@ class PytestPluginManager(PluginManager):
def __init__(self):
super().__init__("pytest")
self._conftest_plugins = set()
# The objects are module objects, only used generically.
self._conftest_plugins = set() # type: Set[object]
# state related to local conftest plugins
self._dirpath2confmods = {}
self._conftestpath2mod = {}
# Maps a py.path.local to a list of module objects.
self._dirpath2confmods = {} # type: Dict[Any, List[object]]
# Maps a py.path.local to a module object.
self._conftestpath2mod = {} # type: Dict[Any, object]
self._confcutdir = None
self._noconftest = False
self._duplicatepaths = set()
# Set of py.path.local's.
self._duplicatepaths = set() # type: Set[Any]
self.add_hookspecs(_pytest.hookspec)
self.register(self)
@ -653,7 +670,7 @@ class Config:
args = attr.ib()
plugins = attr.ib()
dir = attr.ib()
dir = attr.ib(type=Path)
def __init__(self, pluginmanager, *, invocation_params=None):
from .argparsing import Parser, FILE_OR_DIR
@ -674,10 +691,10 @@ class Config:
self.pluginmanager = pluginmanager
self.trace = self.pluginmanager.trace.root.get("config")
self.hook = self.pluginmanager.hook
self._inicache = {}
self._override_ini = ()
self._opt2dest = {}
self._cleanup = []
self._inicache = {} # type: Dict[str, Any]
self._override_ini = () # type: Sequence[str]
self._opt2dest = {} # type: Dict[str, str]
self._cleanup = [] # type: List[Callable[[], None]]
self.pluginmanager.register(self, "pytestconfig")
self._configured = False
self.hook.pytest_addoption.call_historic(kwargs=dict(parser=self._parser))
@ -778,7 +795,7 @@ class Config:
def pytest_load_initial_conftests(self, early_config):
self.pluginmanager._set_initial_conftests(early_config.known_args_namespace)
def _initini(self, args):
def _initini(self, args) -> None:
ns, unknown_args = self._parser.parse_known_and_unknown_args(
args, namespace=copy.copy(self.option)
)
@ -879,8 +896,7 @@ class Config:
self.hook.pytest_load_initial_conftests(
early_config=self, args=args, parser=self._parser
)
except ConftestImportFailure:
e = sys.exc_info()[1]
except ConftestImportFailure as e:
if ns.help or ns.version:
# we don't want to prevent --help/--version to work
# so just let is pass and print a warning at the end
@ -946,7 +962,7 @@ class Config:
assert isinstance(x, list)
x.append(line) # modifies the cached list inline
def getini(self, name):
def getini(self, name: str):
""" return configuration value from an :ref:`ini file <inifiles>`. If the
specified name hasn't been registered through a prior
:py:func:`parser.addini <_pytest.config.Parser.addini>`
@ -957,7 +973,7 @@ class Config:
self._inicache[name] = val = self._getini(name)
return val
def _getini(self, name):
def _getini(self, name: str) -> Any:
try:
description, type, default = self._parser._inidict[name]
except KeyError:
@ -1002,7 +1018,7 @@ class Config:
values.append(relroot)
return values
def _get_override_ini_value(self, name):
def _get_override_ini_value(self, name: str) -> Optional[str]:
value = None
# override_ini is a list of "ini=value" options
# always use the last item if multiple values are set for same ini-name,
@ -1017,7 +1033,7 @@ class Config:
value = user_ini_value
return value
def getoption(self, name, default=notset, skip=False):
def getoption(self, name: str, default=notset, skip: bool = False):
""" return command line option value.
:arg name: name of the option. You may also specify

View File

@ -2,6 +2,11 @@ import argparse
import sys
import warnings
from gettext import gettext
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import py
@ -21,12 +26,12 @@ class Parser:
def __init__(self, usage=None, processopt=None):
self._anonymous = OptionGroup("custom options", parser=self)
self._groups = []
self._groups = [] # type: List[OptionGroup]
self._processopt = processopt
self._usage = usage
self._inidict = {}
self._ininames = []
self.extra_info = {}
self._inidict = {} # type: Dict[str, Tuple[str, Optional[str], Any]]
self._ininames = [] # type: List[str]
self.extra_info = {} # type: Dict[str, Any]
def processoption(self, option):
if self._processopt:
@ -80,7 +85,7 @@ class Parser:
args = [str(x) if isinstance(x, py.path.local) else x for x in args]
return self.optparser.parse_args(args, namespace=namespace)
def _getparser(self):
def _getparser(self) -> "MyOptionParser":
from _pytest._argcomplete import filescompleter
optparser = MyOptionParser(self, self.extra_info, prog=self.prog)
@ -94,7 +99,10 @@ class Parser:
a = option.attrs()
arggroup.add_argument(*n, **a)
# bash like autocompletion for dirs (appending '/')
optparser.add_argument(FILE_OR_DIR, nargs="*").completer = filescompleter
# Type ignored because typeshed doesn't know about argcomplete.
optparser.add_argument( # type: ignore
FILE_OR_DIR, nargs="*"
).completer = filescompleter
return optparser
def parse_setoption(self, args, option, namespace=None):
@ -103,13 +111,15 @@ class Parser:
setattr(option, name, value)
return getattr(parsedoption, FILE_OR_DIR)
def parse_known_args(self, args, namespace=None):
def parse_known_args(self, args, namespace=None) -> argparse.Namespace:
"""parses and returns a namespace object with known arguments at this
point.
"""
return self.parse_known_and_unknown_args(args, namespace=namespace)[0]
def parse_known_and_unknown_args(self, args, namespace=None):
def parse_known_and_unknown_args(
self, args, namespace=None
) -> Tuple[argparse.Namespace, List[str]]:
"""parses and returns a namespace object with known arguments, and
the remaining arguments unknown at this point.
"""
@ -163,8 +173,8 @@ class Argument:
def __init__(self, *names, **attrs):
"""store parms in private vars for use in add_argument"""
self._attrs = attrs
self._short_opts = []
self._long_opts = []
self._short_opts = [] # type: List[str]
self._long_opts = [] # type: List[str]
self.dest = attrs.get("dest")
if "%default" in (attrs.get("help") or ""):
warnings.warn(
@ -268,8 +278,8 @@ class Argument:
)
self._long_opts.append(opt)
def __repr__(self):
args = []
def __repr__(self) -> str:
args = [] # type: List[str]
if self._short_opts:
args += ["_short_opts: " + repr(self._short_opts)]
if self._long_opts:
@ -286,7 +296,7 @@ class OptionGroup:
def __init__(self, name, description="", parser=None):
self.name = name
self.description = description
self.options = []
self.options = [] # type: List[Argument]
self.parser = parser
def addoption(self, *optnames, **attrs):
@ -421,7 +431,7 @@ class DropShorterLongHelpFormatter(argparse.HelpFormatter):
option_map = getattr(action, "map_long_option", {})
if option_map is None:
option_map = {}
short_long = {}
short_long = {} # type: Dict[str, str]
for option in options:
if len(option) == 2 or option[2] == " ":
continue

View File

@ -1,10 +1,15 @@
import os
from typing import List
from typing import Optional
import py
from .exceptions import UsageError
from _pytest.outcomes import fail
if False:
from . import Config # noqa: F401
def exists(path, ignore=EnvironmentError):
try:
@ -102,7 +107,12 @@ def get_dirs_from_args(args):
CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead."
def determine_setup(inifile, args, rootdir_cmd_arg=None, config=None):
def determine_setup(
inifile: str,
args: List[str],
rootdir_cmd_arg: Optional[str] = None,
config: Optional["Config"] = None,
):
dirs = get_dirs_from_args(args)
if inifile:
iniconfig = py.iniconfig.IniConfig(inifile)

View File

@ -51,6 +51,8 @@ class MarkEvaluator:
except TEST_OUTCOME:
self.exc = sys.exc_info()
if isinstance(self.exc[1], SyntaxError):
# TODO: Investigate why SyntaxError.offset is Optional, and if it can be None here.
assert self.exc[1].offset is not None
msg = [" " * (self.exc[1].offset + 4) + "^"]
msg.append("SyntaxError: invalid syntax")
else:

View File

@ -292,7 +292,7 @@ class MarkGenerator:
_config = None
_markers = set() # type: Set[str]
def __getattr__(self, name):
def __getattr__(self, name: str) -> MarkDecorator:
if name[0] == "_":
raise AttributeError("Marker name must NOT start with underscore")

View File

@ -1,14 +1,26 @@
import os
import warnings
from functools import lru_cache
from typing import Any
from typing import Dict
from typing import List
from typing import Set
from typing import Tuple
from typing import Union
import py
import _pytest._code
from _pytest.compat import getfslineno
from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator
from _pytest.mark.structures import NodeKeywords
from _pytest.outcomes import fail
if False: # TYPE_CHECKING
# Imported here due to circular import.
from _pytest.fixtures import FixtureDef
SEP = "/"
tracebackcutdir = py.path.local(_pytest.__file__).dirpath()
@ -78,13 +90,13 @@ class Node:
self.keywords = NodeKeywords(self)
#: the marker objects belonging to this node
self.own_markers = []
self.own_markers = [] # type: List[Mark]
#: allow adding of extra keywords to use for matching
self.extra_keyword_matches = set()
self.extra_keyword_matches = set() # type: Set[str]
# used for storing artificial fixturedefs for direct parametrization
self._name2pseudofixturedef = {}
self._name2pseudofixturedef = {} # type: Dict[str, FixtureDef]
if nodeid is not None:
assert "::()" not in nodeid
@ -127,7 +139,8 @@ class Node:
)
)
path, lineno = get_fslocation_from_item(self)
warnings.warn_explicit(
# Type ignored: https://github.com/python/typeshed/pull/3121
warnings.warn_explicit( # type: ignore
warning,
category=None,
filename=str(path),
@ -160,7 +173,9 @@ class Node:
chain.reverse()
return chain
def add_marker(self, marker, append=True):
def add_marker(
self, marker: Union[str, MarkDecorator], append: bool = True
) -> None:
"""dynamically add a marker object to the node.
:type marker: ``str`` or ``pytest.mark.*`` object
@ -168,17 +183,19 @@ class Node:
``append=True`` whether to append the marker,
if ``False`` insert at position ``0``.
"""
from _pytest.mark import MarkDecorator, MARK_GEN
from _pytest.mark import MARK_GEN
if isinstance(marker, str):
marker = getattr(MARK_GEN, marker)
elif not isinstance(marker, MarkDecorator):
raise ValueError("is not a string or pytest.mark.* Marker")
self.keywords[marker.name] = marker
if append:
self.own_markers.append(marker.mark)
if isinstance(marker, MarkDecorator):
marker_ = marker
elif isinstance(marker, str):
marker_ = getattr(MARK_GEN, marker)
else:
self.own_markers.insert(0, marker.mark)
raise ValueError("is not a string or pytest.mark.* Marker")
self.keywords[marker_.name] = marker
if append:
self.own_markers.append(marker_.mark)
else:
self.own_markers.insert(0, marker_.mark)
def iter_markers(self, name=None):
"""
@ -211,7 +228,7 @@ class Node:
def listextrakeywords(self):
""" Return a set of all extra keywords in self and any parents."""
extra_keywords = set()
extra_keywords = set() # type: Set[str]
for item in self.listchain():
extra_keywords.update(item.extra_keyword_matches)
return extra_keywords
@ -239,7 +256,8 @@ class Node:
pass
def _repr_failure_py(self, excinfo, style=None):
if excinfo.errisinstance(fail.Exception):
# Type ignored: see comment where fail.Exception is defined.
if excinfo.errisinstance(fail.Exception): # type: ignore
if not excinfo.value.pytrace:
return str(excinfo.value)
fm = self.session._fixturemanager
@ -385,13 +403,13 @@ class Item(Node):
def __init__(self, name, parent=None, config=None, session=None, nodeid=None):
super().__init__(name, parent, config, session, nodeid=nodeid)
self._report_sections = []
self._report_sections = [] # type: List[Tuple[str, str, str]]
#: user properties is a list of tuples (name, value) that holds user
#: defined properties for this test.
self.user_properties = []
self.user_properties = [] # type: List[Tuple[str, Any]]
def add_report_section(self, when, key, content):
def add_report_section(self, when: str, key: str, content: str) -> None:
"""
Adds a new report section, similar to what's done internally to add stdout and
stderr captured output::

View File

@ -1,5 +1,6 @@
from pprint import pprint
from typing import Optional
from typing import Union
import py
@ -221,7 +222,6 @@ class BaseReport:
reprcrash = reportdict["longrepr"]["reprcrash"]
unserialized_entries = []
reprentry = None
for entry_data in reprtraceback["reprentries"]:
data = entry_data["data"]
entry_type = entry_data["type"]
@ -242,7 +242,7 @@ class BaseReport:
reprlocals=reprlocals,
filelocrepr=reprfileloc,
style=data["style"],
)
) # type: Union[ReprEntry, ReprEntryNative]
elif entry_type == "ReprEntryNative":
reprentry = ReprEntryNative(data["lines"])
else:
@ -352,7 +352,8 @@ class TestReport(BaseReport):
if not isinstance(excinfo, ExceptionInfo):
outcome = "failed"
longrepr = excinfo
elif excinfo.errisinstance(skip.Exception):
# Type ignored -- see comment where skip.Exception is defined.
elif excinfo.errisinstance(skip.Exception): # type: ignore
outcome = "skipped"
r = excinfo._getreprcrash()
longrepr = (str(r.path), r.lineno, r.message)

View File

@ -3,6 +3,10 @@ import bdb
import os
import sys
from time import time
from typing import Callable
from typing import Dict
from typing import List
from typing import Tuple
import attr
@ -10,10 +14,14 @@ from .reports import CollectErrorRepr
from .reports import CollectReport
from .reports import TestReport
from _pytest._code.code import ExceptionInfo
from _pytest.nodes import Node
from _pytest.outcomes import Exit
from _pytest.outcomes import Skipped
from _pytest.outcomes import TEST_OUTCOME
if False: # TYPE_CHECKING
from typing import Type
#
# pytest plugin hooks
@ -118,6 +126,7 @@ def pytest_runtest_call(item):
except Exception:
# Store trace info to allow postmortem debugging
type, value, tb = sys.exc_info()
assert tb is not None
tb = tb.tb_next # Skip *this* frame
sys.last_type = type
sys.last_value = value
@ -185,7 +194,7 @@ def check_interactive_exception(call, report):
def call_runtest_hook(item, when, **kwds):
hookname = "pytest_runtest_" + when
ihook = getattr(item.ihook, hookname)
reraise = (Exit,)
reraise = (Exit,) # type: Tuple[Type[BaseException], ...]
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
return CallInfo.from_call(
@ -252,7 +261,8 @@ def pytest_make_collect_report(collector):
skip_exceptions = [Skipped]
unittest = sys.modules.get("unittest")
if unittest is not None:
skip_exceptions.append(unittest.SkipTest)
# Type ignored because unittest is loaded dynamically.
skip_exceptions.append(unittest.SkipTest) # type: ignore
if call.excinfo.errisinstance(tuple(skip_exceptions)):
outcome = "skipped"
r = collector._repr_failure_py(call.excinfo, "line").reprcrash
@ -266,7 +276,7 @@ def pytest_make_collect_report(collector):
rep = CollectReport(
collector.nodeid, outcome, longrepr, getattr(call, "result", None)
)
rep.call = call # see collect_one_node
rep.call = call # type: ignore # see collect_one_node
return rep
@ -274,8 +284,8 @@ class SetupState:
""" shared state for setting up/tearing down test items or collectors. """
def __init__(self):
self.stack = []
self._finalizers = {}
self.stack = [] # type: List[Node]
self._finalizers = {} # type: Dict[Node, List[Callable[[], None]]]
def addfinalizer(self, finalizer, colitem):
""" attach a finalizer to the given colitem. """
@ -302,6 +312,7 @@ class SetupState:
exc = sys.exc_info()
if exc:
_, val, tb = exc
assert val is not None
raise val.with_traceback(tb)
def _teardown_with_finalization(self, colitem):
@ -335,6 +346,7 @@ class SetupState:
exc = sys.exc_info()
if exc:
_, val, tb = exc
assert val is not None
raise val.with_traceback(tb)
def prepare(self, colitem):