diff --git a/changelog/3457.feature.rst b/changelog/3457.feature.rst new file mode 100644 index 000000000..3f6765144 --- /dev/null +++ b/changelog/3457.feature.rst @@ -0,0 +1,2 @@ +Adds ``pytest_assertion_pass`` hook, called with assertion context information +(original asssertion statement and pytest explanation) whenever an assertion passes. diff --git a/setup.py b/setup.py index 4c87c6429..7d9532816 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ INSTALL_REQUIRES = [ "pluggy>=0.12,<1.0", "importlib-metadata>=0.12", "wcwidth", + "astor", ] diff --git a/src/_pytest/assertion/__init__.py b/src/_pytest/assertion/__init__.py index e52101c9f..b59b1bfdf 100644 --- a/src/_pytest/assertion/__init__.py +++ b/src/_pytest/assertion/__init__.py @@ -92,7 +92,7 @@ def pytest_collection(session): def pytest_runtest_setup(item): - """Setup the pytest_assertrepr_compare hook + """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks The newinterpret and rewrite modules will use util._reprcompare if it exists to use custom reporting via the @@ -129,9 +129,15 @@ def pytest_runtest_setup(item): util._reprcompare = callbinrepr + if item.ihook.pytest_assertion_pass.get_hookimpls(): + def call_assertion_pass_hook(lineno, expl, orig): + item.ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl) + util._assertion_pass = call_assertion_pass_hook + def pytest_runtest_teardown(item): util._reprcompare = None + util._assertion_pass = None def pytest_sessionfinish(session): diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index ce698f368..77dab5242 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -1,5 +1,6 @@ """Rewrite assertion AST to produce nice error messages""" import ast +import astor import errno import imp import itertools @@ -357,6 +358,11 @@ def _rewrite_test(config, fn): state.trace("failed to parse: {!r}".format(fn)) return None, None rewrite_asserts(tree, fn, config) + + # TODO: REMOVE, THIS IS ONLY FOR DEBUG + with open(f'{str(fn)+"bak"}', "w", encoding="utf-8") as f: + f.write(astor.to_source(tree)) + try: co = compile(tree, fn.strpath, "exec", dont_inherit=True) except SyntaxError: @@ -434,7 +440,7 @@ def _format_assertmsg(obj): # contains a newline it gets escaped, however if an object has a # .__repr__() which contains newlines it does not get escaped. # However in either case we want to preserve the newline. - replaces = [("\n", "\n~"), ("%", "%%")] + replaces = [("\n", "\n~")] if not isinstance(obj, str): obj = saferepr(obj) replaces.append(("\\n", "\n~")) @@ -478,6 +484,17 @@ def _call_reprcompare(ops, results, expls, each_obj): return expl +def _call_assertion_pass(lineno, orig, expl): + if util._assertion_pass is not None: + util._assertion_pass(lineno=lineno, orig=orig, expl=expl) + + +def _check_if_assertionpass_impl(): + """Checks if any plugins implement the pytest_assertion_pass hook + in order not to generate explanation unecessarily (might be expensive)""" + return True if util._assertion_pass else False + + unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} binop_map = { @@ -550,7 +567,8 @@ class AssertionRewriter(ast.NodeVisitor): original assert statement: it rewrites the test of an assertion to provide intermediate values and replace it with an if statement which raises an assertion error with a detailed explanation in - case the expression is false. + case the expression is false and calls pytest_assertion_pass hook + if expression is true. For this .visit_Assert() uses the visitor pattern to visit all the AST nodes of the ast.Assert.test field, each visit call returning @@ -568,9 +586,10 @@ class AssertionRewriter(ast.NodeVisitor): by statements. Variables are created using .variable() and have the form of "@py_assert0". - :on_failure: The AST statements which will be executed if the - assertion test fails. This is the code which will construct - the failure message and raises the AssertionError. + :expl_stmts: The AST statements which will be executed to get + data from the assertion. This is the code which will construct + the detailed assertion message that is used in the AssertionError + or for the pytest_assertion_pass hook. :explanation_specifiers: A dict filled by .explanation_param() with %-formatting placeholders and their corresponding @@ -720,7 +739,7 @@ class AssertionRewriter(ast.NodeVisitor): The expl_expr should be an ast.Str instance constructed from the %-placeholders created by .explanation_param(). This will - add the required code to format said string to .on_failure and + add the required code to format said string to .expl_stmts and return the ast.Name instance of the formatted string. """ @@ -731,7 +750,8 @@ class AssertionRewriter(ast.NodeVisitor): format_dict = ast.Dict(keys, list(current.values())) form = ast.BinOp(expl_expr, ast.Mod(), format_dict) name = "@py_format" + str(next(self.variable_counter)) - self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form)) + self.format_variables.append(name) + self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) return ast.Name(name, ast.Load()) def generic_visit(self, node): @@ -765,8 +785,9 @@ class AssertionRewriter(ast.NodeVisitor): self.statements = [] self.variables = [] self.variable_counter = itertools.count() + self.format_variables = [] self.stack = [] - self.on_failure = [] + self.expl_stmts = [] self.push_format_context() # Rewrite assert into a bunch of statements. top_condition, explanation = self.visit(assert_.test) @@ -777,24 +798,46 @@ class AssertionRewriter(ast.NodeVisitor): top_condition, module_path=self.module_path, lineno=assert_.lineno ) ) - # Create failure message. - body = self.on_failure negation = ast.UnaryOp(ast.Not(), top_condition) - self.statements.append(ast.If(negation, body, [])) + msg = self.pop_format_context(ast.Str(explanation)) if assert_.msg: assertmsg = self.helper("_format_assertmsg", assert_.msg) - explanation = "\n>assert " + explanation + gluestr = "\n>assert " else: assertmsg = ast.Str("") - explanation = "assert " + explanation - template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) - msg = self.pop_format_context(template) - fmt = self.helper("_format_explanation", msg) + gluestr = "assert " + err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg) + err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) err_name = ast.Name("AssertionError", ast.Load()) + fmt = self.helper("_format_explanation", err_msg) + fmt_pass = self.helper("_format_explanation", msg) exc = ast.Call(err_name, [fmt], []) - raise_ = ast.Raise(exc, None) + if sys.version_info[0] >= 3: + raise_ = ast.Raise(exc, None) + else: + raise_ = ast.Raise(exc, None, None) + # Call to hook when passes + orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")") + hook_call_pass = ast.Expr( + self.helper( + "_call_assertion_pass", ast.Num(assert_.lineno), ast.Str(orig), fmt_pass + ) + ) - body.append(raise_) + # If any hooks implement assert_pass hook + hook_impl_test = ast.If( + self.helper("_check_if_assertionpass_impl"), + [hook_call_pass], + [], + ) + main_test = ast.If(negation, [raise_], [hook_impl_test]) + + self.statements.extend(self.expl_stmts) + self.statements.append(main_test) + if self.format_variables: + variables = [ast.Name(name, ast.Store()) for name in self.format_variables] + clear_format = ast.Assign(variables, _NameConstant(None)) + self.statements.append(clear_format) # Clear temporary variables by setting them to None. if self.variables: variables = [ast.Name(name, ast.Store()) for name in self.variables] @@ -848,7 +891,7 @@ warn_explicit( app = ast.Attribute(expl_list, "append", ast.Load()) is_or = int(isinstance(boolop.op, ast.Or)) body = save = self.statements - fail_save = self.on_failure + fail_save = self.expl_stmts levels = len(boolop.values) - 1 self.push_format_context() # Process each operand, short-circuting if needed. @@ -856,14 +899,14 @@ warn_explicit( if i: fail_inner = [] # cond is set in a prior loop iteration below - self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa - self.on_failure = fail_inner + self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa + self.expl_stmts = fail_inner self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) expl_format = self.pop_format_context(ast.Str(expl)) call = ast.Call(app, [expl_format], []) - self.on_failure.append(ast.Expr(call)) + self.expl_stmts.append(ast.Expr(call)) if i < levels: cond = res if is_or: @@ -872,7 +915,7 @@ warn_explicit( self.statements.append(ast.If(cond, inner, [])) self.statements = body = inner self.statements = save - self.on_failure = fail_save + self.expl_stmts = fail_save expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or)) expl = self.pop_format_context(expl_template) return ast.Name(res_var, ast.Load()), self.explanation_param(expl) diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index 762e5761d..df4587985 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -12,6 +12,10 @@ from _pytest._io.saferepr import saferepr # DebugInterpreter. _reprcompare = None +# Works similarly as _reprcompare attribute. Is populated with the hook call +# when pytest_runtest_setup is called. +_assertion_pass = None + def format_explanation(explanation): """This formats an explanation diff --git a/src/_pytest/hookspec.py b/src/_pytest/hookspec.py index d40a36811..70301566c 100644 --- a/src/_pytest/hookspec.py +++ b/src/_pytest/hookspec.py @@ -485,6 +485,21 @@ def pytest_assertrepr_compare(config, op, left, right): """ +def pytest_assertion_pass(item, lineno, orig, expl): + """Process explanation when assertions are valid. + + Use this hook to do some processing after a passing assertion. + The original assertion information is available in the `orig` string + and the pytest introspected assertion information is available in the + `expl` string. + + :param _pytest.nodes.Item item: pytest item object of current test + :param int lineno: line number of the assert statement + :param string orig: string with original assertion + :param string expl: string with assert explanation + """ + + # ------------------------------------------------------------------------- # hooks for influencing reporting (invoked from _pytest_terminal) # ------------------------------------------------------------------------- diff --git a/testing/acceptance_test.py b/testing/acceptance_test.py index 60cc21c4a..ccb69dd79 100644 --- a/testing/acceptance_test.py +++ b/testing/acceptance_test.py @@ -1047,51 +1047,6 @@ def test_deferred_hook_checking(testdir): result.stdout.fnmatch_lines(["* 1 passed *"]) -def test_fixture_values_leak(testdir): - """Ensure that fixture objects are properly destroyed by the garbage collector at the end of their expected - life-times (#2981). - """ - testdir.makepyfile( - """ - import attr - import gc - import pytest - import weakref - - @attr.s - class SomeObj(object): - name = attr.ib() - - fix_of_test1_ref = None - session_ref = None - - @pytest.fixture(scope='session') - def session_fix(): - global session_ref - obj = SomeObj(name='session-fixture') - session_ref = weakref.ref(obj) - return obj - - @pytest.fixture - def fix(session_fix): - global fix_of_test1_ref - obj = SomeObj(name='local-fixture') - fix_of_test1_ref = weakref.ref(obj) - return obj - - def test1(fix): - assert fix_of_test1_ref() is fix - - def test2(): - gc.collect() - # fixture "fix" created during test1 must have been destroyed by now - assert fix_of_test1_ref() is None - """ - ) - result = testdir.runpytest() - result.stdout.fnmatch_lines(["* 2 passed *"]) - - def test_fixture_order_respects_scope(testdir): """Ensure that fixtures are created according to scope order, regression test for #2405 """ diff --git a/testing/fixture_values_leak_test.py b/testing/fixture_values_leak_test.py new file mode 100644 index 000000000..6f4c90d3e --- /dev/null +++ b/testing/fixture_values_leak_test.py @@ -0,0 +1,53 @@ +"""Ensure that fixture objects are properly destroyed by the garbage collector at the end of their expected +life-times (#2981). + +This comes from the old acceptance_test.py::test_fixture_values_leak(testdir): +This used pytester before but was not working when using pytest_assert_reprcompare +because pytester tracks hook calls and it would hold a reference (ParsedCall object), +preventing garbage collection + +, + 'op': 'is', + 'left': SomeObj(name='local-fixture'), + 'right': SomeObj(name='local-fixture')})> +""" +import attr +import gc +import pytest +import weakref + + +@attr.s +class SomeObj(object): + name = attr.ib() + + +fix_of_test1_ref = None +session_ref = None + + +@pytest.fixture(scope="session") +def session_fix(): + global session_ref + obj = SomeObj(name="session-fixture") + session_ref = weakref.ref(obj) + return obj + + +@pytest.fixture +def fix(session_fix): + global fix_of_test1_ref + obj = SomeObj(name="local-fixture") + fix_of_test1_ref = weakref.ref(obj) + return obj + + +def test1(fix): + assert fix_of_test1_ref() is fix + + +def test2(): + gc.collect() + # fixture "fix" created during test1 must have been destroyed by now + assert fix_of_test1_ref() is None diff --git a/testing/python/raises.py b/testing/python/raises.py index 89cef38f1..c9ede412a 100644 --- a/testing/python/raises.py +++ b/testing/python/raises.py @@ -202,6 +202,9 @@ class TestRaises: assert sys.exc_info() == (None, None, None) del t + # Make sure this does get updated in locals dict + # otherwise it could keep a reference + locals() # ensure the t instance is not stuck in a cyclic reference for o in gc.get_objects(): diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 0e6f42239..8304cf057 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1305,3 +1305,54 @@ class TestEarlyRewriteBailout: ) result = testdir.runpytest() result.stdout.fnmatch_lines(["* 1 passed in *"]) + + +class TestAssertionPass: + def test_hook_call(self, testdir): + testdir.makeconftest( + """ + def pytest_assertion_pass(item, lineno, orig, expl): + raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno)) + """ + ) + testdir.makepyfile( + """ + def test_simple(): + a=1 + b=2 + c=3 + d=0 + + assert a+b == c+d + """ + ) + result = testdir.runpytest() + print(testdir.tmpdir) + result.stdout.fnmatch_lines( + "*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*" + ) + + def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch): + """Assertion pass should not be called (and hence formatting should + not occur) if there is no hook declared for pytest_assertion_pass""" + + def raise_on_assertionpass(*_, **__): + raise Exception("Assertion passed called when it shouldn't!") + + monkeypatch.setattr(_pytest.assertion.rewrite, + "_call_assertion_pass", + raise_on_assertionpass) + + testdir.makepyfile( + """ + def test_simple(): + a=1 + b=2 + c=3 + d=0 + + assert a+b == c+d + """ + ) + result = testdir.runpytest() + result.assert_outcomes(passed=1)