place assertion imports after __future__ statements and docstrings

This commit is contained in:
Benjamin Peterson 2011-05-19 16:53:13 -05:00
parent c742e47de0
commit 9e6dfaefd9
2 changed files with 37 additions and 5 deletions

View File

@ -74,15 +74,23 @@ class AssertionRewriter(ast.NodeVisitor):
if not mod.body: if not mod.body:
# Nothing to do. # Nothing to do.
return return
# Insert some special imports at top but after any docstrings. # Insert some special imports at top but after any docstrings and
# __future__ imports.
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
ast.alias("py", "@pylib"), ast.alias("py", "@pylib"),
ast.alias("_pytest.assertrewrite", "@pytest_ar")] ast.alias("_pytest.assertrewrite", "@pytest_ar")]
imports = [ast.Import([alias], lineno=0, col_offset=0) imports = [ast.Import([alias], lineno=0, col_offset=0)
for alias in aliases] for alias in aliases]
expect_docstring = True
pos = 0 pos = 0
if isinstance(mod.body[0], ast.Str): for item in mod.body:
pos = 1 if (expect_docstring and isinstance(item, ast.Expr) and
isinstance(item.value, ast.Str)):
expect_docstring = False
elif (not isinstance(item, ast.ImportFrom) or item.level > 0 and
item.identifier != "__future__"):
break
pos += 1
mod.body[pos:pos] = imports mod.body[pos:pos] = imports
# Collect asserts. # Collect asserts.
asserts = [] asserts = []

View File

@ -16,11 +16,15 @@ def teardown_module(mod):
del mod._old_reprcompare del mod._old_reprcompare
def rewrite(src):
tree = ast.parse(src)
rewrite_asserts(tree)
return tree
def getmsg(f, extra_ns=None, must_pass=False): def getmsg(f, extra_ns=None, must_pass=False):
"""Rewrite the assertions in f, run it, and get the failure message.""" """Rewrite the assertions in f, run it, and get the failure message."""
src = '\n'.join(py.code.Code(f).source().lines) src = '\n'.join(py.code.Code(f).source().lines)
mod = ast.parse(src) mod = rewrite(src)
rewrite_asserts(mod)
code = compile(mod, "<test>", "exec") code = compile(mod, "<test>", "exec")
ns = {} ns = {}
if extra_ns is not None: if extra_ns is not None:
@ -43,6 +47,26 @@ def getmsg(f, extra_ns=None, must_pass=False):
class TestAssertionRewrite: class TestAssertionRewrite:
def test_place_initial_imports(self):
s = """'Doc string'"""
m = rewrite(s)
assert isinstance(m.body[0], ast.Expr)
assert isinstance(m.body[0].value, ast.Str)
for imp in m.body[1:]:
assert isinstance(imp, ast.Import)
s = """from __future__ import with_statement"""
m = rewrite(s)
assert isinstance(m.body[0], ast.ImportFrom)
for imp in m.body[1:]:
assert isinstance(imp, ast.Import)
s = """'doc string'\nfrom __future__ import with_statement"""
m = rewrite(s)
assert isinstance(m.body[0], ast.Expr)
assert isinstance(m.body[0].value, ast.Str)
assert isinstance(m.body[1], ast.ImportFrom)
for imp in m.body[2:]:
assert isinstance(imp, ast.Import)
def test_name(self): def test_name(self):
def f(): def f():
assert False assert False