From 28761c8da1ec2f16a63fd283e89196f100e7ea03 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Thu, 18 Jul 2019 00:39:48 +0300 Subject: [PATCH 1/2] Have AssertionRewritingHook derive from importlib.abc.MetaPathFinder This is nice for self-documentation, and is the type required by mypy for adding to sys.meta_path. --- src/_pytest/assertion/rewrite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 0567e8fb8..0782bfbee 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -2,6 +2,7 @@ import ast import errno import functools +import importlib.abc import importlib.machinery import importlib.util import io @@ -37,7 +38,7 @@ AST_IS = ast.Is() AST_NONE = ast.NameConstant(None) -class AssertionRewritingHook: +class AssertionRewritingHook(importlib.abc.MetaPathFinder): """PEP302/PEP451 import hook which rewrites asserts.""" def __init__(self, config): From 7259c453d6c1dba6727cd328e6db5635ccf5821c Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Sun, 14 Jul 2019 18:45:40 +0300 Subject: [PATCH 2/2] Fix some check_untyped_defs = True mypy warnings --- src/_pytest/_code/code.py | 27 +++++++++------ src/_pytest/_code/source.py | 13 +++---- src/_pytest/assertion/__init__.py | 11 ++++-- src/_pytest/assertion/rewrite.py | 47 ++++++++++++++------------ src/_pytest/assertion/util.py | 15 +++++---- src/_pytest/config/__init__.py | 50 +++++++++++++++++---------- src/_pytest/config/argparsing.py | 38 +++++++++++++-------- src/_pytest/config/findpaths.py | 12 ++++++- src/_pytest/mark/evaluate.py | 2 ++ src/_pytest/mark/structures.py | 2 +- src/_pytest/nodes.py | 56 ++++++++++++++++++++----------- src/_pytest/reports.py | 7 ++-- src/_pytest/runner.py | 22 +++++++++--- 13 files changed, 196 insertions(+), 106 deletions(-) diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index 7d72234e7..744e9cf66 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -5,10 +5,15 @@ import traceback from inspect import CO_VARARGS from inspect import CO_VARKEYWORDS from traceback import format_exception_only +from types import CodeType from types import TracebackType +from typing import Any +from typing import Dict from typing import Generic +from typing import List from typing import Optional from typing import Pattern +from typing import Set from typing import Tuple from typing import TypeVar from typing import Union @@ -29,7 +34,7 @@ if False: # TYPE_CHECKING class Code: """ wrapper around Python code objects """ - def __init__(self, rawcode): + def __init__(self, rawcode) -> None: if not hasattr(rawcode, "co_filename"): rawcode = getrawcode(rawcode) try: @@ -38,7 +43,7 @@ class Code: self.name = rawcode.co_name except AttributeError: raise TypeError("not a code object: {!r}".format(rawcode)) - self.raw = rawcode + self.raw = rawcode # type: CodeType def __eq__(self, other): return self.raw == other.raw @@ -351,7 +356,7 @@ class Traceback(list): """ return the index of the frame/TracebackEntry where recursion originates if appropriate, None if no recursion occurred """ - cache = {} + cache = {} # type: Dict[Tuple[Any, int, int], List[Dict[str, Any]]] for i, entry in enumerate(self): # id for the code.raw is needed to work around # the strange metaprogramming in the decorator lib from pypi @@ -650,7 +655,7 @@ class FormattedExcinfo: args.append((argname, saferepr(argvalue))) return ReprFuncArgs(args) - def get_source(self, source, line_index=-1, excinfo=None, short=False): + def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]: """ return formatted and marked up source lines. """ import _pytest._code @@ -722,7 +727,7 @@ class FormattedExcinfo: else: line_index = entry.lineno - entry.getfirstlinesource() - lines = [] + lines = [] # type: List[str] style = entry._repr_style if style is None: style = self.style @@ -799,7 +804,7 @@ class FormattedExcinfo: exc_msg=str(e), max_frames=max_frames, total=len(traceback), - ) + ) # type: Optional[str] traceback = traceback[:max_frames] + traceback[-max_frames:] else: if recursionindex is not None: @@ -812,10 +817,12 @@ class FormattedExcinfo: def repr_excinfo(self, excinfo): - repr_chain = [] + repr_chain = ( + [] + ) # type: List[Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]] e = excinfo.value descr = None - seen = set() + seen = set() # type: Set[int] while e is not None and id(e) not in seen: seen.add(id(e)) if excinfo: @@ -868,8 +875,8 @@ class TerminalRepr: class ExceptionRepr(TerminalRepr): - def __init__(self): - self.sections = [] + def __init__(self) -> None: + self.sections = [] # type: List[Tuple[str, str, str]] def addsection(self, name, content, sep="-"): self.sections.append((name, content, sep)) diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index ea2fc5e3f..db78bbd0d 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -7,6 +7,7 @@ import tokenize import warnings from ast import PyCF_ONLY_AST as _AST_FLAG from bisect import bisect_right +from typing import List import py @@ -19,11 +20,11 @@ class Source: _compilecounter = 0 def __init__(self, *parts, **kwargs): - self.lines = lines = [] + self.lines = lines = [] # type: List[str] de = kwargs.get("deindent", True) for part in parts: if not part: - partlines = [] + partlines = [] # type: List[str] elif isinstance(part, Source): partlines = part.lines elif isinstance(part, (tuple, list)): @@ -157,8 +158,7 @@ class Source: source = "\n".join(self.lines) + "\n" try: co = compile(source, filename, mode, flag) - except SyntaxError: - ex = sys.exc_info()[1] + except SyntaxError as ex: # re-represent syntax errors from parsing python strings msglines = self.lines[: ex.lineno] if ex.offset: @@ -173,7 +173,8 @@ class Source: if flag & _AST_FLAG: return co lines = [(x + "\n") for x in self.lines] - linecache.cache[filename] = (1, None, lines, filename) + # Type ignored because linecache.cache is private. + linecache.cache[filename] = (1, None, lines, filename) # type: ignore return co @@ -282,7 +283,7 @@ def get_statement_startend2(lineno, node): return start, end -def getstatementrange_ast(lineno, source, assertion=False, astnode=None): +def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None): if astnode is None: content = str(source) # See #4260: diff --git a/src/_pytest/assertion/__init__.py b/src/_pytest/assertion/__init__.py index 126929b6a..3b42b356d 100644 --- a/src/_pytest/assertion/__init__.py +++ b/src/_pytest/assertion/__init__.py @@ -2,6 +2,7 @@ support for presenting detailed information in failing assertions. """ import sys +from typing import Optional from _pytest.assertion import rewrite from _pytest.assertion import truncate @@ -52,7 +53,9 @@ def register_assert_rewrite(*names): importhook = hook break else: - importhook = DummyRewriteHook() + # TODO(typing): Add a protocol for mark_rewrite() and use it + # for importhook and for PytestPluginManager.rewrite_hook. + importhook = DummyRewriteHook() # type: ignore importhook.mark_rewrite(*names) @@ -69,7 +72,7 @@ class AssertionState: def __init__(self, config, mode): self.mode = mode self.trace = config.trace.root.get("assertion") - self.hook = None + self.hook = None # type: Optional[rewrite.AssertionRewritingHook] def install_importhook(config): @@ -108,6 +111,7 @@ def pytest_runtest_setup(item): """ def callbinrepr(op, left, right): + # type: (str, object, object) -> Optional[str] """Call the pytest_assertrepr_compare hook and prepare the result This uses the first result from the hook and then ensures the @@ -133,12 +137,13 @@ def pytest_runtest_setup(item): if item.config.getvalue("assertmode") == "rewrite": res = res.replace("%", "%%") return res + return None util._reprcompare = callbinrepr if item.ihook.pytest_assertion_pass.get_hookimpls(): - def call_assertion_pass_hook(lineno, expl, orig): + def call_assertion_pass_hook(lineno, orig, expl): item.ihook.pytest_assertion_pass( item=item, lineno=lineno, orig=orig, expl=expl ) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 0782bfbee..df5131449 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -17,6 +17,7 @@ from typing import Dict from typing import List from typing import Optional from typing import Set +from typing import Tuple import atomicwrites @@ -48,13 +49,13 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder): except ValueError: self.fnpats = ["test_*.py", "*_test.py"] self.session = None - self._rewritten_names = set() - self._must_rewrite = set() + self._rewritten_names = set() # type: Set[str] + self._must_rewrite = set() # type: Set[str] # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # which might result in infinite recursion (#3506) self._writing_pyc = False self._basenames_to_check_rewrite = {"conftest"} - self._marked_for_rewrite_cache = {} + self._marked_for_rewrite_cache = {} # type: Dict[str, bool] self._session_paths_checked = False def set_session(self, session): @@ -203,7 +204,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder): return self._is_marked_for_rewrite(name, state) - def _is_marked_for_rewrite(self, name, state): + def _is_marked_for_rewrite(self, name: str, state): try: return self._marked_for_rewrite_cache[name] except KeyError: @@ -218,7 +219,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder): self._marked_for_rewrite_cache[name] = False return False - def mark_rewrite(self, *names): + def mark_rewrite(self, *names: str) -> None: """Mark import names as needing to be rewritten. The named module or package as well as any nested modules will @@ -385,6 +386,7 @@ def _format_boolop(explanations, is_or): def _call_reprcompare(ops, results, expls, each_obj): + # type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str for i, res, expl in zip(range(len(ops)), results, expls): try: done = not res @@ -400,11 +402,13 @@ def _call_reprcompare(ops, results, expls, each_obj): def _call_assertion_pass(lineno, orig, expl): + # type: (int, str, str) -> None if util._assertion_pass is not None: - util._assertion_pass(lineno=lineno, orig=orig, expl=expl) + util._assertion_pass(lineno, orig, expl) def _check_if_assertion_pass_impl(): + # type: () -> bool """Checks if any plugins implement the pytest_assertion_pass hook in order not to generate explanation unecessarily (might be expensive)""" return True if util._assertion_pass else False @@ -578,7 +582,7 @@ class AssertionRewriter(ast.NodeVisitor): def _assert_expr_to_lineno(self): return _get_assertion_exprs(self.source) - def run(self, mod): + def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" if not mod.body: # Nothing to do. @@ -620,12 +624,12 @@ class AssertionRewriter(ast.NodeVisitor): ] mod.body[pos:pos] = imports # Collect asserts. - nodes = [mod] + nodes = [mod] # type: List[ast.AST] while nodes: node = nodes.pop() for name, field in ast.iter_fields(node): if isinstance(field, list): - new = [] + new = [] # type: List for i, child in enumerate(field): if isinstance(child, ast.Assert): # Transform assert. @@ -699,7 +703,7 @@ class AssertionRewriter(ast.NodeVisitor): .explanation_param(). """ - self.explanation_specifiers = {} + self.explanation_specifiers = {} # type: Dict[str, ast.expr] self.stack.append(self.explanation_specifiers) def pop_format_context(self, expl_expr): @@ -742,7 +746,8 @@ class AssertionRewriter(ast.NodeVisitor): from _pytest.warning_types import PytestAssertRewriteWarning import warnings - warnings.warn_explicit( + # Ignore type: typeshed bug https://github.com/python/typeshed/pull/3121 + warnings.warn_explicit( # type: ignore PytestAssertRewriteWarning( "assertion is always true, perhaps remove parentheses?" ), @@ -751,15 +756,15 @@ class AssertionRewriter(ast.NodeVisitor): lineno=assert_.lineno, ) - self.statements = [] - self.variables = [] + self.statements = [] # type: List[ast.stmt] + self.variables = [] # type: List[str] self.variable_counter = itertools.count() if self.enable_assertion_pass_hook: - self.format_variables = [] + self.format_variables = [] # type: List[str] - self.stack = [] - self.expl_stmts = [] + self.stack = [] # type: List[Dict[str, ast.expr]] + self.expl_stmts = [] # type: List[ast.stmt] self.push_format_context() # Rewrite assert into a bunch of statements. top_condition, explanation = self.visit(assert_.test) @@ -897,7 +902,7 @@ warn_explicit( # Process each operand, short-circuiting if needed. for i, v in enumerate(boolop.values): if i: - fail_inner = [] + fail_inner = [] # type: List[ast.stmt] # cond is set in a prior loop iteration below self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa self.expl_stmts = fail_inner @@ -908,10 +913,10 @@ warn_explicit( call = ast.Call(app, [expl_format], []) self.expl_stmts.append(ast.Expr(call)) if i < levels: - cond = res + cond = res # type: ast.expr if is_or: cond = ast.UnaryOp(ast.Not(), cond) - inner = [] + inner = [] # type: List[ast.stmt] self.statements.append(ast.If(cond, inner, [])) self.statements = body = inner self.statements = save @@ -977,7 +982,7 @@ warn_explicit( expl = pat % (res_expl, res_expl, value_expl, attr.attr) return res, expl - def visit_Compare(self, comp): + def visit_Compare(self, comp: ast.Compare): self.push_format_context() left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, (ast.Compare, ast.BoolOp)): @@ -1010,7 +1015,7 @@ warn_explicit( ast.Tuple(results, ast.Load()), ) if len(comp.ops) > 1: - res = ast.BoolOp(ast.And(), load_names) + res = ast.BoolOp(ast.And(), load_names) # type: ast.expr else: res = load_names[0] return res, self.explanation_param(self.pop_format_context(expl_call)) diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index 732194ec2..11c7bdf6f 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -1,6 +1,9 @@ """Utilities for assertion debugging""" import pprint from collections.abc import Sequence +from typing import Callable +from typing import List +from typing import Optional import _pytest._code from _pytest import outcomes @@ -10,11 +13,11 @@ from _pytest._io.saferepr import saferepr # interpretation code and assertion rewriter to detect this plugin was # loaded and in turn call the hooks defined here as part of the # DebugInterpreter. -_reprcompare = None +_reprcompare = None # type: Optional[Callable[[str, object, object], Optional[str]]] # Works similarly as _reprcompare attribute. Is populated with the hook call # when pytest_runtest_setup is called. -_assertion_pass = None +_assertion_pass = None # type: Optional[Callable[[int, str, str], None]] def format_explanation(explanation): @@ -177,7 +180,7 @@ def _diff_text(left, right, verbose=0): """ from difflib import ndiff - explanation = [] + explanation = [] # type: List[str] def escape_for_readable_diff(binary_text): """ @@ -235,7 +238,7 @@ def _compare_eq_verbose(left, right): left_lines = repr(left).splitlines(keepends) right_lines = repr(right).splitlines(keepends) - explanation = [] + explanation = [] # type: List[str] explanation += ["-" + line for line in left_lines] explanation += ["+" + line for line in right_lines] @@ -259,7 +262,7 @@ def _compare_eq_iterable(left, right, verbose=0): def _compare_eq_sequence(left, right, verbose=0): comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes) - explanation = [] + explanation = [] # type: List[str] len_left = len(left) len_right = len(right) for i in range(min(len_left, len_right)): @@ -327,7 +330,7 @@ def _compare_eq_set(left, right, verbose=0): def _compare_eq_dict(left, right, verbose=0): - explanation = [] + explanation = [] # type: List[str] set_left = set(left) set_right = set(right) common = set_left.intersection(set_right) diff --git a/src/_pytest/config/__init__.py b/src/_pytest/config/__init__.py index b861563e9..d547f033d 100644 --- a/src/_pytest/config/__init__.py +++ b/src/_pytest/config/__init__.py @@ -9,6 +9,15 @@ import types import warnings from functools import lru_cache from pathlib import Path +from types import TracebackType +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple import attr import py @@ -32,6 +41,10 @@ from _pytest.outcomes import fail from _pytest.outcomes import Skipped from _pytest.warning_types import PytestConfigWarning +if False: # TYPE_CHECKING + from typing import Type + + hookimpl = HookimplMarker("pytest") hookspec = HookspecMarker("pytest") @@ -40,7 +53,7 @@ class ConftestImportFailure(Exception): def __init__(self, path, excinfo): Exception.__init__(self, path, excinfo) self.path = path - self.excinfo = excinfo + self.excinfo = excinfo # type: Tuple[Type[Exception], Exception, TracebackType] def main(args=None, plugins=None): @@ -237,14 +250,18 @@ class PytestPluginManager(PluginManager): def __init__(self): super().__init__("pytest") - self._conftest_plugins = set() + # The objects are module objects, only used generically. + self._conftest_plugins = set() # type: Set[object] # state related to local conftest plugins - self._dirpath2confmods = {} - self._conftestpath2mod = {} + # Maps a py.path.local to a list of module objects. + self._dirpath2confmods = {} # type: Dict[Any, List[object]] + # Maps a py.path.local to a module object. + self._conftestpath2mod = {} # type: Dict[Any, object] self._confcutdir = None self._noconftest = False - self._duplicatepaths = set() + # Set of py.path.local's. + self._duplicatepaths = set() # type: Set[Any] self.add_hookspecs(_pytest.hookspec) self.register(self) @@ -653,7 +670,7 @@ class Config: args = attr.ib() plugins = attr.ib() - dir = attr.ib() + dir = attr.ib(type=Path) def __init__(self, pluginmanager, *, invocation_params=None): from .argparsing import Parser, FILE_OR_DIR @@ -674,10 +691,10 @@ class Config: self.pluginmanager = pluginmanager self.trace = self.pluginmanager.trace.root.get("config") self.hook = self.pluginmanager.hook - self._inicache = {} - self._override_ini = () - self._opt2dest = {} - self._cleanup = [] + self._inicache = {} # type: Dict[str, Any] + self._override_ini = () # type: Sequence[str] + self._opt2dest = {} # type: Dict[str, str] + self._cleanup = [] # type: List[Callable[[], None]] self.pluginmanager.register(self, "pytestconfig") self._configured = False self.hook.pytest_addoption.call_historic(kwargs=dict(parser=self._parser)) @@ -778,7 +795,7 @@ class Config: def pytest_load_initial_conftests(self, early_config): self.pluginmanager._set_initial_conftests(early_config.known_args_namespace) - def _initini(self, args): + def _initini(self, args) -> None: ns, unknown_args = self._parser.parse_known_and_unknown_args( args, namespace=copy.copy(self.option) ) @@ -879,8 +896,7 @@ class Config: self.hook.pytest_load_initial_conftests( early_config=self, args=args, parser=self._parser ) - except ConftestImportFailure: - e = sys.exc_info()[1] + except ConftestImportFailure as e: if ns.help or ns.version: # we don't want to prevent --help/--version to work # so just let is pass and print a warning at the end @@ -946,7 +962,7 @@ class Config: assert isinstance(x, list) x.append(line) # modifies the cached list inline - def getini(self, name): + def getini(self, name: str): """ return configuration value from an :ref:`ini file `. If the specified name hasn't been registered through a prior :py:func:`parser.addini <_pytest.config.Parser.addini>` @@ -957,7 +973,7 @@ class Config: self._inicache[name] = val = self._getini(name) return val - def _getini(self, name): + def _getini(self, name: str) -> Any: try: description, type, default = self._parser._inidict[name] except KeyError: @@ -1002,7 +1018,7 @@ class Config: values.append(relroot) return values - def _get_override_ini_value(self, name): + def _get_override_ini_value(self, name: str) -> Optional[str]: value = None # override_ini is a list of "ini=value" options # always use the last item if multiple values are set for same ini-name, @@ -1017,7 +1033,7 @@ class Config: value = user_ini_value return value - def getoption(self, name, default=notset, skip=False): + def getoption(self, name: str, default=notset, skip: bool = False): """ return command line option value. :arg name: name of the option. You may also specify diff --git a/src/_pytest/config/argparsing.py b/src/_pytest/config/argparsing.py index 8994ff7d9..4bf3b54ba 100644 --- a/src/_pytest/config/argparsing.py +++ b/src/_pytest/config/argparsing.py @@ -2,6 +2,11 @@ import argparse import sys import warnings from gettext import gettext +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple import py @@ -21,12 +26,12 @@ class Parser: def __init__(self, usage=None, processopt=None): self._anonymous = OptionGroup("custom options", parser=self) - self._groups = [] + self._groups = [] # type: List[OptionGroup] self._processopt = processopt self._usage = usage - self._inidict = {} - self._ininames = [] - self.extra_info = {} + self._inidict = {} # type: Dict[str, Tuple[str, Optional[str], Any]] + self._ininames = [] # type: List[str] + self.extra_info = {} # type: Dict[str, Any] def processoption(self, option): if self._processopt: @@ -80,7 +85,7 @@ class Parser: args = [str(x) if isinstance(x, py.path.local) else x for x in args] return self.optparser.parse_args(args, namespace=namespace) - def _getparser(self): + def _getparser(self) -> "MyOptionParser": from _pytest._argcomplete import filescompleter optparser = MyOptionParser(self, self.extra_info, prog=self.prog) @@ -94,7 +99,10 @@ class Parser: a = option.attrs() arggroup.add_argument(*n, **a) # bash like autocompletion for dirs (appending '/') - optparser.add_argument(FILE_OR_DIR, nargs="*").completer = filescompleter + # Type ignored because typeshed doesn't know about argcomplete. + optparser.add_argument( # type: ignore + FILE_OR_DIR, nargs="*" + ).completer = filescompleter return optparser def parse_setoption(self, args, option, namespace=None): @@ -103,13 +111,15 @@ class Parser: setattr(option, name, value) return getattr(parsedoption, FILE_OR_DIR) - def parse_known_args(self, args, namespace=None): + def parse_known_args(self, args, namespace=None) -> argparse.Namespace: """parses and returns a namespace object with known arguments at this point. """ return self.parse_known_and_unknown_args(args, namespace=namespace)[0] - def parse_known_and_unknown_args(self, args, namespace=None): + def parse_known_and_unknown_args( + self, args, namespace=None + ) -> Tuple[argparse.Namespace, List[str]]: """parses and returns a namespace object with known arguments, and the remaining arguments unknown at this point. """ @@ -163,8 +173,8 @@ class Argument: def __init__(self, *names, **attrs): """store parms in private vars for use in add_argument""" self._attrs = attrs - self._short_opts = [] - self._long_opts = [] + self._short_opts = [] # type: List[str] + self._long_opts = [] # type: List[str] self.dest = attrs.get("dest") if "%default" in (attrs.get("help") or ""): warnings.warn( @@ -268,8 +278,8 @@ class Argument: ) self._long_opts.append(opt) - def __repr__(self): - args = [] + def __repr__(self) -> str: + args = [] # type: List[str] if self._short_opts: args += ["_short_opts: " + repr(self._short_opts)] if self._long_opts: @@ -286,7 +296,7 @@ class OptionGroup: def __init__(self, name, description="", parser=None): self.name = name self.description = description - self.options = [] + self.options = [] # type: List[Argument] self.parser = parser def addoption(self, *optnames, **attrs): @@ -421,7 +431,7 @@ class DropShorterLongHelpFormatter(argparse.HelpFormatter): option_map = getattr(action, "map_long_option", {}) if option_map is None: option_map = {} - short_long = {} + short_long = {} # type: Dict[str, str] for option in options: if len(option) == 2 or option[2] == " ": continue diff --git a/src/_pytest/config/findpaths.py b/src/_pytest/config/findpaths.py index ec991316a..f06c9cfff 100644 --- a/src/_pytest/config/findpaths.py +++ b/src/_pytest/config/findpaths.py @@ -1,10 +1,15 @@ import os +from typing import List +from typing import Optional import py from .exceptions import UsageError from _pytest.outcomes import fail +if False: + from . import Config # noqa: F401 + def exists(path, ignore=EnvironmentError): try: @@ -102,7 +107,12 @@ def get_dirs_from_args(args): CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead." -def determine_setup(inifile, args, rootdir_cmd_arg=None, config=None): +def determine_setup( + inifile: str, + args: List[str], + rootdir_cmd_arg: Optional[str] = None, + config: Optional["Config"] = None, +): dirs = get_dirs_from_args(args) if inifile: iniconfig = py.iniconfig.IniConfig(inifile) diff --git a/src/_pytest/mark/evaluate.py b/src/_pytest/mark/evaluate.py index 898278e30..b9f2d61f8 100644 --- a/src/_pytest/mark/evaluate.py +++ b/src/_pytest/mark/evaluate.py @@ -51,6 +51,8 @@ class MarkEvaluator: except TEST_OUTCOME: self.exc = sys.exc_info() if isinstance(self.exc[1], SyntaxError): + # TODO: Investigate why SyntaxError.offset is Optional, and if it can be None here. + assert self.exc[1].offset is not None msg = [" " * (self.exc[1].offset + 4) + "^"] msg.append("SyntaxError: invalid syntax") else: diff --git a/src/_pytest/mark/structures.py b/src/_pytest/mark/structures.py index 332c86bde..f8cf55b4c 100644 --- a/src/_pytest/mark/structures.py +++ b/src/_pytest/mark/structures.py @@ -292,7 +292,7 @@ class MarkGenerator: _config = None _markers = set() # type: Set[str] - def __getattr__(self, name): + def __getattr__(self, name: str) -> MarkDecorator: if name[0] == "_": raise AttributeError("Marker name must NOT start with underscore") diff --git a/src/_pytest/nodes.py b/src/_pytest/nodes.py index 9b78dca38..b1bbc2943 100644 --- a/src/_pytest/nodes.py +++ b/src/_pytest/nodes.py @@ -1,14 +1,26 @@ import os import warnings from functools import lru_cache +from typing import Any +from typing import Dict +from typing import List +from typing import Set +from typing import Tuple +from typing import Union import py import _pytest._code from _pytest.compat import getfslineno +from _pytest.mark.structures import Mark +from _pytest.mark.structures import MarkDecorator from _pytest.mark.structures import NodeKeywords from _pytest.outcomes import fail +if False: # TYPE_CHECKING + # Imported here due to circular import. + from _pytest.fixtures import FixtureDef + SEP = "/" tracebackcutdir = py.path.local(_pytest.__file__).dirpath() @@ -78,13 +90,13 @@ class Node: self.keywords = NodeKeywords(self) #: the marker objects belonging to this node - self.own_markers = [] + self.own_markers = [] # type: List[Mark] #: allow adding of extra keywords to use for matching - self.extra_keyword_matches = set() + self.extra_keyword_matches = set() # type: Set[str] # used for storing artificial fixturedefs for direct parametrization - self._name2pseudofixturedef = {} + self._name2pseudofixturedef = {} # type: Dict[str, FixtureDef] if nodeid is not None: assert "::()" not in nodeid @@ -127,7 +139,8 @@ class Node: ) ) path, lineno = get_fslocation_from_item(self) - warnings.warn_explicit( + # Type ignored: https://github.com/python/typeshed/pull/3121 + warnings.warn_explicit( # type: ignore warning, category=None, filename=str(path), @@ -160,7 +173,9 @@ class Node: chain.reverse() return chain - def add_marker(self, marker, append=True): + def add_marker( + self, marker: Union[str, MarkDecorator], append: bool = True + ) -> None: """dynamically add a marker object to the node. :type marker: ``str`` or ``pytest.mark.*`` object @@ -168,17 +183,19 @@ class Node: ``append=True`` whether to append the marker, if ``False`` insert at position ``0``. """ - from _pytest.mark import MarkDecorator, MARK_GEN + from _pytest.mark import MARK_GEN - if isinstance(marker, str): - marker = getattr(MARK_GEN, marker) - elif not isinstance(marker, MarkDecorator): - raise ValueError("is not a string or pytest.mark.* Marker") - self.keywords[marker.name] = marker - if append: - self.own_markers.append(marker.mark) + if isinstance(marker, MarkDecorator): + marker_ = marker + elif isinstance(marker, str): + marker_ = getattr(MARK_GEN, marker) else: - self.own_markers.insert(0, marker.mark) + raise ValueError("is not a string or pytest.mark.* Marker") + self.keywords[marker_.name] = marker + if append: + self.own_markers.append(marker_.mark) + else: + self.own_markers.insert(0, marker_.mark) def iter_markers(self, name=None): """ @@ -211,7 +228,7 @@ class Node: def listextrakeywords(self): """ Return a set of all extra keywords in self and any parents.""" - extra_keywords = set() + extra_keywords = set() # type: Set[str] for item in self.listchain(): extra_keywords.update(item.extra_keyword_matches) return extra_keywords @@ -239,7 +256,8 @@ class Node: pass def _repr_failure_py(self, excinfo, style=None): - if excinfo.errisinstance(fail.Exception): + # Type ignored: see comment where fail.Exception is defined. + if excinfo.errisinstance(fail.Exception): # type: ignore if not excinfo.value.pytrace: return str(excinfo.value) fm = self.session._fixturemanager @@ -385,13 +403,13 @@ class Item(Node): def __init__(self, name, parent=None, config=None, session=None, nodeid=None): super().__init__(name, parent, config, session, nodeid=nodeid) - self._report_sections = [] + self._report_sections = [] # type: List[Tuple[str, str, str]] #: user properties is a list of tuples (name, value) that holds user #: defined properties for this test. - self.user_properties = [] + self.user_properties = [] # type: List[Tuple[str, Any]] - def add_report_section(self, when, key, content): + def add_report_section(self, when: str, key: str, content: str) -> None: """ Adds a new report section, similar to what's done internally to add stdout and stderr captured output:: diff --git a/src/_pytest/reports.py b/src/_pytest/reports.py index 4682d5b6e..732408323 100644 --- a/src/_pytest/reports.py +++ b/src/_pytest/reports.py @@ -1,5 +1,6 @@ from pprint import pprint from typing import Optional +from typing import Union import py @@ -221,7 +222,6 @@ class BaseReport: reprcrash = reportdict["longrepr"]["reprcrash"] unserialized_entries = [] - reprentry = None for entry_data in reprtraceback["reprentries"]: data = entry_data["data"] entry_type = entry_data["type"] @@ -242,7 +242,7 @@ class BaseReport: reprlocals=reprlocals, filelocrepr=reprfileloc, style=data["style"], - ) + ) # type: Union[ReprEntry, ReprEntryNative] elif entry_type == "ReprEntryNative": reprentry = ReprEntryNative(data["lines"]) else: @@ -352,7 +352,8 @@ class TestReport(BaseReport): if not isinstance(excinfo, ExceptionInfo): outcome = "failed" longrepr = excinfo - elif excinfo.errisinstance(skip.Exception): + # Type ignored -- see comment where skip.Exception is defined. + elif excinfo.errisinstance(skip.Exception): # type: ignore outcome = "skipped" r = excinfo._getreprcrash() longrepr = (str(r.path), r.lineno, r.message) diff --git a/src/_pytest/runner.py b/src/_pytest/runner.py index 8aae163c3..7d8b74a80 100644 --- a/src/_pytest/runner.py +++ b/src/_pytest/runner.py @@ -3,6 +3,10 @@ import bdb import os import sys from time import time +from typing import Callable +from typing import Dict +from typing import List +from typing import Tuple import attr @@ -10,10 +14,14 @@ from .reports import CollectErrorRepr from .reports import CollectReport from .reports import TestReport from _pytest._code.code import ExceptionInfo +from _pytest.nodes import Node from _pytest.outcomes import Exit from _pytest.outcomes import Skipped from _pytest.outcomes import TEST_OUTCOME +if False: # TYPE_CHECKING + from typing import Type + # # pytest plugin hooks @@ -118,6 +126,7 @@ def pytest_runtest_call(item): except Exception: # Store trace info to allow postmortem debugging type, value, tb = sys.exc_info() + assert tb is not None tb = tb.tb_next # Skip *this* frame sys.last_type = type sys.last_value = value @@ -185,7 +194,7 @@ def check_interactive_exception(call, report): def call_runtest_hook(item, when, **kwds): hookname = "pytest_runtest_" + when ihook = getattr(item.ihook, hookname) - reraise = (Exit,) + reraise = (Exit,) # type: Tuple[Type[BaseException], ...] if not item.config.getoption("usepdb", False): reraise += (KeyboardInterrupt,) return CallInfo.from_call( @@ -252,7 +261,8 @@ def pytest_make_collect_report(collector): skip_exceptions = [Skipped] unittest = sys.modules.get("unittest") if unittest is not None: - skip_exceptions.append(unittest.SkipTest) + # Type ignored because unittest is loaded dynamically. + skip_exceptions.append(unittest.SkipTest) # type: ignore if call.excinfo.errisinstance(tuple(skip_exceptions)): outcome = "skipped" r = collector._repr_failure_py(call.excinfo, "line").reprcrash @@ -266,7 +276,7 @@ def pytest_make_collect_report(collector): rep = CollectReport( collector.nodeid, outcome, longrepr, getattr(call, "result", None) ) - rep.call = call # see collect_one_node + rep.call = call # type: ignore # see collect_one_node return rep @@ -274,8 +284,8 @@ class SetupState: """ shared state for setting up/tearing down test items or collectors. """ def __init__(self): - self.stack = [] - self._finalizers = {} + self.stack = [] # type: List[Node] + self._finalizers = {} # type: Dict[Node, List[Callable[[], None]]] def addfinalizer(self, finalizer, colitem): """ attach a finalizer to the given colitem. """ @@ -302,6 +312,7 @@ class SetupState: exc = sys.exc_info() if exc: _, val, tb = exc + assert val is not None raise val.with_traceback(tb) def _teardown_with_finalization(self, colitem): @@ -335,6 +346,7 @@ class SetupState: exc = sys.exc_info() if exc: _, val, tb = exc + assert val is not None raise val.with_traceback(tb) def prepare(self, colitem):