diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 06687a0c8..c3966340a 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -600,20 +600,21 @@ class AssertionRewriter(ast.NodeVisitor): if doc is not None and self.is_rewrite_disabled(doc): return pos = 0 - lineno = 0 + lineno = 1 for item in mod.body: if (expect_docstring and isinstance(item, ast.Expr) and isinstance(item.value, ast.Str)): doc = item.value.s if self.is_rewrite_disabled(doc): return - lineno += len(doc) - 1 expect_docstring = False elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or item.module != "__future__"): lineno = item.lineno break pos += 1 + else: + lineno = item.lineno imports = [ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases] mod.body[pos:pos] = imports diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index c935a7862..45c0c7b16 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -82,6 +82,17 @@ class TestAssertionRewrite(object): assert imp.lineno == 2 assert imp.col_offset == 0 assert isinstance(m.body[3], ast.Expr) + s = """'doc string'\nfrom __future__ import with_statement""" + m = rewrite(s) + if sys.version_info < (3, 7): + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + del m.body[0] + assert isinstance(m.body[0], ast.ImportFrom) + for imp in m.body[1:3]: + assert isinstance(imp, ast.Import) + assert imp.lineno == 2 + assert imp.col_offset == 0 s = """'doc string'\nfrom __future__ import with_statement\nother""" m = rewrite(s) if sys.version_info < (3, 7):