diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 25e1ce075..68fe8fd09 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -803,10 +803,12 @@ class AssertionRewriter(ast.NodeVisitor): top_condition, module_path=self.module_path, lineno=assert_.lineno ) ) - if self.enable_assertion_pass_hook: - ### Experimental pytest_assertion_pass hook + + if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook negation = ast.UnaryOp(ast.Not(), top_condition) msg = self.pop_format_context(ast.Str(explanation)) + + # Failed if assert_.msg: assertmsg = self.helper("_format_assertmsg", assert_.msg) gluestr = "\n>assert " @@ -817,10 +819,14 @@ class AssertionRewriter(ast.NodeVisitor): err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) err_name = ast.Name("AssertionError", ast.Load()) fmt = self.helper("_format_explanation", err_msg) - fmt_pass = self.helper("_format_explanation", msg) exc = ast.Call(err_name, [fmt], []) raise_ = ast.Raise(exc, None) - # Call to hook when passes + statements_fail = [] + statements_fail.extend(self.expl_stmts) + statements_fail.append(raise_) + + # Passed + fmt_pass = self.helper("_format_explanation", msg) orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")") hook_call_pass = ast.Expr( self.helper( @@ -830,14 +836,16 @@ class AssertionRewriter(ast.NodeVisitor): fmt_pass, ) ) - # If any hooks implement assert_pass hook hook_impl_test = ast.If( self.helper("_check_if_assertionpass_impl"), [hook_call_pass], [] ) - main_test = ast.If(negation, [raise_], [hook_impl_test]) + statements_pass = [] + statements_pass.extend(self.expl_stmts) + statements_pass.append(hook_impl_test) - self.statements.extend(self.expl_stmts) + # Test for assertion condition + main_test = ast.If(negation, statements_fail, statements_pass) self.statements.append(main_test) if self.format_variables: variables = [ @@ -845,8 +853,8 @@ class AssertionRewriter(ast.NodeVisitor): ] clear_format = ast.Assign(variables, _NameConstant(None)) self.statements.append(clear_format) - else: - ### Original assertion rewriting + + else: # Original assertion rewriting # Create failure message. body = self.expl_stmts negation = ast.UnaryOp(ast.Not(), top_condition) @@ -865,6 +873,7 @@ class AssertionRewriter(ast.NodeVisitor): raise_ = ast.Raise(exc, None) body.append(raise_) + # Clear temporary variables by setting them to None. if self.variables: variables = [ast.Name(name, ast.Store()) for name in self.variables]