Now dependent on command line option.
This commit is contained in:
parent
cfbfa53f2b
commit
4db5488ed8
|
@ -24,6 +24,14 @@ def pytest_addoption(parser):
|
||||||
expression information.""",
|
expression information.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group = parser.getgroup("experimental")
|
||||||
|
group.addoption(
|
||||||
|
"--enable-assertion-pass-hook",
|
||||||
|
action="store_true",
|
||||||
|
help="Enables the pytest_assertion_pass hook."
|
||||||
|
"Make sure to delete any previously generated pyc cache files.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_assert_rewrite(*names):
|
def register_assert_rewrite(*names):
|
||||||
"""Register one or more module names to be rewritten on import.
|
"""Register one or more module names to be rewritten on import.
|
||||||
|
|
|
@ -745,6 +745,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
format_dict = ast.Dict(keys, list(current.values()))
|
format_dict = ast.Dict(keys, list(current.values()))
|
||||||
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
|
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
|
||||||
name = "@py_format" + str(next(self.variable_counter))
|
name = "@py_format" + str(next(self.variable_counter))
|
||||||
|
if getattr(self.config._ns, "enable_assertion_pass_hook", False):
|
||||||
self.format_variables.append(name)
|
self.format_variables.append(name)
|
||||||
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
||||||
return ast.Name(name, ast.Load())
|
return ast.Name(name, ast.Load())
|
||||||
|
@ -780,7 +781,10 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
self.statements = []
|
self.statements = []
|
||||||
self.variables = []
|
self.variables = []
|
||||||
self.variable_counter = itertools.count()
|
self.variable_counter = itertools.count()
|
||||||
|
|
||||||
|
if getattr(self.config._ns, "enable_assertion_pass_hook", False):
|
||||||
self.format_variables = []
|
self.format_variables = []
|
||||||
|
|
||||||
self.stack = []
|
self.stack = []
|
||||||
self.expl_stmts = []
|
self.expl_stmts = []
|
||||||
self.push_format_context()
|
self.push_format_context()
|
||||||
|
@ -793,6 +797,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
top_condition, module_path=self.module_path, lineno=assert_.lineno
|
top_condition, module_path=self.module_path, lineno=assert_.lineno
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if getattr(self.config._ns, "enable_assertion_pass_hook", False):
|
||||||
|
### Experimental pytest_assertion_pass hook
|
||||||
negation = ast.UnaryOp(ast.Not(), top_condition)
|
negation = ast.UnaryOp(ast.Not(), top_condition)
|
||||||
msg = self.pop_format_context(ast.Str(explanation))
|
msg = self.pop_format_context(ast.Str(explanation))
|
||||||
if assert_.msg:
|
if assert_.msg:
|
||||||
|
@ -812,7 +818,10 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
|
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
|
||||||
hook_call_pass = ast.Expr(
|
hook_call_pass = ast.Expr(
|
||||||
self.helper(
|
self.helper(
|
||||||
"_call_assertion_pass", ast.Num(assert_.lineno), ast.Str(orig), fmt_pass
|
"_call_assertion_pass",
|
||||||
|
ast.Num(assert_.lineno),
|
||||||
|
ast.Str(orig),
|
||||||
|
fmt_pass,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -825,9 +834,31 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
self.statements.extend(self.expl_stmts)
|
self.statements.extend(self.expl_stmts)
|
||||||
self.statements.append(main_test)
|
self.statements.append(main_test)
|
||||||
if self.format_variables:
|
if self.format_variables:
|
||||||
variables = [ast.Name(name, ast.Store()) for name in self.format_variables]
|
variables = [
|
||||||
|
ast.Name(name, ast.Store()) for name in self.format_variables
|
||||||
|
]
|
||||||
clear_format = ast.Assign(variables, _NameConstant(None))
|
clear_format = ast.Assign(variables, _NameConstant(None))
|
||||||
self.statements.append(clear_format)
|
self.statements.append(clear_format)
|
||||||
|
else:
|
||||||
|
### Original assertion rewriting
|
||||||
|
# Create failure message.
|
||||||
|
body = self.expl_stmts
|
||||||
|
negation = ast.UnaryOp(ast.Not(), top_condition)
|
||||||
|
self.statements.append(ast.If(negation, body, []))
|
||||||
|
if assert_.msg:
|
||||||
|
assertmsg = self.helper("_format_assertmsg", assert_.msg)
|
||||||
|
explanation = "\n>assert " + explanation
|
||||||
|
else:
|
||||||
|
assertmsg = ast.Str("")
|
||||||
|
explanation = "assert " + explanation
|
||||||
|
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
|
||||||
|
msg = self.pop_format_context(template)
|
||||||
|
fmt = self.helper("_format_explanation", msg)
|
||||||
|
err_name = ast.Name("AssertionError", ast.Load())
|
||||||
|
exc = ast.Call(err_name, [fmt], [])
|
||||||
|
raise_ = ast.Raise(exc, None)
|
||||||
|
|
||||||
|
body.append(raise_)
|
||||||
# Clear temporary variables by setting them to None.
|
# Clear temporary variables by setting them to None.
|
||||||
if self.variables:
|
if self.variables:
|
||||||
variables = [ast.Name(name, ast.Store()) for name in self.variables]
|
variables = [ast.Name(name, ast.Store()) for name in self.variables]
|
||||||
|
|
|
@ -746,8 +746,10 @@ class Config:
|
||||||
and find all the installed plugins to mark them for rewriting
|
and find all the installed plugins to mark them for rewriting
|
||||||
by the importhook.
|
by the importhook.
|
||||||
"""
|
"""
|
||||||
ns, unknown_args = self._parser.parse_known_and_unknown_args(args)
|
# Saving _ns so it can be used for other assertion rewriting purposes
|
||||||
mode = getattr(ns, "assertmode", "plain")
|
# e.g. experimental assertion pass hook
|
||||||
|
self._ns, self._unknown_args = self._parser.parse_known_and_unknown_args(args)
|
||||||
|
mode = getattr(self._ns, "assertmode", "plain")
|
||||||
if mode == "rewrite":
|
if mode == "rewrite":
|
||||||
try:
|
try:
|
||||||
hook = _pytest.assertion.install_importhook(self)
|
hook = _pytest.assertion.install_importhook(self)
|
||||||
|
|
|
@ -503,6 +503,8 @@ def pytest_assertion_pass(item, lineno, orig, expl):
|
||||||
This hook is still *experimental*, so its parameters or even the hook itself might
|
This hook is still *experimental*, so its parameters or even the hook itself might
|
||||||
be changed/removed without warning in any future pytest release.
|
be changed/removed without warning in any future pytest release.
|
||||||
|
|
||||||
|
It should be enabled using the `--enable-assertion-pass-hook` command line option.
|
||||||
|
|
||||||
If you find this hook useful, please share your feedback opening an issue.
|
If you find this hook useful, please share your feedback opening an issue.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -1326,8 +1326,7 @@ class TestAssertionPass:
|
||||||
assert a+b == c+d
|
assert a+b == c+d
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
result = testdir.runpytest()
|
result = testdir.runpytest("--enable-assertion-pass-hook")
|
||||||
print(testdir.tmpdir)
|
|
||||||
result.stdout.fnmatch_lines(
|
result.stdout.fnmatch_lines(
|
||||||
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
|
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
|
||||||
)
|
)
|
||||||
|
@ -1343,6 +1342,38 @@ class TestAssertionPass:
|
||||||
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
|
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
|
||||||
)
|
)
|
||||||
|
|
||||||
|
testdir.makepyfile(
|
||||||
|
"""
|
||||||
|
def test_simple():
|
||||||
|
a=1
|
||||||
|
b=2
|
||||||
|
c=3
|
||||||
|
d=0
|
||||||
|
|
||||||
|
assert a+b == c+d
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = testdir.runpytest("--enable-assertion-pass-hook")
|
||||||
|
result.assert_outcomes(passed=1)
|
||||||
|
|
||||||
|
def test_hook_not_called_without_cmd_option(self, testdir, monkeypatch):
|
||||||
|
"""Assertion pass should not be called (and hence formatting should
|
||||||
|
not occur) if there is no hook declared for pytest_assertion_pass"""
|
||||||
|
|
||||||
|
def raise_on_assertionpass(*_, **__):
|
||||||
|
raise Exception("Assertion passed called when it shouldn't!")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
|
||||||
|
)
|
||||||
|
|
||||||
|
testdir.makeconftest(
|
||||||
|
"""
|
||||||
|
def pytest_assertion_pass(item, lineno, orig, expl):
|
||||||
|
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
testdir.makepyfile(
|
testdir.makepyfile(
|
||||||
"""
|
"""
|
||||||
def test_simple():
|
def test_simple():
|
||||||
|
|
Loading…
Reference in New Issue