diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index 9109eadc3..bc755123f 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -99,19 +99,22 @@ class AssertionRewriter(ast.NodeVisitor): node = nodes.popleft() for name, field in ast.iter_fields(node): if isinstance(field, list): + new = [] for i, child in enumerate(field): if isinstance(child, ast.Assert): + # Transform assert. + new.extend(self.visit(child)) asserts.append((field, i, child)) - elif isinstance(child, ast.AST): - nodes.append(child) + else: + new.append(child) + if isinstance(child, ast.AST): + nodes.append(child) + setattr(node, name, new) elif (isinstance(field, ast.AST) and # Don't recurse into expressions as they can't contain # asserts. not isinstance(field, ast.expr)): nodes.append(field) - # Transform asserts. - for parent, pos, assert_ in asserts: - parent[pos:pos + 1] = self.visit(assert_) def variable(self): """Get a new variable.""" diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 64274fb07..988f09264 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -196,6 +196,11 @@ class TestAssertionRewrite: a, b, c = range(3) assert a < b <= c getmsg(f, must_pass=True) + def f(): + a, b, c = range(3) + assert a < b + assert b < c + getmsg(f, must_pass=True) def test_len(self): def f():