Type annotate _pytest.assertion

This commit is contained in:
Ran Benita 2020-05-01 14:40:15 +03:00
parent 30e3d473c4
commit d95132178c
4 changed files with 120 additions and 65 deletions

View File

@ -46,7 +46,7 @@ def pytest_addoption(parser: Parser) -> None:
)
def register_assert_rewrite(*names) -> None:
def register_assert_rewrite(*names: str) -> None:
"""Register one or more module names to be rewritten on import.
This function will make sure that this module or all modules inside
@ -75,27 +75,27 @@ def register_assert_rewrite(*names) -> None:
class DummyRewriteHook:
"""A no-op import hook for when rewriting is disabled."""
def mark_rewrite(self, *names):
def mark_rewrite(self, *names: str) -> None:
pass
class AssertionState:
"""State for the assertion plugin."""
def __init__(self, config, mode):
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
def install_importhook(config):
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
"""Try to install the rewrite hook, raise SystemError if it fails."""
config._store[assertstate_key] = AssertionState(config, "rewrite")
config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook)
config._store[assertstate_key].trace("installed rewrite import hook")
def undo():
def undo() -> None:
hook = config._store[assertstate_key].hook
if hook is not None and hook in sys.meta_path:
sys.meta_path.remove(hook)

View File

@ -13,11 +13,15 @@ import struct
import sys
import tokenize
import types
from typing import Callable
from typing import Dict
from typing import IO
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union
from _pytest._io.saferepr import saferepr
from _pytest._version import version
@ -27,6 +31,8 @@ from _pytest.assertion.util import ( # noqa: F401
)
from _pytest.compat import fspath
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.main import Session
from _pytest.pathlib import fnmatch_ex
from _pytest.pathlib import Path
from _pytest.pathlib import PurePath
@ -48,13 +54,13 @@ PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
"""PEP302/PEP451 import hook which rewrites asserts."""
def __init__(self, config):
def __init__(self, config: Config) -> None:
self.config = config
try:
self.fnpats = config.getini("python_files")
except ValueError:
self.fnpats = ["test_*.py", "*_test.py"]
self.session = None
self.session = None # type: Optional[Session]
self._rewritten_names = set() # type: Set[str]
self._must_rewrite = set() # type: Set[str]
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
@ -64,14 +70,19 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
self._session_paths_checked = False
def set_session(self, session):
def set_session(self, session: Optional[Session]) -> None:
self.session = session
self._session_paths_checked = False
# Indirection so we can mock calls to find_spec originated from the hook during testing
_find_spec = importlib.machinery.PathFinder.find_spec
def find_spec(self, name, path=None, target=None):
def find_spec(
self,
name: str,
path: Optional[Sequence[Union[str, bytes]]] = None,
target: Optional[types.ModuleType] = None,
) -> Optional[importlib.machinery.ModuleSpec]:
if self._writing_pyc:
return None
state = self.config._store[assertstate_key]
@ -79,7 +90,8 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
return None
state.trace("find_module called for: %s" % name)
spec = self._find_spec(name, path)
# Type ignored because mypy is confused about the `self` binding here.
spec = self._find_spec(name, path) # type: ignore
if (
# the import machinery could not find a file to import
spec is None
@ -108,10 +120,14 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
submodule_search_locations=spec.submodule_search_locations,
)
def create_module(self, spec):
def create_module(
self, spec: importlib.machinery.ModuleSpec
) -> Optional[types.ModuleType]:
return None # default behaviour is fine
def exec_module(self, module):
def exec_module(self, module: types.ModuleType) -> None:
assert module.__spec__ is not None
assert module.__spec__.origin is not None
fn = Path(module.__spec__.origin)
state = self.config._store[assertstate_key]
@ -151,7 +167,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
state.trace("found cached rewritten pyc for {}".format(fn))
exec(co, module.__dict__)
def _early_rewrite_bailout(self, name, state):
def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
"""This is a fast way to get out of rewriting modules.
Profiling has shown that the call to PathFinder.find_spec (inside of
@ -190,7 +206,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
state.trace("early skip of rewriting module: {}".format(name))
return True
def _should_rewrite(self, name, fn, state):
def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
# always rewrite conftest files
if os.path.basename(fn) == "conftest.py":
state.trace("rewriting conftest file: {!r}".format(fn))
@ -213,7 +229,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
return self._is_marked_for_rewrite(name, state)
def _is_marked_for_rewrite(self, name: str, state):
def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
try:
return self._marked_for_rewrite_cache[name]
except KeyError:
@ -246,7 +262,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
self._must_rewrite.update(names)
self._marked_for_rewrite_cache.clear()
def _warn_already_imported(self, name):
def _warn_already_imported(self, name: str) -> None:
from _pytest.warning_types import PytestAssertRewriteWarning
from _pytest.warnings import _issue_warning_captured
@ -258,13 +274,15 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
stacklevel=5,
)
def get_data(self, pathname):
def get_data(self, pathname: Union[str, bytes]) -> bytes:
"""Optional PEP302 get_data API."""
with open(pathname, "rb") as f:
return f.read()
def _write_pyc_fp(fp, source_stat, co):
def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
) -> None:
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason deviate.
@ -280,7 +298,12 @@ def _write_pyc_fp(fp, source_stat, co):
if sys.platform == "win32":
from atomicwrites import atomic_write
def _write_pyc(state, co, source_stat, pyc):
def _write_pyc(
state: "AssertionState",
co: types.CodeType,
source_stat: os.stat_result,
pyc: Path,
) -> bool:
try:
with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp:
_write_pyc_fp(fp, source_stat, co)
@ -295,7 +318,12 @@ if sys.platform == "win32":
else:
def _write_pyc(state, co, source_stat, pyc):
def _write_pyc(
state: "AssertionState",
co: types.CodeType,
source_stat: os.stat_result,
pyc: Path,
) -> bool:
proc_pyc = "{}.{}".format(pyc, os.getpid())
try:
fp = open(proc_pyc, "wb")
@ -319,19 +347,21 @@ else:
return True
def _rewrite_test(fn, config):
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
"""read and rewrite *fn* and return the code object."""
fn = fspath(fn)
stat = os.stat(fn)
with open(fn, "rb") as f:
fn_ = fspath(fn)
stat = os.stat(fn_)
with open(fn_, "rb") as f:
source = f.read()
tree = ast.parse(source, filename=fn)
rewrite_asserts(tree, source, fn, config)
co = compile(tree, fn, "exec", dont_inherit=True)
tree = ast.parse(source, filename=fn_)
rewrite_asserts(tree, source, fn_, config)
co = compile(tree, fn_, "exec", dont_inherit=True)
return stat, co
def _read_pyc(source, pyc, trace=lambda x: None):
def _read_pyc(
source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
) -> Optional[types.CodeType]:
"""Possibly read a pytest pyc containing rewritten code.
Return rewritten code if successful or None if not.
@ -368,12 +398,17 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co
def rewrite_asserts(mod, source, module_path=None, config=None):
def rewrite_asserts(
mod: ast.Module,
source: bytes,
module_path: Optional[str] = None,
config: Optional[Config] = None,
) -> None:
"""Rewrite the assert statements in mod."""
AssertionRewriter(module_path, config, source).run(mod)
def _saferepr(obj):
def _saferepr(obj: object) -> str:
"""Get a safe repr of an object for assertion error messages.
The assertion formatting (util.format_explanation()) requires
@ -387,7 +422,7 @@ def _saferepr(obj):
return saferepr(obj).replace("\n", "\\n")
def _format_assertmsg(obj):
def _format_assertmsg(obj: object) -> str:
"""Format the custom assertion message given.
For strings this simply replaces newlines with '\n~' so that
@ -410,7 +445,7 @@ def _format_assertmsg(obj):
return obj
def _should_repr_global_name(obj):
def _should_repr_global_name(obj: object) -> bool:
if callable(obj):
return False
@ -420,7 +455,7 @@ def _should_repr_global_name(obj):
return True
def _format_boolop(explanations, is_or):
def _format_boolop(explanations, is_or: bool):
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
if isinstance(explanation, str):
return explanation.replace("%", "%%")
@ -428,8 +463,12 @@ def _format_boolop(explanations, is_or):
return explanation.replace(b"%", b"%%")
def _call_reprcompare(ops, results, expls, each_obj):
# type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
def _call_reprcompare(
ops: Sequence[str],
results: Sequence[bool],
expls: Sequence[str],
each_obj: Sequence[object],
) -> str:
for i, res, expl in zip(range(len(ops)), results, expls):
try:
done = not res
@ -607,7 +646,9 @@ class AssertionRewriter(ast.NodeVisitor):
"""
def __init__(self, module_path, config, source):
def __init__(
self, module_path: Optional[str], config: Optional[Config], source: bytes
) -> None:
super().__init__()
self.module_path = module_path
self.config = config
@ -620,7 +661,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.source = source
@functools.lru_cache(maxsize=1)
def _assert_expr_to_lineno(self):
def _assert_expr_to_lineno(self) -> Dict[int, str]:
return _get_assertion_exprs(self.source)
def run(self, mod: ast.Module) -> None:
@ -689,38 +730,38 @@ class AssertionRewriter(ast.NodeVisitor):
nodes.append(field)
@staticmethod
def is_rewrite_disabled(docstring):
def is_rewrite_disabled(docstring: str) -> bool:
return "PYTEST_DONT_REWRITE" in docstring
def variable(self):
def variable(self) -> str:
"""Get a new variable."""
# Use a character invalid in python identifiers to avoid clashing.
name = "@py_assert" + str(next(self.variable_counter))
self.variables.append(name)
return name
def assign(self, expr):
def assign(self, expr: ast.expr) -> ast.Name:
"""Give *expr* a name."""
name = self.variable()
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
return ast.Name(name, ast.Load())
def display(self, expr):
def display(self, expr: ast.expr) -> ast.expr:
"""Call saferepr on the expression."""
return self.helper("_saferepr", expr)
def helper(self, name, *args):
def helper(self, name: str, *args: ast.expr) -> ast.expr:
"""Call a helper in this module."""
py_name = ast.Name("@pytest_ar", ast.Load())
attr = ast.Attribute(py_name, name, ast.Load())
return ast.Call(attr, list(args), [])
def builtin(self, name):
def builtin(self, name: str) -> ast.Attribute:
"""Return the builtin called *name*."""
builtin_name = ast.Name("@py_builtins", ast.Load())
return ast.Attribute(builtin_name, name, ast.Load())
def explanation_param(self, expr):
def explanation_param(self, expr: ast.expr) -> str:
"""Return a new named %-formatting placeholder for expr.
This creates a %-formatting placeholder for expr in the
@ -733,7 +774,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.explanation_specifiers[specifier] = expr
return "%(" + specifier + ")s"
def push_format_context(self):
def push_format_context(self) -> None:
"""Create a new formatting context.
The format context is used for when an explanation wants to
@ -747,10 +788,10 @@ class AssertionRewriter(ast.NodeVisitor):
self.explanation_specifiers = {} # type: Dict[str, ast.expr]
self.stack.append(self.explanation_specifiers)
def pop_format_context(self, expl_expr):
def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
"""Format the %-formatted string with current format context.
The expl_expr should be an ast.Str instance constructed from
The expl_expr should be an str ast.expr instance constructed from
the %-placeholders created by .explanation_param(). This will
add the required code to format said string to .expl_stmts and
return the ast.Name instance of the formatted string.
@ -768,13 +809,13 @@ class AssertionRewriter(ast.NodeVisitor):
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load())
def generic_visit(self, node):
def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
"""Handle expressions we don't have custom code for."""
assert isinstance(node, ast.expr)
res = self.assign(node)
return res, self.explanation_param(self.display(res))
def visit_Assert(self, assert_):
def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
"""Return the AST statements to replace the ast.Assert instance.
This rewrites the test of an assertion to provide
@ -787,6 +828,8 @@ class AssertionRewriter(ast.NodeVisitor):
from _pytest.warning_types import PytestAssertRewriteWarning
import warnings
# TODO: This assert should not be needed.
assert self.module_path is not None
warnings.warn_explicit(
PytestAssertRewriteWarning(
"assertion is always true, perhaps remove parentheses?"
@ -889,7 +932,7 @@ class AssertionRewriter(ast.NodeVisitor):
set_location(stmt, assert_.lineno, assert_.col_offset)
return self.statements
def visit_Name(self, name):
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
# Display the repr of the name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
@ -899,7 +942,7 @@ class AssertionRewriter(ast.NodeVisitor):
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
return name, self.explanation_param(expr)
def visit_BoolOp(self, boolop):
def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
res_var = self.variable()
expl_list = self.assign(ast.List([], ast.Load()))
app = ast.Attribute(expl_list, "append", ast.Load())
@ -934,13 +977,13 @@ class AssertionRewriter(ast.NodeVisitor):
expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
def visit_UnaryOp(self, unary):
def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
pattern = UNARY_MAP[unary.op.__class__]
operand_res, operand_expl = self.visit(unary.operand)
res = self.assign(ast.UnaryOp(unary.op, operand_res))
return res, pattern % (operand_expl,)
def visit_BinOp(self, binop):
def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
symbol = BINOP_MAP[binop.op.__class__]
left_expr, left_expl = self.visit(binop.left)
right_expr, right_expl = self.visit(binop.right)
@ -948,7 +991,7 @@ class AssertionRewriter(ast.NodeVisitor):
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
return res, explanation
def visit_Call(self, call):
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
"""
visit `ast.Call` nodes
"""
@ -975,13 +1018,13 @@ class AssertionRewriter(ast.NodeVisitor):
outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
return res, outer_expl
def visit_Starred(self, starred):
def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
# From Python 3.5, a Starred node can appear in a function call
res, expl = self.visit(starred.value)
new_starred = ast.Starred(res, starred.ctx)
return new_starred, "*" + expl
def visit_Attribute(self, attr):
def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
value, value_expl = self.visit(attr.value)
@ -991,7 +1034,7 @@ class AssertionRewriter(ast.NodeVisitor):
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
return res, expl
def visit_Compare(self, comp: ast.Compare):
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context()
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
@ -1030,7 +1073,7 @@ class AssertionRewriter(ast.NodeVisitor):
return res, self.explanation_param(self.pop_format_context(expl_call))
def try_makedirs(cache_dir) -> bool:
def try_makedirs(cache_dir: Path) -> bool:
"""Attempts to create the given directory and sub-directories exist, returns True if
successful or it already exists"""
try:

View File

@ -5,13 +5,20 @@ Current default behaviour is to truncate assertion explanations at
~8 terminal lines, unless running in "-vv" mode or running on CI.
"""
import os
from typing import List
from typing import Optional
from _pytest.nodes import Item
DEFAULT_MAX_LINES = 8
DEFAULT_MAX_CHARS = 8 * 80
USAGE_MSG = "use '-vv' to show"
def truncate_if_required(explanation, item, max_length=None):
def truncate_if_required(
explanation: List[str], item: Item, max_length: Optional[int] = None
) -> List[str]:
"""
Truncate this assertion explanation if the given test item is eligible.
"""
@ -20,7 +27,7 @@ def truncate_if_required(explanation, item, max_length=None):
return explanation
def _should_truncate_item(item):
def _should_truncate_item(item: Item) -> bool:
"""
Whether or not this test item is eligible for truncation.
"""
@ -28,13 +35,17 @@ def _should_truncate_item(item):
return verbose < 2 and not _running_on_ci()
def _running_on_ci():
def _running_on_ci() -> bool:
"""Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"]
return any(var in os.environ for var in env_vars)
def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
def _truncate_explanation(
input_lines: List[str],
max_lines: Optional[int] = None,
max_chars: Optional[int] = None,
) -> List[str]:
"""
Truncate given list of strings that makes up the assertion explanation.
@ -73,7 +84,7 @@ def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
return truncated_explanation
def _truncate_by_char_count(input_lines, max_chars):
def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
# Check if truncation required
if len("".join(input_lines)) <= max_chars:
return input_lines

View File

@ -952,7 +952,8 @@ class TestAssertionRewriteHookDetails:
state = AssertionState(config, "rewrite")
source_path = str(tmpdir.ensure("source.py"))
pycpath = tmpdir.join("pyc").strpath
assert _write_pyc(state, [1], os.stat(source_path), pycpath)
co = compile("1", "f.py", "single")
assert _write_pyc(state, co, os.stat(source_path), pycpath)
if sys.platform == "win32":
from contextlib import contextmanager
@ -974,7 +975,7 @@ class TestAssertionRewriteHookDetails:
monkeypatch.setattr("os.rename", raise_oserror)
assert not _write_pyc(state, [1], os.stat(source_path), pycpath)
assert not _write_pyc(state, co, os.stat(source_path), pycpath)
def test_resources_provider_for_loader(self, testdir):
"""