Fix some check_untyped_defs = True mypy warnings

This commit is contained in:
Ran Benita 2019-07-14 18:45:40 +03:00 committed by Ran Benita
parent 28761c8da1
commit 7259c453d6
13 changed files with 196 additions and 106 deletions

View File

@ -5,10 +5,15 @@ import traceback
from inspect import CO_VARARGS from inspect import CO_VARARGS
from inspect import CO_VARKEYWORDS from inspect import CO_VARKEYWORDS
from traceback import format_exception_only from traceback import format_exception_only
from types import CodeType
from types import TracebackType from types import TracebackType
from typing import Any
from typing import Dict
from typing import Generic from typing import Generic
from typing import List
from typing import Optional from typing import Optional
from typing import Pattern from typing import Pattern
from typing import Set
from typing import Tuple from typing import Tuple
from typing import TypeVar from typing import TypeVar
from typing import Union from typing import Union
@ -29,7 +34,7 @@ if False: # TYPE_CHECKING
class Code: class Code:
""" wrapper around Python code objects """ """ wrapper around Python code objects """
def __init__(self, rawcode): def __init__(self, rawcode) -> None:
if not hasattr(rawcode, "co_filename"): if not hasattr(rawcode, "co_filename"):
rawcode = getrawcode(rawcode) rawcode = getrawcode(rawcode)
try: try:
@ -38,7 +43,7 @@ class Code:
self.name = rawcode.co_name self.name = rawcode.co_name
except AttributeError: except AttributeError:
raise TypeError("not a code object: {!r}".format(rawcode)) raise TypeError("not a code object: {!r}".format(rawcode))
self.raw = rawcode self.raw = rawcode # type: CodeType
def __eq__(self, other): def __eq__(self, other):
return self.raw == other.raw return self.raw == other.raw
@ -351,7 +356,7 @@ class Traceback(list):
""" return the index of the frame/TracebackEntry where recursion """ return the index of the frame/TracebackEntry where recursion
originates if appropriate, None if no recursion occurred originates if appropriate, None if no recursion occurred
""" """
cache = {} cache = {} # type: Dict[Tuple[Any, int, int], List[Dict[str, Any]]]
for i, entry in enumerate(self): for i, entry in enumerate(self):
# id for the code.raw is needed to work around # id for the code.raw is needed to work around
# the strange metaprogramming in the decorator lib from pypi # the strange metaprogramming in the decorator lib from pypi
@ -650,7 +655,7 @@ class FormattedExcinfo:
args.append((argname, saferepr(argvalue))) args.append((argname, saferepr(argvalue)))
return ReprFuncArgs(args) return ReprFuncArgs(args)
def get_source(self, source, line_index=-1, excinfo=None, short=False): def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]:
""" return formatted and marked up source lines. """ """ return formatted and marked up source lines. """
import _pytest._code import _pytest._code
@ -722,7 +727,7 @@ class FormattedExcinfo:
else: else:
line_index = entry.lineno - entry.getfirstlinesource() line_index = entry.lineno - entry.getfirstlinesource()
lines = [] lines = [] # type: List[str]
style = entry._repr_style style = entry._repr_style
if style is None: if style is None:
style = self.style style = self.style
@ -799,7 +804,7 @@ class FormattedExcinfo:
exc_msg=str(e), exc_msg=str(e),
max_frames=max_frames, max_frames=max_frames,
total=len(traceback), total=len(traceback),
) ) # type: Optional[str]
traceback = traceback[:max_frames] + traceback[-max_frames:] traceback = traceback[:max_frames] + traceback[-max_frames:]
else: else:
if recursionindex is not None: if recursionindex is not None:
@ -812,10 +817,12 @@ class FormattedExcinfo:
def repr_excinfo(self, excinfo): def repr_excinfo(self, excinfo):
repr_chain = [] repr_chain = (
[]
) # type: List[Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]]
e = excinfo.value e = excinfo.value
descr = None descr = None
seen = set() seen = set() # type: Set[int]
while e is not None and id(e) not in seen: while e is not None and id(e) not in seen:
seen.add(id(e)) seen.add(id(e))
if excinfo: if excinfo:
@ -868,8 +875,8 @@ class TerminalRepr:
class ExceptionRepr(TerminalRepr): class ExceptionRepr(TerminalRepr):
def __init__(self): def __init__(self) -> None:
self.sections = [] self.sections = [] # type: List[Tuple[str, str, str]]
def addsection(self, name, content, sep="-"): def addsection(self, name, content, sep="-"):
self.sections.append((name, content, sep)) self.sections.append((name, content, sep))

View File

