Type annotate _pytest.assertion
This commit is contained in:
parent
30e3d473c4
commit
d95132178c
|
@ -46,7 +46,7 @@ def pytest_addoption(parser: Parser) -> None:
|
|||
)
|
||||
|
||||
|
||||
def register_assert_rewrite(*names) -> None:
|
||||
def register_assert_rewrite(*names: str) -> None:
|
||||
"""Register one or more module names to be rewritten on import.
|
||||
|
||||
This function will make sure that this module or all modules inside
|
||||
|
@ -75,27 +75,27 @@ def register_assert_rewrite(*names) -> None:
|
|||
class DummyRewriteHook:
|
||||
"""A no-op import hook for when rewriting is disabled."""
|
||||
|
||||
def mark_rewrite(self, *names):
|
||||
def mark_rewrite(self, *names: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AssertionState:
|
||||
"""State for the assertion plugin."""
|
||||
|
||||
def __init__(self, config, mode):
|
||||
def __init__(self, config: Config, mode) -> None:
|
||||
self.mode = mode
|
||||
self.trace = config.trace.root.get("assertion")
|
||||
self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
|
||||
|
||||
|
||||
def install_importhook(config):
|
||||
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
|
||||
"""Try to install the rewrite hook, raise SystemError if it fails."""
|
||||
config._store[assertstate_key] = AssertionState(config, "rewrite")
|
||||
config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
|
||||
sys.meta_path.insert(0, hook)
|
||||
config._store[assertstate_key].trace("installed rewrite import hook")
|
||||
|
||||
def undo():
|
||||
def undo() -> None:
|
||||
hook = config._store[assertstate_key].hook
|
||||
if hook is not None and hook in sys.meta_path:
|
||||
sys.meta_path.remove(hook)
|
||||
|
|
|
@ -13,11 +13,15 @@ import struct
|
|||
import sys
|
||||
import tokenize
|
||||
import types
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import IO
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from _pytest._io.saferepr import saferepr
|
||||
from _pytest._version import version
|
||||
|
@ -27,6 +31,8 @@ from _pytest.assertion.util import ( # noqa: F401
|
|||
)
|
||||
from _pytest.compat import fspath
|
||||
from _pytest.compat import TYPE_CHECKING
|
||||
from _pytest.config import Config
|
||||
from _pytest.main import Session
|
||||
from _pytest.pathlib import fnmatch_ex
|
||||
from _pytest.pathlib import Path
|
||||
from _pytest.pathlib import PurePath
|
||||
|
@ -48,13 +54,13 @@ PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
|
|||
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
|
||||
"""PEP302/PEP451 import hook which rewrites asserts."""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = config
|
||||
try:
|
||||
self.fnpats = config.getini("python_files")
|
||||
except ValueError:
|
||||
self.fnpats = ["test_*.py", "*_test.py"]
|
||||
self.session = None
|
||||
self.session = None # type: Optional[Session]
|
||||
self._rewritten_names = set() # type: Set[str]
|
||||
self._must_rewrite = set() # type: Set[str]
|
||||
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
|
||||
|
@ -64,14 +70,19 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
|||
self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
|
||||
self._session_paths_checked = False
|
||||
|
||||
def set_session(self, session):
|
||||
def set_session(self, session: Optional[Session]) -> None:
|
||||
self.session = session
|
||||
self._session_paths_checked = False
|
||||
|
||||
# Indirection so we can mock calls to find_spec originated from the hook during testing
|
||||
_find_spec = importlib.machinery.PathFinder.find_spec
|
||||
|
||||
def find_spec(self, name, path=None, target=None):
|
||||
def find_spec(
|
||||
self,
|
||||
name: str,
|
||||
path: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
target: Optional[types.ModuleType] = None,
|
||||
) -> Optional[importlib.machinery.ModuleSpec]:
|
||||
if self._writing_pyc:
|
||||
return None
|
||||
state = self.config._store[assertstate_key]
|
||||
|
@ -79,7 +90,8 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
|||
return None
|
||||
state.trace("find_module called for: %s" % name)
|
||||
|
||||
spec = self._find_spec(name, path)
|
||||
# Type ignored because mypy is confused about the `self` binding here.
|
||||
spec = self._find_spec(name, path) # type: ignore
|
||||
if (
|
||||
# the import machinery could not find a file to import
|
||||
spec is None
|
||||
|
@ -108,10 +120,14 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
|||
submodule_search_locations=spec.submodule_search_locations,
|
||||
)
|
||||
|
||||
def create_module(self, spec):
|
||||
def create_module(
|
||||
self, spec: importlib.machinery.ModuleSpec
|
||||
) -> Optional[types.ModuleType]:
|
||||
return None # default behaviour is fine
|
||||
|
||||
def exec_module(self, module):
|
||||
def exec_module(self, module: types.ModuleType) -> None:
|
||||
assert module.__spec__ is not None
|
||||
assert module.__spec__.origin is not None
|
||||
fn = Path(module.__spec__.origin)
|
||||
state = self.config._store[assertstate_key]
|
||||
|
||||
|
@ -151,7 +167,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
|||
state.trace("found cached rewritten pyc for {}".format(fn))
|
||||
exec(co, module.__dict__)
|
||||
|
||||
def _early_rewrite_bailout(self, name, state):
|
||||
def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
|
||||
"""This is a fast way to get out of rewriting modules.
|
||||
|
||||
Profiling has shown that the call to PathFinder.find_spec (inside of
|
||||
|
@ -190,7 +206,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
|||
state.trace("early skip of rewriting module: {}".format(name))
|
||||
return True
|
||||
|
||||
def _should_rewrite(self, name, fn, state):
|
||||
def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
|
||||
# always rewrite conftest files
|
||||
if os.path.basename(fn) == "conftest.py":
|
||||
state.trace("rewriting conftest file: {!r}".format(fn))
|
||||
|
@ -213,7 +229,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
|||
|
||||
return self._is_marked_for_rewrite(name, state)
|
||||
|
||||
def _is_marked_for_rewrite(self, name: str, state):
|
||||
def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
|
||||
try:
|
||||
return self._marked_for_rewrite_cache[name]
|
||||
except KeyError:
|
||||
|
@ -246,7 +262,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
|||
self._must_rewrite.update(names)
|
||||
self._marked_for_rewrite_cache.clear()
|
||||
|
||||
def _warn_already_imported(self, name):
|
||||
def _warn_already_imported(self, name: str) -> None:
|
||||
from _pytest.warning_types import PytestAssertRewriteWarning
|
||||
from _pytest.warnings import _issue_warning_captured
|
||||
|
||||
|
@ -258,13 +274,15 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
|||
stacklevel=5,
|
||||
)
|
||||
|
||||
def get_data(self, pathname):
|
||||
def get_data(self, pathname: Union[str, bytes]) -> bytes:
|
||||
"""Optional PEP302 get_data API."""
|
||||
with open(pathname, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _write_pyc_fp(fp, source_stat, co):
|
||||
def _write_pyc_fp(
|
||||
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
|
||||
) -> None:
|
||||
# Technically, we don't have to have the same pyc format as
|
||||
# (C)Python, since these "pycs" should never be seen by builtin
|
||||
# import. However, there's little reason deviate.
|
||||
|
@ -280,7 +298,12 @@ def _write_pyc_fp(fp, source_stat, co):
|
|||
if sys.platform == "win32":
|
||||
from atomicwrites import atomic_write
|
||||
|
||||
def _write_pyc(state, co, source_stat, pyc):
|
||||
def _write_pyc(
|
||||
state: "AssertionState",
|
||||
co: types.CodeType,
|
||||
source_stat: os.stat_result,
|
||||
pyc: Path,
|
||||
) -> bool:
|
||||
try:
|
||||
with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp:
|
||||
_write_pyc_fp(fp, source_stat, co)
|
||||
|
@ -295,7 +318,12 @@ if sys.platform == "win32":
|
|||
|
||||
else:
|
||||
|
||||
def _write_pyc(state, co, source_stat, pyc):
|
||||
def _write_pyc(
|
||||
state: "AssertionState",
|
||||
co: types.CodeType,
|
||||
source_stat: os.stat_result,
|
||||
pyc: Path,
|
||||
) -> bool:
|
||||
proc_pyc = "{}.{}".format(pyc, os.getpid())
|
||||
try:
|
||||
fp = open(proc_pyc, "wb")
|
||||
|
@ -319,19 +347,21 @@ else:
|
|||
return True
|
||||
|
||||
|
||||
def _rewrite_test(fn, config):
|
||||
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
|
||||
"""read and rewrite *fn* and return the code object."""
|
||||
fn = fspath(fn)
|
||||
stat = os.stat(fn)
|
||||
with open(fn, "rb") as f:
|
||||
fn_ = fspath(fn)
|
||||
stat = os.stat(fn_)
|
||||
with open(fn_, "rb") as f:
|
||||
source = f.read()
|
||||
tree = ast.parse(source, filename=fn)
|
||||
rewrite_asserts(tree, source, fn, config)
|
||||
co = compile(tree, fn, "exec", dont_inherit=True)
|
||||
tree = ast.parse(source, filename=fn_)
|
||||
rewrite_asserts(tree, source, fn_, config)
|
||||
co = compile(tree, fn_, "exec", dont_inherit=True)
|
||||
return stat, co
|
||||
|
||||
|
||||
def _read_pyc(source, pyc, trace=lambda x: None):
|
||||
def _read_pyc(
|
||||
source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
|
||||
) -> Optional[types.CodeType]:
|
||||
"""Possibly read a pytest pyc containing rewritten code.
|
||||
|
||||
Return rewritten code if successful or None if not.
|
||||
|
@ -368,12 +398,17 @@ def _read_pyc(source, pyc, trace=lambda x: None):
|
|||
return co
|
||||
|
||||
|
||||
def rewrite_asserts(mod, source, module_path=None, config=None):
|
||||
def rewrite_asserts(
|
||||
mod: ast.Module,
|
||||
source: bytes,
|
||||
module_path: Optional[str] = None,
|
||||
config: Optional[Config] = None,
|
||||
) -> None:
|
||||
"""Rewrite the assert statements in mod."""
|
||||
AssertionRewriter(module_path, config, source).run(mod)
|
||||
|
||||
|
||||
def _saferepr(obj):
|
||||
def _saferepr(obj: object) -> str:
|
||||
"""Get a safe repr of an object for assertion error messages.
|
||||
|
||||
The assertion formatting (util.format_explanation()) requires
|
||||
|
@ -387,7 +422,7 @@ def _saferepr(obj):
|
|||
return saferepr(obj).replace("\n", "\\n")
|
||||
|
||||
|
||||
def _format_assertmsg(obj):
|
||||
def _format_assertmsg(obj: object) -> str:
|
||||
"""Format the custom assertion message given.
|
||||
|
||||
For strings this simply replaces newlines with '\n~' so that
|
||||
|
@ -410,7 +445,7 @@ def _format_assertmsg(obj):
|
|||
return obj
|
||||
|
||||
|
||||
def _should_repr_global_name(obj):
|
||||
def _should_repr_global_name(obj: object) -> bool:
|
||||
if callable(obj):
|
||||
return False
|
||||
|
||||
|
@ -420,7 +455,7 @@ def _should_repr_global_name(obj):
|
|||
return True
|
||||
|
||||
|
||||
def _format_boolop(explanations, is_or):
|
||||
def _format_boolop(explanations, is_or: bool):
|
||||
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
|
||||
if isinstance(explanation, str):
|
||||
return explanation.replace("%", "%%")
|
||||
|
@ -428,8 +463,12 @@ def _format_boolop(explanations, is_or):
|
|||
return explanation.replace(b"%", b"%%")
|
||||
|
||||
|
||||
def _call_reprcompare(ops, results, expls, each_obj):
|
||||
# type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
|
||||
def _call_reprcompare(
|
||||
ops: Sequence[str],
|
||||
results: Sequence[bool],
|
||||
expls: Sequence[str],
|
||||
each_obj: Sequence[object],
|
||||
) -> str:
|
||||
for i, res, expl in zip(range(len(ops)), results, expls):
|
||||
try:
|
||||
done = not res
|
||||
|
@ -607,7 +646,9 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, module_path, config, source):
|
||||
def __init__(
|
||||
self, module_path: Optional[str], config: Optional[Config], source: bytes
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.module_path = module_path
|
||||
self.config = config
|
||||
|
@ -620,7 +661,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
self.source = source
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _assert_expr_to_lineno(self):
|
||||
def _assert_expr_to_lineno(self) -> Dict[int, str]:
|
||||
return _get_assertion_exprs(self.source)
|
||||
|
||||
def run(self, mod: ast.Module) -> None:
|
||||
|
@ -689,38 +730,38 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
nodes.append(field)
|
||||
|
||||
@staticmethod
|
||||
def is_rewrite_disabled(docstring):
|
||||
def is_rewrite_disabled(docstring: str) -> bool:
|
||||
return "PYTEST_DONT_REWRITE" in docstring
|
||||
|
||||
def variable(self):
|
||||
def variable(self) -> str:
|
||||
"""Get a new variable."""
|
||||
# Use a character invalid in python identifiers to avoid clashing.
|
||||
name = "@py_assert" + str(next(self.variable_counter))
|
||||
self.variables.append(name)
|
||||
return name
|
||||
|
||||
def assign(self, expr):
|
||||
def assign(self, expr: ast.expr) -> ast.Name:
|
||||
"""Give *expr* a name."""
|
||||
name = self.variable()
|
||||
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
|
||||
return ast.Name(name, ast.Load())
|
||||
|
||||
def display(self, expr):
|
||||
def display(self, expr: ast.expr) -> ast.expr:
|
||||
"""Call saferepr on the expression."""
|
||||
return self.helper("_saferepr", expr)
|
||||
|
||||
def helper(self, name, *args):
|
||||
def helper(self, name: str, *args: ast.expr) -> ast.expr:
|
||||
"""Call a helper in this module."""
|
||||
py_name = ast.Name("@pytest_ar", ast.Load())
|
||||
attr = ast.Attribute(py_name, name, ast.Load())
|
||||
return ast.Call(attr, list(args), [])
|
||||
|
||||
def builtin(self, name):
|
||||
def builtin(self, name: str) -> ast.Attribute:
|
||||
"""Return the builtin called *name*."""
|
||||
builtin_name = ast.Name("@py_builtins", ast.Load())
|
||||
return ast.Attribute(builtin_name, name, ast.Load())
|
||||
|
||||
def explanation_param(self, expr):
|
||||
def explanation_param(self, expr: ast.expr) -> str:
|
||||
"""Return a new named %-formatting placeholder for expr.
|
||||
|
||||
This creates a %-formatting placeholder for expr in the
|
||||
|
@ -733,7 +774,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
self.explanation_specifiers[specifier] = expr
|
||||
return "%(" + specifier + ")s"
|
||||
|
||||
def push_format_context(self):
|
||||
def push_format_context(self) -> None:
|
||||
"""Create a new formatting context.
|
||||
|
||||
The format context is used for when an explanation wants to
|
||||
|
@ -747,10 +788,10 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
self.explanation_specifiers = {} # type: Dict[str, ast.expr]
|
||||
self.stack.append(self.explanation_specifiers)
|
||||
|
||||
def pop_format_context(self, expl_expr):
|
||||
def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
|
||||
"""Format the %-formatted string with current format context.
|
||||
|
||||
The expl_expr should be an ast.Str instance constructed from
|
||||
The expl_expr should be an str ast.expr instance constructed from
|
||||
the %-placeholders created by .explanation_param(). This will
|
||||
add the required code to format said string to .expl_stmts and
|
||||
return the ast.Name instance of the formatted string.
|
||||
|
@ -768,13 +809,13 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
||||
return ast.Name(name, ast.Load())
|
||||
|
||||
def generic_visit(self, node):
|
||||
def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
|
||||
"""Handle expressions we don't have custom code for."""
|
||||
assert isinstance(node, ast.expr)
|
||||
res = self.assign(node)
|
||||
return res, self.explanation_param(self.display(res))
|
||||
|
||||
def visit_Assert(self, assert_):
|
||||
def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
|
||||
"""Return the AST statements to replace the ast.Assert instance.
|
||||
|
||||
This rewrites the test of an assertion to provide
|
||||
|
@ -787,6 +828,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
from _pytest.warning_types import PytestAssertRewriteWarning
|
||||
import warnings
|
||||
|
||||
# TODO: This assert should not be needed.
|
||||
assert self.module_path is not None
|
||||
warnings.warn_explicit(
|
||||
PytestAssertRewriteWarning(
|
||||
"assertion is always true, perhaps remove parentheses?"
|
||||
|
@ -889,7 +932,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
set_location(stmt, assert_.lineno, assert_.col_offset)
|
||||
return self.statements
|
||||
|
||||
def visit_Name(self, name):
|
||||
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
|
||||
# Display the repr of the name if it's a local variable or
|
||||
# _should_repr_global_name() thinks it's acceptable.
|
||||
locs = ast.Call(self.builtin("locals"), [], [])
|
||||
|
@ -899,7 +942,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
|
||||
return name, self.explanation_param(expr)
|
||||
|
||||
def visit_BoolOp(self, boolop):
|
||||
def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
|
||||
res_var = self.variable()
|
||||
expl_list = self.assign(ast.List([], ast.Load()))
|
||||
app = ast.Attribute(expl_list, "append", ast.Load())
|
||||
|
@ -934,13 +977,13 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
expl = self.pop_format_context(expl_template)
|
||||
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
|
||||
|
||||
def visit_UnaryOp(self, unary):
|
||||
def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
|
||||
pattern = UNARY_MAP[unary.op.__class__]
|
||||
operand_res, operand_expl = self.visit(unary.operand)
|
||||
res = self.assign(ast.UnaryOp(unary.op, operand_res))
|
||||
return res, pattern % (operand_expl,)
|
||||
|
||||
def visit_BinOp(self, binop):
|
||||
def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
|
||||
symbol = BINOP_MAP[binop.op.__class__]
|
||||
left_expr, left_expl = self.visit(binop.left)
|
||||
right_expr, right_expl = self.visit(binop.right)
|
||||
|
@ -948,7 +991,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
|
||||
return res, explanation
|
||||
|
||||
def visit_Call(self, call):
|
||||
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
||||
"""
|
||||
visit `ast.Call` nodes
|
||||
"""
|
||||
|
@ -975,13 +1018,13 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
|
||||
return res, outer_expl
|
||||
|
||||
def visit_Starred(self, starred):
|
||||
def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
|
||||
# From Python 3.5, a Starred node can appear in a function call
|
||||
res, expl = self.visit(starred.value)
|
||||
new_starred = ast.Starred(res, starred.ctx)
|
||||
return new_starred, "*" + expl
|
||||
|
||||
def visit_Attribute(self, attr):
|
||||
def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
|
||||
if not isinstance(attr.ctx, ast.Load):
|
||||
return self.generic_visit(attr)
|
||||
value, value_expl = self.visit(attr.value)
|
||||
|
@ -991,7 +1034,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
|
||||
return res, expl
|
||||
|
||||
def visit_Compare(self, comp: ast.Compare):
|
||||
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
||||
self.push_format_context()
|
||||
left_res, left_expl = self.visit(comp.left)
|
||||
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
|
||||
|
@ -1030,7 +1073,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
return res, self.explanation_param(self.pop_format_context(expl_call))
|
||||
|
||||
|
||||
def try_makedirs(cache_dir) -> bool:
|
||||
def try_makedirs(cache_dir: Path) -> bool:
|
||||
"""Attempts to create the given directory and sub-directories exist, returns True if
|
||||
successful or it already exists"""
|
||||
try:
|
||||
|
|
|
@ -5,13 +5,20 @@ Current default behaviour is to truncate assertion explanations at
|
|||
~8 terminal lines, unless running in "-vv" mode or running on CI.
|
||||
"""
|
||||
import os
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from _pytest.nodes import Item
|
||||
|
||||
|
||||
DEFAULT_MAX_LINES = 8
|
||||
DEFAULT_MAX_CHARS = 8 * 80
|
||||
USAGE_MSG = "use '-vv' to show"
|
||||
|
||||
|
||||
def truncate_if_required(explanation, item, max_length=None):
|
||||
def truncate_if_required(
|
||||
explanation: List[str], item: Item, max_length: Optional[int] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Truncate this assertion explanation if the given test item is eligible.
|
||||
"""
|
||||
|
@ -20,7 +27,7 @@ def truncate_if_required(explanation, item, max_length=None):
|
|||
return explanation
|
||||
|
||||
|
||||
def _should_truncate_item(item):
|
||||
def _should_truncate_item(item: Item) -> bool:
|
||||
"""
|
||||
Whether or not this test item is eligible for truncation.
|
||||
"""
|
||||
|
@ -28,13 +35,17 @@ def _should_truncate_item(item):
|
|||
return verbose < 2 and not _running_on_ci()
|
||||
|
||||
|
||||
def _running_on_ci():
|
||||
def _running_on_ci() -> bool:
|
||||
"""Check if we're currently running on a CI system."""
|
||||
env_vars = ["CI", "BUILD_NUMBER"]
|
||||
return any(var in os.environ for var in env_vars)
|
||||
|
||||
|
||||
def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
|
||||
def _truncate_explanation(
|
||||
input_lines: List[str],
|
||||
max_lines: Optional[int] = None,
|
||||
max_chars: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Truncate given list of strings that makes up the assertion explanation.
|
||||
|
||||
|
@ -73,7 +84,7 @@ def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
|
|||
return truncated_explanation
|
||||
|
||||
|
||||
def _truncate_by_char_count(input_lines, max_chars):
|
||||
def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
|
||||
# Check if truncation required
|
||||
if len("".join(input_lines)) <= max_chars:
|
||||
return input_lines
|
||||
|
|
|
@ -952,7 +952,8 @@ class TestAssertionRewriteHookDetails:
|
|||
state = AssertionState(config, "rewrite")
|
||||
source_path = str(tmpdir.ensure("source.py"))
|
||||
pycpath = tmpdir.join("pyc").strpath
|
||||
assert _write_pyc(state, [1], os.stat(source_path), pycpath)
|
||||
co = compile("1", "f.py", "single")
|
||||
assert _write_pyc(state, co, os.stat(source_path), pycpath)
|
||||
|
||||
if sys.platform == "win32":
|
||||
from contextlib import contextmanager
|
||||
|
@ -974,7 +975,7 @@ class TestAssertionRewriteHookDetails:
|
|||
|
||||
monkeypatch.setattr("os.rename", raise_oserror)
|
||||
|
||||
assert not _write_pyc(state, [1], os.stat(source_path), pycpath)
|
||||
assert not _write_pyc(state, co, os.stat(source_path), pycpath)
|
||||
|
||||
def test_resources_provider_for_loader(self, testdir):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue