Features assertion pass hook (#3479)

Features assertion pass hook
This commit is contained in:
Bruno Oliveira 2019-06-26 21:14:19 -03:00 committed by GitHub
commit 37fb50a3ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 301 additions and 41 deletions

View File

@ -0,0 +1,4 @@
New `pytest_assertion_pass <https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_assertion_pass>`__
hook, called with context information when an assertion *passes*.
This hook is still **experimental** so use it with caution.

View File

@ -0,0 +1 @@
pytest now also depends on the `astor <https://pypi.org/project/astor/>`__ package.

View File

@ -665,15 +665,14 @@ Session related reporting hooks:
.. autofunction:: pytest_fixture_post_finalizer .. autofunction:: pytest_fixture_post_finalizer
.. autofunction:: pytest_warning_captured .. autofunction:: pytest_warning_captured
And here is the central hook for reporting about Central hook for reporting about test execution:
test execution:
.. autofunction:: pytest_runtest_logreport .. autofunction:: pytest_runtest_logreport
You can also use this hook to customize assertion representation for some Assertion related hooks:
types:
.. autofunction:: pytest_assertrepr_compare .. autofunction:: pytest_assertrepr_compare
.. autofunction:: pytest_assertion_pass
Debugging/Interaction hooks Debugging/Interaction hooks

View File

@ -13,6 +13,7 @@ INSTALL_REQUIRES = [
"pluggy>=0.12,<1.0", "pluggy>=0.12,<1.0",
"importlib-metadata>=0.12", "importlib-metadata>=0.12",
"wcwidth", "wcwidth",
"astor",
] ]

View File

@ -23,6 +23,13 @@ def pytest_addoption(parser):
test modules on import to provide assert test modules on import to provide assert
expression information.""", expression information.""",
) )
parser.addini(
"enable_assertion_pass_hook",
type="bool",
default=False,
help="Enables the pytest_assertion_pass hook."
"Make sure to delete any previously generated pyc cache files.",
)
def register_assert_rewrite(*names): def register_assert_rewrite(*names):
@ -92,7 +99,7 @@ def pytest_collection(session):
def pytest_runtest_setup(item): def pytest_runtest_setup(item):
"""Setup the pytest_assertrepr_compare hook """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks
The newinterpret and rewrite modules will use util._reprcompare if The newinterpret and rewrite modules will use util._reprcompare if
it exists to use custom reporting via the it exists to use custom reporting via the
@ -129,9 +136,19 @@ def pytest_runtest_setup(item):
util._reprcompare = callbinrepr util._reprcompare = callbinrepr
if item.ihook.pytest_assertion_pass.get_hookimpls():
def call_assertion_pass_hook(lineno, expl, orig):
item.ihook.pytest_assertion_pass(
item=item, lineno=lineno, orig=orig, expl=expl
)
util._assertion_pass = call_assertion_pass_hook
def pytest_runtest_teardown(item): def pytest_runtest_teardown(item):
util._reprcompare = None util._reprcompare = None
util._assertion_pass = None
def pytest_sessionfinish(session): def pytest_sessionfinish(session):

View File

@ -10,6 +10,7 @@ import struct
import sys import sys
import types import types
import astor
import atomicwrites import atomicwrites
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
@ -134,7 +135,7 @@ class AssertionRewritingHook:
co = _read_pyc(fn, pyc, state.trace) co = _read_pyc(fn, pyc, state.trace)
if co is None: if co is None:
state.trace("rewriting {!r}".format(fn)) state.trace("rewriting {!r}".format(fn))
source_stat, co = _rewrite_test(fn) source_stat, co = _rewrite_test(fn, self.config)
if write: if write:
self._writing_pyc = True self._writing_pyc = True
try: try:
@ -278,13 +279,13 @@ def _write_pyc(state, co, source_stat, pyc):
return True return True
def _rewrite_test(fn): def _rewrite_test(fn, config):
"""read and rewrite *fn* and return the code object.""" """read and rewrite *fn* and return the code object."""
stat = os.stat(fn) stat = os.stat(fn)
with open(fn, "rb") as f: with open(fn, "rb") as f:
source = f.read() source = f.read()
tree = ast.parse(source, filename=fn) tree = ast.parse(source, filename=fn)
rewrite_asserts(tree, fn) rewrite_asserts(tree, fn, config)
co = compile(tree, fn, "exec", dont_inherit=True) co = compile(tree, fn, "exec", dont_inherit=True)
return stat, co return stat, co
@ -326,9 +327,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co return co
def rewrite_asserts(mod, module_path=None): def rewrite_asserts(mod, module_path=None, config=None):
"""Rewrite the assert statements in mod.""" """Rewrite the assert statements in mod."""
AssertionRewriter(module_path).run(mod) AssertionRewriter(module_path, config).run(mod)
def _saferepr(obj): def _saferepr(obj):
@ -401,6 +402,17 @@ def _call_reprcompare(ops, results, expls, each_obj):
return expl return expl
def _call_assertion_pass(lineno, orig, expl):
if util._assertion_pass is not None:
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)
def _check_if_assertion_pass_impl():
"""Checks if any plugins implement the pytest_assertion_pass hook
in order not to generate explanation unecessarily (might be expensive)"""
return True if util._assertion_pass else False
unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
binop_map = { binop_map = {
@ -473,7 +485,8 @@ class AssertionRewriter(ast.NodeVisitor):
original assert statement: it rewrites the test of an assertion original assert statement: it rewrites the test of an assertion
to provide intermediate values and replace it with an if statement to provide intermediate values and replace it with an if statement
which raises an assertion error with a detailed explanation in which raises an assertion error with a detailed explanation in
case the expression is false. case the expression is false and calls pytest_assertion_pass hook
if expression is true.
For this .visit_Assert() uses the visitor pattern to visit all the For this .visit_Assert() uses the visitor pattern to visit all the
AST nodes of the ast.Assert.test field, each visit call returning AST nodes of the ast.Assert.test field, each visit call returning
@ -491,9 +504,10 @@ class AssertionRewriter(ast.NodeVisitor):
by statements. Variables are created using .variable() and by statements. Variables are created using .variable() and
have the form of "@py_assert0". have the form of "@py_assert0".
:on_failure: The AST statements which will be executed if the :expl_stmts: The AST statements which will be executed to get
assertion test fails. This is the code which will construct data from the assertion. This is the code which will construct
the failure message and raises the AssertionError. the detailed assertion message that is used in the AssertionError
or for the pytest_assertion_pass hook.
:explanation_specifiers: A dict filled by .explanation_param() :explanation_specifiers: A dict filled by .explanation_param()
with %-formatting placeholders and their corresponding with %-formatting placeholders and their corresponding
@ -509,9 +523,16 @@ class AssertionRewriter(ast.NodeVisitor):
""" """
def __init__(self, module_path): def __init__(self, module_path, config):
super().__init__() super().__init__()
self.module_path = module_path self.module_path = module_path
self.config = config
if config is not None:
self.enable_assertion_pass_hook = config.getini(
"enable_assertion_pass_hook"
)
else:
self.enable_assertion_pass_hook = False
def run(self, mod): def run(self, mod):
"""Find all assert statements in *mod* and rewrite them.""" """Find all assert statements in *mod* and rewrite them."""
@ -642,7 +663,7 @@ class AssertionRewriter(ast.NodeVisitor):
The expl_expr should be an ast.Str instance constructed from The expl_expr should be an ast.Str instance constructed from
the %-placeholders created by .explanation_param(). This will the %-placeholders created by .explanation_param(). This will
add the required code to format said string to .on_failure and add the required code to format said string to .expl_stmts and
return the ast.Name instance of the formatted string. return the ast.Name instance of the formatted string.
""" """
@ -653,7 +674,9 @@ class AssertionRewriter(ast.NodeVisitor):
format_dict = ast.Dict(keys, list(current.values())) format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict) form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter)) name = "@py_format" + str(next(self.variable_counter))
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form)) if self.enable_assertion_pass_hook:
self.format_variables.append(name)
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load()) return ast.Name(name, ast.Load())
def generic_visit(self, node): def generic_visit(self, node):
@ -687,8 +710,12 @@ class AssertionRewriter(ast.NodeVisitor):
self.statements = [] self.statements = []
self.variables = [] self.variables = []
self.variable_counter = itertools.count() self.variable_counter = itertools.count()
if self.enable_assertion_pass_hook:
self.format_variables = []
self.stack = [] self.stack = []
self.on_failure = [] self.expl_stmts = []
self.push_format_context() self.push_format_context()
# Rewrite assert into a bunch of statements. # Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test) top_condition, explanation = self.visit(assert_.test)
@ -699,24 +726,77 @@ class AssertionRewriter(ast.NodeVisitor):
top_condition, module_path=self.module_path, lineno=assert_.lineno top_condition, module_path=self.module_path, lineno=assert_.lineno
) )
) )
# Create failure message.
body = self.on_failure
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
else:
assertmsg = ast.Str("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)
body.append(raise_) if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
negation = ast.UnaryOp(ast.Not(), top_condition)
msg = self.pop_format_context(ast.Str(explanation))
# Failed
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
gluestr = "\n>assert "
else:
assertmsg = ast.Str("")
gluestr = "assert "
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
err_name = ast.Name("AssertionError", ast.Load())
fmt = self.helper("_format_explanation", err_msg)
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)
statements_fail = []
statements_fail.extend(self.expl_stmts)
statements_fail.append(raise_)
# Passed
fmt_pass = self.helper("_format_explanation", msg)
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass",
ast.Num(assert_.lineno),
ast.Str(orig),
fmt_pass,
)
)
# If any hooks implement assert_pass hook
hook_impl_test = ast.If(
self.helper("_check_if_assertion_pass_impl"),
self.expl_stmts + [hook_call_pass],
[],
)
statements_pass = [hook_impl_test]
# Test for assertion condition
main_test = ast.If(negation, statements_fail, statements_pass)
self.statements.append(main_test)
if self.format_variables:
variables = [
ast.Name(name, ast.Store()) for name in self.format_variables
]
clear_format = ast.Assign(variables, _NameConstant(None))
self.statements.append(clear_format)
else: # Original assertion rewriting
# Create failure message.
body = self.expl_stmts
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
else:
assertmsg = ast.Str("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)
body.append(raise_)
# Clear temporary variables by setting them to None. # Clear temporary variables by setting them to None.
if self.variables: if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables] variables = [ast.Name(name, ast.Store()) for name in self.variables]
@ -770,7 +850,7 @@ warn_explicit(
app = ast.Attribute(expl_list, "append", ast.Load()) app = ast.Attribute(expl_list, "append", ast.Load())
is_or = int(isinstance(boolop.op, ast.Or)) is_or = int(isinstance(boolop.op, ast.Or))
body = save = self.statements body = save = self.statements
fail_save = self.on_failure fail_save = self.expl_stmts
levels = len(boolop.values) - 1 levels = len(boolop.values) - 1
self.push_format_context() self.push_format_context()
# Process each operand, short-circuiting if needed. # Process each operand, short-circuiting if needed.
@ -778,14 +858,14 @@ warn_explicit(
if i: if i:
fail_inner = [] fail_inner = []
# cond is set in a prior loop iteration below # cond is set in a prior loop iteration below
self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.on_failure = fail_inner self.expl_stmts = fail_inner
self.push_format_context() self.push_format_context()
res, expl = self.visit(v) res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expl_format = self.pop_format_context(ast.Str(expl)) expl_format = self.pop_format_context(ast.Str(expl))
call = ast.Call(app, [expl_format], []) call = ast.Call(app, [expl_format], [])
self.on_failure.append(ast.Expr(call)) self.expl_stmts.append(ast.Expr(call))
if i < levels: if i < levels:
cond = res cond = res
if is_or: if is_or:
@ -794,7 +874,7 @@ warn_explicit(
self.statements.append(ast.If(cond, inner, [])) self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner self.statements = body = inner
self.statements = save self.statements = save
self.on_failure = fail_save self.expl_stmts = fail_save
expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or)) expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
expl = self.pop_format_context(expl_template) expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl) return ast.Name(res_var, ast.Load()), self.explanation_param(expl)

View File

@ -12,6 +12,10 @@ from _pytest._io.saferepr import saferepr
# DebugInterpreter. # DebugInterpreter.
_reprcompare = None _reprcompare = None
# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass = None
def format_explanation(explanation): def format_explanation(explanation):
"""This formats an explanation """This formats an explanation

View File

@ -485,6 +485,42 @@ def pytest_assertrepr_compare(config, op, left, right):
""" """
def pytest_assertion_pass(item, lineno, orig, expl):
"""
**(Experimental)**
Hook called whenever an assertion *passes*.
Use this hook to do some processing after a passing assertion.
The original assertion information is available in the `orig` string
and the pytest introspected assertion information is available in the
`expl` string.
This hook must be explicitly enabled by the ``enable_assertion_pass_hook``
ini-file option:
.. code-block:: ini
[pytest]
enable_assertion_pass_hook=true
You need to **clean the .pyc** files in your project directory and interpreter libraries
when enabling this option, as assertions will require to be re-written.
:param _pytest.nodes.Item item: pytest item object of current test
:param int lineno: line number of the assert statement
:param string orig: string with original assertion
:param string expl: string with assert explanation
.. note::
This hook is **experimental**, so its parameters or even the hook itself might
be changed/removed without warning in any future pytest release.
If you find this hook useful, please share your feedback opening an issue.
"""
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# hooks for influencing reporting (invoked from _pytest_terminal) # hooks for influencing reporting (invoked from _pytest_terminal)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------

View File

@ -1101,7 +1101,10 @@ def test_fixture_values_leak(testdir):
assert fix_of_test1_ref() is None assert fix_of_test1_ref() is None
""" """
) )
result = testdir.runpytest() # Running on subprocess does not activate the HookRecorder
# which holds itself a reference to objects in case of the
# pytest_assert_reprcompare hook
result = testdir.runpytest_subprocess()
result.stdout.fnmatch_lines(["* 2 passed *"]) result.stdout.fnmatch_lines(["* 2 passed *"])

View File

@ -202,6 +202,9 @@ class TestRaises:
assert sys.exc_info() == (None, None, None) assert sys.exc_info() == (None, None, None)
del t del t
# Make sure this does get updated in locals dict
# otherwise it could keep a reference
locals()
# ensure the t instance is not stuck in a cyclic reference # ensure the t instance is not stuck in a cyclic reference
for o in gc.get_objects(): for o in gc.get_objects():

View File

@ -1332,3 +1332,115 @@ class TestEarlyRewriteBailout:
) )
result = testdir.runpytest() result = testdir.runpytest()
result.stdout.fnmatch_lines(["* 1 passed in *"]) result.stdout.fnmatch_lines(["* 1 passed in *"])
class TestAssertionPass:
def test_option_default(self, testdir):
config = testdir.parseconfig()
assert config.getini("enable_assertion_pass_hook") is False
def test_hook_call(self, testdir):
testdir.makeconftest(
"""
def pytest_assertion_pass(item, lineno, orig, expl):
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
"""
)
testdir.makeini(
"""
[pytest]
enable_assertion_pass_hook = True
"""
)
testdir.makepyfile(
"""
def test_simple():
a=1
b=2
c=3
d=0
assert a+b == c+d
# cover failing assertions with a message
def test_fails():
assert False, "assert with message"
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
)
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch):
"""Assertion pass should not be called (and hence formatting should
not occur) if there is no hook declared for pytest_assertion_pass"""
def raise_on_assertionpass(*_, **__):
raise Exception("Assertion passed called when it shouldn't!")
monkeypatch.setattr(
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
)
testdir.makeini(
"""
[pytest]
enable_assertion_pass_hook = True
"""
)
testdir.makepyfile(
"""
def test_simple():
a=1
b=2
c=3
d=0
assert a+b == c+d
"""
)
result = testdir.runpytest()
result.assert_outcomes(passed=1)
def test_hook_not_called_without_cmd_option(self, testdir, monkeypatch):
"""Assertion pass should not be called (and hence formatting should
not occur) if there is no hook declared for pytest_assertion_pass"""
def raise_on_assertionpass(*_, **__):
raise Exception("Assertion passed called when it shouldn't!")
monkeypatch.setattr(
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
)
testdir.makeconftest(
"""
def pytest_assertion_pass(item, lineno, orig, expl):
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
"""
)
testdir.makeini(
"""
[pytest]
enable_assertion_pass_hook = False
"""
)
testdir.makepyfile(
"""
def test_simple():
a=1
b=2
c=3
d=0
assert a+b == c+d
"""
)
result = testdir.runpytest()
result.assert_outcomes(passed=1)