correctly handle multiple asserts

This commit is contained in:
Benjamin Peterson 2011-05-19 18:56:48 -05:00
parent 9ac818fb5c
commit aae89cd021
2 changed files with 13 additions and 5 deletions

View File

@ -99,19 +99,22 @@ class AssertionRewriter(ast.NodeVisitor):
node = nodes.popleft()
for name, field in ast.iter_fields(node):
if isinstance(field, list):
new = []
for i, child in enumerate(field):
if isinstance(child, ast.Assert):
# Transform assert.
new.extend(self.visit(child))
asserts.append((field, i, child))
elif isinstance(child, ast.AST):
nodes.append(child)
else:
new.append(child)
if isinstance(child, ast.AST):
nodes.append(child)
setattr(node, name, new)
elif (isinstance(field, ast.AST) and
# Don't recurse into expressions as they can't contain
# asserts.
not isinstance(field, ast.expr)):
nodes.append(field)
# Transform asserts.
for parent, pos, assert_ in asserts:
parent[pos:pos + 1] = self.visit(assert_)
def variable(self):
"""Get a new variable."""

View File

@ -196,6 +196,11 @@ class TestAssertionRewrite:
a, b, c = range(3)
assert a < b <= c
getmsg(f, must_pass=True)
def f():
a, b, c = range(3)
assert a < b
assert b < c
getmsg(f, must_pass=True)
def test_len(self):
def f():