parent
dd7beb39d6
commit
7259e8db98
1
AUTHORS
1
AUTHORS
|
@ -235,6 +235,7 @@ Maho
|
||||||
Maik Figura
|
Maik Figura
|
||||||
Mandeep Bhutani
|
Mandeep Bhutani
|
||||||
Manuel Krebber
|
Manuel Krebber
|
||||||
|
Marc Mueller
|
||||||
Marc Schlaich
|
Marc Schlaich
|
||||||
Marcelo Duarte Trevisani
|
Marcelo Duarte Trevisani
|
||||||
Marcin Bachry
|
Marcin Bachry
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Fixed ``:=`` in asserts impacting unrelated test cases.
|
|
@ -13,6 +13,7 @@ import struct
|
||||||
import sys
|
import sys
|
||||||
import tokenize
|
import tokenize
|
||||||
import types
|
import types
|
||||||
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pathlib import PurePath
|
from pathlib import PurePath
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
@ -45,6 +46,10 @@ if TYPE_CHECKING:
|
||||||
from _pytest.assertion import AssertionState
|
from _pytest.assertion import AssertionState
|
||||||
|
|
||||||
|
|
||||||
|
class Sentinel:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
assertstate_key = StashKey["AssertionState"]()
|
assertstate_key = StashKey["AssertionState"]()
|
||||||
|
|
||||||
# pytest caches rewritten pycs in pycache dirs
|
# pytest caches rewritten pycs in pycache dirs
|
||||||
|
@ -52,6 +57,9 @@ PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
|
||||||
PYC_EXT = ".py" + (__debug__ and "c" or "o")
|
PYC_EXT = ".py" + (__debug__ and "c" or "o")
|
||||||
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
|
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
|
||||||
|
|
||||||
|
# Special marker that denotes we have just left a scope definition
|
||||||
|
_SCOPE_END_MARKER = Sentinel()
|
||||||
|
|
||||||
|
|
||||||
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
|
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
|
||||||
"""PEP302/PEP451 import hook which rewrites asserts."""
|
"""PEP302/PEP451 import hook which rewrites asserts."""
|
||||||
|
@ -634,6 +642,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
.push_format_context() and .pop_format_context() which allows
|
.push_format_context() and .pop_format_context() which allows
|
||||||
to build another %-formatted string while already building one.
|
to build another %-formatted string while already building one.
|
||||||
|
|
||||||
|
:scope: A tuple containing the current scope used for variables_overwrite.
|
||||||
|
|
||||||
:variables_overwrite: A dict filled with references to variables
|
:variables_overwrite: A dict filled with references to variables
|
||||||
that change value within an assert. This happens when a variable is
|
that change value within an assert. This happens when a variable is
|
||||||
reassigned with the walrus operator
|
reassigned with the walrus operator
|
||||||
|
@ -655,7 +665,10 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
else:
|
else:
|
||||||
self.enable_assertion_pass_hook = False
|
self.enable_assertion_pass_hook = False
|
||||||
self.source = source
|
self.source = source
|
||||||
self.variables_overwrite: Dict[str, str] = {}
|
self.scope: tuple[ast.AST, ...] = ()
|
||||||
|
self.variables_overwrite: defaultdict[
|
||||||
|
tuple[ast.AST, ...], Dict[str, str]
|
||||||
|
] = defaultdict(dict)
|
||||||
|
|
||||||
def run(self, mod: ast.Module) -> None:
|
def run(self, mod: ast.Module) -> None:
|
||||||
"""Find all assert statements in *mod* and rewrite them."""
|
"""Find all assert statements in *mod* and rewrite them."""
|
||||||
|
@ -719,9 +732,17 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
mod.body[pos:pos] = imports
|
mod.body[pos:pos] = imports
|
||||||
|
|
||||||
# Collect asserts.
|
# Collect asserts.
|
||||||
nodes: List[ast.AST] = [mod]
|
self.scope = (mod,)
|
||||||
|
nodes: List[Union[ast.AST, Sentinel]] = [mod]
|
||||||
while nodes:
|
while nodes:
|
||||||
node = nodes.pop()
|
node = nodes.pop()
|
||||||
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||||
|
self.scope = tuple((*self.scope, node))
|
||||||
|
nodes.append(_SCOPE_END_MARKER)
|
||||||
|
if node == _SCOPE_END_MARKER:
|
||||||
|
self.scope = self.scope[:-1]
|
||||||
|
continue
|
||||||
|
assert isinstance(node, ast.AST)
|
||||||
for name, field in ast.iter_fields(node):
|
for name, field in ast.iter_fields(node):
|
||||||
if isinstance(field, list):
|
if isinstance(field, list):
|
||||||
new: List[ast.AST] = []
|
new: List[ast.AST] = []
|
||||||
|
@ -992,7 +1013,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
pytest_temp = self.variable()
|
pytest_temp = self.variable()
|
||||||
self.variables_overwrite[
|
self.variables_overwrite[self.scope][
|
||||||
v.left.target.id
|
v.left.target.id
|
||||||
] = v.left # type:ignore[assignment]
|
] = v.left # type:ignore[assignment]
|
||||||
v.left.target.id = pytest_temp
|
v.left.target.id = pytest_temp
|
||||||
|
@ -1035,17 +1056,20 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
new_args = []
|
new_args = []
|
||||||
new_kwargs = []
|
new_kwargs = []
|
||||||
for arg in call.args:
|
for arg in call.args:
|
||||||
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite:
|
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
|
||||||
arg = self.variables_overwrite[arg.id] # type:ignore[assignment]
|
self.scope, {}
|
||||||
|
):
|
||||||
|
arg = self.variables_overwrite[self.scope][
|
||||||
|
arg.id
|
||||||
|
] # type:ignore[assignment]
|
||||||
res, expl = self.visit(arg)
|
res, expl = self.visit(arg)
|
||||||
arg_expls.append(expl)
|
arg_expls.append(expl)
|
||||||
new_args.append(res)
|
new_args.append(res)
|
||||||
for keyword in call.keywords:
|
for keyword in call.keywords:
|
||||||
if (
|
if isinstance(
|
||||||
isinstance(keyword.value, ast.Name)
|
keyword.value, ast.Name
|
||||||
and keyword.value.id in self.variables_overwrite
|
) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
|
||||||
):
|
keyword.value = self.variables_overwrite[self.scope][
|
||||||
keyword.value = self.variables_overwrite[
|
|
||||||
keyword.value.id
|
keyword.value.id
|
||||||
] # type:ignore[assignment]
|
] # type:ignore[assignment]
|
||||||
res, expl = self.visit(keyword.value)
|
res, expl = self.visit(keyword.value)
|
||||||
|
@ -1081,12 +1105,14 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
||||||
self.push_format_context()
|
self.push_format_context()
|
||||||
# We first check if we have overwritten a variable in the previous assert
|
# We first check if we have overwritten a variable in the previous assert
|
||||||
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
|
if isinstance(
|
||||||
comp.left = self.variables_overwrite[
|
comp.left, ast.Name
|
||||||
|
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
|
||||||
|
comp.left = self.variables_overwrite[self.scope][
|
||||||
comp.left.id
|
comp.left.id
|
||||||
] # type:ignore[assignment]
|
] # type:ignore[assignment]
|
||||||
if isinstance(comp.left, ast.NamedExpr):
|
if isinstance(comp.left, ast.NamedExpr):
|
||||||
self.variables_overwrite[
|
self.variables_overwrite[self.scope][
|
||||||
comp.left.target.id
|
comp.left.target.id
|
||||||
] = comp.left # type:ignore[assignment]
|
] = comp.left # type:ignore[assignment]
|
||||||
left_res, left_expl = self.visit(comp.left)
|
left_res, left_expl = self.visit(comp.left)
|
||||||
|
@ -1106,7 +1132,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
and next_operand.target.id == left_res.id
|
and next_operand.target.id == left_res.id
|
||||||
):
|
):
|
||||||
next_operand.target.id = self.variable()
|
next_operand.target.id = self.variable()
|
||||||
self.variables_overwrite[
|
self.variables_overwrite[self.scope][
|
||||||
left_res.id
|
left_res.id
|
||||||
] = next_operand # type:ignore[assignment]
|
] = next_operand # type:ignore[assignment]
|
||||||
next_res, next_expl = self.visit(next_operand)
|
next_res, next_expl = self.visit(next_operand)
|
||||||
|
|
|
@ -1543,6 +1543,27 @@ class TestIssue11028:
|
||||||
result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"])
|
result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssue11239:
|
||||||
|
def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None:
|
||||||
|
"""Regression for (#11239)
|
||||||
|
|
||||||
|
Walrus operator rewriting would leak to separate test cases if they used the same variables.
|
||||||
|
"""
|
||||||
|
pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
def test_1():
|
||||||
|
state = {"x": 2}.get("x")
|
||||||
|
assert state is not None
|
||||||
|
|
||||||
|
def test_2():
|
||||||
|
db = {"x": 2}
|
||||||
|
assert (state := db.get("x")) is not None
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = pytester.runpytest()
|
||||||
|
assert result.ret == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
|
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue