Fix walrus operator support in assertion rewriting (#10758)
Closes #10743
This commit is contained in:
parent
a869141b3d
commit
6e478b0947
1
AUTHORS
1
AUTHORS
|
@ -12,6 +12,7 @@ Adam Uhlir
|
||||||
Ahn Ki-Wook
|
Ahn Ki-Wook
|
||||||
Akiomi Kamakura
|
Akiomi Kamakura
|
||||||
Alan Velasco
|
Alan Velasco
|
||||||
|
Alessio Izzo
|
||||||
Alexander Johnson
|
Alexander Johnson
|
||||||
Alexander King
|
Alexander King
|
||||||
Alexei Kozlenok
|
Alexei Kozlenok
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
The assertion rewriting mechanism now works correctly when assertion expressions contain the walrus operator.
|
|
@ -44,10 +44,14 @@ from _pytest.stash import StashKey
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from _pytest.assertion import AssertionState
|
from _pytest.assertion import AssertionState
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 8):
|
||||||
|
namedExpr = ast.NamedExpr
|
||||||
|
else:
|
||||||
|
namedExpr = ast.Expr
|
||||||
|
|
||||||
|
|
||||||
assertstate_key = StashKey["AssertionState"]()
|
assertstate_key = StashKey["AssertionState"]()
|
||||||
|
|
||||||
|
|
||||||
# pytest caches rewritten pycs in pycache dirs
|
# pytest caches rewritten pycs in pycache dirs
|
||||||
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
|
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
|
||||||
PYC_EXT = ".py" + (__debug__ and "c" or "o")
|
PYC_EXT = ".py" + (__debug__ and "c" or "o")
|
||||||
|
@ -635,8 +639,12 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
.push_format_context() and .pop_format_context() which allows
|
.push_format_context() and .pop_format_context() which allows
|
||||||
to build another %-formatted string while already building one.
|
to build another %-formatted string while already building one.
|
||||||
|
|
||||||
This state is reset on every new assert statement visited and used
|
:variables_overwrite: A dict filled with references to variables
|
||||||
by the other visitors.
|
that change value within an assert. This happens when a variable is
|
||||||
|
reassigned with the walrus operator
|
||||||
|
|
||||||
|
This state, except the variables_overwrite, is reset on every new assert
|
||||||
|
statement visited and used by the other visitors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -652,6 +660,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
else:
|
else:
|
||||||
self.enable_assertion_pass_hook = False
|
self.enable_assertion_pass_hook = False
|
||||||
self.source = source
|
self.source = source
|
||||||
|
self.variables_overwrite: Dict[str, str] = {}
|
||||||
|
|
||||||
def run(self, mod: ast.Module) -> None:
|
def run(self, mod: ast.Module) -> None:
|
||||||
"""Find all assert statements in *mod* and rewrite them."""
|
"""Find all assert statements in *mod* and rewrite them."""
|
||||||
|
@ -666,7 +675,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
if doc is not None and self.is_rewrite_disabled(doc):
|
if doc is not None and self.is_rewrite_disabled(doc):
|
||||||
return
|
return
|
||||||
pos = 0
|
pos = 0
|
||||||
lineno = 1
|
item = None
|
||||||
for item in mod.body:
|
for item in mod.body:
|
||||||
if (
|
if (
|
||||||
expect_docstring
|
expect_docstring
|
||||||
|
@ -937,6 +946,18 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
ast.copy_location(node, assert_)
|
ast.copy_location(node, assert_)
|
||||||
return self.statements
|
return self.statements
|
||||||
|
|
||||||
|
def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
|
||||||
|
# This method handles the 'walrus operator' repr of the target
|
||||||
|
# name if it's a local variable or _should_repr_global_name()
|
||||||
|
# thinks it's acceptable.
|
||||||
|
locs = ast.Call(self.builtin("locals"), [], [])
|
||||||
|
target_id = name.target.id # type: ignore[attr-defined]
|
||||||
|
inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs])
|
||||||
|
dorepr = self.helper("_should_repr_global_name", name)
|
||||||
|
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
|
||||||
|
expr = ast.IfExp(test, self.display(name), ast.Str(target_id))
|
||||||
|
return name, self.explanation_param(expr)
|
||||||
|
|
||||||
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
|
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
|
||||||
# Display the repr of the name if it's a local variable or
|
# Display the repr of the name if it's a local variable or
|
||||||
# _should_repr_global_name() thinks it's acceptable.
|
# _should_repr_global_name() thinks it's acceptable.
|
||||||
|
@ -963,6 +984,20 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
# cond is set in a prior loop iteration below
|
# cond is set in a prior loop iteration below
|
||||||
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
|
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
|
||||||
self.expl_stmts = fail_inner
|
self.expl_stmts = fail_inner
|
||||||
|
# Check if the left operand is a namedExpr and the value has already been visited
|
||||||
|
if (
|
||||||
|
isinstance(v, ast.Compare)
|
||||||
|
and isinstance(v.left, namedExpr)
|
||||||
|
and v.left.target.id
|
||||||
|
in [
|
||||||
|
ast_expr.id
|
||||||
|
for ast_expr in boolop.values[:i]
|
||||||
|
if hasattr(ast_expr, "id")
|
||||||
|
]
|
||||||
|
):
|
||||||
|
pytest_temp = self.variable()
|
||||||
|
self.variables_overwrite[v.left.target.id] = pytest_temp
|
||||||
|
v.left.target.id = pytest_temp
|
||||||
self.push_format_context()
|
self.push_format_context()
|
||||||
res, expl = self.visit(v)
|
res, expl = self.visit(v)
|
||||||
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
||||||
|
@ -1038,6 +1073,9 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
|
|
||||||
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
||||||
self.push_format_context()
|
self.push_format_context()
|
||||||
|
# We first check if we have overwritten a variable in the previous assert
|
||||||
|
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
|
||||||
|
comp.left.id = self.variables_overwrite[comp.left.id]
|
||||||
left_res, left_expl = self.visit(comp.left)
|
left_res, left_expl = self.visit(comp.left)
|
||||||
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
|
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
|
||||||
left_expl = f"({left_expl})"
|
left_expl = f"({left_expl})"
|
||||||
|
@ -1049,6 +1087,13 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
syms = []
|
syms = []
|
||||||
results = [left_res]
|
results = [left_res]
|
||||||
for i, op, next_operand in it:
|
for i, op, next_operand in it:
|
||||||
|
if (
|
||||||
|
isinstance(next_operand, namedExpr)
|
||||||
|
and isinstance(left_res, ast.Name)
|
||||||
|
and next_operand.target.id == left_res.id
|
||||||
|
):
|
||||||
|
next_operand.target.id = self.variable()
|
||||||
|
self.variables_overwrite[left_res.id] = next_operand.target.id
|
||||||
next_res, next_expl = self.visit(next_operand)
|
next_res, next_expl = self.visit(next_operand)
|
||||||
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
|
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
|
||||||
next_expl = f"({next_expl})"
|
next_expl = f"({next_expl})"
|
||||||
|
@ -1072,6 +1117,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
res: ast.expr = ast.BoolOp(ast.And(), load_names)
|
res: ast.expr = ast.BoolOp(ast.And(), load_names)
|
||||||
else:
|
else:
|
||||||
res = load_names[0]
|
res = load_names[0]
|
||||||
|
|
||||||
return res, self.explanation_param(self.pop_format_context(expl_call))
|
return res, self.explanation_param(self.pop_format_context(expl_call))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1265,6 +1265,177 @@ class TestIssue2121:
|
||||||
result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"])
|
result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 8), reason="walrus operator not available in py<38"
|
||||||
|
)
|
||||||
|
class TestIssue10743:
|
||||||
|
def test_assertion_walrus_operator(self, pytester: Pytester) -> None:
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
def my_func(before, after):
|
||||||
|
return before == after
|
||||||
|
|
||||||
|
def change_value(value):
|
||||||
|
return value.lower()
|
||||||
|
|
||||||
|
def test_walrus_conversion():
|
||||||
|
a = "Hello"
|
||||||
|
assert not my_func(a, a := change_value(a))
|
||||||
|
assert a == "hello"
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 0
|
||||||
|
|
||||||
|
def test_assertion_walrus_operator_dont_rewrite(self, pytester: Pytester) -> None:
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
'PYTEST_DONT_REWRITE'
|
||||||
|
def my_func(before, after):
|
||||||
|
return before == after
|
||||||
|
|
||||||
|
def change_value(value):
|
||||||
|
return value.lower()
|
||||||
|
|
||||||
|
def test_walrus_conversion_dont_rewrite():
|
||||||
|
a = "Hello"
|
||||||
|
assert not my_func(a, a := change_value(a))
|
||||||
|
assert a == "hello"
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 0
|
||||||
|
|
||||||
|
def test_assertion_inline_walrus_operator(self, pytester: Pytester) -> None:
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
def my_func(before, after):
|
||||||
|
return before == after
|
||||||
|
|
||||||
|
def test_walrus_conversion_inline():
|
||||||
|
a = "Hello"
|
||||||
|
assert not my_func(a, a := a.lower())
|
||||||
|
assert a == "hello"
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 0
|
||||||
|
|
||||||
|
def test_assertion_inline_walrus_operator_reverse(self, pytester: Pytester) -> None:
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
def my_func(before, after):
|
||||||
|
return before == after
|
||||||
|
|
||||||
|
def test_walrus_conversion_reverse():
|
||||||
|
a = "Hello"
|
||||||
|
assert my_func(a := a.lower(), a)
|
||||||
|
assert a == 'hello'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 0
|
||||||
|
|
||||||
|
def test_assertion_walrus_no_variable_name_conflict(
|
||||||
|
self, pytester: Pytester
|
||||||
|
) -> None:
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
def test_walrus_conversion_no_conflict():
|
||||||
|
a = "Hello"
|
||||||
|
assert a == (b := a.lower())
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 1
|
||||||
|
result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"])
|
||||||
|
|
||||||
|
def test_assertion_walrus_operator_true_assertion_and_changes_variable_value(
|
||||||
|
self, pytester: Pytester
|
||||||
|
) -> None:
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
def test_walrus_conversion_succeed():
|
||||||
|
a = "Hello"
|
||||||
|
assert a != (a := a.lower())
|
||||||
|
assert a == 'hello'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 0
|
||||||
|
|
||||||
|
def test_assertion_walrus_operator_fail_assertion(self, pytester: Pytester) -> None:
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
def test_walrus_conversion_fails():
|
||||||
|
a = "Hello"
|
||||||
|
assert a == (a := a.lower())
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 1
|
||||||
|
result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"])
|
||||||
|
|
||||||
|
def test_assertion_walrus_operator_boolean_composite(
|
||||||
|
self, pytester: Pytester
|
||||||
|
) -> None:
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
def test_walrus_operator_change_boolean_value():
|
||||||
|
a = True
|
||||||
|
assert a and True and ((a := False) is False) and (a is False) and ((a := None) is None)
|
||||||
|
assert a is None
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 0
|
||||||
|
|
||||||
|
def test_assertion_walrus_operator_compare_boolean_fails(
|
||||||
|
self, pytester: Pytester
|
||||||
|
) -> None:
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
def test_walrus_operator_change_boolean_value():
|
||||||
|
a = True
|
||||||
|
assert not (a and ((a := False) is False))
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 1
|
||||||
|
result.stdout.fnmatch_lines(["*assert not (True and False is False)"])
|
||||||
|
|
||||||
|
def test_assertion_walrus_operator_boolean_none_fails(
|
||||||
|
self, pytester: Pytester
|
||||||
|
) -> None:
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
def test_walrus_operator_change_boolean_value():
|
||||||
|
a = True
|
||||||
|
assert not (a and ((a := None) is None))
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 1
|
||||||
|
result.stdout.fnmatch_lines(["*assert not (True and None is None)"])
|
||||||
|
|
||||||
|
def test_assertion_walrus_operator_value_changes_cleared_after_each_test(
|
||||||
|
self, pytester: Pytester
|
||||||
|
) -> None:
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
def test_walrus_operator_change_value():
|
||||||
|
a = True
|
||||||
|
assert (a := None) is None
|
||||||
|
|
||||||
|
def test_walrus_operator_not_override_value():
|
||||||
|
a = True
|
||||||
|
assert a is True
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
|
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue