diff --git a/AUTHORS b/AUTHORS index 0395feceb..448b71d3a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -12,6 +12,7 @@ Adam Uhlir Ahn Ki-Wook Akiomi Kamakura Alan Velasco +Alessio Izzo Alexander Johnson Alexander King Alexei Kozlenok diff --git a/changelog/10743.bugfix.rst b/changelog/10743.bugfix.rst new file mode 100644 index 000000000..ad5c63e80 --- /dev/null +++ b/changelog/10743.bugfix.rst @@ -0,0 +1 @@ +The assertion rewriting mechanism now works correctly when assertion expressions contain the walrus operator. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 42664add4..8b1823470 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -44,10 +44,14 @@ from _pytest.stash import StashKey if TYPE_CHECKING: from _pytest.assertion import AssertionState +if sys.version_info >= (3, 8): + namedExpr = ast.NamedExpr +else: + namedExpr = ast.Expr + assertstate_key = StashKey["AssertionState"]() - # pytest caches rewritten pycs in pycache dirs PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}" PYC_EXT = ".py" + (__debug__ and "c" or "o") @@ -635,8 +639,12 @@ class AssertionRewriter(ast.NodeVisitor): .push_format_context() and .pop_format_context() which allows to build another %-formatted string while already building one. - This state is reset on every new assert statement visited and used - by the other visitors. + :variables_overwrite: A dict filled with references to variables + that change value within an assert. This happens when a variable is + reassigned with the walrus operator + + This state, except the variables_overwrite, is reset on every new assert + statement visited and used by the other visitors. """ def __init__( @@ -652,6 +660,7 @@ class AssertionRewriter(ast.NodeVisitor): else: self.enable_assertion_pass_hook = False self.source = source + self.variables_overwrite: Dict[str, str] = {} def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -666,7 +675,7 @@ class AssertionRewriter(ast.NodeVisitor): if doc is not None and self.is_rewrite_disabled(doc): return pos = 0 - lineno = 1 + item = None for item in mod.body: if ( expect_docstring @@ -937,6 +946,18 @@ class AssertionRewriter(ast.NodeVisitor): ast.copy_location(node, assert_) return self.statements + def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]: + # This method handles the 'walrus operator' repr of the target + # name if it's a local variable or _should_repr_global_name() + # thinks it's acceptable. + locs = ast.Call(self.builtin("locals"), [], []) + target_id = name.target.id # type: ignore[attr-defined] + inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs]) + dorepr = self.helper("_should_repr_global_name", name) + test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) + expr = ast.IfExp(test, self.display(name), ast.Str(target_id)) + return name, self.explanation_param(expr) + def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: # Display the repr of the name if it's a local variable or # _should_repr_global_name() thinks it's acceptable. @@ -963,6 +984,20 @@ class AssertionRewriter(ast.NodeVisitor): # cond is set in a prior loop iteration below self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa self.expl_stmts = fail_inner + # Check if the left operand is a namedExpr and the value has already been visited + if ( + isinstance(v, ast.Compare) + and isinstance(v.left, namedExpr) + and v.left.target.id + in [ + ast_expr.id + for ast_expr in boolop.values[:i] + if hasattr(ast_expr, "id") + ] + ): + pytest_temp = self.variable() + self.variables_overwrite[v.left.target.id] = pytest_temp + v.left.target.id = pytest_temp self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) @@ -1038,6 +1073,9 @@ class AssertionRewriter(ast.NodeVisitor): def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() + # We first check if we have overwritten a variable in the previous assert + if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite: + comp.left.id = self.variables_overwrite[comp.left.id] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, (ast.Compare, ast.BoolOp)): left_expl = f"({left_expl})" @@ -1049,6 +1087,13 @@ class AssertionRewriter(ast.NodeVisitor): syms = [] results = [left_res] for i, op, next_operand in it: + if ( + isinstance(next_operand, namedExpr) + and isinstance(left_res, ast.Name) + and next_operand.target.id == left_res.id + ): + next_operand.target.id = self.variable() + self.variables_overwrite[left_res.id] = next_operand.target.id next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, (ast.Compare, ast.BoolOp)): next_expl = f"({next_expl})" @@ -1072,6 +1117,7 @@ class AssertionRewriter(ast.NodeVisitor): res: ast.expr = ast.BoolOp(ast.And(), load_names) else: res = load_names[0] + return res, self.explanation_param(self.pop_format_context(expl_call)) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 3c98392ed..8d9441403 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1265,6 +1265,177 @@ class TestIssue2121: result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"]) +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="walrus operator not available in py<38" +) +class TestIssue10743: + def test_assertion_walrus_operator(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def my_func(before, after): + return before == after + + def change_value(value): + return value.lower() + + def test_walrus_conversion(): + a = "Hello" + assert not my_func(a, a := change_value(a)) + assert a == "hello" + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_dont_rewrite(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + 'PYTEST_DONT_REWRITE' + def my_func(before, after): + return before == after + + def change_value(value): + return value.lower() + + def test_walrus_conversion_dont_rewrite(): + a = "Hello" + assert not my_func(a, a := change_value(a)) + assert a == "hello" + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_inline_walrus_operator(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def my_func(before, after): + return before == after + + def test_walrus_conversion_inline(): + a = "Hello" + assert not my_func(a, a := a.lower()) + assert a == "hello" + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_inline_walrus_operator_reverse(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def my_func(before, after): + return before == after + + def test_walrus_conversion_reverse(): + a = "Hello" + assert my_func(a := a.lower(), a) + assert a == 'hello' + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_no_variable_name_conflict( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_conversion_no_conflict(): + a = "Hello" + assert a == (b := a.lower()) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"]) + + def test_assertion_walrus_operator_true_assertion_and_changes_variable_value( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_conversion_succeed(): + a = "Hello" + assert a != (a := a.lower()) + assert a == 'hello' + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_fail_assertion(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def test_walrus_conversion_fails(): + a = "Hello" + assert a == (a := a.lower()) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"]) + + def test_assertion_walrus_operator_boolean_composite( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_boolean_value(): + a = True + assert a and True and ((a := False) is False) and (a is False) and ((a := None) is None) + assert a is None + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_compare_boolean_fails( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_boolean_value(): + a = True + assert not (a and ((a := False) is False)) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines(["*assert not (True and False is False)"]) + + def test_assertion_walrus_operator_boolean_none_fails( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_boolean_value(): + a = True + assert not (a and ((a := None) is None)) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines(["*assert not (True and None is None)"]) + + def test_assertion_walrus_operator_value_changes_cleared_after_each_test( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_value(): + a = True + assert (a := None) is None + + def test_walrus_operator_not_override_value(): + a = True + assert a is True + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" )