correctly handle multiple asserts
This commit is contained in:
parent
9ac818fb5c
commit
aae89cd021
|
@ -99,19 +99,22 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
node = nodes.popleft()
|
||||
for name, field in ast.iter_fields(node):
|
||||
if isinstance(field, list):
|
||||
new = []
|
||||
for i, child in enumerate(field):
|
||||
if isinstance(child, ast.Assert):
|
||||
# Transform assert.
|
||||
new.extend(self.visit(child))
|
||||
asserts.append((field, i, child))
|
||||
elif isinstance(child, ast.AST):
|
||||
nodes.append(child)
|
||||
else:
|
||||
new.append(child)
|
||||
if isinstance(child, ast.AST):
|
||||
nodes.append(child)
|
||||
setattr(node, name, new)
|
||||
elif (isinstance(field, ast.AST) and
|
||||
# Don't recurse into expressions as they can't contain
|
||||
# asserts.
|
||||
not isinstance(field, ast.expr)):
|
||||
nodes.append(field)
|
||||
# Transform asserts.
|
||||
for parent, pos, assert_ in asserts:
|
||||
parent[pos:pos + 1] = self.visit(assert_)
|
||||
|
||||
def variable(self):
|
||||
"""Get a new variable."""
|
||||
|
|
|
@ -196,6 +196,11 @@ class TestAssertionRewrite:
|
|||
a, b, c = range(3)
|
||||
assert a < b <= c
|
||||
getmsg(f, must_pass=True)
|
||||
def f():
|
||||
a, b, c = range(3)
|
||||
assert a < b
|
||||
assert b < c
|
||||
getmsg(f, must_pass=True)
|
||||
|
||||
def test_len(self):
|
||||
def f():
|
||||
|
|
Loading…
Reference in New Issue