@ -7,6 +7,7 @@ import tokenize
import warnings import warnings
from ast import PyCF_ONLY_AST as _AST_FLAG from ast import PyCF_ONLY_AST as _AST_FLAG
from bisect import bisect_right from bisect import bisect_right
from typing import List
import py import py
@ -19,11 +20,11 @@ class Source:
_compilecounter = 0 _compilecounter = 0
def __init__(self, *parts, **kwargs): def __init__(self, *parts, **kwargs):
self.lines = lines = [] self.lines = lines = [] # type: List[str]
de = kwargs.get("deindent", True) de = kwargs.get("deindent", True)
for part in parts: for part in parts:
if not part: if not part:
partlines = [] partlines = [] # type: List[str]
elif isinstance(part, Source): elif isinstance(part, Source):
partlines = part.lines partlines = part.lines
elif isinstance(part, (tuple, list)): elif isinstance(part, (tuple, list)):
@ -157,8 +158,7 @@ class Source:
source = "\n".join(self.lines) + "\n" source = "\n".join(self.lines) + "\n"
try: try:
co = compile(source, filename, mode, flag) co = compile(source, filename, mode, flag)
except SyntaxError: except SyntaxError as ex:
ex = sys.exc_info()[1]
# re-represent syntax errors from parsing python strings # re-represent syntax errors from parsing python strings
msglines = self.lines[: ex.lineno] msglines = self.lines[: ex.lineno]
if ex.offset: if ex.offset:
@ -173,7 +173,8 @@ class Source:
if flag & _AST_FLAG: if flag & _AST_FLAG:
return co return co
lines = [(x + "\n") for x in self.lines] lines = [(x + "\n") for x in self.lines]
linecache.cache[filename] = (1, None, lines, filename) # Type ignored because linecache.cache is private.
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
return co return co
@ -282,7 +283,7 @@ def get_statement_startend2(lineno, node):
return start, end return start, end
def getstatementrange_ast(lineno, source, assertion=False, astnode=None): def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None):
if astnode is None: if astnode is None:
content = str(source) content = str(source)
# See #4260: # See #4260:

View File

@ -2,6 +2,7 @@
support for presenting detailed information in failing assertions. support for presenting detailed information in failing assertions.
""" """
import sys import sys
from typing import Optional
from _pytest.assertion import rewrite from _pytest.assertion import rewrite
from _pytest.assertion import truncate from _pytest.assertion import truncate
@ -52,7 +53,9 @@ def register_assert_rewrite(*names):
importhook = hook importhook = hook
break break
else: else:
importhook = DummyRewriteHook() # TODO(typing): Add a protocol for mark_rewrite() and use it
# for importhook and for PytestPluginManager.rewrite_hook.
importhook = DummyRewriteHook() # type: ignore
importhook.mark_rewrite(*names) importhook.mark_rewrite(*names)
@ -69,7 +72,7 @@ class AssertionState:
def __init__(self, config, mode): def __init__(self, config, mode):
self.mode = mode self.mode = mode
self.trace = config.trace.root.get("assertion") self.trace = config.trace.root.get("assertion")
self.hook = None self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
def install_importhook(config): def install_importhook(config):
@ -108,6 +111,7 @@ def pytest_runtest_setup(item):
""" """
def callbinrepr(op, left, right): def callbinrepr(op, left, right):
# type: (str, object, object) -> Optional[str]
"""Call the pytest_assertrepr_compare hook and prepare the result """Call the pytest_assertrepr_compare hook and prepare the result
This uses the first result from the hook and then ensures the This uses the first result from the hook and then ensures the
@ -133,12 +137,13 @@ def pytest_runtest_setup(item):
if item.config.getvalue("assertmode") == "rewrite": if item.config.getvalue("assertmode") == "rewrite":
res = res.replace("%", "%%") res = res.replace("%", "%%")
return res return res
return None
util._reprcompare = callbinrepr util._reprcompare = callbinrepr
if item.ihook.pytest_assertion_pass.get_hookimpls(): if item.ihook.pytest_assertion_pass.get_hookimpls():
def call_assertion_pass_hook(lineno, expl, orig): def call_assertion_pass_hook(lineno, orig, expl):
item.ihook.pytest_assertion_pass( item.ihook.pytest_assertion_pass(
item=item, lineno=lineno, orig=orig, expl=expl item=item, lineno=lineno, orig=orig, expl=expl
) )

View File

