Type annotate _pytest.assertion
This commit is contained in:
parent
30e3d473c4
commit
d95132178c
|
@ -46,7 +46,7 @@ def pytest_addoption(parser: Parser) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_assert_rewrite(*names) -> None:
|
def register_assert_rewrite(*names: str) -> None:
|
||||||
"""Register one or more module names to be rewritten on import.
|
"""Register one or more module names to be rewritten on import.
|
||||||
|
|
||||||
This function will make sure that this module or all modules inside
|
This function will make sure that this module or all modules inside
|
||||||
|
@ -75,27 +75,27 @@ def register_assert_rewrite(*names) -> None:
|
||||||
class DummyRewriteHook:
|
class DummyRewriteHook:
|
||||||
"""A no-op import hook for when rewriting is disabled."""
|
"""A no-op import hook for when rewriting is disabled."""
|
||||||
|
|
||||||
def mark_rewrite(self, *names):
|
def mark_rewrite(self, *names: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AssertionState:
|
class AssertionState:
|
||||||
"""State for the assertion plugin."""
|
"""State for the assertion plugin."""
|
||||||
|
|
||||||
def __init__(self, config, mode):
|
def __init__(self, config: Config, mode) -> None:
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.trace = config.trace.root.get("assertion")
|
self.trace = config.trace.root.get("assertion")
|
||||||
self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
|
self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
|
||||||
|
|
||||||
|
|
||||||
def install_importhook(config):
|
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
|
||||||
"""Try to install the rewrite hook, raise SystemError if it fails."""
|
"""Try to install the rewrite hook, raise SystemError if it fails."""
|
||||||
config._store[assertstate_key] = AssertionState(config, "rewrite")
|
config._store[assertstate_key] = AssertionState(config, "rewrite")
|
||||||
config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
|
config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
|
||||||
sys.meta_path.insert(0, hook)
|
sys.meta_path.insert(0, hook)
|
||||||
config._store[assertstate_key].trace("installed rewrite import hook")
|
config._store[assertstate_key].trace("installed rewrite import hook")
|
||||||
|
|
||||||
def undo():
|
def undo() -> None:
|
||||||
hook = config._store[assertstate_key].hook
|
hook = config._store[assertstate_key].hook
|
||||||
if hook is not None and hook in sys.meta_path:
|
if hook is not None and hook in sys.meta_path:
|
||||||
sys.meta_path.remove(hook)
|
sys.meta_path.remove(hook)
|
||||||
|
|
|
@ -13,11 +13,15 @@ import struct
|
||||||
import sys
|
import sys
|
||||||
import tokenize
|
import tokenize
|
||||||
import types
|
import types
|
||||||
|
from typing import Callable
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
from typing import IO
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from typing import Sequence
|
||||||
from typing import Set
|
from typing import Set
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from _pytest._io.saferepr import saferepr
|
from _pytest._io.saferepr import saferepr
|
||||||
from _pytest._version import version
|
from _pytest._version import version
|
||||||
|
@ -27,6 +31,8 @@ from _pytest.assertion.util import ( # noqa: F401
|
||||||
)
|
)
|
||||||
from _pytest.compat import fspath
|
from _pytest.compat import fspath
|
||||||
from _pytest.compat import TYPE_CHECKING
|
from _pytest.compat import TYPE_CHECKING
|
||||||
|
from _pytest.config import Config
|
||||||
|
from _pytest.main import Session
|
||||||
from _pytest.pathlib import fnmatch_ex
|
from _pytest.pathlib import fnmatch_ex
|
||||||
from _pytest.pathlib import Path
|
from _pytest.pathlib import Path
|
||||||
from _pytest.pathlib import PurePath
|
from _pytest.pathlib import PurePath
|
||||||
|
@ -48,13 +54,13 @@ PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
|
||||||
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
|
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
|
||||||
"""PEP302/PEP451 import hook which rewrites asserts."""
|
"""PEP302/PEP451 import hook which rewrites asserts."""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config: Config) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
try:
|
try:
|
||||||
self.fnpats = config.getini("python_files")
|
self.fnpats = config.getini("python_files")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
self.fnpats = ["test_*.py", "*_test.py"]
|
self.fnpats = ["test_*.py", "*_test.py"]
|
||||||
self.session = None
|
self.session = None # type: Optional[Session]
|
||||||
self._rewritten_names = set() # type: Set[str]
|
self._rewritten_names = set() # type: Set[str]
|
||||||
self._must_rewrite = set() # type: Set[str]
|
self._must_rewrite = set() # type: Set[str]
|
||||||
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
|
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
|
||||||
|
@ -64,14 +70,19 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||||
self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
|
self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
|
||||||
self._session_paths_checked = False
|
self._session_paths_checked = False
|
||||||
|
|
||||||
def set_session(self, session):
|
def set_session(self, session: Optional[Session]) -> None:
|
||||||
self.session = session
|
self.session = session
|
||||||
self._session_paths_checked = False
|
self._session_paths_checked = False
|
||||||
|
|
||||||
# Indirection so we can mock calls to find_spec originated from the hook during testing
|
# Indirection so we can mock calls to find_spec originated from the hook during testing
|
||||||
_find_spec = importlib.machinery.PathFinder.find_spec
|
_find_spec = importlib.machinery.PathFinder.find_spec
|
||||||
|
|
||||||
def find_spec(self, name, path=None, target=None):
|
def find_spec(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
path: Optional[Sequence[Union[str, bytes]]] = None,
|
||||||
|
target: Optional[types.ModuleType] = None,
|
||||||
|
) -> Optional[importlib.machinery.ModuleSpec]:
|
||||||
if self._writing_pyc:
|
if self._writing_pyc:
|
||||||
return None
|
return None
|
||||||
state = self.config._store[assertstate_key]
|
state = self.config._store[assertstate_key]
|
||||||
|
@ -79,7 +90,8 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||||
return None
|
return None
|
||||||
state.trace("find_module called for: %s" % name)
|
state.trace("find_module called for: %s" % name)
|
||||||
|
|
||||||
spec = self._find_spec(name, path)
|
# Type ignored because mypy is confused about the `self` binding here.
|
||||||
|
spec = self._find_spec(name, path) # type: ignore
|
||||||
if (
|
if (
|
||||||
# the import machinery could not find a file to import
|
# the import machinery could not find a file to import
|
||||||
spec is None
|
spec is None
|
||||||
|
@ -108,10 +120,14 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||||
submodule_search_locations=spec.submodule_search_locations,
|
submodule_search_locations=spec.submodule_search_locations,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_module(self, spec):
|
def create_module(
|
||||||
|
self, spec: importlib.machinery.ModuleSpec
|
||||||
|
) -> Optional[types.ModuleType]:
|
||||||
return None # default behaviour is fine
|
return None # default behaviour is fine
|
||||||
|
|
||||||
def exec_module(self, module):
|
def exec_module(self, module: types.ModuleType) -> None:
|
||||||
|
assert module.__spec__ is not None
|
||||||
|
assert module.__spec__.origin is not None
|
||||||
fn = Path(module.__spec__.origin)
|
fn = Path(module.__spec__.origin)
|
||||||
state = self.config._store[assertstate_key]
|
state = self.config._store[assertstate_key]
|
||||||
|
|
||||||
|
@ -151,7 +167,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||||
state.trace("found cached rewritten pyc for {}".format(fn))
|
state.trace("found cached rewritten pyc for {}".format(fn))
|
||||||
exec(co, module.__dict__)
|
exec(co, module.__dict__)
|
||||||
|
|
||||||
def _early_rewrite_bailout(self, name, state):
|
def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
|
||||||
"""This is a fast way to get out of rewriting modules.
|
"""This is a fast way to get out of rewriting modules.
|
||||||
|
|
||||||
Profiling has shown that the call to PathFinder.find_spec (inside of
|
Profiling has shown that the call to PathFinder.find_spec (inside of
|
||||||
|
@ -190,7 +206,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||||
state.trace("early skip of rewriting module: {}".format(name))
|
state.trace("early skip of rewriting module: {}".format(name))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _should_rewrite(self, name, fn, state):
|
def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
|
||||||
# always rewrite conftest files
|
# always rewrite conftest files
|
||||||
if os.path.basename(fn) == "conftest.py":
|
if os.path.basename(fn) == "conftest.py":
|
||||||
state.trace("rewriting conftest file: {!r}".format(fn))
|
state.trace("rewriting conftest file: {!r}".format(fn))
|
||||||
|
@ -213,7 +229,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||||
|
|
||||||
return self._is_marked_for_rewrite(name, state)
|
return self._is_marked_for_rewrite(name, state)
|
||||||
|
|
||||||
def _is_marked_for_rewrite(self, name: str, state):
|
def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
|
||||||
try:
|
try:
|
||||||
return self._marked_for_rewrite_cache[name]
|
return self._marked_for_rewrite_cache[name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -246,7 +262,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||||
self._must_rewrite.update(names)
|
self._must_rewrite.update(names)
|
||||||
self._marked_for_rewrite_cache.clear()
|
self._marked_for_rewrite_cache.clear()
|
||||||
|
|
||||||
def _warn_already_imported(self, name):
|
def _warn_already_imported(self, name: str) -> None:
|
||||||
from _pytest.warning_types import PytestAssertRewriteWarning
|
from _pytest.warning_types import PytestAssertRewriteWarning
|
||||||
from _pytest.warnings import _issue_warning_captured
|
from _pytest.warnings import _issue_warning_captured
|
||||||
|
|
||||||
|
@ -258,13 +274,15 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||||
stacklevel=5,
|
stacklevel=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_data(self, pathname):
|
def get_data(self, pathname: Union[str, bytes]) -> bytes:
|
||||||
"""Optional PEP302 get_data API."""
|
"""Optional PEP302 get_data API."""
|
||||||
with open(pathname, "rb") as f:
|
with open(pathname, "rb") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
def _write_pyc_fp(fp, source_stat, co):
|
def _write_pyc_fp(
|
||||||
|
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
|
||||||
|
) -> None:
|
||||||
# Technically, we don't have to have the same pyc format as
|
# Technically, we don't have to have the same pyc format as
|
||||||
# (C)Python, since these "pycs" should never be seen by builtin
|
# (C)Python, since these "pycs" should never be seen by builtin
|
||||||
# import. However, there's little reason deviate.
|
# import. However, there's little reason deviate.
|
||||||
|
@ -280,7 +298,12 @@ def _write_pyc_fp(fp, source_stat, co):
|
||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
from atomicwrites import atomic_write
|
from atomicwrites import atomic_write
|
||||||
|
|
||||||
def _write_pyc(state, co, source_stat, pyc):
|
def _write_pyc(
|
||||||
|
state: "AssertionState",
|
||||||
|
co: types.CodeType,
|
||||||
|
source_stat: os.stat_result,
|
||||||
|
pyc: Path,
|
||||||
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp:
|
with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp:
|
||||||
_write_pyc_fp(fp, source_stat, co)
|
_write_pyc_fp(fp, source_stat, co)
|
||||||
|
@ -295,7 +318,12 @@ if sys.platform == "win32":
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def _write_pyc(state, co, source_stat, pyc):
|
def _write_pyc(
|
||||||
|
state: "AssertionState",
|
||||||
|
co: types.CodeType,
|
||||||
|
source_stat: os.stat_result,
|
||||||
|
pyc: Path,
|
||||||
|
) -> bool:
|
||||||
proc_pyc = "{}.{}".format(pyc, os.getpid())
|
proc_pyc = "{}.{}".format(pyc, os.getpid())
|
||||||
try:
|
try:
|
||||||
fp = open(proc_pyc, "wb")
|
fp = open(proc_pyc, "wb")
|
||||||
|
@ -319,19 +347,21 @@ else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _rewrite_test(fn, config):
|
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
|
||||||
"""read and rewrite *fn* and return the code object."""
|
"""read and rewrite *fn* and return the code object."""
|
||||||
fn = fspath(fn)
|
fn_ = fspath(fn)
|
||||||
stat = os.stat(fn)
|
stat = os.stat(fn_)
|
||||||
with open(fn, "rb") as f:
|
with open(fn_, "rb") as f:
|
||||||
source = f.read()
|
source = f.read()
|
||||||
tree = ast.parse(source, filename=fn)
|
tree = ast.parse(source, filename=fn_)
|
||||||
rewrite_asserts(tree, source, fn, config)
|
rewrite_asserts(tree, source, fn_, config)
|
||||||
co = compile(tree, fn, "exec", dont_inherit=True)
|
co = compile(tree, fn_, "exec", dont_inherit=True)
|
||||||
return stat, co
|
return stat, co
|
||||||
|
|
||||||
|
|
||||||
def _read_pyc(source, pyc, trace=lambda x: None):
|
def _read_pyc(
|
||||||
|
source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
|
||||||
|
) -> Optional[types.CodeType]:
|
||||||
"""Possibly read a pytest pyc containing rewritten code.
|
"""Possibly read a pytest pyc containing rewritten code.
|
||||||
|
|
||||||
Return rewritten code if successful or None if not.
|
Return rewritten code if successful or None if not.
|
||||||
|
@ -368,12 +398,17 @@ def _read_pyc(source, pyc, trace=lambda x: None):
|
||||||
return co
|
return co
|
||||||
|
|
||||||
|
|
||||||
def rewrite_asserts(mod, source, module_path=None, config=None):
|
def rewrite_asserts(
|
||||||
|
mod: ast.Module,
|
||||||
|
source: bytes,
|
||||||
|
module_path: Optional[str] = None,
|
||||||
|
config: Optional[Config] = None,
|
||||||
|
) -> None:
|
||||||
"""Rewrite the assert statements in mod."""
|
"""Rewrite the assert statements in mod."""
|
||||||
AssertionRewriter(module_path, config, source).run(mod)
|
AssertionRewriter(module_path, config, source).run(mod)
|
||||||
|
|
||||||
|
|
||||||
def _saferepr(obj):
|
def _saferepr(obj: object) -> str:
|
||||||
"""Get a safe repr of an object for assertion error messages.
|
"""Get a safe repr of an object for assertion error messages.
|
||||||
|
|
||||||
The assertion formatting (util.format_explanation()) requires
|
The assertion formatting (util.format_explanation()) requires
|
||||||
|
@ -387,7 +422,7 @@ def _saferepr(obj):
|
||||||
return saferepr(obj).replace("\n", "\\n")
|
return saferepr(obj).replace("\n", "\\n")
|
||||||
|
|
||||||
|
|
||||||
def _format_assertmsg(obj):
|
def _format_assertmsg(obj: object) -> str:
|
||||||
"""Format the custom assertion message given.
|
"""Format the custom assertion message given.
|
||||||
|
|
||||||
For strings this simply replaces newlines with '\n~' so that
|
For strings this simply replaces newlines with '\n~' so that
|
||||||
|
@ -410,7 +445,7 @@ def _format_assertmsg(obj):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def _should_repr_global_name(obj):
|
def _should_repr_global_name(obj: object) -> bool:
|
||||||
if callable(obj):
|
if callable(obj):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -420,7 +455,7 @@ def _should_repr_global_name(obj):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _format_boolop(explanations, is_or):
|
def _format_boolop(explanations, is_or: bool):
|
||||||
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
|
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
|
||||||
if isinstance(explanation, str):
|
if isinstance(explanation, str):
|
||||||
return explanation.replace("%", "%%")
|
return explanation.replace("%", "%%")
|
||||||
|
@ -428,8 +463,12 @@ def _format_boolop(explanations, is_or):
|
||||||
return explanation.replace(b"%", b"%%")
|
return explanation.replace(b"%", b"%%")
|
||||||
|
|
||||||
|
|
||||||
def _call_reprcompare(ops, results, expls, each_obj):
|
def _call_reprcompare(
|
||||||
# type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
|
ops: Sequence[str],
|
||||||
|
results: Sequence[bool],
|
||||||
|
expls: Sequence[str],
|
||||||
|
each_obj: Sequence[object],
|
||||||
|
) -> str:
|
||||||
for i, res, expl in zip(range(len(ops)), results, expls):
|
for i, res, expl in zip(range(len(ops)), results, expls):
|
||||||
try:
|
try:
|
||||||
done = not res
|
done = not res
|
||||||
|
@ -607,7 +646,9 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, module_path, config, source):
|
def __init__(
|
||||||
|
self, module_path: Optional[str], config: Optional[Config], source: bytes
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module_path = module_path
|
self.module_path = module_path
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -620,7 +661,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
self.source = source
|
self.source = source
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=1)
|
@functools.lru_cache(maxsize=1)
|
||||||
def _assert_expr_to_lineno(self):
|
def _assert_expr_to_lineno(self) -> Dict[int, str]:
|
||||||
return _get_assertion_exprs(self.source)
|
return _get_assertion_exprs(self.source)
|
||||||
|
|
||||||
def run(self, mod: ast.Module) -> None:
|
def run(self, mod: ast.Module) -> None:
|
||||||
|
@ -689,38 +730,38 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
nodes.append(field)
|
nodes.append(field)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_rewrite_disabled(docstring):
|
def is_rewrite_disabled(docstring: str) -> bool:
|
||||||
return "PYTEST_DONT_REWRITE" in docstring
|
return "PYTEST_DONT_REWRITE" in docstring
|
||||||
|
|
||||||
def variable(self):
|
def variable(self) -> str:
|
||||||
"""Get a new variable."""
|
"""Get a new variable."""
|
||||||
# Use a character invalid in python identifiers to avoid clashing.
|
# Use a character invalid in python identifiers to avoid clashing.
|
||||||
name = "@py_assert" + str(next(self.variable_counter))
|
name = "@py_assert" + str(next(self.variable_counter))
|
||||||
self.variables.append(name)
|
self.variables.append(name)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def assign(self, expr):
|
def assign(self, expr: ast.expr) -> ast.Name:
|
||||||
"""Give *expr* a name."""
|
"""Give *expr* a name."""
|
||||||
name = self.variable()
|
name = self.variable()
|
||||||
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
|
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
|
||||||
return ast.Name(name, ast.Load())
|
return ast.Name(name, ast.Load())
|
||||||
|
|
||||||
def display(self, expr):
|
def display(self, expr: ast.expr) -> ast.expr:
|
||||||
"""Call saferepr on the expression."""
|
"""Call saferepr on the expression."""
|
||||||
return self.helper("_saferepr", expr)
|
return self.helper("_saferepr", expr)
|
||||||
|
|
||||||
def helper(self, name, *args):
|
def helper(self, name: str, *args: ast.expr) -> ast.expr:
|
||||||
"""Call a helper in this module."""
|
"""Call a helper in this module."""
|
||||||
py_name = ast.Name("@pytest_ar", ast.Load())
|
py_name = ast.Name("@pytest_ar", ast.Load())
|
||||||
attr = ast.Attribute(py_name, name, ast.Load())
|
attr = ast.Attribute(py_name, name, ast.Load())
|
||||||
return ast.Call(attr, list(args), [])
|
return ast.Call(attr, list(args), [])
|
||||||
|
|
||||||
def builtin(self, name):
|
def builtin(self, name: str) -> ast.Attribute:
|
||||||
"""Return the builtin called *name*."""
|
"""Return the builtin called *name*."""
|
||||||
builtin_name = ast.Name("@py_builtins", ast.Load())
|
builtin_name = ast.Name("@py_builtins", ast.Load())
|
||||||
return ast.Attribute(builtin_name, name, ast.Load())
|
return ast.Attribute(builtin_name, name, ast.Load())
|
||||||
|
|
||||||
def explanation_param(self, expr):
|
def explanation_param(self, expr: ast.expr) -> str:
|
||||||
"""Return a new named %-formatting placeholder for expr.
|
"""Return a new named %-formatting placeholder for expr.
|
||||||
|
|
||||||
This creates a %-formatting placeholder for expr in the
|
This creates a %-formatting placeholder for expr in the
|
||||||
|
@ -733,7 +774,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
self.explanation_specifiers[specifier] = expr
|
self.explanation_specifiers[specifier] = expr
|
||||||
return "%(" + specifier + ")s"
|
return "%(" + specifier + ")s"
|
||||||
|
|
||||||
def push_format_context(self):
|
def push_format_context(self) -> None:
|
||||||
"""Create a new formatting context.
|
"""Create a new formatting context.
|
||||||
|
|
||||||
The format context is used for when an explanation wants to
|
The format context is used for when an explanation wants to
|
||||||
|
@ -747,10 +788,10 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
self.explanation_specifiers = {} # type: Dict[str, ast.expr]
|
self.explanation_specifiers = {} # type: Dict[str, ast.expr]
|
||||||
self.stack.append(self.explanation_specifiers)
|
self.stack.append(self.explanation_specifiers)
|
||||||
|
|
||||||
def pop_format_context(self, expl_expr):
|
def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
|
||||||
"""Format the %-formatted string with current format context.
|
"""Format the %-formatted string with current format context.
|
||||||
|
|
||||||
The expl_expr should be an ast.Str instance constructed from
|
The expl_expr should be an str ast.expr instance constructed from
|
||||||
the %-placeholders created by .explanation_param(). This will
|
the %-placeholders created by .explanation_param(). This will
|
||||||
add the required code to format said string to .expl_stmts and
|
add the required code to format said string to .expl_stmts and
|
||||||
return the ast.Name instance of the formatted string.
|
return the ast.Name instance of the formatted string.
|
||||||
|
@ -768,13 +809,13 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
||||||
return ast.Name(name, ast.Load())
|
return ast.Name(name, ast.Load())
|
||||||
|
|
||||||
def generic_visit(self, node):
|
def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
|
||||||
"""Handle expressions we don't have custom code for."""
|
"""Handle expressions we don't have custom code for."""
|
||||||
assert isinstance(node, ast.expr)
|
assert isinstance(node, ast.expr)
|
||||||
res = self.assign(node)
|
res = self.assign(node)
|
||||||
return res, self.explanation_param(self.display(res))
|
return res, self.explanation_param(self.display(res))
|
||||||
|
|
||||||
def visit_Assert(self, assert_):
|
def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
|
||||||
"""Return the AST statements to replace the ast.Assert instance.
|
"""Return the AST statements to replace the ast.Assert instance.
|
||||||
|
|
||||||
This rewrites the test of an assertion to provide
|
This rewrites the test of an assertion to provide
|
||||||
|
@ -787,6 +828,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
from _pytest.warning_types import PytestAssertRewriteWarning
|
from _pytest.warning_types import PytestAssertRewriteWarning
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
# TODO: This assert should not be needed.
|
||||||
|
assert self.module_path is not None
|
||||||
warnings.warn_explicit(
|
warnings.warn_explicit(
|
||||||
PytestAssertRewriteWarning(
|
PytestAssertRewriteWarning(
|
||||||
"assertion is always true, perhaps remove parentheses?"
|
"assertion is always true, perhaps remove parentheses?"
|
||||||
|
@ -889,7 +932,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
set_location(stmt, assert_.lineno, assert_.col_offset)
|
set_location(stmt, assert_.lineno, assert_.col_offset)
|
||||||
return self.statements
|
return self.statements
|
||||||
|
|
||||||
def visit_Name(self, name):
|
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
|
||||||
# Display the repr of the name if it's a local variable or
|
# Display the repr of the name if it's a local variable or
|
||||||
# _should_repr_global_name() thinks it's acceptable.
|
# _should_repr_global_name() thinks it's acceptable.
|
||||||
locs = ast.Call(self.builtin("locals"), [], [])
|
locs = ast.Call(self.builtin("locals"), [], [])
|
||||||
|
@ -899,7 +942,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
|
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
|
||||||
return name, self.explanation_param(expr)
|
return name, self.explanation_param(expr)
|
||||||
|
|
||||||
def visit_BoolOp(self, boolop):
|
def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
|
||||||
res_var = self.variable()
|
res_var = self.variable()
|
||||||
expl_list = self.assign(ast.List([], ast.Load()))
|
expl_list = self.assign(ast.List([], ast.Load()))
|
||||||
app = ast.Attribute(expl_list, "append", ast.Load())
|
app = ast.Attribute(expl_list, "append", ast.Load())
|
||||||
|
@ -934,13 +977,13 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
expl = self.pop_format_context(expl_template)
|
expl = self.pop_format_context(expl_template)
|
||||||
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
|
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
|
||||||
|
|
||||||
def visit_UnaryOp(self, unary):
|
def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
|
||||||
pattern = UNARY_MAP[unary.op.__class__]
|
pattern = UNARY_MAP[unary.op.__class__]
|
||||||
operand_res, operand_expl = self.visit(unary.operand)
|
operand_res, operand_expl = self.visit(unary.operand)
|
||||||
res = self.assign(ast.UnaryOp(unary.op, operand_res))
|
res = self.assign(ast.UnaryOp(unary.op, operand_res))
|
||||||
return res, pattern % (operand_expl,)
|
return res, pattern % (operand_expl,)
|
||||||
|
|
||||||
def visit_BinOp(self, binop):
|
def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
|
||||||
symbol = BINOP_MAP[binop.op.__class__]
|
symbol = BINOP_MAP[binop.op.__class__]
|
||||||
left_expr, left_expl = self.visit(binop.left)
|
left_expr, left_expl = self.visit(binop.left)
|
||||||
right_expr, right_expl = self.visit(binop.right)
|
right_expr, right_expl = self.visit(binop.right)
|
||||||
|
@ -948,7 +991,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
|
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
|
||||||
return res, explanation
|
return res, explanation
|
||||||
|
|
||||||
def visit_Call(self, call):
|
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
||||||
"""
|
"""
|
||||||
visit `ast.Call` nodes
|
visit `ast.Call` nodes
|
||||||
"""
|
"""
|
||||||
|
@ -975,13 +1018,13 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
|
outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
|
||||||
return res, outer_expl
|
return res, outer_expl
|
||||||
|
|
||||||
def visit_Starred(self, starred):
|
def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
|
||||||
# From Python 3.5, a Starred node can appear in a function call
|
# From Python 3.5, a Starred node can appear in a function call
|
||||||
res, expl = self.visit(starred.value)
|
res, expl = self.visit(starred.value)
|
||||||
new_starred = ast.Starred(res, starred.ctx)
|
new_starred = ast.Starred(res, starred.ctx)
|
||||||
return new_starred, "*" + expl
|
return new_starred, "*" + expl
|
||||||
|
|
||||||
def visit_Attribute(self, attr):
|
def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
|
||||||
if not isinstance(attr.ctx, ast.Load):
|
if not isinstance(attr.ctx, ast.Load):
|
||||||
return self.generic_visit(attr)
|
return self.generic_visit(attr)
|
||||||
value, value_expl = self.visit(attr.value)
|
value, value_expl = self.visit(attr.value)
|
||||||
|
@ -991,7 +1034,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
|
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
|
||||||
return res, expl
|
return res, expl
|
||||||
|
|
||||||
def visit_Compare(self, comp: ast.Compare):
|
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
||||||
self.push_format_context()
|
self.push_format_context()
|
||||||
left_res, left_expl = self.visit(comp.left)
|
left_res, left_expl = self.visit(comp.left)
|
||||||
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
|
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
|
||||||
|
@ -1030,7 +1073,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
return res, self.explanation_param(self.pop_format_context(expl_call))
|
return res, self.explanation_param(self.pop_format_context(expl_call))
|
||||||
|
|
||||||
|
|
||||||
def try_makedirs(cache_dir) -> bool:
|
def try_makedirs(cache_dir: Path) -> bool:
|
||||||
"""Attempts to create the given directory and sub-directories exist, returns True if
|
"""Attempts to create the given directory and sub-directories exist, returns True if
|
||||||
successful or it already exists"""
|
successful or it already exists"""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -5,13 +5,20 @@ Current default behaviour is to truncate assertion explanations at
|
||||||
~8 terminal lines, unless running in "-vv" mode or running on CI.
|
~8 terminal lines, unless running in "-vv" mode or running on CI.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from _pytest.nodes import Item
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_MAX_LINES = 8
|
DEFAULT_MAX_LINES = 8
|
||||||
DEFAULT_MAX_CHARS = 8 * 80
|
DEFAULT_MAX_CHARS = 8 * 80
|
||||||
USAGE_MSG = "use '-vv' to show"
|
USAGE_MSG = "use '-vv' to show"
|
||||||
|
|
||||||
|
|
||||||
def truncate_if_required(explanation, item, max_length=None):
|
def truncate_if_required(
|
||||||
|
explanation: List[str], item: Item, max_length: Optional[int] = None
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Truncate this assertion explanation if the given test item is eligible.
|
Truncate this assertion explanation if the given test item is eligible.
|
||||||
"""
|
"""
|
||||||
|
@ -20,7 +27,7 @@ def truncate_if_required(explanation, item, max_length=None):
|
||||||
return explanation
|
return explanation
|
||||||
|
|
||||||
|
|
||||||
def _should_truncate_item(item):
|
def _should_truncate_item(item: Item) -> bool:
|
||||||
"""
|
"""
|
||||||
Whether or not this test item is eligible for truncation.
|
Whether or not this test item is eligible for truncation.
|
||||||
"""
|
"""
|
||||||
|
@ -28,13 +35,17 @@ def _should_truncate_item(item):
|
||||||
return verbose < 2 and not _running_on_ci()
|
return verbose < 2 and not _running_on_ci()
|
||||||
|
|
||||||
|
|
||||||
def _running_on_ci():
|
def _running_on_ci() -> bool:
|
||||||
"""Check if we're currently running on a CI system."""
|
"""Check if we're currently running on a CI system."""
|
||||||
env_vars = ["CI", "BUILD_NUMBER"]
|
env_vars = ["CI", "BUILD_NUMBER"]
|
||||||
return any(var in os.environ for var in env_vars)
|
return any(var in os.environ for var in env_vars)
|
||||||
|
|
||||||
|
|
||||||
def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
|
def _truncate_explanation(
|
||||||
|
input_lines: List[str],
|
||||||
|
max_lines: Optional[int] = None,
|
||||||
|
max_chars: Optional[int] = None,
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Truncate given list of strings that makes up the assertion explanation.
|
Truncate given list of strings that makes up the assertion explanation.
|
||||||
|
|
||||||
|
@ -73,7 +84,7 @@ def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
|
||||||
return truncated_explanation
|
return truncated_explanation
|
||||||
|
|
||||||
|
|
||||||
def _truncate_by_char_count(input_lines, max_chars):
|
def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
|
||||||
# Check if truncation required
|
# Check if truncation required
|
||||||
if len("".join(input_lines)) <= max_chars:
|
if len("".join(input_lines)) <= max_chars:
|
||||||
return input_lines
|
return input_lines
|
||||||
|
|
|
@ -952,7 +952,8 @@ class TestAssertionRewriteHookDetails:
|
||||||
state = AssertionState(config, "rewrite")
|
state = AssertionState(config, "rewrite")
|
||||||
source_path = str(tmpdir.ensure("source.py"))
|
source_path = str(tmpdir.ensure("source.py"))
|
||||||
pycpath = tmpdir.join("pyc").strpath
|
pycpath = tmpdir.join("pyc").strpath
|
||||||
assert _write_pyc(state, [1], os.stat(source_path), pycpath)
|
co = compile("1", "f.py", "single")
|
||||||
|
assert _write_pyc(state, co, os.stat(source_path), pycpath)
|
||||||
|
|
||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
@ -974,7 +975,7 @@ class TestAssertionRewriteHookDetails:
|
||||||
|
|
||||||
monkeypatch.setattr("os.rename", raise_oserror)
|
monkeypatch.setattr("os.rename", raise_oserror)
|
||||||
|
|
||||||
assert not _write_pyc(state, [1], os.stat(source_path), pycpath)
|
assert not _write_pyc(state, co, os.stat(source_path), pycpath)
|
||||||
|
|
||||||
def test_resources_provider_for_loader(self, testdir):
|
def test_resources_provider_for_loader(self, testdir):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue