Now dependent on command line option.

This commit is contained in:
Victor Maryama 2019-06-25 19:49:05 +02:00
parent cfbfa53f2b
commit 4db5488ed8
5 changed files with 112 additions and 38 deletions

View File

@ -24,6 +24,14 @@ def pytest_addoption(parser):
expression information.""", expression information.""",
) )
group = parser.getgroup("experimental")
group.addoption(
"--enable-assertion-pass-hook",
action="store_true",
help="Enables the pytest_assertion_pass hook."
"Make sure to delete any previously generated pyc cache files.",
)
def register_assert_rewrite(*names): def register_assert_rewrite(*names):
"""Register one or more module names to be rewritten on import. """Register one or more module names to be rewritten on import.

View File

@ -745,6 +745,7 @@ class AssertionRewriter(ast.NodeVisitor):
format_dict = ast.Dict(keys, list(current.values())) format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict) form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter)) name = "@py_format" + str(next(self.variable_counter))
if getattr(self.config._ns, "enable_assertion_pass_hook", False):
self.format_variables.append(name) self.format_variables.append(name)
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load()) return ast.Name(name, ast.Load())
@ -780,7 +781,10 @@ class AssertionRewriter(ast.NodeVisitor):
self.statements = [] self.statements = []
self.variables = [] self.variables = []
self.variable_counter = itertools.count() self.variable_counter = itertools.count()
if getattr(self.config._ns, "enable_assertion_pass_hook", False):
self.format_variables = [] self.format_variables = []
self.stack = [] self.stack = []
self.expl_stmts = [] self.expl_stmts = []
self.push_format_context() self.push_format_context()
@ -793,6 +797,8 @@ class AssertionRewriter(ast.NodeVisitor):
top_condition, module_path=self.module_path, lineno=assert_.lineno top_condition, module_path=self.module_path, lineno=assert_.lineno
) )
) )
if getattr(self.config._ns, "enable_assertion_pass_hook", False):
### Experimental pytest_assertion_pass hook
negation = ast.UnaryOp(ast.Not(), top_condition) negation = ast.UnaryOp(ast.Not(), top_condition)
msg = self.pop_format_context(ast.Str(explanation)) msg = self.pop_format_context(ast.Str(explanation))
if assert_.msg: if assert_.msg:
@ -812,7 +818,10 @@ class AssertionRewriter(ast.NodeVisitor):
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")") orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
hook_call_pass = ast.Expr( hook_call_pass = ast.Expr(
self.helper( self.helper(
"_call_assertion_pass", ast.Num(assert_.lineno), ast.Str(orig), fmt_pass "_call_assertion_pass",
ast.Num(assert_.lineno),
ast.Str(orig),
fmt_pass,
) )
) )
@ -825,9 +834,31 @@ class AssertionRewriter(ast.NodeVisitor):
self.statements.extend(self.expl_stmts) self.statements.extend(self.expl_stmts)
self.statements.append(main_test) self.statements.append(main_test)
if self.format_variables: if self.format_variables:
variables = [ast.Name(name, ast.Store()) for name in self.format_variables] variables = [
ast.Name(name, ast.Store()) for name in self.format_variables
]
clear_format = ast.Assign(variables, _NameConstant(None)) clear_format = ast.Assign(variables, _NameConstant(None))
self.statements.append(clear_format) self.statements.append(clear_format)
else:
### Original assertion rewriting
# Create failure message.
body = self.expl_stmts
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
else:
assertmsg = ast.Str("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)
body.append(raise_)
# Clear temporary variables by setting them to None. # Clear temporary variables by setting them to None.
if self.variables: if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables] variables = [ast.Name(name, ast.Store()) for name in self.variables]

View File

@ -746,8 +746,10 @@ class Config:
and find all the installed plugins to mark them for rewriting and find all the installed plugins to mark them for rewriting
by the importhook. by the importhook.
""" """
ns, unknown_args = self._parser.parse_known_and_unknown_args(args) # Saving _ns so it can be used for other assertion rewriting purposes
mode = getattr(ns, "assertmode", "plain") # e.g. experimental assertion pass hook
self._ns, self._unknown_args = self._parser.parse_known_and_unknown_args(args)
mode = getattr(self._ns, "assertmode", "plain")
if mode == "rewrite": if mode == "rewrite":
try: try:
hook = _pytest.assertion.install_importhook(self) hook = _pytest.assertion.install_importhook(self)

View File

@ -503,6 +503,8 @@ def pytest_assertion_pass(item, lineno, orig, expl):
This hook is still *experimental*, so its parameters or even the hook itself might This hook is still *experimental*, so its parameters or even the hook itself might
be changed/removed without warning in any future pytest release. be changed/removed without warning in any future pytest release.
It should be enabled using the `--enable-assertion-pass-hook` command line option.
If you find this hook useful, please share your feedback opening an issue. If you find this hook useful, please share your feedback opening an issue.
""" """

View File

@ -1326,8 +1326,7 @@ class TestAssertionPass:
assert a+b == c+d assert a+b == c+d
""" """
) )
result = testdir.runpytest() result = testdir.runpytest("--enable-assertion-pass-hook")
print(testdir.tmpdir)
result.stdout.fnmatch_lines( result.stdout.fnmatch_lines(
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*" "*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
) )
@ -1343,6 +1342,38 @@ class TestAssertionPass:
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass _pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
) )
testdir.makepyfile(
"""
def test_simple():
a=1
b=2
c=3
d=0
assert a+b == c+d
"""
)
result = testdir.runpytest("--enable-assertion-pass-hook")
result.assert_outcomes(passed=1)
def test_hook_not_called_without_cmd_option(self, testdir, monkeypatch):
"""Assertion pass should not be called (and hence formatting should
not occur) if there is no hook declared for pytest_assertion_pass"""
def raise_on_assertionpass(*_, **__):
raise Exception("Assertion passed called when it shouldn't!")
monkeypatch.setattr(
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
)
testdir.makeconftest(
"""
def pytest_assertion_pass(item, lineno, orig, expl):
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
"""
)
testdir.makepyfile( testdir.makepyfile(
""" """
def test_simple(): def test_simple():