give initial imports a reasonable lineno

This commit is contained in:
Benjamin Peterson 2011-05-24 17:21:58 -05:00
parent 7ba8fee3dc
commit 9c4f6791e5
2 changed files with 11 additions and 2 deletions

View File

@ -92,18 +92,21 @@ class AssertionRewriter(ast.NodeVisitor):
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
ast.alias("py", "@pylib"), ast.alias("py", "@pylib"),
ast.alias("_pytest.assertrewrite", "@pytest_ar")] ast.alias("_pytest.assertrewrite", "@pytest_ar")]
imports = [ast.Import([alias], lineno=0, col_offset=0)
for alias in aliases]
expect_docstring = True expect_docstring = True
pos = 0 pos = 0
lineno = 0
for item in mod.body: for item in mod.body:
if (expect_docstring and isinstance(item, ast.Expr) and if (expect_docstring and isinstance(item, ast.Expr) and
isinstance(item.value, ast.Str)): isinstance(item.value, ast.Str)):
lineno += len(item.value.s.splitlines()) - 1
expect_docstring = False expect_docstring = False
elif (not isinstance(item, ast.ImportFrom) or item.level > 0 and elif (not isinstance(item, ast.ImportFrom) or item.level > 0 and
item.identifier != "__future__"): item.identifier != "__future__"):
lineno = item.lineno
break break
pos += 1 pos += 1
imports = [ast.Import([alias], lineno=lineno, col_offset=0)
for alias in aliases]
mod.body[pos:pos] = imports mod.body[pos:pos] = imports
# Collect asserts. # Collect asserts.
nodes = collections.deque([mod]) nodes = collections.deque([mod])

View File

@ -54,12 +54,16 @@ class TestAssertionRewrite:
assert isinstance(m.body[0].value, ast.Str) assert isinstance(m.body[0].value, ast.Str)
for imp in m.body[1:4]: for imp in m.body[1:4]:
assert isinstance(imp, ast.Import) assert isinstance(imp, ast.Import)
assert imp.lineno == 2
assert imp.col_offset == 0
assert isinstance(m.body[4], ast.Assign) assert isinstance(m.body[4], ast.Assign)
s = """from __future__ import with_statement\nother_stuff""" s = """from __future__ import with_statement\nother_stuff"""
m = rewrite(s) m = rewrite(s)
assert isinstance(m.body[0], ast.ImportFrom) assert isinstance(m.body[0], ast.ImportFrom)
for imp in m.body[1:4]: for imp in m.body[1:4]:
assert isinstance(imp, ast.Import) assert isinstance(imp, ast.Import)
assert imp.lineno == 2
assert imp.col_offset == 0
assert isinstance(m.body[4], ast.Expr) assert isinstance(m.body[4], ast.Expr)
s = """'doc string'\nfrom __future__ import with_statement\nother""" s = """'doc string'\nfrom __future__ import with_statement\nother"""
m = rewrite(s) m = rewrite(s)
@ -68,6 +72,8 @@ class TestAssertionRewrite:
assert isinstance(m.body[1], ast.ImportFrom) assert isinstance(m.body[1], ast.ImportFrom)
for imp in m.body[2:5]: for imp in m.body[2:5]:
assert isinstance(imp, ast.Import) assert isinstance(imp, ast.Import)
assert imp.lineno == 3
assert imp.col_offset == 0
assert isinstance(m.body[5], ast.Expr) assert isinstance(m.body[5], ast.Expr)
def test_name(self): def test_name(self):