diff --git a/src/_pytest/assertion/__init__.py b/src/_pytest/assertion/__init__.py index 6504db574..f404607c1 100644 --- a/src/_pytest/assertion/__init__.py +++ b/src/_pytest/assertion/__init__.py @@ -46,7 +46,7 @@ def pytest_addoption(parser: Parser) -> None: ) -def register_assert_rewrite(*names) -> None: +def register_assert_rewrite(*names: str) -> None: """Register one or more module names to be rewritten on import. This function will make sure that this module or all modules inside @@ -75,27 +75,27 @@ def register_assert_rewrite(*names) -> None: class DummyRewriteHook: """A no-op import hook for when rewriting is disabled.""" - def mark_rewrite(self, *names): + def mark_rewrite(self, *names: str) -> None: pass class AssertionState: """State for the assertion plugin.""" - def __init__(self, config, mode): + def __init__(self, config: Config, mode) -> None: self.mode = mode self.trace = config.trace.root.get("assertion") self.hook = None # type: Optional[rewrite.AssertionRewritingHook] -def install_importhook(config): +def install_importhook(config: Config) -> rewrite.AssertionRewritingHook: """Try to install the rewrite hook, raise SystemError if it fails.""" config._store[assertstate_key] = AssertionState(config, "rewrite") config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config) sys.meta_path.insert(0, hook) config._store[assertstate_key].trace("installed rewrite import hook") - def undo(): + def undo() -> None: hook = config._store[assertstate_key].hook if hook is not None and hook in sys.meta_path: sys.meta_path.remove(hook) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index bd4ea022c..cec0c5501 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -13,11 +13,15 @@ import struct import sys import tokenize import types +from typing import Callable from typing import Dict +from typing import IO from typing import List from typing import Optional +from typing import Sequence from typing import Set from typing import Tuple +from typing import Union from _pytest._io.saferepr import saferepr from _pytest._version import version @@ -27,6 +31,8 @@ from _pytest.assertion.util import ( # noqa: F401 ) from _pytest.compat import fspath from _pytest.compat import TYPE_CHECKING +from _pytest.config import Config +from _pytest.main import Session from _pytest.pathlib import fnmatch_ex from _pytest.pathlib import Path from _pytest.pathlib import PurePath @@ -48,13 +54,13 @@ PYC_TAIL = "." + PYTEST_TAG + PYC_EXT class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): """PEP302/PEP451 import hook which rewrites asserts.""" - def __init__(self, config): + def __init__(self, config: Config) -> None: self.config = config try: self.fnpats = config.getini("python_files") except ValueError: self.fnpats = ["test_*.py", "*_test.py"] - self.session = None + self.session = None # type: Optional[Session] self._rewritten_names = set() # type: Set[str] self._must_rewrite = set() # type: Set[str] # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, @@ -64,14 +70,19 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) self._marked_for_rewrite_cache = {} # type: Dict[str, bool] self._session_paths_checked = False - def set_session(self, session): + def set_session(self, session: Optional[Session]) -> None: self.session = session self._session_paths_checked = False # Indirection so we can mock calls to find_spec originated from the hook during testing _find_spec = importlib.machinery.PathFinder.find_spec - def find_spec(self, name, path=None, target=None): + def find_spec( + self, + name: str, + path: Optional[Sequence[Union[str, bytes]]] = None, + target: Optional[types.ModuleType] = None, + ) -> Optional[importlib.machinery.ModuleSpec]: if self._writing_pyc: return None state = self.config._store[assertstate_key] @@ -79,7 +90,8 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) return None state.trace("find_module called for: %s" % name) - spec = self._find_spec(name, path) + # Type ignored because mypy is confused about the `self` binding here. + spec = self._find_spec(name, path) # type: ignore if ( # the import machinery could not find a file to import spec is None @@ -108,10 +120,14 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) submodule_search_locations=spec.submodule_search_locations, ) - def create_module(self, spec): + def create_module( + self, spec: importlib.machinery.ModuleSpec + ) -> Optional[types.ModuleType]: return None # default behaviour is fine - def exec_module(self, module): + def exec_module(self, module: types.ModuleType) -> None: + assert module.__spec__ is not None + assert module.__spec__.origin is not None fn = Path(module.__spec__.origin) state = self.config._store[assertstate_key] @@ -151,7 +167,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) state.trace("found cached rewritten pyc for {}".format(fn)) exec(co, module.__dict__) - def _early_rewrite_bailout(self, name, state): + def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool: """This is a fast way to get out of rewriting modules. Profiling has shown that the call to PathFinder.find_spec (inside of @@ -190,7 +206,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) state.trace("early skip of rewriting module: {}".format(name)) return True - def _should_rewrite(self, name, fn, state): + def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool: # always rewrite conftest files if os.path.basename(fn) == "conftest.py": state.trace("rewriting conftest file: {!r}".format(fn)) @@ -213,7 +229,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) return self._is_marked_for_rewrite(name, state) - def _is_marked_for_rewrite(self, name: str, state): + def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool: try: return self._marked_for_rewrite_cache[name] except KeyError: @@ -246,7 +262,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) self._must_rewrite.update(names) self._marked_for_rewrite_cache.clear() - def _warn_already_imported(self, name): + def _warn_already_imported(self, name: str) -> None: from _pytest.warning_types import PytestAssertRewriteWarning from _pytest.warnings import _issue_warning_captured @@ -258,13 +274,15 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) stacklevel=5, ) - def get_data(self, pathname): + def get_data(self, pathname: Union[str, bytes]) -> bytes: """Optional PEP302 get_data API.""" with open(pathname, "rb") as f: return f.read() -def _write_pyc_fp(fp, source_stat, co): +def _write_pyc_fp( + fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType +) -> None: # Technically, we don't have to have the same pyc format as # (C)Python, since these "pycs" should never be seen by builtin # import. However, there's little reason deviate. @@ -280,7 +298,12 @@ def _write_pyc_fp(fp, source_stat, co): if sys.platform == "win32": from atomicwrites import atomic_write - def _write_pyc(state, co, source_stat, pyc): + def _write_pyc( + state: "AssertionState", + co: types.CodeType, + source_stat: os.stat_result, + pyc: Path, + ) -> bool: try: with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp: _write_pyc_fp(fp, source_stat, co) @@ -295,7 +318,12 @@ if sys.platform == "win32": else: - def _write_pyc(state, co, source_stat, pyc): + def _write_pyc( + state: "AssertionState", + co: types.CodeType, + source_stat: os.stat_result, + pyc: Path, + ) -> bool: proc_pyc = "{}.{}".format(pyc, os.getpid()) try: fp = open(proc_pyc, "wb") @@ -319,19 +347,21 @@ else: return True -def _rewrite_test(fn, config): +def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: """read and rewrite *fn* and return the code object.""" - fn = fspath(fn) - stat = os.stat(fn) - with open(fn, "rb") as f: + fn_ = fspath(fn) + stat = os.stat(fn_) + with open(fn_, "rb") as f: source = f.read() - tree = ast.parse(source, filename=fn) - rewrite_asserts(tree, source, fn, config) - co = compile(tree, fn, "exec", dont_inherit=True) + tree = ast.parse(source, filename=fn_) + rewrite_asserts(tree, source, fn_, config) + co = compile(tree, fn_, "exec", dont_inherit=True) return stat, co -def _read_pyc(source, pyc, trace=lambda x: None): +def _read_pyc( + source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None +) -> Optional[types.CodeType]: """Possibly read a pytest pyc containing rewritten code. Return rewritten code if successful or None if not. @@ -368,12 +398,17 @@ def _read_pyc(source, pyc, trace=lambda x: None): return co -def rewrite_asserts(mod, source, module_path=None, config=None): +def rewrite_asserts( + mod: ast.Module, + source: bytes, + module_path: Optional[str] = None, + config: Optional[Config] = None, +) -> None: """Rewrite the assert statements in mod.""" AssertionRewriter(module_path, config, source).run(mod) -def _saferepr(obj): +def _saferepr(obj: object) -> str: """Get a safe repr of an object for assertion error messages. The assertion formatting (util.format_explanation()) requires @@ -387,7 +422,7 @@ def _saferepr(obj): return saferepr(obj).replace("\n", "\\n") -def _format_assertmsg(obj): +def _format_assertmsg(obj: object) -> str: """Format the custom assertion message given. For strings this simply replaces newlines with '\n~' so that @@ -410,7 +445,7 @@ def _format_assertmsg(obj): return obj -def _should_repr_global_name(obj): +def _should_repr_global_name(obj: object) -> bool: if callable(obj): return False @@ -420,7 +455,7 @@ def _should_repr_global_name(obj): return True -def _format_boolop(explanations, is_or): +def _format_boolop(explanations, is_or: bool): explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")" if isinstance(explanation, str): return explanation.replace("%", "%%") @@ -428,8 +463,12 @@ def _format_boolop(explanations, is_or): return explanation.replace(b"%", b"%%") -def _call_reprcompare(ops, results, expls, each_obj): - # type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str +def _call_reprcompare( + ops: Sequence[str], + results: Sequence[bool], + expls: Sequence[str], + each_obj: Sequence[object], +) -> str: for i, res, expl in zip(range(len(ops)), results, expls): try: done = not res @@ -607,7 +646,9 @@ class AssertionRewriter(ast.NodeVisitor): """ - def __init__(self, module_path, config, source): + def __init__( + self, module_path: Optional[str], config: Optional[Config], source: bytes + ) -> None: super().__init__() self.module_path = module_path self.config = config @@ -620,7 +661,7 @@ class AssertionRewriter(ast.NodeVisitor): self.source = source @functools.lru_cache(maxsize=1) - def _assert_expr_to_lineno(self): + def _assert_expr_to_lineno(self) -> Dict[int, str]: return _get_assertion_exprs(self.source) def run(self, mod: ast.Module) -> None: @@ -689,38 +730,38 @@ class AssertionRewriter(ast.NodeVisitor): nodes.append(field) @staticmethod - def is_rewrite_disabled(docstring): + def is_rewrite_disabled(docstring: str) -> bool: return "PYTEST_DONT_REWRITE" in docstring - def variable(self): + def variable(self) -> str: """Get a new variable.""" # Use a character invalid in python identifiers to avoid clashing. name = "@py_assert" + str(next(self.variable_counter)) self.variables.append(name) return name - def assign(self, expr): + def assign(self, expr: ast.expr) -> ast.Name: """Give *expr* a name.""" name = self.variable() self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) return ast.Name(name, ast.Load()) - def display(self, expr): + def display(self, expr: ast.expr) -> ast.expr: """Call saferepr on the expression.""" return self.helper("_saferepr", expr) - def helper(self, name, *args): + def helper(self, name: str, *args: ast.expr) -> ast.expr: """Call a helper in this module.""" py_name = ast.Name("@pytest_ar", ast.Load()) attr = ast.Attribute(py_name, name, ast.Load()) return ast.Call(attr, list(args), []) - def builtin(self, name): + def builtin(self, name: str) -> ast.Attribute: """Return the builtin called *name*.""" builtin_name = ast.Name("@py_builtins", ast.Load()) return ast.Attribute(builtin_name, name, ast.Load()) - def explanation_param(self, expr): + def explanation_param(self, expr: ast.expr) -> str: """Return a new named %-formatting placeholder for expr. This creates a %-formatting placeholder for expr in the @@ -733,7 +774,7 @@ class AssertionRewriter(ast.NodeVisitor): self.explanation_specifiers[specifier] = expr return "%(" + specifier + ")s" - def push_format_context(self): + def push_format_context(self) -> None: """Create a new formatting context. The format context is used for when an explanation wants to @@ -747,10 +788,10 @@ class AssertionRewriter(ast.NodeVisitor): self.explanation_specifiers = {} # type: Dict[str, ast.expr] self.stack.append(self.explanation_specifiers) - def pop_format_context(self, expl_expr): + def pop_format_context(self, expl_expr: ast.expr) -> ast.Name: """Format the %-formatted string with current format context. - The expl_expr should be an ast.Str instance constructed from + The expl_expr should be an str ast.expr instance constructed from the %-placeholders created by .explanation_param(). This will add the required code to format said string to .expl_stmts and return the ast.Name instance of the formatted string. @@ -768,13 +809,13 @@ class AssertionRewriter(ast.NodeVisitor): self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) return ast.Name(name, ast.Load()) - def generic_visit(self, node): + def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]: """Handle expressions we don't have custom code for.""" assert isinstance(node, ast.expr) res = self.assign(node) return res, self.explanation_param(self.display(res)) - def visit_Assert(self, assert_): + def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: """Return the AST statements to replace the ast.Assert instance. This rewrites the test of an assertion to provide @@ -787,6 +828,8 @@ class AssertionRewriter(ast.NodeVisitor): from _pytest.warning_types import PytestAssertRewriteWarning import warnings + # TODO: This assert should not be needed. + assert self.module_path is not None warnings.warn_explicit( PytestAssertRewriteWarning( "assertion is always true, perhaps remove parentheses?" @@ -889,7 +932,7 @@ class AssertionRewriter(ast.NodeVisitor): set_location(stmt, assert_.lineno, assert_.col_offset) return self.statements - def visit_Name(self, name): + def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: # Display the repr of the name if it's a local variable or # _should_repr_global_name() thinks it's acceptable. locs = ast.Call(self.builtin("locals"), [], []) @@ -899,7 +942,7 @@ class AssertionRewriter(ast.NodeVisitor): expr = ast.IfExp(test, self.display(name), ast.Str(name.id)) return name, self.explanation_param(expr) - def visit_BoolOp(self, boolop): + def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: res_var = self.variable() expl_list = self.assign(ast.List([], ast.Load())) app = ast.Attribute(expl_list, "append", ast.Load()) @@ -934,13 +977,13 @@ class AssertionRewriter(ast.NodeVisitor): expl = self.pop_format_context(expl_template) return ast.Name(res_var, ast.Load()), self.explanation_param(expl) - def visit_UnaryOp(self, unary): + def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]: pattern = UNARY_MAP[unary.op.__class__] operand_res, operand_expl = self.visit(unary.operand) res = self.assign(ast.UnaryOp(unary.op, operand_res)) return res, pattern % (operand_expl,) - def visit_BinOp(self, binop): + def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]: symbol = BINOP_MAP[binop.op.__class__] left_expr, left_expl = self.visit(binop.left) right_expr, right_expl = self.visit(binop.right) @@ -948,7 +991,7 @@ class AssertionRewriter(ast.NodeVisitor): res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) return res, explanation - def visit_Call(self, call): + def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: """ visit `ast.Call` nodes """ @@ -975,13 +1018,13 @@ class AssertionRewriter(ast.NodeVisitor): outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl) return res, outer_expl - def visit_Starred(self, starred): + def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]: # From Python 3.5, a Starred node can appear in a function call res, expl = self.visit(starred.value) new_starred = ast.Starred(res, starred.ctx) return new_starred, "*" + expl - def visit_Attribute(self, attr): + def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: if not isinstance(attr.ctx, ast.Load): return self.generic_visit(attr) value, value_expl = self.visit(attr.value) @@ -991,7 +1034,7 @@ class AssertionRewriter(ast.NodeVisitor): expl = pat % (res_expl, res_expl, value_expl, attr.attr) return res, expl - def visit_Compare(self, comp: ast.Compare): + def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, (ast.Compare, ast.BoolOp)): @@ -1030,7 +1073,7 @@ class AssertionRewriter(ast.NodeVisitor): return res, self.explanation_param(self.pop_format_context(expl_call)) -def try_makedirs(cache_dir) -> bool: +def try_makedirs(cache_dir: Path) -> bool: """Attempts to create the given directory and sub-directories exist, returns True if successful or it already exists""" try: diff --git a/src/_pytest/assertion/truncate.py b/src/_pytest/assertion/truncate.py index d97b05b44..fb2bf9c8e 100644 --- a/src/_pytest/assertion/truncate.py +++ b/src/_pytest/assertion/truncate.py @@ -5,13 +5,20 @@ Current default behaviour is to truncate assertion explanations at ~8 terminal lines, unless running in "-vv" mode or running on CI. """ import os +from typing import List +from typing import Optional + +from _pytest.nodes import Item + DEFAULT_MAX_LINES = 8 DEFAULT_MAX_CHARS = 8 * 80 USAGE_MSG = "use '-vv' to show" -def truncate_if_required(explanation, item, max_length=None): +def truncate_if_required( + explanation: List[str], item: Item, max_length: Optional[int] = None +) -> List[str]: """ Truncate this assertion explanation if the given test item is eligible. """ @@ -20,7 +27,7 @@ def truncate_if_required(explanation, item, max_length=None): return explanation -def _should_truncate_item(item): +def _should_truncate_item(item: Item) -> bool: """ Whether or not this test item is eligible for truncation. """ @@ -28,13 +35,17 @@ def _should_truncate_item(item): return verbose < 2 and not _running_on_ci() -def _running_on_ci(): +def _running_on_ci() -> bool: """Check if we're currently running on a CI system.""" env_vars = ["CI", "BUILD_NUMBER"] return any(var in os.environ for var in env_vars) -def _truncate_explanation(input_lines, max_lines=None, max_chars=None): +def _truncate_explanation( + input_lines: List[str], + max_lines: Optional[int] = None, + max_chars: Optional[int] = None, +) -> List[str]: """ Truncate given list of strings that makes up the assertion explanation. @@ -73,7 +84,7 @@ def _truncate_explanation(input_lines, max_lines=None, max_chars=None): return truncated_explanation -def _truncate_by_char_count(input_lines, max_chars): +def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]: # Check if truncation required if len("".join(input_lines)) <= max_chars: return input_lines diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 7bc853e82..212c631ef 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -952,7 +952,8 @@ class TestAssertionRewriteHookDetails: state = AssertionState(config, "rewrite") source_path = str(tmpdir.ensure("source.py")) pycpath = tmpdir.join("pyc").strpath - assert _write_pyc(state, [1], os.stat(source_path), pycpath) + co = compile("1", "f.py", "single") + assert _write_pyc(state, co, os.stat(source_path), pycpath) if sys.platform == "win32": from contextlib import contextmanager @@ -974,7 +975,7 @@ class TestAssertionRewriteHookDetails: monkeypatch.setattr("os.rename", raise_oserror) - assert not _write_pyc(state, [1], os.stat(source_path), pycpath) + assert not _write_pyc(state, co, os.stat(source_path), pycpath) def test_resources_provider_for_loader(self, testdir): """