Type annotate _pytest.assertion

This commit is contained in:
Ran Benita 2020-05-01 14:40:15 +03:00
parent 30e3d473c4
commit d95132178c
4 changed files with 120 additions and 65 deletions

View File

@ -46,7 +46,7 @@ def pytest_addoption(parser: Parser) -> None:
) )
def register_assert_rewrite(*names) -> None: def register_assert_rewrite(*names: str) -> None:
"""Register one or more module names to be rewritten on import. """Register one or more module names to be rewritten on import.
This function will make sure that this module or all modules inside This function will make sure that this module or all modules inside
@ -75,27 +75,27 @@ def register_assert_rewrite(*names) -> None:
class DummyRewriteHook: class DummyRewriteHook:
"""A no-op import hook for when rewriting is disabled.""" """A no-op import hook for when rewriting is disabled."""
def mark_rewrite(self, *names): def mark_rewrite(self, *names: str) -> None:
pass pass
class AssertionState: class AssertionState:
"""State for the assertion plugin.""" """State for the assertion plugin."""
def __init__(self, config, mode): def __init__(self, config: Config, mode) -> None:
self.mode = mode self.mode = mode
self.trace = config.trace.root.get("assertion") self.trace = config.trace.root.get("assertion")
self.hook = None # type: Optional[rewrite.AssertionRewritingHook] self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
def install_importhook(config): def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
"""Try to install the rewrite hook, raise SystemError if it fails.""" """Try to install the rewrite hook, raise SystemError if it fails."""
config._store[assertstate_key] = AssertionState(config, "rewrite") config._store[assertstate_key] = AssertionState(config, "rewrite")
config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config) config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook) sys.meta_path.insert(0, hook)
config._store[assertstate_key].trace("installed rewrite import hook") config._store[assertstate_key].trace("installed rewrite import hook")
def undo(): def undo() -> None:
hook = config._store[assertstate_key].hook hook = config._store[assertstate_key].hook
if hook is not None and hook in sys.meta_path: if hook is not None and hook in sys.meta_path:
sys.meta_path.remove(hook) sys.meta_path.remove(hook)

View File