@ -17,6 +17,7 @@ from typing import Dict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Set from typing import Set
from typing import Tuple
import atomicwrites import atomicwrites
@ -48,13 +49,13 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder):
except ValueError: except ValueError:
self.fnpats = ["test_*.py", "*_test.py"] self.fnpats = ["test_*.py", "*_test.py"]
self.session = None self.session = None
self._rewritten_names = set() self._rewritten_names = set() # type: Set[str]
self._must_rewrite = set() self._must_rewrite = set() # type: Set[str]
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506) # which might result in infinite recursion (#3506)
self._writing_pyc = False self._writing_pyc = False
self._basenames_to_check_rewrite = {"conftest"} self._basenames_to_check_rewrite = {"conftest"}
self._marked_for_rewrite_cache = {} self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
self._session_paths_checked = False self._session_paths_checked = False
def set_session(self, session): def set_session(self, session):
@ -203,7 +204,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder):
return self._is_marked_for_rewrite(name, state) return self._is_marked_for_rewrite(name, state)
def _is_marked_for_rewrite(self, name, state): def _is_marked_for_rewrite(self, name: str, state):
try: try:
return self._marked_for_rewrite_cache[name] return self._marked_for_rewrite_cache[name]
except KeyError: except KeyError:
@ -218,7 +219,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder):
self._marked_for_rewrite_cache[name] = False self._marked_for_rewrite_cache[name] = False
return False return False
def mark_rewrite(self, *names): def mark_rewrite(self, *names: str) -> None:
"""Mark import names as needing to be rewritten. """Mark import names as needing to be rewritten.
The named module or package as well as any nested modules will The named module or package as well as any nested modules will
@ -385,6 +386,7 @@ def _format_boolop(explanations, is_or):
def _call_reprcompare(ops, results, expls, each_obj): def _call_reprcompare(ops, results, expls, each_obj):
# type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
for i, res, expl in zip(range(len(ops)), results, expls): for i, res, expl in zip(range(len(ops)), results, expls):
try: try:
done = not res done = not res
@ -400,11 +402,13 @@ def _call_reprcompare(ops, results, expls, each_obj):
def _call_assertion_pass(lineno, orig, expl): def _call_assertion_pass(lineno, orig, expl):
# type: (int, str, str) -> None
if util._assertion_pass is not None: if util._assertion_pass is not None:
util._assertion_pass(lineno=lineno, orig=orig, expl=expl) util._assertion_pass(lineno, orig, expl)
def _check_if_assertion_pass_impl(): def _check_if_assertion_pass_impl():
# type: () -> bool
"""Checks if any plugins implement the pytest_assertion_pass hook """Checks if any plugins implement the pytest_assertion_pass hook
in order not to generate explanation unecessarily (might be expensive)""" in order not to generate explanation unecessarily (might be expensive)"""
return True if util._assertion_pass else False return True if util._assertion_pass else False
@ -578,7 +582,7 @@ class AssertionRewriter(ast.NodeVisitor):
def _assert_expr_to_lineno(self): def _assert_expr_to_lineno(self):
return _get_assertion_exprs(self.source) return _get_assertion_exprs(self.source)
def run(self, mod): def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them.""" """Find all assert statements in *mod* and rewrite them."""
if not mod.body: if not mod.body:
# Nothing to do. # Nothing to do.
@ -620,12 +624,12 @@ class AssertionRewriter(ast.NodeVisitor):
] ]
mod.body[pos:pos] = imports mod.body[pos:pos] = imports
# Collect asserts. # Collect asserts.
nodes = [mod] nodes = [mod] # type: List[ast.AST]
while nodes: while nodes:
node = nodes.pop() node = nodes.pop()
for name, field in ast.iter_fields(node): for name, field in ast.iter_fields(node):
if isinstance(field, list): if isinstance(field, list):
new = [] new = [] # type: List
for i, child in enumerate(field): for i, child in enumerate(field):
if isinstance(child, ast.Assert): if isinstance(child, ast.Assert):
# Transform assert. # Transform assert.
@ -699,7 +703,7 @@ class AssertionRewriter(ast.NodeVisitor):
.explanation_param(). .explanation_param().
""" """
self.explanation_specifiers = {} self.explanation_specifiers = {} # type: Dict[str, ast.expr]
self.stack.append(self.explanation_specifiers) self.stack.append(self.explanation_specifiers)
def pop_format_context(self, expl_expr): def pop_format_context(self, expl_expr):
@ -742,7 +746,8 @@ class AssertionRewriter(ast.NodeVisitor):
from _pytest.warning_types import PytestAssertRewriteWarning from _pytest.warning_types import PytestAssertRewriteWarning
import warnings import warnings
warnings.warn_explicit( # Ignore type: typeshed bug https://github.com/python/typeshed/pull/3121
warnings.warn_explicit( # type: ignore
PytestAssertRewriteWarning( PytestAssertRewriteWarning(
"assertion is always true, perhaps remove parentheses?" "assertion is always true, perhaps remove parentheses?"
), ),
@ -751,15 +756,15 @@ class AssertionRewriter(ast.NodeVisitor):
lineno=assert_.lineno, lineno=assert_.lineno,
) )
self.statements = [] self.statements = [] # type: List[ast.stmt]
self.variables = [] self.variables = [] # type: List[str]
self.variable_counter = itertools.count() self.variable_counter = itertools.count()
if self.enable_assertion_pass_hook: if self.enable_assertion_pass_hook:
self.format_variables = [] self.format_variables = [] # type: List[str]
self.stack = [] self.stack = [] # type: List[Dict[str, ast.expr]]
self.expl_stmts = [] self.expl_stmts = [] # type: List[ast.stmt]
self.push_format_context() self.push_format_context()
# Rewrite assert into a bunch of statements. # Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test) top_condition, explanation = self.visit(assert_.test)
@ -897,7 +902,7 @@ warn_explicit(
# Process each operand, short-circuiting if needed. # Process each operand, short-circuiting if needed.
for i, v in enumerate(boolop.values): for i, v in enumerate(boolop.values):
if i: if i:
fail_inner = [] fail_inner = [] # type: List[ast.stmt]
# cond is set in a prior loop iteration below # cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner self.expl_stmts = fail_inner
@ -908,10 +913,10 @@ warn_explicit(
call = ast.Call(app, [expl_format], []) call = ast.Call(app, [expl_format], [])
self.expl_stmts.append(ast.Expr(call)) self.expl_stmts.append(ast.Expr(call))
if i < levels: if i < levels:
cond = res cond = res # type: ast.expr
if is_or: if is_or:
cond = ast.UnaryOp(ast.Not(), cond) cond = ast.UnaryOp(ast.Not(), cond)
inner = [] inner = [] # type: List[ast.stmt]
self.statements.append(ast.If(cond, inner, [])) self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner self.statements = body = inner
self.statements = save self.statements = save
@ -977,7 +982,7 @@ warn_explicit(
expl = pat % (res_expl, res_expl, value_expl, attr.attr) expl = pat % (res_expl, res_expl, value_expl, attr.attr)
return res, expl return res, expl
def visit_Compare(self, comp): def visit_Compare(self, comp: ast.Compare):
self.push_format_context() self.push_format_context()
left_res, left_expl = self.visit(comp.left) left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)): if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
@ -1010,7 +1015,7 @@ warn_explicit(
ast.Tuple(results, ast.Load()), ast.Tuple(results, ast.Load()),
) )
if len(comp.ops) > 1: if len(comp.ops) > 1:
res = ast.BoolOp(ast.And(), load_names) res = ast.BoolOp(ast.And(), load_names) # type: ast.expr
else: else:
res = load_names[0] res = load_names[0]
return res, self.explanation_param(self.pop_format_context(expl_call)) return res, self.explanation_param(self.pop_format_context(expl_call))

View File

@ -1,6 +1,9 @@
"""Utilities for assertion debugging""" """Utilities for assertion debugging"""
import pprint import pprint
from collections.abc import Sequence from collections.abc import Sequence
from typing import Callable
from typing import List
from typing import Optional
import _pytest._code import _pytest._code
from _pytest import outcomes from _pytest import outcomes
@ -10,11 +13,11 @@ from _pytest._io.saferepr import saferepr
# interpretation code and assertion rewriter to detect this plugin was # interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the # loaded and in turn call the hooks defined here as part of the
# DebugInterpreter. # DebugInterpreter.
_reprcompare = None _reprcompare = None # type: Optional[Callable[[str, object, object], Optional[str]]]
# Works similarly as _reprcompare attribute. Is populated with the hook call # Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called. # when pytest_runtest_setup is called.
_assertion_pass = None _assertion_pass = None # type: Optional[Callable[[int, str, str], None]]
def format_explanation(explanation): def format_explanation(explanation):
@ -177,7 +180,7 @@ def _diff_text(left, right, verbose=0):
""" """
from difflib import ndiff from difflib import ndiff
explanation = [] explanation = [] # type: List[str]
def escape_for_readable_diff(binary_text): def escape_for_readable_diff(binary_text):
""" """
@ -235,7 +238,7 @@ def _compare_eq_verbose(left, right):
left_lines = repr(left).splitlines(keepends) left_lines = repr(left).splitlines(keepends)
right_lines = repr(right).splitlines(keepends) right_lines = repr(right).splitlines(keepends)
explanation = [] explanation = [] # type: List[str]
explanation += ["-" + line for line in left_lines] explanation += ["-" + line for line in left_lines]
explanation += ["+" + line for line in right_lines] explanation += ["+" + line for line in right_lines]
@ -259,7 +262,7 @@ def _compare_eq_iterable(left, right, verbose=0):
def _compare_eq_sequence(left, right, verbose=0): def _compare_eq_sequence(left, right, verbose=0):
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes) comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
explanation = [] explanation = [] # type: List[str]
len_left = len(left) len_left = len(left)
len_right = len(right) len_right = len(right)
for i in range(min(len_left, len_right)): for i in range(min(len_left, len_right)):
@ -327,7 +330,7 @@ def _compare_eq_set(left, right, verbose=0):
def _compare_eq_dict(left, right, verbose=0): def _compare_eq_dict(left, right, verbose=0):
explanation = [] explanation = [] # type: List[str]
set_left = set(left) set_left = set(left)
set_right = set(right) set_right = set(right)
common = set_left.intersection(set_right) common = set_left.intersection(set_right)

View File

@ -9,6 +9,15 @@ import types
import warnings import warnings
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
import attr import attr
import py import py
@ -32,6 +41,10 @@ from _pytest.outcomes import fail
from _pytest.outcomes import Skipped from _pytest.outcomes import Skipped
from _pytest.warning_types import PytestConfigWarning from _pytest.warning_types import PytestConfigWarning
if False: # TYPE_CHECKING
from typing import Type
hookimpl = HookimplMarker("pytest") hookimpl = HookimplMarker("pytest")
hookspec = HookspecMarker("pytest") hookspec = HookspecMarker("pytest")
@ -40,7 +53,7 @@ class ConftestImportFailure(Exception):
def __init__(self, path, excinfo): def __init__(self, path, excinfo):
Exception.__init__(self, path, excinfo) Exception.__init__(self, path, excinfo)
self.path = path self.path = path
self.excinfo = excinfo self.excinfo = excinfo # type: Tuple[Type[Exception], Exception, TracebackType]
def main(args=None, plugins=None): def main(args=None, plugins=None):
@ -237,14 +250,18 @@ class PytestPluginManager(PluginManager):
def __init__(self): def __init__(self):
super().__init__("pytest") super().__init__("pytest")
self._conftest_plugins = set() # The objects are module objects, only used generically.
self._conftest_plugins = set() # type: Set[object]
# state related to local conftest plugins # state related to local conftest plugins
self._dirpath2confmods = {} # Maps a py.path.local to a list of module objects.
self._conftestpath2mod = {} self._dirpath2confmods = {} # type: Dict[Any, List[object]]
# Maps a py.path.local to a module object.
self._conftestpath2mod = {} # type: Dict[Any, object]
self._confcutdir = None self._confcutdir = None
self._noconftest = False self._noconftest = False
self._duplicatepaths = set() # Set of py.path.local's.
self._duplicatepaths = set() # type: Set[Any]
self.add_hookspecs(_pytest.hookspec) self.add_hookspecs(_pytest.hookspec)
self.register(self) self.register(self)
@ -653,7 +670,7 @@ class Config:
args = attr.ib() args = attr.ib()
plugins = attr.ib() plugins = attr.ib()
dir = attr.ib() dir = attr.ib(type=Path)
def __init__(self, pluginmanager, *, invocation_params=None): def __init__(self, pluginmanager, *, invocation_params=None):
from .argparsing import Parser, FILE_OR_DIR from .argparsing import Parser, FILE_OR_DIR
@ -674,10 +691,10 @@ class Config:
self.pluginmanager = pluginmanager self.pluginmanager = pluginmanager
self.trace = self.pluginmanager.trace.root.get("config") self.trace = self.pluginmanager.trace.root.get("config")
self.hook = self.pluginmanager.hook self.hook = self.pluginmanager.hook
self._inicache = {} self._inicache = {} # type: Dict[str, Any]
self._override_ini = () self._override_ini = () # type: Sequence[str]
self._opt2dest = {} self._opt2dest = {} # type: Dict[str, str]
self._cleanup = [] self._cleanup = [] # type: List[Callable[[], None]]
self.pluginmanager.register(self, "pytestconfig") self.pluginmanager.register(self, "pytestconfig")
self._configured = False self._configured = False
self.hook.pytest_addoption.call_historic(kwargs=dict(parser=self._parser)) self.hook.pytest_addoption.call_historic(kwargs=dict(parser=self._parser))
@ -778,7 +795,7 @@ class Config:
def pytest_load_initial_conftests(self, early_config): def pytest_load_initial_conftests(self, early_config):
self.pluginmanager._set_initial_conftests(early_config.known_args_namespace) self.pluginmanager._set_initial_conftests(early_config.known_args_namespace)
def _initini(self, args): def _initini(self, args) -> None:
ns, unknown_args = self._parser.parse_known_and_unknown_args( ns, unknown_args = self._parser.parse_known_and_unknown_args(
args, namespace=copy.copy(self.option) args, namespace=copy.copy(self.option)
) )
@ -879,8 +896,7 @@ class Config:
self.hook.pytest_load_initial_conftests( self.hook.pytest_load_initial_conftests(
early_config=self, args=args, parser=self._parser early_config=self, args=args, parser=self._parser
) )
except ConftestImportFailure: except ConftestImportFailure as e:
e = sys.exc_info()[1]
if ns.help or ns.version: if ns.help or ns.version:
# we don't want to prevent --help/--version to work # we don't want to prevent --help/--version to work
# so just let is pass and print a warning at the end # so just let is pass and print a warning at the end
@ -946,7 +962,7 @@ class Config:
assert isinstance(x, list) assert isinstance(x, list)
x.append(line) # modifies the cached list inline x.append(line) # modifies the cached list inline
def getini(self, name): def getini(self, name: str):
""" return configuration value from an :ref:`ini file <inifiles>`. If the """ return configuration value from an :ref:`ini file <inifiles>`. If the
specified name hasn't been registered through a prior specified name hasn't been registered through a prior
:py:func:`parser.addini <_pytest.config.Parser.addini>` :py:func:`parser.addini <_pytest.config.Parser.addini>`
@ -957,7 +973,7 @@ class Config:
self._inicache[name] = val = self._getini(name) self._inicache[name] = val = self._getini(name)
return val return val
def _getini(self, name): def _getini(self, name: str) -> Any:
try: try:
description, type, default = self._parser._inidict[name] description, type, default = self._parser._inidict[name]
except KeyError: except KeyError:
@ -1002,7 +1018,7 @@ class Config:
values.append(relroot) values.append(relroot)
return values return values
def _get_override_ini_value(self, name): def _get_override_ini_value(self, name: str) -> Optional[str]:
value = None value = None
# override_ini is a list of "ini=value" options # override_ini is a list of "ini=value" options
# always use the last item if multiple values are set for same ini-name, # always use the last item if multiple values are set for same ini-name,
@ -1017,7 +1033,7 @@ class Config:
value = user_ini_value value = user_ini_value
return value return value
def getoption(self, name, default=notset, skip=False): def getoption(self, name: str, default=notset, skip: bool = False):
""" return command line option value. """ return command line option value.
:arg name: name of the option. You may also specify :arg name: name of the option. You may also specify

View File

@ -2,6 +2,11 @@ import argparse
import sys import sys
import warnings import warnings
from gettext import gettext from gettext import gettext
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import py import py
@ -21,12 +26,12 @@ class Parser:
def __init__(self, usage=None, processopt=None): def __init__(self, usage=None, processopt=None):
self._anonymous = OptionGroup("custom options", parser=self) self._anonymous = OptionGroup("custom options", parser=self)
self._groups = [] self._groups = [] # type: List[OptionGroup]
self._processopt = processopt self._processopt = processopt
self._usage = usage self._usage = usage
self._inidict = {} self._inidict = {} # type: Dict[str, Tuple[str, Optional[str], Any]]
self._ininames = [] self._ininames = [] # type: List[str]
self.extra_info = {} self.extra_info = {} # type: Dict[str, Any]
def processoption(self, option): def processoption(self, option):
if self._processopt: if self._processopt:
@ -80,7 +85,7 @@ class Parser:
args = [str(x) if isinstance(x, py.path.local) else x for x in args] args = [str(x) if isinstance(x, py.path.local) else x for x in args]
return self.optparser.parse_args(args, namespace=namespace) return self.optparser.parse_args(args, namespace=namespace)
def _getparser(self): def _getparser(self) -> "MyOptionParser":
from _pytest._argcomplete import filescompleter from _pytest._argcomplete import filescompleter
optparser = MyOptionParser(self, self.extra_info, prog=self.prog) optparser = MyOptionParser(self, self.extra_info, prog=self.prog)
@ -94,7 +99,10 @@ class Parser:
a = option.attrs() a = option.attrs()
arggroup.add_argument(*n, **a) arggroup.add_argument(*n, **a)
# bash like autocompletion for dirs (appending '/') # bash like autocompletion for dirs (appending '/')
optparser.add_argument(FILE_OR_DIR, nargs="*").completer = filescompleter # Type ignored because typeshed doesn't know about argcomplete.
optparser.add_argument( # type: ignore
FILE_OR_DIR, nargs="*"
).completer = filescompleter
return optparser return optparser
def parse_setoption(self, args, option, namespace=None): def parse_setoption(self, args, option, namespace=None):
@ -103,13 +111,15 @@ class Parser:
setattr(option, name, value) setattr(option, name, value)
return getattr(parsedoption, FILE_OR_DIR) return getattr(parsedoption, FILE_OR_DIR)
def parse_known_args(self, args, namespace=None): def parse_known_args(self, args, namespace=None) -> argparse.Namespace:
"""parses and returns a namespace object with known arguments at this """parses and returns a namespace object with known arguments at this
point. point.
""" """
return self.parse_known_and_unknown_args(args, namespace=namespace)[0] return self.parse_known_and_unknown_args(args, namespace=namespace)[0]
def parse_known_and_unknown_args(self, args, namespace=None): def parse_known_and_unknown_args(
self, args, namespace=None
) -> Tuple[argparse.Namespace, List[str]]:
"""parses and returns a namespace object with known arguments, and """parses and returns a namespace object with known arguments, and
the remaining arguments unknown at this point. the remaining arguments unknown at this point.
""" """
@ -163,8 +173,8 @@ class Argument:
def __init__(self, *names, **attrs): def __init__(self, *names, **attrs):
"""store parms in private vars for use in add_argument""" """store parms in private vars for use in add_argument"""
self._attrs = attrs self._attrs = attrs
self._short_opts = [] self._short_opts = [] # type: List[str]
self._long_opts = [] self._long_opts = [] # type: List[str]
self.dest = attrs.get("dest") self.dest = attrs.get("dest")
if "%default" in (attrs.get("help") or ""): if "%default" in (attrs.get("help") or ""):
warnings.warn( warnings.warn(
@ -268,8 +278,8 @@ class Argument:
) )
self._long_opts.append(opt) self._long_opts.append(opt)
def __repr__(self): def __repr__(self) -> str:
args = [] args = [] # type: List[str]
if self._short_opts: if self._short_opts:
args += ["_short_opts: " + repr(self._short_opts)] args += ["_short_opts: " + repr(self._short_opts)]
if self._long_opts: if self._long_opts:
@ -286,7 +296,7 @@ class OptionGroup:
def __init__(self, name, description="", parser=None): def __init__(self, name, description="", parser=None):
self.name = name self.name = name
self.description = description self.description = description
self.options = [] self.options = [] # type: List[Argument]
self.parser = parser self.parser = parser
def addoption(self, *optnames, **attrs): def addoption(self, *optnames, **attrs):
@ -421,7 +431,7 @@ class DropShorterLongHelpFormatter(argparse.HelpFormatter):
option_map = getattr(action, "map_long_option", {}) option_map = getattr(action, "map_long_option", {})
if option_map is None: if option_map is None:
option_map = {} option_map = {}
short_long = {} short_long = {} # type: Dict[str, str]
for option in options: for option in options:
if len(option) == 2 or option[2] == " ": if len(option) == 2 or option[2] == " ":
continue continue

View File

@ -1,10 +1,15 @@
import os import os
from typing import List
from typing import Optional
import py import py
from .exceptions import UsageError from .exceptions import UsageError
from _pytest.outcomes import fail from _pytest.outcomes import fail
if False:
from . import Config # noqa: F401
def exists(path, ignore=EnvironmentError): def exists(path, ignore=EnvironmentError):
try: try:
@ -102,7 +107,12 @@ def get_dirs_from_args(args):
CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead." CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead."
def determine_setup(inifile, args, rootdir_cmd_arg=None, config=None): def determine_setup(
inifile: str,
args: List[str],
rootdir_cmd_arg: Optional[str] = None,
config: Optional["Config"] = None,
):
dirs = get_dirs_from_args(args) dirs = get_dirs_from_args(args)
if inifile: if inifile:
iniconfig = py.iniconfig.IniConfig(inifile) iniconfig = py.iniconfig.IniConfig(inifile)

View File

@ -51,6 +51,8 @@ class MarkEvaluator:
except TEST_OUTCOME: except TEST_OUTCOME:
self.exc = sys.exc_info() self.exc = sys.exc_info()
if isinstance(self.exc[1], SyntaxError): if isinstance(self.exc[1], SyntaxError):
# TODO: Investigate why SyntaxError.offset is Optional, and if it can be None here.
assert self.exc[1].offset is not None
msg = [" " * (self.exc[1].offset + 4) + "^"] msg = [" " * (self.exc[1].offset + 4) + "^"]
msg.append("SyntaxError: invalid syntax") msg.append("SyntaxError: invalid syntax")
else: else:

View File

@ -292,7 +292,7 @@ class MarkGenerator:
_config = None _config = None
_markers = set() # type: Set[str] _markers = set() # type: Set[str]
def __getattr__(self, name): def __getattr__(self, name: str) -> MarkDecorator:
if name[0] == "_": if name[0] == "_":
raise AttributeError("Marker name must NOT start with underscore") raise AttributeError("Marker name must NOT start with underscore")

View File

@ -1,14 +1,26 @@
import os import os
import warnings import warnings
from functools import lru_cache from functools import lru_cache
from typing import Any
from typing import Dict
from typing import List
from typing import Set
from typing import Tuple
from typing import Union
import py import py
import _pytest._code import _pytest._code
from _pytest.compat import getfslineno from _pytest.compat import getfslineno
from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator
from _pytest.mark.structures import NodeKeywords from _pytest.mark.structures import NodeKeywords
from _pytest.outcomes import fail from _pytest.outcomes import fail
if False: # TYPE_CHECKING
# Imported here due to circular import.
from _pytest.fixtures import FixtureDef
SEP = "/" SEP = "/"
tracebackcutdir = py.path.local(_pytest.__file__).dirpath() tracebackcutdir = py.path.local(_pytest.__file__).dirpath()
@ -78,13 +90,13 @@ class Node:
self.keywords = NodeKeywords(self) self.keywords = NodeKeywords(self)
#: the marker objects belonging to this node #: the marker objects belonging to this node
self.own_markers = [] self.own_markers = [] # type: List[Mark]
#: allow adding of extra keywords to use for matching #: allow adding of extra keywords to use for matching
self.extra_keyword_matches = set() self.extra_keyword_matches = set() # type: Set[str]
# used for storing artificial fixturedefs for direct parametrization # used for storing artificial fixturedefs for direct parametrization
self._name2pseudofixturedef = {} self._name2pseudofixturedef = {} # type: Dict[str, FixtureDef]
if nodeid is not None: if nodeid is not None:
assert "::()" not in nodeid assert "::()" not in nodeid
@ -127,7 +139,8 @@ class Node:
) )
) )
path, lineno = get_fslocation_from_item(self) path, lineno = get_fslocation_from_item(self)
warnings.warn_explicit( # Type ignored: https://github.com/python/typeshed/pull/3121
warnings.warn_explicit( # type: ignore
warning, warning,
category=None, category=None,
filename=str(path), filename=str(path),
@ -160,7 +173,9 @@ class Node:
chain.reverse() chain.reverse()
return chain return chain
def add_marker(self, marker, append=True): def add_marker(
self, marker: Union[str, MarkDecorator], append: bool = True
) -> None:
"""dynamically add a marker object to the node. """dynamically add a marker object to the node.
:type marker: ``str`` or ``pytest.mark.*`` object :type marker: ``str`` or ``pytest.mark.*`` object
@ -168,17 +183,19 @@ class Node:
``append=True`` whether to append the marker, ``append=True`` whether to append the marker,
if ``False`` insert at position ``0``. if ``False`` insert at position ``0``.
""" """
from _pytest.mark import MarkDecorator, MARK_GEN from _pytest.mark import MARK_GEN
if isinstance(marker, str): if isinstance(marker, MarkDecorator):
marker = getattr(MARK_GEN, marker) marker_ = marker
elif not isinstance(marker, MarkDecorator): elif isinstance(marker, str):
raise ValueError("is not a string or pytest.mark.* Marker") marker_ = getattr(MARK_GEN, marker)
self.keywords[marker.name] = marker
if append:
self.own_markers.append(marker.mark)
else: else:
self.own_markers.insert(0, marker.mark) raise ValueError("is not a string or pytest.mark.* Marker")
self.keywords[marker_.name] = marker
if append:
self.own_markers.append(marker_.mark)
else:
self.own_markers.insert(0, marker_.mark)
def iter_markers(self, name=None): def iter_markers(self, name=None):
""" """
@ -211,7 +228,7 @@ class Node:
def listextrakeywords(self): def listextrakeywords(self):
""" Return a set of all extra keywords in self and any parents.""" """ Return a set of all extra keywords in self and any parents."""
extra_keywords = set() extra_keywords = set() # type: Set[str]
for item in self.listchain(): for item in self.listchain():
extra_keywords.update(item.extra_keyword_matches) extra_keywords.update(item.extra_keyword_matches)
return extra_keywords return extra_keywords
@ -239,7 +256,8 @@ class Node:
pass pass
def _repr_failure_py(self, excinfo, style=None): def _repr_failure_py(self, excinfo, style=None):
if excinfo.errisinstance(fail.Exception): # Type ignored: see comment where fail.Exception is defined.
if excinfo.errisinstance(fail.Exception): # type: ignore
if not excinfo.value.pytrace: if not excinfo.value.pytrace:
return str(excinfo.value) return str(excinfo.value)
fm = self.session._fixturemanager fm = self.session._fixturemanager
@ -385,13 +403,13 @@ class Item(Node):
def __init__(self, name, parent=None, config=None, session=None, nodeid=None): def __init__(self, name, parent=None, config=None, session=None, nodeid=None):
super().__init__(name, parent, config, session, nodeid=nodeid) super().__init__(name, parent, config, session, nodeid=nodeid)
self._report_sections = [] self._report_sections = [] # type: List[Tuple[str, str, str]]
#: user properties is a list of tuples (name, value) that holds user #: user properties is a list of tuples (name, value) that holds user
#: defined properties for this test. #: defined properties for this test.
self.user_properties = [] self.user_properties = [] # type: List[Tuple[str, Any]]
def add_report_section(self, when, key, content): def add_report_section(self, when: str, key: str, content: str) -> None:
""" """
Adds a new report section, similar to what's done internally to add stdout and Adds a new report section, similar to what's done internally to add stdout and
stderr captured output:: stderr captured output::

View File

@ -1,5 +1,6 @@
from pprint import pprint from pprint import pprint
from typing import Optional from typing import Optional
from typing import Union
import py import py
@ -221,7 +222,6 @@ class BaseReport:
reprcrash = reportdict["longrepr"]["reprcrash"] reprcrash = reportdict["longrepr"]["reprcrash"]
unserialized_entries = [] unserialized_entries = []
reprentry = None
for entry_data in reprtraceback["reprentries"]: for entry_data in reprtraceback["reprentries"]:
data = entry_data["data"] data = entry_data["data"]
entry_type = entry_data["type"] entry_type = entry_data["type"]
@ -242,7 +242,7 @@ class BaseReport:
reprlocals=reprlocals, reprlocals=reprlocals,
filelocrepr=reprfileloc, filelocrepr=reprfileloc,
style=data["style"], style=data["style"],
) ) # type: Union[ReprEntry, ReprEntryNative]
elif entry_type == "ReprEntryNative": elif entry_type == "ReprEntryNative":
reprentry = ReprEntryNative(data["lines"]) reprentry = ReprEntryNative(data["lines"])
else: else:
@ -352,7 +352,8 @@ class TestReport(BaseReport):
if not isinstance(excinfo, ExceptionInfo): if not isinstance(excinfo, ExceptionInfo):
outcome = "failed" outcome = "failed"
longrepr = excinfo longrepr = excinfo
elif excinfo.errisinstance(skip.Exception): # Type ignored -- see comment where skip.Exception is defined.
elif excinfo.errisinstance(skip.Exception): # type: ignore
outcome = "skipped" outcome = "skipped"
r = excinfo._getreprcrash() r = excinfo._getreprcrash()
longrepr = (str(r.path), r.lineno, r.message) longrepr = (str(r.path), r.lineno, r.message)

View File

@ -3,6 +3,10 @@ import bdb
import os import os
import sys import sys
from time import time from time import time
from typing import Callable
from typing import Dict
from typing import List
from typing import Tuple
import attr import attr
@ -10,10 +14,14 @@ from .reports import CollectErrorRepr
from .reports import CollectReport from .reports import CollectReport
from .reports import TestReport from .reports import TestReport
from _pytest._code.code import ExceptionInfo from _pytest._code.code import ExceptionInfo
from _pytest.nodes import Node
from _pytest.outcomes import Exit from _pytest.outcomes import Exit
from _pytest.outcomes import Skipped from _pytest.outcomes import Skipped
from _pytest.outcomes import TEST_OUTCOME from _pytest.outcomes import TEST_OUTCOME
if False: # TYPE_CHECKING
from typing import Type
# #
# pytest plugin hooks # pytest plugin hooks
@ -118,6 +126,7 @@ def pytest_runtest_call(item):
except Exception: except Exception:
# Store trace info to allow postmortem debugging # Store trace info to allow postmortem debugging
type, value, tb = sys.exc_info() type, value, tb = sys.exc_info()
assert tb is not None
tb = tb.tb_next # Skip *this* frame tb = tb.tb_next # Skip *this* frame
sys.last_type = type sys.last_type = type
sys.last_value = value sys.last_value = value
@ -185,7 +194,7 @@ def check_interactive_exception(call, report):
def call_runtest_hook(item, when, **kwds): def call_runtest_hook(item, when, **kwds):
hookname = "pytest_runtest_" + when hookname = "pytest_runtest_" + when
ihook = getattr(item.ihook, hookname) ihook = getattr(item.ihook, hookname)
reraise = (Exit,) reraise = (Exit,) # type: Tuple[Type[BaseException], ...]
if not item.config.getoption("usepdb", False): if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,) reraise += (KeyboardInterrupt,)
return CallInfo.from_call( return CallInfo.from_call(
@ -252,7 +261,8 @@ def pytest_make_collect_report(collector):
skip_exceptions = [Skipped] skip_exceptions = [Skipped]
unittest = sys.modules.get("unittest") unittest = sys.modules.get("unittest")
if unittest is not None: if unittest is not None:
skip_exceptions.append(unittest.SkipTest) # Type ignored because unittest is loaded dynamically.
skip_exceptions.append(unittest.SkipTest) # type: ignore
if call.excinfo.errisinstance(tuple(skip_exceptions)): if call.excinfo.errisinstance(tuple(skip_exceptions)):
outcome = "skipped" outcome = "skipped"
r = collector._repr_failure_py(call.excinfo, "line").reprcrash r = collector._repr_failure_py(call.excinfo, "line").reprcrash
@ -266,7 +276,7 @@ def pytest_make_collect_report(collector):
rep = CollectReport( rep = CollectReport(
collector.nodeid, outcome, longrepr, getattr(call, "result", None) collector.nodeid, outcome, longrepr, getattr(call, "result", None)
) )
rep.call = call # see collect_one_node rep.call = call # type: ignore # see collect_one_node
return rep return rep
@ -274,8 +284,8 @@ class SetupState:
""" shared state for setting up/tearing down test items or collectors. """ """ shared state for setting up/tearing down test items or collectors. """
def __init__(self): def __init__(self):
self.stack = [] self.stack = [] # type: List[Node]
self._finalizers = {} self._finalizers = {} # type: Dict[Node, List[Callable[[], None]]]
def addfinalizer(self, finalizer, colitem): def addfinalizer(self, finalizer, colitem):
""" attach a finalizer to the given colitem. """ """ attach a finalizer to the given colitem. """
@ -302,6 +312,7 @@ class SetupState:
exc = sys.exc_info() exc = sys.exc_info()
if exc: if exc:
_, val, tb = exc _, val, tb = exc
assert val is not None
raise val.with_traceback(tb) raise val.with_traceback(tb)
def _teardown_with_finalization(self, colitem): def _teardown_with_finalization(self, colitem):
@ -335,6 +346,7 @@ class SetupState:
exc = sys.exc_info() exc = sys.exc_info()
if exc: if exc:
_, val, tb = exc _, val, tb = exc
assert val is not None
raise val.with_traceback(tb) raise val.with_traceback(tb)
def prepare(self, colitem): def prepare(self, colitem):