diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index 4aca62d5c..14ddc93ce 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -74,15 +74,23 @@ class AssertionRewriter(ast.NodeVisitor): if not mod.body: # Nothing to do. return - # Insert some special imports at top but after any docstrings. + # Insert some special imports at top but after any docstrings and + # __future__ imports. aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), ast.alias("py", "@pylib"), ast.alias("_pytest.assertrewrite", "@pytest_ar")] imports = [ast.Import([alias], lineno=0, col_offset=0) for alias in aliases] + expect_docstring = True pos = 0 - if isinstance(mod.body[0], ast.Str): - pos = 1 + for item in mod.body: + if (expect_docstring and isinstance(item, ast.Expr) and + isinstance(item.value, ast.Str)): + expect_docstring = False + elif (not isinstance(item, ast.ImportFrom) or item.level > 0 and + item.identifier != "__future__"): + break + pos += 1 mod.body[pos:pos] = imports # Collect asserts. asserts = [] diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 052a1cd27..64274fb07 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -16,11 +16,15 @@ def teardown_module(mod): del mod._old_reprcompare +def rewrite(src): + tree = ast.parse(src) + rewrite_asserts(tree) + return tree + def getmsg(f, extra_ns=None, must_pass=False): """Rewrite the assertions in f, run it, and get the failure message.""" src = '\n'.join(py.code.Code(f).source().lines) - mod = ast.parse(src) - rewrite_asserts(mod) + mod = rewrite(src) code = compile(mod, "", "exec") ns = {} if extra_ns is not None: @@ -43,6 +47,26 @@ def getmsg(f, extra_ns=None, must_pass=False): class TestAssertionRewrite: + def test_place_initial_imports(self): + s = """'Doc string'""" + m = rewrite(s) + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + for imp in m.body[1:]: + assert isinstance(imp, ast.Import) + s = """from __future__ import with_statement""" + m = rewrite(s) + assert isinstance(m.body[0], ast.ImportFrom) + for imp in m.body[1:]: + assert isinstance(imp, ast.Import) + s = """'doc string'\nfrom __future__ import with_statement""" + m = rewrite(s) + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + assert isinstance(m.body[1], ast.ImportFrom) + for imp in m.body[2:]: + assert isinstance(imp, ast.Import) + def test_name(self): def f(): assert False