@ -13,11 +13,15 @@ import struct
import sys import sys
import tokenize import tokenize
import types import types
from typing import Callable
from typing import Dict from typing import Dict
from typing import IO
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Sequence
from typing import Set from typing import Set
from typing import Tuple from typing import Tuple
from typing import Union
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest._version import version from _pytest._version import version
@ -27,6 +31,8 @@ from _pytest.assertion.util import ( # noqa: F401
) )
from _pytest.compat import fspath from _pytest.compat import fspath
from _pytest.compat import TYPE_CHECKING from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.main import Session
from _pytest.pathlib import fnmatch_ex from _pytest.pathlib import fnmatch_ex
from _pytest.pathlib import Path from _pytest.pathlib import Path
from _pytest.pathlib import PurePath from _pytest.pathlib import PurePath
@ -48,13 +54,13 @@ PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
"""PEP302/PEP451 import hook which rewrites asserts.""" """PEP302/PEP451 import hook which rewrites asserts."""
def __init__(self, config): def __init__(self, config: Config) -> None:
self.config = config self.config = config
try: try:
self.fnpats = config.getini("python_files") self.fnpats = config.getini("python_files")
except ValueError: except ValueError:
self.fnpats = ["test_*.py", "*_test.py"] self.fnpats = ["test_*.py", "*_test.py"]
self.session = None self.session = None # type: Optional[Session]
self._rewritten_names = set() # type: Set[str] self._rewritten_names = set() # type: Set[str]
self._must_rewrite = set() # type: Set[str] self._must_rewrite = set() # type: Set[str]
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
@ -64,14 +70,19 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
self._marked_for_rewrite_cache = {} # type: Dict[str, bool] self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
self._session_paths_checked = False self._session_paths_checked = False
def set_session(self, session): def set_session(self, session: Optional[Session]) -> None:
self.session = session self.session = session
self._session_paths_checked = False self._session_paths_checked = False
# Indirection so we can mock calls to find_spec originated from the hook during testing # Indirection so we can mock calls to find_spec originated from the hook during testing
_find_spec = importlib.machinery.PathFinder.find_spec _find_spec = importlib.machinery.PathFinder.find_spec
def find_spec(self, name, path=None, target=None): def find_spec(
self,
name: str,
path: Optional[Sequence[Union[str, bytes]]] = None,
target: Optional[types.ModuleType] = None,
) -> Optional[importlib.machinery.ModuleSpec]:
if self._writing_pyc: if self._writing_pyc:
return None return None
state = self.config._store[assertstate_key] state = self.config._store[assertstate_key]
@ -79,7 +90,8 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
return None return None
state.trace("find_module called for: %s" % name) state.trace("find_module called for: %s" % name)
spec = self._find_spec(name, path) # Type ignored because mypy is confused about the `self` binding here.
spec = self._find_spec(name, path) # type: ignore
if ( if (
# the import machinery could not find a file to import # the import machinery could not find a file to import
spec is None spec is None
@ -108,10 +120,14 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
submodule_search_locations=spec.submodule_search_locations, submodule_search_locations=spec.submodule_search_locations,
) )
def create_module(self, spec): def create_module(
self, spec: importlib.machinery.ModuleSpec
) -> Optional[types.ModuleType]:
return None # default behaviour is fine return None # default behaviour is fine
def exec_module(self, module): def exec_module(self, module: types.ModuleType) -> None:
assert module.__spec__ is not None
assert module.__spec__.origin is not None
fn = Path(module.__spec__.origin) fn = Path(module.__spec__.origin)
state = self.config._store[assertstate_key] state = self.config._store[assertstate_key]
@ -151,7 +167,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
state.trace("found cached rewritten pyc for {}".format(fn)) state.trace("found cached rewritten pyc for {}".format(fn))
exec(co, module.__dict__) exec(co, module.__dict__)
def _early_rewrite_bailout(self, name, state): def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
"""This is a fast way to get out of rewriting modules. """This is a fast way to get out of rewriting modules.
Profiling has shown that the call to PathFinder.find_spec (inside of Profiling has shown that the call to PathFinder.find_spec (inside of
@ -190,7 +206,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
state.trace("early skip of rewriting module: {}".format(name)) state.trace("early skip of rewriting module: {}".format(name))
return True return True
def _should_rewrite(self, name, fn, state): def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
# always rewrite conftest files # always rewrite conftest files
if os.path.basename(fn) == "conftest.py": if os.path.basename(fn) == "conftest.py":
state.trace("rewriting conftest file: {!r}".format(fn)) state.trace("rewriting conftest file: {!r}".format(fn))
@ -213,7 +229,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
return self._is_marked_for_rewrite(name, state) return self._is_marked_for_rewrite(name, state)
def _is_marked_for_rewrite(self, name: str, state): def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
try: try:
return self._marked_for_rewrite_cache[name] return self._marked_for_rewrite_cache[name]
except KeyError: except KeyError:
@ -246,7 +262,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
self._must_rewrite.update(names) self._must_rewrite.update(names)
self._marked_for_rewrite_cache.clear() self._marked_for_rewrite_cache.clear()
def _warn_already_imported(self, name): def _warn_already_imported(self, name: str) -> None:
from _pytest.warning_types import PytestAssertRewriteWarning from _pytest.warning_types import PytestAssertRewriteWarning
from _pytest.warnings import _issue_warning_captured from _pytest.warnings import _issue_warning_captured
@ -258,13 +274,15 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
stacklevel=5, stacklevel=5,
) )
def get_data(self, pathname): def get_data(self, pathname: Union[str, bytes]) -> bytes:
"""Optional PEP302 get_data API.""" """Optional PEP302 get_data API."""
with open(pathname, "rb") as f: with open(pathname, "rb") as f:
return f.read() return f.read()
def _write_pyc_fp(fp, source_stat, co): def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
) -> None:
# Technically, we don't have to have the same pyc format as # Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin # (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason deviate. # import. However, there's little reason deviate.
@ -280,7 +298,12 @@ def _write_pyc_fp(fp, source_stat, co):
if sys.platform == "win32": if sys.platform == "win32":
from atomicwrites import atomic_write from atomicwrites import atomic_write
def _write_pyc(state, co, source_stat, pyc): def _write_pyc(
state: "AssertionState",
co: types.CodeType,
source_stat: os.stat_result,
pyc: Path,
) -> bool:
try: try:
with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp: with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp:
_write_pyc_fp(fp, source_stat, co) _write_pyc_fp(fp, source_stat, co)
@ -295,7 +318,12 @@ if sys.platform == "win32":
else: else:
def _write_pyc(state, co, source_stat, pyc): def _write_pyc(
state: "AssertionState",
co: types.CodeType,
source_stat: os.stat_result,
pyc: Path,
) -> bool:
proc_pyc = "{}.{}".format(pyc, os.getpid()) proc_pyc = "{}.{}".format(pyc, os.getpid())
try: try:
fp = open(proc_pyc, "wb") fp = open(proc_pyc, "wb")
@ -319,19 +347,21 @@ else:
return True return True
def _rewrite_test(fn, config): def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
"""read and rewrite *fn* and return the code object.""" """read and rewrite *fn* and return the code object."""
fn = fspath(fn) fn_ = fspath(fn)
stat = os.stat(fn) stat = os.stat(fn_)
with open(fn, "rb") as f: with open(fn_, "rb") as f:
source = f.read() source = f.read()
tree = ast.parse(source, filename=fn) tree = ast.parse(source, filename=fn_)
rewrite_asserts(tree, source, fn, config) rewrite_asserts(tree, source, fn_, config)
co = compile(tree, fn, "exec", dont_inherit=True) co = compile(tree, fn_, "exec", dont_inherit=True)
return stat, co return stat, co
def _read_pyc(source, pyc, trace=lambda x: None): def _read_pyc(
source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
) -> Optional[types.CodeType]:
"""Possibly read a pytest pyc containing rewritten code. """Possibly read a pytest pyc containing rewritten code.
Return rewritten code if successful or None if not. Return rewritten code if successful or None if not.
@ -368,12 +398,17 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co return co
def rewrite_asserts(mod, source, module_path=None, config=None): def rewrite_asserts(
mod: ast.Module,
source: bytes,
module_path: Optional[str] = None,
config: Optional[Config] = None,
) -> None:
"""Rewrite the assert statements in mod.""" """Rewrite the assert statements in mod."""
AssertionRewriter(module_path, config, source).run(mod) AssertionRewriter(module_path, config, source).run(mod)
def _saferepr(obj): def _saferepr(obj: object) -> str:
"""Get a safe repr of an object for assertion error messages. """Get a safe repr of an object for assertion error messages.
The assertion formatting (util.format_explanation()) requires The assertion formatting (util.format_explanation()) requires
@ -387,7 +422,7 @@ def _saferepr(obj):
return saferepr(obj).replace("\n", "\\n") return saferepr(obj).replace("\n", "\\n")
def _format_assertmsg(obj): def _format_assertmsg(obj: object) -> str:
"""Format the custom assertion message given. """Format the custom assertion message given.
For strings this simply replaces newlines with '\n~' so that For strings this simply replaces newlines with '\n~' so that
@ -410,7 +445,7 @@ def _format_assertmsg(obj):
return obj return obj
def _should_repr_global_name(obj): def _should_repr_global_name(obj: object) -> bool:
if callable(obj): if callable(obj):
return False return False
@ -420,7 +455,7 @@ def _should_repr_global_name(obj):
return True return True
def _format_boolop(explanations, is_or): def _format_boolop(explanations, is_or: bool):
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")" explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
if isinstance(explanation, str): if isinstance(explanation, str):
return explanation.replace("%", "%%") return explanation.replace("%", "%%")
@ -428,8 +463,12 @@ def _format_boolop(explanations, is_or):
return explanation.replace(b"%", b"%%") return explanation.replace(b"%", b"%%")
def _call_reprcompare(ops, results, expls, each_obj): def _call_reprcompare(
# type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str ops: Sequence[str],
results: Sequence[bool],
expls: Sequence[str],
each_obj: Sequence[object],
) -> str:
for i, res, expl in zip(range(len(ops)), results, expls): for i, res, expl in zip(range(len(ops)), results, expls):
try: try:
done = not res done = not res
@ -607,7 +646,9 @@ class AssertionRewriter(ast.NodeVisitor):
""" """
def __init__(self, module_path, config, source): def __init__(
self, module_path: Optional[str], config: Optional[Config], source: bytes
) -> None:
super().__init__() super().__init__()
self.module_path = module_path self.module_path = module_path
self.config = config self.config = config
@ -620,7 +661,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.source = source self.source = source
@functools.lru_cache(maxsize=1) @functools.lru_cache(maxsize=1)
def _assert_expr_to_lineno(self): def _assert_expr_to_lineno(self) -> Dict[int, str]:
return _get_assertion_exprs(self.source) return _get_assertion_exprs(self.source)
def run(self, mod: ast.Module) -> None: def run(self, mod: ast.Module) -> None:
@ -689,38 +730,38 @@ class AssertionRewriter(ast.NodeVisitor):
nodes.append(field) nodes.append(field)
@staticmethod @staticmethod
def is_rewrite_disabled(docstring): def is_rewrite_disabled(docstring: str) -> bool:
return "PYTEST_DONT_REWRITE" in docstring return "PYTEST_DONT_REWRITE" in docstring
def variable(self): def variable(self) -> str:
"""Get a new variable.""" """Get a new variable."""
# Use a character invalid in python identifiers to avoid clashing. # Use a character invalid in python identifiers to avoid clashing.
name = "@py_assert" + str(next(self.variable_counter)) name = "@py_assert" + str(next(self.variable_counter))
self.variables.append(name) self.variables.append(name)
return name return name
def assign(self, expr): def assign(self, expr: ast.expr) -> ast.Name:
"""Give *expr* a name.""" """Give *expr* a name."""
name = self.variable() name = self.variable()
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
return ast.Name(name, ast.Load()) return ast.Name(name, ast.Load())
def display(self, expr): def display(self, expr: ast.expr) -> ast.expr:
"""Call saferepr on the expression.""" """Call saferepr on the expression."""
return self.helper("_saferepr", expr) return self.helper("_saferepr", expr)
def helper(self, name, *args): def helper(self, name: str, *args: ast.expr) -> ast.expr:
"""Call a helper in this module.""" """Call a helper in this module."""
py_name = ast.Name("@pytest_ar", ast.Load()) py_name = ast.Name("@pytest_ar", ast.Load())
attr = ast.Attribute(py_name, name, ast.Load()) attr = ast.Attribute(py_name, name, ast.Load())
return ast.Call(attr, list(args), []) return ast.Call(attr, list(args), [])
def builtin(self, name): def builtin(self, name: str) -> ast.Attribute:
"""Return the builtin called *name*.""" """Return the builtin called *name*."""
builtin_name = ast.Name("@py_builtins", ast.Load()) builtin_name = ast.Name("@py_builtins", ast.Load())
return ast.Attribute(builtin_name, name, ast.Load()) return ast.Attribute(builtin_name, name, ast.Load())
def explanation_param(self, expr): def explanation_param(self, expr: ast.expr) -> str:
"""Return a new named %-formatting placeholder for expr. """Return a new named %-formatting placeholder for expr.
This creates a %-formatting placeholder for expr in the This creates a %-formatting placeholder for expr in the
@ -733,7 +774,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.explanation_specifiers[specifier] = expr self.explanation_specifiers[specifier] = expr
return "%(" + specifier + ")s" return "%(" + specifier + ")s"
def push_format_context(self): def push_format_context(self) -> None:
"""Create a new formatting context. """Create a new formatting context.
The format context is used for when an explanation wants to The format context is used for when an explanation wants to
@ -747,10 +788,10 @@ class AssertionRewriter(ast.NodeVisitor):
self.explanation_specifiers = {} # type: Dict[str, ast.expr] self.explanation_specifiers = {} # type: Dict[str, ast.expr]
self.stack.append(self.explanation_specifiers) self.stack.append(self.explanation_specifiers)
def pop_format_context(self, expl_expr): def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
"""Format the %-formatted string with current format context. """Format the %-formatted string with current format context.
The expl_expr should be an ast.Str instance constructed from The expl_expr should be an str ast.expr instance constructed from
the %-placeholders created by .explanation_param(). This will the %-placeholders created by .explanation_param(). This will
add the required code to format said string to .expl_stmts and add the required code to format said string to .expl_stmts and
return the ast.Name instance of the formatted string. return the ast.Name instance of the formatted string.
@ -768,13 +809,13 @@ class AssertionRewriter(ast.NodeVisitor):
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load()) return ast.Name(name, ast.Load())
def generic_visit(self, node): def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
"""Handle expressions we don't have custom code for.""" """Handle expressions we don't have custom code for."""
assert isinstance(node, ast.expr) assert isinstance(node, ast.expr)
res = self.assign(node) res = self.assign(node)
return res, self.explanation_param(self.display(res)) return res, self.explanation_param(self.display(res))
def visit_Assert(self, assert_): def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
"""Return the AST statements to replace the ast.Assert instance. """Return the AST statements to replace the ast.Assert instance.
This rewrites the test of an assertion to provide This rewrites the test of an assertion to provide
@ -787,6 +828,8 @@ class AssertionRewriter(ast.NodeVisitor):
from _pytest.warning_types import PytestAssertRewriteWarning from _pytest.warning_types import PytestAssertRewriteWarning
import warnings import warnings
# TODO: This assert should not be needed.
assert self.module_path is not None
warnings.warn_explicit( warnings.warn_explicit(
PytestAssertRewriteWarning( PytestAssertRewriteWarning(
"assertion is always true, perhaps remove parentheses?" "assertion is always true, perhaps remove parentheses?"
@ -889,7 +932,7 @@ class AssertionRewriter(ast.NodeVisitor):
set_location(stmt, assert_.lineno, assert_.col_offset) set_location(stmt, assert_.lineno, assert_.col_offset)
return self.statements return self.statements
def visit_Name(self, name): def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
# Display the repr of the name if it's a local variable or # Display the repr of the name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable. # _should_repr_global_name() thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], []) locs = ast.Call(self.builtin("locals"), [], [])
@ -899,7 +942,7 @@ class AssertionRewriter(ast.NodeVisitor):
expr = ast.IfExp(test, self.display(name), ast.Str(name.id)) expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
return name, self.explanation_param(expr) return name, self.explanation_param(expr)
def visit_BoolOp(self, boolop): def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
res_var = self.variable() res_var = self.variable()
expl_list = self.assign(ast.List([], ast.Load())) expl_list = self.assign(ast.List([], ast.Load()))
app = ast.Attribute(expl_list, "append", ast.Load()) app = ast.Attribute(expl_list, "append", ast.Load())
@ -934,13 +977,13 @@ class AssertionRewriter(ast.NodeVisitor):
expl = self.pop_format_context(expl_template) expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl) return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
def visit_UnaryOp(self, unary): def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
pattern = UNARY_MAP[unary.op.__class__] pattern = UNARY_MAP[unary.op.__class__]
operand_res, operand_expl = self.visit(unary.operand) operand_res, operand_expl = self.visit(unary.operand)
res = self.assign(ast.UnaryOp(unary.op, operand_res)) res = self.assign(ast.UnaryOp(unary.op, operand_res))
return res, pattern % (operand_expl,) return res, pattern % (operand_expl,)
def visit_BinOp(self, binop): def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
symbol = BINOP_MAP[binop.op.__class__] symbol = BINOP_MAP[binop.op.__class__]
left_expr, left_expl = self.visit(binop.left) left_expr, left_expl = self.visit(binop.left)
right_expr, right_expl = self.visit(binop.right) right_expr, right_expl = self.visit(binop.right)
@ -948,7 +991,7 @@ class AssertionRewriter(ast.NodeVisitor):
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
return res, explanation return res, explanation
def visit_Call(self, call): def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
""" """
visit `ast.Call` nodes visit `ast.Call` nodes
""" """
@ -975,13 +1018,13 @@ class AssertionRewriter(ast.NodeVisitor):
outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl) outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
return res, outer_expl return res, outer_expl
def visit_Starred(self, starred): def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
# From Python 3.5, a Starred node can appear in a function call # From Python 3.5, a Starred node can appear in a function call
res, expl = self.visit(starred.value) res, expl = self.visit(starred.value)
new_starred = ast.Starred(res, starred.ctx) new_starred = ast.Starred(res, starred.ctx)
return new_starred, "*" + expl return new_starred, "*" + expl
def visit_Attribute(self, attr): def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
if not isinstance(attr.ctx, ast.Load): if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr) return self.generic_visit(attr)
value, value_expl = self.visit(attr.value) value, value_expl = self.visit(attr.value)
@ -991,7 +1034,7 @@ class AssertionRewriter(ast.NodeVisitor):
expl = pat % (res_expl, res_expl, value_expl, attr.attr) expl = pat % (res_expl, res_expl, value_expl, attr.attr)
return res, expl return res, expl
def visit_Compare(self, comp: ast.Compare): def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context() self.push_format_context()
left_res, left_expl = self.visit(comp.left) left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)): if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
@ -1030,7 +1073,7 @@ class AssertionRewriter(ast.NodeVisitor):
return res, self.explanation_param(self.pop_format_context(expl_call)) return res, self.explanation_param(self.pop_format_context(expl_call))
def try_makedirs(cache_dir) -> bool: def try_makedirs(cache_dir: Path) -> bool:
"""Attempts to create the given directory and sub-directories exist, returns True if """Attempts to create the given directory and sub-directories exist, returns True if
successful or it already exists""" successful or it already exists"""
try: try:

View File

@ -5,13 +5,20 @@ Current default behaviour is to truncate assertion explanations at
~8 terminal lines, unless running in "-vv" mode or running on CI. ~8 terminal lines, unless running in "-vv" mode or running on CI.
""" """
import os import os
from typing import List
from typing import Optional
from _pytest.nodes import Item
DEFAULT_MAX_LINES = 8 DEFAULT_MAX_LINES = 8
DEFAULT_MAX_CHARS = 8 * 80 DEFAULT_MAX_CHARS = 8 * 80
USAGE_MSG = "use '-vv' to show" USAGE_MSG = "use '-vv' to show"
def truncate_if_required(explanation, item, max_length=None): def truncate_if_required(
explanation: List[str], item: Item, max_length: Optional[int] = None
) -> List[str]:
""" """
Truncate this assertion explanation if the given test item is eligible. Truncate this assertion explanation if the given test item is eligible.
""" """
@ -20,7 +27,7 @@ def truncate_if_required(explanation, item, max_length=None):
return explanation return explanation
def _should_truncate_item(item): def _should_truncate_item(item: Item) -> bool:
""" """
Whether or not this test item is eligible for truncation. Whether or not this test item is eligible for truncation.
""" """
@ -28,13 +35,17 @@ def _should_truncate_item(item):
return verbose < 2 and not _running_on_ci() return verbose < 2 and not _running_on_ci()
def _running_on_ci(): def _running_on_ci() -> bool:
"""Check if we're currently running on a CI system.""" """Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"] env_vars = ["CI", "BUILD_NUMBER"]
return any(var in os.environ for var in env_vars) return any(var in os.environ for var in env_vars)
def _truncate_explanation(input_lines, max_lines=None, max_chars=None): def _truncate_explanation(
input_lines: List[str],
max_lines: Optional[int] = None,
max_chars: Optional[int] = None,
) -> List[str]:
""" """
Truncate given list of strings that makes up the assertion explanation. Truncate given list of strings that makes up the assertion explanation.
@ -73,7 +84,7 @@ def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
return truncated_explanation return truncated_explanation
def _truncate_by_char_count(input_lines, max_chars): def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
# Check if truncation required # Check if truncation required
if len("".join(input_lines)) <= max_chars: if len("".join(input_lines)) <= max_chars:
return input_lines return input_lines

View File

@ -952,7 +952,8 @@ class TestAssertionRewriteHookDetails:
state = AssertionState(config, "rewrite") state = AssertionState(config, "rewrite")
source_path = str(tmpdir.ensure("source.py")) source_path = str(tmpdir.ensure("source.py"))
pycpath = tmpdir.join("pyc").strpath pycpath = tmpdir.join("pyc").strpath
assert _write_pyc(state, [1], os.stat(source_path), pycpath) co = compile("1", "f.py", "single")
assert _write_pyc(state, co, os.stat(source_path), pycpath)
if sys.platform == "win32": if sys.platform == "win32":
from contextlib import contextmanager from contextlib import contextmanager
@ -974,7 +975,7 @@ class TestAssertionRewriteHookDetails:
monkeypatch.setattr("os.rename", raise_oserror) monkeypatch.setattr("os.rename", raise_oserror)
assert not _write_pyc(state, [1], os.stat(source_path), pycpath) assert not _write_pyc(state, co, os.stat(source_path), pycpath)
def test_resources_provider_for_loader(self, testdir): def test_resources_provider_for_loader(self, testdir):
""" """