place assertion imports after __future__ statements and docstrings
This commit is contained in:
parent
c742e47de0
commit
9e6dfaefd9
|
@ -74,15 +74,23 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
if not mod.body:
|
if not mod.body:
|
||||||
# Nothing to do.
|
# Nothing to do.
|
||||||
return
|
return
|
||||||
# Insert some special imports at top but after any docstrings.
|
# Insert some special imports at top but after any docstrings and
|
||||||
|
# __future__ imports.
|
||||||
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
|
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
|
||||||
ast.alias("py", "@pylib"),
|
ast.alias("py", "@pylib"),
|
||||||
ast.alias("_pytest.assertrewrite", "@pytest_ar")]
|
ast.alias("_pytest.assertrewrite", "@pytest_ar")]
|
||||||
imports = [ast.Import([alias], lineno=0, col_offset=0)
|
imports = [ast.Import([alias], lineno=0, col_offset=0)
|
||||||
for alias in aliases]
|
for alias in aliases]
|
||||||
|
expect_docstring = True
|
||||||
pos = 0
|
pos = 0
|
||||||
if isinstance(mod.body[0], ast.Str):
|
for item in mod.body:
|
||||||
pos = 1
|
if (expect_docstring and isinstance(item, ast.Expr) and
|
||||||
|
isinstance(item.value, ast.Str)):
|
||||||
|
expect_docstring = False
|
||||||
|
elif (not isinstance(item, ast.ImportFrom) or item.level > 0 and
|
||||||
|
item.identifier != "__future__"):
|
||||||
|
break
|
||||||
|
pos += 1
|
||||||
mod.body[pos:pos] = imports
|
mod.body[pos:pos] = imports
|
||||||
# Collect asserts.
|
# Collect asserts.
|
||||||
asserts = []
|
asserts = []
|
||||||
|
|
|
@ -16,11 +16,15 @@ def teardown_module(mod):
|
||||||
del mod._old_reprcompare
|
del mod._old_reprcompare
|
||||||
|
|
||||||
|
|
||||||
|
def rewrite(src):
|
||||||
|
tree = ast.parse(src)
|
||||||
|
rewrite_asserts(tree)
|
||||||
|
return tree
|
||||||
|
|
||||||
def getmsg(f, extra_ns=None, must_pass=False):
|
def getmsg(f, extra_ns=None, must_pass=False):
|
||||||
"""Rewrite the assertions in f, run it, and get the failure message."""
|
"""Rewrite the assertions in f, run it, and get the failure message."""
|
||||||
src = '\n'.join(py.code.Code(f).source().lines)
|
src = '\n'.join(py.code.Code(f).source().lines)
|
||||||
mod = ast.parse(src)
|
mod = rewrite(src)
|
||||||
rewrite_asserts(mod)
|
|
||||||
code = compile(mod, "<test>", "exec")
|
code = compile(mod, "<test>", "exec")
|
||||||
ns = {}
|
ns = {}
|
||||||
if extra_ns is not None:
|
if extra_ns is not None:
|
||||||
|
@ -43,6 +47,26 @@ def getmsg(f, extra_ns=None, must_pass=False):
|
||||||
|
|
||||||
class TestAssertionRewrite:
|
class TestAssertionRewrite:
|
||||||
|
|
||||||
|
def test_place_initial_imports(self):
|
||||||
|
s = """'Doc string'"""
|
||||||
|
m = rewrite(s)
|
||||||
|
assert isinstance(m.body[0], ast.Expr)
|
||||||
|
assert isinstance(m.body[0].value, ast.Str)
|
||||||
|
for imp in m.body[1:]:
|
||||||
|
assert isinstance(imp, ast.Import)
|
||||||
|
s = """from __future__ import with_statement"""
|
||||||
|
m = rewrite(s)
|
||||||
|
assert isinstance(m.body[0], ast.ImportFrom)
|
||||||
|
for imp in m.body[1:]:
|
||||||
|
assert isinstance(imp, ast.Import)
|
||||||
|
s = """'doc string'\nfrom __future__ import with_statement"""
|
||||||
|
m = rewrite(s)
|
||||||
|
assert isinstance(m.body[0], ast.Expr)
|
||||||
|
assert isinstance(m.body[0].value, ast.Str)
|
||||||
|
assert isinstance(m.body[1], ast.ImportFrom)
|
||||||
|
for imp in m.body[2:]:
|
||||||
|
assert isinstance(imp, ast.Import)
|
||||||
|
|
||||||
def test_name(self):
|
def test_name(self):
|
||||||
def f():
|
def f():
|
||||||
assert False
|
assert False
|
||||||
|
|
Loading…
Reference in New Issue