Assertion passed hook
This commit is contained in:
parent
3d01dd3adf
commit
9a89783fbb
|
@ -0,0 +1,2 @@
|
|||
Adds ``pytest_assertion_pass`` hook, called with assertion context information
|
||||
(original asssertion statement and pytest explanation) whenever an assertion passes.
|
1
setup.py
1
setup.py
|
@ -13,6 +13,7 @@ INSTALL_REQUIRES = [
|
|||
"pluggy>=0.12,<1.0",
|
||||
"importlib-metadata>=0.12",
|
||||
"wcwidth",
|
||||
"astor",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ def pytest_collection(session):
|
|||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
"""Setup the pytest_assertrepr_compare hook
|
||||
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks
|
||||
|
||||
The newinterpret and rewrite modules will use util._reprcompare if
|
||||
it exists to use custom reporting via the
|
||||
|
@ -129,9 +129,15 @@ def pytest_runtest_setup(item):
|
|||
|
||||
util._reprcompare = callbinrepr
|
||||
|
||||
if item.ihook.pytest_assertion_pass.get_hookimpls():
|
||||
def call_assertion_pass_hook(lineno, expl, orig):
|
||||
item.ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
|
||||
util._assertion_pass = call_assertion_pass_hook
|
||||
|
||||
|
||||
def pytest_runtest_teardown(item):
|
||||
util._reprcompare = None
|
||||
util._assertion_pass = None
|
||||
|
||||
|
||||
def pytest_sessionfinish(session):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Rewrite assertion AST to produce nice error messages"""
|
||||
import ast
|
||||
import astor
|
||||
import errno
|
||||
import imp
|
||||
import itertools
|
||||
|
@ -357,6 +358,11 @@ def _rewrite_test(config, fn):
|
|||
state.trace("failed to parse: {!r}".format(fn))
|
||||
return None, None
|
||||
rewrite_asserts(tree, fn, config)
|
||||
|
||||
# TODO: REMOVE, THIS IS ONLY FOR DEBUG
|
||||
with open(f'{str(fn)+"bak"}', "w", encoding="utf-8") as f:
|
||||
f.write(astor.to_source(tree))
|
||||
|
||||
try:
|
||||
co = compile(tree, fn.strpath, "exec", dont_inherit=True)
|
||||
except SyntaxError:
|
||||
|
@ -434,7 +440,7 @@ def _format_assertmsg(obj):
|
|||
# contains a newline it gets escaped, however if an object has a
|
||||
# .__repr__() which contains newlines it does not get escaped.
|
||||
# However in either case we want to preserve the newline.
|
||||
replaces = [("\n", "\n~"), ("%", "%%")]
|
||||
replaces = [("\n", "\n~")]
|
||||
if not isinstance(obj, str):
|
||||
obj = saferepr(obj)
|
||||
replaces.append(("\\n", "\n~"))
|
||||
|
@ -478,6 +484,17 @@ def _call_reprcompare(ops, results, expls, each_obj):
|
|||
return expl
|
||||
|
||||
|
||||
def _call_assertion_pass(lineno, orig, expl):
|
||||
if util._assertion_pass is not None:
|
||||
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)
|
||||
|
||||
|
||||
def _check_if_assertionpass_impl():
|
||||
"""Checks if any plugins implement the pytest_assertion_pass hook
|
||||
in order not to generate explanation unecessarily (might be expensive)"""
|
||||
return True if util._assertion_pass else False
|
||||
|
||||
|
||||
unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
|
||||
|
||||
binop_map = {
|
||||
|
@ -550,7 +567,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
original assert statement: it rewrites the test of an assertion
|
||||
to provide intermediate values and replace it with an if statement
|
||||
which raises an assertion error with a detailed explanation in
|
||||
case the expression is false.
|
||||
case the expression is false and calls pytest_assertion_pass hook
|
||||
if expression is true.
|
||||
|
||||
For this .visit_Assert() uses the visitor pattern to visit all the
|
||||
AST nodes of the ast.Assert.test field, each visit call returning
|
||||
|
@ -568,9 +586,10 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
by statements. Variables are created using .variable() and
|
||||
have the form of "@py_assert0".
|
||||
|
||||
:on_failure: The AST statements which will be executed if the
|
||||
assertion test fails. This is the code which will construct
|
||||
the failure message and raises the AssertionError.
|
||||
:expl_stmts: The AST statements which will be executed to get
|
||||
data from the assertion. This is the code which will construct
|
||||
the detailed assertion message that is used in the AssertionError
|
||||
or for the pytest_assertion_pass hook.
|
||||
|
||||
:explanation_specifiers: A dict filled by .explanation_param()
|
||||
with %-formatting placeholders and their corresponding
|
||||
|
@ -720,7 +739,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
|
||||
The expl_expr should be an ast.Str instance constructed from
|
||||
the %-placeholders created by .explanation_param(). This will
|
||||
add the required code to format said string to .on_failure and
|
||||
add the required code to format said string to .expl_stmts and
|
||||
return the ast.Name instance of the formatted string.
|
||||
|
||||
"""
|
||||
|
@ -731,7 +750,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
format_dict = ast.Dict(keys, list(current.values()))
|
||||
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
|
||||
name = "@py_format" + str(next(self.variable_counter))
|
||||
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
||||
self.format_variables.append(name)
|
||||
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
||||
return ast.Name(name, ast.Load())
|
||||
|
||||
def generic_visit(self, node):
|
||||
|
@ -765,8 +785,9 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
self.statements = []
|
||||
self.variables = []
|
||||
self.variable_counter = itertools.count()
|
||||
self.format_variables = []
|
||||
self.stack = []
|
||||
self.on_failure = []
|
||||
self.expl_stmts = []
|
||||
self.push_format_context()
|
||||
# Rewrite assert into a bunch of statements.
|
||||
top_condition, explanation = self.visit(assert_.test)
|
||||
|
@ -777,24 +798,46 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
top_condition, module_path=self.module_path, lineno=assert_.lineno
|
||||
)
|
||||
)
|
||||
# Create failure message.
|
||||
body = self.on_failure
|
||||
negation = ast.UnaryOp(ast.Not(), top_condition)
|
||||
self.statements.append(ast.If(negation, body, []))
|
||||
msg = self.pop_format_context(ast.Str(explanation))
|
||||
if assert_.msg:
|
||||
assertmsg = self.helper("_format_assertmsg", assert_.msg)
|
||||
explanation = "\n>assert " + explanation
|
||||
gluestr = "\n>assert "
|
||||
else:
|
||||
assertmsg = ast.Str("")
|
||||
explanation = "assert " + explanation
|
||||
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
|
||||
msg = self.pop_format_context(template)
|
||||
fmt = self.helper("_format_explanation", msg)
|
||||
gluestr = "assert "
|
||||
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
|
||||
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
|
||||
err_name = ast.Name("AssertionError", ast.Load())
|
||||
fmt = self.helper("_format_explanation", err_msg)
|
||||
fmt_pass = self.helper("_format_explanation", msg)
|
||||
exc = ast.Call(err_name, [fmt], [])
|
||||
if sys.version_info[0] >= 3:
|
||||
raise_ = ast.Raise(exc, None)
|
||||
else:
|
||||
raise_ = ast.Raise(exc, None, None)
|
||||
# Call to hook when passes
|
||||
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
|
||||
hook_call_pass = ast.Expr(
|
||||
self.helper(
|
||||
"_call_assertion_pass", ast.Num(assert_.lineno), ast.Str(orig), fmt_pass
|
||||
)
|
||||
)
|
||||
|
||||
body.append(raise_)
|
||||
# If any hooks implement assert_pass hook
|
||||
hook_impl_test = ast.If(
|
||||
self.helper("_check_if_assertionpass_impl"),
|
||||
[hook_call_pass],
|
||||
[],
|
||||
)
|
||||
main_test = ast.If(negation, [raise_], [hook_impl_test])
|
||||
|
||||
self.statements.extend(self.expl_stmts)
|
||||
self.statements.append(main_test)
|
||||
if self.format_variables:
|
||||
variables = [ast.Name(name, ast.Store()) for name in self.format_variables]
|
||||
clear_format = ast.Assign(variables, _NameConstant(None))
|
||||
self.statements.append(clear_format)
|
||||
# Clear temporary variables by setting them to None.
|
||||
if self.variables:
|
||||
variables = [ast.Name(name, ast.Store()) for name in self.variables]
|
||||
|
@ -848,7 +891,7 @@ warn_explicit(
|
|||
app = ast.Attribute(expl_list, "append", ast.Load())
|
||||
is_or = int(isinstance(boolop.op, ast.Or))
|
||||
body = save = self.statements
|
||||
fail_save = self.on_failure
|
||||
fail_save = self.expl_stmts
|
||||
levels = len(boolop.values) - 1
|
||||
self.push_format_context()
|
||||
# Process each operand, short-circuting if needed.
|
||||
|
@ -856,14 +899,14 @@ warn_explicit(
|
|||
if i:
|
||||
fail_inner = []
|
||||
# cond is set in a prior loop iteration below
|
||||
self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
|
||||
self.on_failure = fail_inner
|
||||
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
|
||||
self.expl_stmts = fail_inner
|
||||
self.push_format_context()
|
||||
res, expl = self.visit(v)
|
||||
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
||||
expl_format = self.pop_format_context(ast.Str(expl))
|
||||
call = ast.Call(app, [expl_format], [])
|
||||
self.on_failure.append(ast.Expr(call))
|
||||
self.expl_stmts.append(ast.Expr(call))
|
||||
if i < levels:
|
||||
cond = res
|
||||
if is_or:
|
||||
|
@ -872,7 +915,7 @@ warn_explicit(
|
|||
self.statements.append(ast.If(cond, inner, []))
|
||||
self.statements = body = inner
|
||||
self.statements = save
|
||||
self.on_failure = fail_save
|
||||
self.expl_stmts = fail_save
|
||||
expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
|
||||
expl = self.pop_format_context(expl_template)
|
||||
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
|
||||
|
|
|
@ -12,6 +12,10 @@ from _pytest._io.saferepr import saferepr
|
|||
# DebugInterpreter.
|
||||
_reprcompare = None
|
||||
|
||||
# Works similarly as _reprcompare attribute. Is populated with the hook call
|
||||
# when pytest_runtest_setup is called.
|
||||
_assertion_pass = None
|
||||
|
||||
|
||||
def format_explanation(explanation):
|
||||
"""This formats an explanation
|
||||
|
|
|
@ -485,6 +485,21 @@ def pytest_assertrepr_compare(config, op, left, right):
|
|||
"""
|
||||
|
||||
|
||||
def pytest_assertion_pass(item, lineno, orig, expl):
|
||||
"""Process explanation when assertions are valid.
|
||||
|
||||
Use this hook to do some processing after a passing assertion.
|
||||
The original assertion information is available in the `orig` string
|
||||
and the pytest introspected assertion information is available in the
|
||||
`expl` string.
|
||||
|
||||
:param _pytest.nodes.Item item: pytest item object of current test
|
||||
:param int lineno: line number of the assert statement
|
||||
:param string orig: string with original assertion
|
||||
:param string expl: string with assert explanation
|
||||
"""
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# hooks for influencing reporting (invoked from _pytest_terminal)
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
|
@ -1047,51 +1047,6 @@ def test_deferred_hook_checking(testdir):
|
|||
result.stdout.fnmatch_lines(["* 1 passed *"])
|
||||
|
||||
|
||||
def test_fixture_values_leak(testdir):
|
||||
"""Ensure that fixture objects are properly destroyed by the garbage collector at the end of their expected
|
||||
life-times (#2981).
|
||||
"""
|
||||
testdir.makepyfile(
|
||||
"""
|
||||
import attr
|
||||
import gc
|
||||
import pytest
|
||||
import weakref
|
||||
|
||||
@attr.s
|
||||
class SomeObj(object):
|
||||
name = attr.ib()
|
||||
|
||||
fix_of_test1_ref = None
|
||||
session_ref = None
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def session_fix():
|
||||
global session_ref
|
||||
obj = SomeObj(name='session-fixture')
|
||||
session_ref = weakref.ref(obj)
|
||||
return obj
|
||||
|
||||
@pytest.fixture
|
||||
def fix(session_fix):
|
||||
global fix_of_test1_ref
|
||||
obj = SomeObj(name='local-fixture')
|
||||
fix_of_test1_ref = weakref.ref(obj)
|
||||
return obj
|
||||
|
||||
def test1(fix):
|
||||
assert fix_of_test1_ref() is fix
|
||||
|
||||
def test2():
|
||||
gc.collect()
|
||||
# fixture "fix" created during test1 must have been destroyed by now
|
||||
assert fix_of_test1_ref() is None
|
||||
"""
|
||||
)
|
||||
result = testdir.runpytest()
|
||||
result.stdout.fnmatch_lines(["* 2 passed *"])
|
||||
|
||||
|
||||
def test_fixture_order_respects_scope(testdir):
|
||||
"""Ensure that fixtures are created according to scope order, regression test for #2405
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
"""Ensure that fixture objects are properly destroyed by the garbage collector at the end of their expected
|
||||
life-times (#2981).
|
||||
|
||||
This comes from the old acceptance_test.py::test_fixture_values_leak(testdir):
|
||||
This used pytester before but was not working when using pytest_assert_reprcompare
|
||||
because pytester tracks hook calls and it would hold a reference (ParsedCall object),
|
||||
preventing garbage collection
|
||||
|
||||
<ParsedCall 'pytest_assertrepr_compare'(**{
|
||||
'config': <_pytest.config.Config object at 0x0000019C18D1C2B0>,
|
||||
'op': 'is',
|
||||
'left': SomeObj(name='local-fixture'),
|
||||
'right': SomeObj(name='local-fixture')})>
|
||||
"""
|
||||
import attr
|
||||
import gc
|
||||
import pytest
|
||||
import weakref
|
||||
|
||||
|
||||
@attr.s
|
||||
class SomeObj(object):
|
||||
name = attr.ib()
|
||||
|
||||
|
||||
fix_of_test1_ref = None
|
||||
session_ref = None
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def session_fix():
|
||||
global session_ref
|
||||
obj = SomeObj(name="session-fixture")
|
||||
session_ref = weakref.ref(obj)
|
||||
return obj
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fix(session_fix):
|
||||
global fix_of_test1_ref
|
||||
obj = SomeObj(name="local-fixture")
|
||||
fix_of_test1_ref = weakref.ref(obj)
|
||||
return obj
|
||||
|
||||
|
||||
def test1(fix):
|
||||
assert fix_of_test1_ref() is fix
|
||||
|
||||
|
||||
def test2():
|
||||
gc.collect()
|
||||
# fixture "fix" created during test1 must have been destroyed by now
|
||||
assert fix_of_test1_ref() is None
|
|
@ -202,6 +202,9 @@ class TestRaises:
|
|||
assert sys.exc_info() == (None, None, None)
|
||||
|
||||
del t
|
||||
# Make sure this does get updated in locals dict
|
||||
# otherwise it could keep a reference
|
||||
locals()
|
||||
|
||||
# ensure the t instance is not stuck in a cyclic reference
|
||||
for o in gc.get_objects():
|
||||
|
|
|
@ -1305,3 +1305,54 @@ class TestEarlyRewriteBailout:
|
|||
)
|
||||
result = testdir.runpytest()
|
||||
result.stdout.fnmatch_lines(["* 1 passed in *"])
|
||||
|
||||
|
||||
class TestAssertionPass:
|
||||
def test_hook_call(self, testdir):
|
||||
testdir.makeconftest(
|
||||
"""
|
||||
def pytest_assertion_pass(item, lineno, orig, expl):
|
||||
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
|
||||
"""
|
||||
)
|
||||
testdir.makepyfile(
|
||||
"""
|
||||
def test_simple():
|
||||
a=1
|
||||
b=2
|
||||
c=3
|
||||
d=0
|
||||
|
||||
assert a+b == c+d
|
||||
"""
|
||||
)
|
||||
result = testdir.runpytest()
|
||||
print(testdir.tmpdir)
|
||||
result.stdout.fnmatch_lines(
|
||||
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
|
||||
)
|
||||
|
||||
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch):
|
||||
"""Assertion pass should not be called (and hence formatting should
|
||||
not occur) if there is no hook declared for pytest_assertion_pass"""
|
||||
|
||||
def raise_on_assertionpass(*_, **__):
|
||||
raise Exception("Assertion passed called when it shouldn't!")
|
||||
|
||||
monkeypatch.setattr(_pytest.assertion.rewrite,
|
||||
"_call_assertion_pass",
|
||||
raise_on_assertionpass)
|
||||
|
||||
testdir.makepyfile(
|
||||
"""
|
||||
def test_simple():
|
||||
a=1
|
||||
b=2
|
||||
c=3
|
||||
d=0
|
||||
|
||||
assert a+b == c+d
|
||||
"""
|
||||
)
|
||||
result = testdir.runpytest()
|
||||
result.assert_outcomes(passed=1)
|
||||
|
|
Loading…
Reference in New Issue