Fix assert rewriting with assignment expressions (#11414)

Fixes #11239
This commit is contained in:
Marc Mueller 2023-09-09 14:09:31 +02:00 committed by GitHub
parent dd7beb39d6
commit 7259e8db98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 14 deletions

View File

@ -235,6 +235,7 @@ Maho
Maik Figura Maik Figura
Mandeep Bhutani Mandeep Bhutani
Manuel Krebber Manuel Krebber
Marc Mueller
Marc Schlaich Marc Schlaich
Marcelo Duarte Trevisani Marcelo Duarte Trevisani
Marcin Bachry Marcin Bachry

View File

@ -0,0 +1 @@
Fixed ``:=`` in asserts impacting unrelated test cases.

View File

@ -13,6 +13,7 @@ import struct
import sys import sys
import tokenize import tokenize
import types import types
from collections import defaultdict
from pathlib import Path from pathlib import Path
from pathlib import PurePath from pathlib import PurePath
from typing import Callable from typing import Callable
@ -45,6 +46,10 @@ if TYPE_CHECKING:
from _pytest.assertion import AssertionState from _pytest.assertion import AssertionState
class Sentinel:
pass
assertstate_key = StashKey["AssertionState"]() assertstate_key = StashKey["AssertionState"]()
# pytest caches rewritten pycs in pycache dirs # pytest caches rewritten pycs in pycache dirs
@ -52,6 +57,9 @@ PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
PYC_EXT = ".py" + (__debug__ and "c" or "o") PYC_EXT = ".py" + (__debug__ and "c" or "o")
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
# Special marker that denotes we have just left a scope definition
_SCOPE_END_MARKER = Sentinel()
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
"""PEP302/PEP451 import hook which rewrites asserts.""" """PEP302/PEP451 import hook which rewrites asserts."""
@ -634,6 +642,8 @@ class AssertionRewriter(ast.NodeVisitor):
.push_format_context() and .pop_format_context() which allows .push_format_context() and .pop_format_context() which allows
to build another %-formatted string while already building one. to build another %-formatted string while already building one.
:scope: A tuple containing the current scope used for variables_overwrite.
:variables_overwrite: A dict filled with references to variables :variables_overwrite: A dict filled with references to variables
that change value within an assert. This happens when a variable is that change value within an assert. This happens when a variable is
reassigned with the walrus operator reassigned with the walrus operator
@ -655,7 +665,10 @@ class AssertionRewriter(ast.NodeVisitor):
else: else:
self.enable_assertion_pass_hook = False self.enable_assertion_pass_hook = False
self.source = source self.source = source
self.variables_overwrite: Dict[str, str] = {} self.scope: tuple[ast.AST, ...] = ()
self.variables_overwrite: defaultdict[
tuple[ast.AST, ...], Dict[str, str]
] = defaultdict(dict)
def run(self, mod: ast.Module) -> None: def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them.""" """Find all assert statements in *mod* and rewrite them."""
@ -719,9 +732,17 @@ class AssertionRewriter(ast.NodeVisitor):
mod.body[pos:pos] = imports mod.body[pos:pos] = imports
# Collect asserts. # Collect asserts.
nodes: List[ast.AST] = [mod] self.scope = (mod,)
nodes: List[Union[ast.AST, Sentinel]] = [mod]
while nodes: while nodes:
node = nodes.pop() node = nodes.pop()
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
self.scope = tuple((*self.scope, node))
nodes.append(_SCOPE_END_MARKER)
if node == _SCOPE_END_MARKER:
self.scope = self.scope[:-1]
continue
assert isinstance(node, ast.AST)
for name, field in ast.iter_fields(node): for name, field in ast.iter_fields(node):
if isinstance(field, list): if isinstance(field, list):
new: List[ast.AST] = [] new: List[ast.AST] = []
@ -992,7 +1013,7 @@ class AssertionRewriter(ast.NodeVisitor):
] ]
): ):
pytest_temp = self.variable() pytest_temp = self.variable()
self.variables_overwrite[ self.variables_overwrite[self.scope][
v.left.target.id v.left.target.id
] = v.left # type:ignore[assignment] ] = v.left # type:ignore[assignment]
v.left.target.id = pytest_temp v.left.target.id = pytest_temp
@ -1035,17 +1056,20 @@ class AssertionRewriter(ast.NodeVisitor):
new_args = [] new_args = []
new_kwargs = [] new_kwargs = []
for arg in call.args: for arg in call.args:
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite: if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
arg = self.variables_overwrite[arg.id] # type:ignore[assignment] self.scope, {}
):
arg = self.variables_overwrite[self.scope][
arg.id
] # type:ignore[assignment]
res, expl = self.visit(arg) res, expl = self.visit(arg)
arg_expls.append(expl) arg_expls.append(expl)
new_args.append(res) new_args.append(res)
for keyword in call.keywords: for keyword in call.keywords:
if ( if isinstance(
isinstance(keyword.value, ast.Name) keyword.value, ast.Name
and keyword.value.id in self.variables_overwrite ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
): keyword.value = self.variables_overwrite[self.scope][
keyword.value = self.variables_overwrite[
keyword.value.id keyword.value.id
] # type:ignore[assignment] ] # type:ignore[assignment]
res, expl = self.visit(keyword.value) res, expl = self.visit(keyword.value)
@ -1081,12 +1105,14 @@ class AssertionRewriter(ast.NodeVisitor):
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context() self.push_format_context()
# We first check if we have overwritten a variable in the previous assert # We first check if we have overwritten a variable in the previous assert
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite: if isinstance(
comp.left = self.variables_overwrite[ comp.left, ast.Name
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
comp.left = self.variables_overwrite[self.scope][
comp.left.id comp.left.id
] # type:ignore[assignment] ] # type:ignore[assignment]
if isinstance(comp.left, ast.NamedExpr): if isinstance(comp.left, ast.NamedExpr):
self.variables_overwrite[ self.variables_overwrite[self.scope][
comp.left.target.id comp.left.target.id
] = comp.left # type:ignore[assignment] ] = comp.left # type:ignore[assignment]
left_res, left_expl = self.visit(comp.left) left_res, left_expl = self.visit(comp.left)
@ -1106,7 +1132,7 @@ class AssertionRewriter(ast.NodeVisitor):
and next_operand.target.id == left_res.id and next_operand.target.id == left_res.id
): ):
next_operand.target.id = self.variable() next_operand.target.id = self.variable()
self.variables_overwrite[ self.variables_overwrite[self.scope][
left_res.id left_res.id
] = next_operand # type:ignore[assignment] ] = next_operand # type:ignore[assignment]
next_res, next_expl = self.visit(next_operand) next_res, next_expl = self.visit(next_operand)

View File

@ -1543,6 +1543,27 @@ class TestIssue11028:
result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"]) result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"])
class TestIssue11239:
def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None:
"""Regression for (#11239)
Walrus operator rewriting would leak to separate test cases if they used the same variables.
"""
pytester.makepyfile(
"""
def test_1():
state = {"x": 2}.get("x")
assert state is not None
def test_2():
db = {"x": 2}
assert (state := db.get("x")) is not None
"""
)
result = pytester.runpytest()
assert result.ret == 0
@pytest.mark.skipif( @pytest.mark.skipif(
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
) )