Assertion passed hook

This commit is contained in:
Victor Maryama 2019-06-24 16:09:39 +02:00
parent 3d01dd3adf
commit 9a89783fbb
10 changed files with 202 additions and 69 deletions

View File

@ -0,0 +1,2 @@
Adds ``pytest_assertion_pass`` hook, called with assertion context information
(original asssertion statement and pytest explanation) whenever an assertion passes.

View File

@ -13,6 +13,7 @@ INSTALL_REQUIRES = [
"pluggy>=0.12,<1.0",
"importlib-metadata>=0.12",
"wcwidth",
"astor",
]

View File

@ -92,7 +92,7 @@ def pytest_collection(session):
def pytest_runtest_setup(item):
"""Setup the pytest_assertrepr_compare hook
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks
The newinterpret and rewrite modules will use util._reprcompare if
it exists to use custom reporting via the
@ -129,9 +129,15 @@ def pytest_runtest_setup(item):
util._reprcompare = callbinrepr
if item.ihook.pytest_assertion_pass.get_hookimpls():
def call_assertion_pass_hook(lineno, expl, orig):
item.ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
util._assertion_pass = call_assertion_pass_hook
def pytest_runtest_teardown(item):
util._reprcompare = None
util._assertion_pass = None
def pytest_sessionfinish(session):

View File

@ -1,5 +1,6 @@
"""Rewrite assertion AST to produce nice error messages"""
import ast
import astor
import errno
import imp
import itertools
@ -357,6 +358,11 @@ def _rewrite_test(config, fn):
state.trace("failed to parse: {!r}".format(fn))
return None, None
rewrite_asserts(tree, fn, config)
# TODO: REMOVE, THIS IS ONLY FOR DEBUG
with open(f'{str(fn)+"bak"}', "w", encoding="utf-8") as f:
f.write(astor.to_source(tree))
try:
co = compile(tree, fn.strpath, "exec", dont_inherit=True)
except SyntaxError:
@ -434,7 +440,7 @@ def _format_assertmsg(obj):
# contains a newline it gets escaped, however if an object has a
# .__repr__() which contains newlines it does not get escaped.
# However in either case we want to preserve the newline.
replaces = [("\n", "\n~"), ("%", "%%")]
replaces = [("\n", "\n~")]
if not isinstance(obj, str):
obj = saferepr(obj)
replaces.append(("\\n", "\n~"))
@ -478,6 +484,17 @@ def _call_reprcompare(ops, results, expls, each_obj):
return expl
def _call_assertion_pass(lineno, orig, expl):
if util._assertion_pass is not None:
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)
def _check_if_assertionpass_impl():
"""Checks if any plugins implement the pytest_assertion_pass hook
in order not to generate explanation unecessarily (might be expensive)"""
return True if util._assertion_pass else False
unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
binop_map = {
@ -550,7 +567,8 @@ class AssertionRewriter(ast.NodeVisitor):
original assert statement: it rewrites the test of an assertion
to provide intermediate values and replace it with an if statement
which raises an assertion error with a detailed explanation in
case the expression is false.
case the expression is false and calls pytest_assertion_pass hook
if expression is true.
For this .visit_Assert() uses the visitor pattern to visit all the
AST nodes of the ast.Assert.test field, each visit call returning
@ -568,9 +586,10 @@ class AssertionRewriter(ast.NodeVisitor):
by statements. Variables are created using .variable() and
have the form of "@py_assert0".
:on_failure: The AST statements which will be executed if the
assertion test fails. This is the code which will construct
the failure message and raises the AssertionError.
:expl_stmts: The AST statements which will be executed to get
data from the assertion. This is the code which will construct
the detailed assertion message that is used in the AssertionError
or for the pytest_assertion_pass hook.
:explanation_specifiers: A dict filled by .explanation_param()
with %-formatting placeholders and their corresponding
@ -720,7 +739,7 @@ class AssertionRewriter(ast.NodeVisitor):
The expl_expr should be an ast.Str instance constructed from
the %-placeholders created by .explanation_param(). This will
add the required code to format said string to .on_failure and
add the required code to format said string to .expl_stmts and
return the ast.Name instance of the formatted string.
"""
@ -731,7 +750,8 @@ class AssertionRewriter(ast.NodeVisitor):
format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter))
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
self.format_variables.append(name)
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load())
def generic_visit(self, node):
@ -765,8 +785,9 @@ class AssertionRewriter(ast.NodeVisitor):
self.statements = []
self.variables = []
self.variable_counter = itertools.count()
self.format_variables = []
self.stack = []
self.on_failure = []
self.expl_stmts = []
self.push_format_context()
# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)
@ -777,24 +798,46 @@ class AssertionRewriter(ast.NodeVisitor):
top_condition, module_path=self.module_path, lineno=assert_.lineno
)
)
# Create failure message.
body = self.on_failure
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))
msg = self.pop_format_context(ast.Str(explanation))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
gluestr = "\n>assert "
else:
assertmsg = ast.Str("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
gluestr = "assert "
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
err_name = ast.Name("AssertionError", ast.Load())
fmt = self.helper("_format_explanation", err_msg)
fmt_pass = self.helper("_format_explanation", msg)
exc = ast.Call(err_name, [fmt], [])
if sys.version_info[0] >= 3:
raise_ = ast.Raise(exc, None)
else:
raise_ = ast.Raise(exc, None, None)
# Call to hook when passes
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass", ast.Num(assert_.lineno), ast.Str(orig), fmt_pass
)
)
body.append(raise_)
# If any hooks implement assert_pass hook
hook_impl_test = ast.If(
self.helper("_check_if_assertionpass_impl"),
[hook_call_pass],
[],
)
main_test = ast.If(negation, [raise_], [hook_impl_test])
self.statements.extend(self.expl_stmts)
self.statements.append(main_test)
if self.format_variables:
variables = [ast.Name(name, ast.Store()) for name in self.format_variables]
clear_format = ast.Assign(variables, _NameConstant(None))
self.statements.append(clear_format)
# Clear temporary variables by setting them to None.
if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables]
@ -848,7 +891,7 @@ warn_explicit(
app = ast.Attribute(expl_list, "append", ast.Load())
is_or = int(isinstance(boolop.op, ast.Or))
body = save = self.statements
fail_save = self.on_failure
fail_save = self.expl_stmts
levels = len(boolop.values) - 1
self.push_format_context()
# Process each operand, short-circuting if needed.
@ -856,14 +899,14 @@ warn_explicit(
if i:
fail_inner = []
# cond is set in a prior loop iteration below
self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
self.on_failure = fail_inner
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expl_format = self.pop_format_context(ast.Str(expl))
call = ast.Call(app, [expl_format], [])
self.on_failure.append(ast.Expr(call))
self.expl_stmts.append(ast.Expr(call))
if i < levels:
cond = res
if is_or:
@ -872,7 +915,7 @@ warn_explicit(
self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner
self.statements = save
self.on_failure = fail_save
self.expl_stmts = fail_save
expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)

View File

@ -12,6 +12,10 @@ from _pytest._io.saferepr import saferepr
# DebugInterpreter.
_reprcompare = None
# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass = None
def format_explanation(explanation):
"""This formats an explanation

View File

@ -485,6 +485,21 @@ def pytest_assertrepr_compare(config, op, left, right):
"""
def pytest_assertion_pass(item, lineno, orig, expl):
"""Process explanation when assertions are valid.
Use this hook to do some processing after a passing assertion.
The original assertion information is available in the `orig` string
and the pytest introspected assertion information is available in the
`expl` string.
:param _pytest.nodes.Item item: pytest item object of current test
:param int lineno: line number of the assert statement
:param string orig: string with original assertion
:param string expl: string with assert explanation
"""
# -------------------------------------------------------------------------
# hooks for influencing reporting (invoked from _pytest_terminal)
# -------------------------------------------------------------------------

View File

@ -1047,51 +1047,6 @@ def test_deferred_hook_checking(testdir):
result.stdout.fnmatch_lines(["* 1 passed *"])
def test_fixture_values_leak(testdir):
"""Ensure that fixture objects are properly destroyed by the garbage collector at the end of their expected
life-times (#2981).
"""
testdir.makepyfile(
"""
import attr
import gc
import pytest
import weakref
@attr.s
class SomeObj(object):
name = attr.ib()
fix_of_test1_ref = None
session_ref = None
@pytest.fixture(scope='session')
def session_fix():
global session_ref
obj = SomeObj(name='session-fixture')
session_ref = weakref.ref(obj)
return obj
@pytest.fixture
def fix(session_fix):
global fix_of_test1_ref
obj = SomeObj(name='local-fixture')
fix_of_test1_ref = weakref.ref(obj)
return obj
def test1(fix):
assert fix_of_test1_ref() is fix
def test2():
gc.collect()
# fixture "fix" created during test1 must have been destroyed by now
assert fix_of_test1_ref() is None
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(["* 2 passed *"])
def test_fixture_order_respects_scope(testdir):
"""Ensure that fixtures are created according to scope order, regression test for #2405
"""

View File

@ -0,0 +1,53 @@
"""Ensure that fixture objects are properly destroyed by the garbage collector at the end of their expected
life-times (#2981).
This comes from the old acceptance_test.py::test_fixture_values_leak(testdir):
This used pytester before but was not working when using pytest_assert_reprcompare
because pytester tracks hook calls and it would hold a reference (ParsedCall object),
preventing garbage collection
<ParsedCall 'pytest_assertrepr_compare'(**{
'config': <_pytest.config.Config object at 0x0000019C18D1C2B0>,
'op': 'is',
'left': SomeObj(name='local-fixture'),
'right': SomeObj(name='local-fixture')})>
"""
import attr
import gc
import pytest
import weakref
@attr.s
class SomeObj(object):
name = attr.ib()
fix_of_test1_ref = None
session_ref = None
@pytest.fixture(scope="session")
def session_fix():
global session_ref
obj = SomeObj(name="session-fixture")
session_ref = weakref.ref(obj)
return obj
@pytest.fixture
def fix(session_fix):
global fix_of_test1_ref
obj = SomeObj(name="local-fixture")
fix_of_test1_ref = weakref.ref(obj)
return obj
def test1(fix):
assert fix_of_test1_ref() is fix
def test2():
gc.collect()
# fixture "fix" created during test1 must have been destroyed by now
assert fix_of_test1_ref() is None

View File

@ -202,6 +202,9 @@ class TestRaises:
assert sys.exc_info() == (None, None, None)
del t
# Make sure this does get updated in locals dict
# otherwise it could keep a reference
locals()
# ensure the t instance is not stuck in a cyclic reference
for o in gc.get_objects():

View File

@ -1305,3 +1305,54 @@ class TestEarlyRewriteBailout:
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(["* 1 passed in *"])
class TestAssertionPass:
def test_hook_call(self, testdir):
testdir.makeconftest(
"""
def pytest_assertion_pass(item, lineno, orig, expl):
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
"""
)
testdir.makepyfile(
"""
def test_simple():
a=1
b=2
c=3
d=0
assert a+b == c+d
"""
)
result = testdir.runpytest()
print(testdir.tmpdir)
result.stdout.fnmatch_lines(
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
)
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch):
"""Assertion pass should not be called (and hence formatting should
not occur) if there is no hook declared for pytest_assertion_pass"""
def raise_on_assertionpass(*_, **__):
raise Exception("Assertion passed called when it shouldn't!")
monkeypatch.setattr(_pytest.assertion.rewrite,
"_call_assertion_pass",
raise_on_assertionpass)
testdir.makepyfile(
"""
def test_simple():
a=1
b=2
c=3
d=0
assert a+b == c+d
"""
)
result = testdir.runpytest()
result.assert_outcomes(passed=1)