diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 988f09264..f6b74d97e 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -48,24 +48,27 @@ def getmsg(f, extra_ns=None, must_pass=False): class TestAssertionRewrite: def test_place_initial_imports(self): - s = """'Doc string'""" + s = """'Doc string'\nother = stuff""" m = rewrite(s) assert isinstance(m.body[0], ast.Expr) assert isinstance(m.body[0].value, ast.Str) - for imp in m.body[1:]: + for imp in m.body[1:4]: assert isinstance(imp, ast.Import) - s = """from __future__ import with_statement""" + assert isinstance(m.body[4], ast.Assign) + s = """from __future__ import with_statement\nother_stuff""" m = rewrite(s) assert isinstance(m.body[0], ast.ImportFrom) - for imp in m.body[1:]: + for imp in m.body[1:4]: assert isinstance(imp, ast.Import) - s = """'doc string'\nfrom __future__ import with_statement""" + assert isinstance(m.body[4], ast.Expr) + s = """'doc string'\nfrom __future__ import with_statement\nother""" m = rewrite(s) assert isinstance(m.body[0], ast.Expr) assert isinstance(m.body[0].value, ast.Str) assert isinstance(m.body[1], ast.ImportFrom) - for imp in m.body[2:]: + for imp in m.body[2:5]: assert isinstance(imp, ast.Import) + assert isinstance(m.body[5], ast.Expr) def test_name(self): def f():