diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 992002b81..c3966340a 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -595,23 +595,26 @@ class AssertionRewriter(ast.NodeVisitor): # docstrings and __future__ imports. aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), ast.alias("_pytest.assertion.rewrite", "@pytest_ar")] - expect_docstring = True + doc = getattr(mod, "docstring", None) + expect_docstring = doc is None + if doc is not None and self.is_rewrite_disabled(doc): + return pos = 0 - lineno = 0 + lineno = 1 for item in mod.body: if (expect_docstring and isinstance(item, ast.Expr) and isinstance(item.value, ast.Str)): doc = item.value.s - if "PYTEST_DONT_REWRITE" in doc: - # The module has disabled assertion rewriting. + if self.is_rewrite_disabled(doc): return - lineno += len(doc) - 1 expect_docstring = False elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or item.module != "__future__"): lineno = item.lineno break pos += 1 + else: + lineno = item.lineno imports = [ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases] mod.body[pos:pos] = imports @@ -637,6 +640,9 @@ class AssertionRewriter(ast.NodeVisitor): not isinstance(field, ast.expr)): nodes.append(field) + def is_rewrite_disabled(self, docstring): + return "PYTEST_DONT_REWRITE" in docstring + def variable(self): """Get a new variable.""" # Use a character invalid in python identifiers to avoid clashing. diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 2d61b7440..02270e157 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -65,13 +65,18 @@ class TestAssertionRewrite(object): def test_place_initial_imports(self): s = """'Doc string'\nother = stuff""" m = rewrite(s) - assert isinstance(m.body[0], ast.Expr) - assert isinstance(m.body[0].value, ast.Str) - for imp in m.body[1:3]: + # Module docstrings in 3.7 are part of Module node, it's not in the body + # so we remove it so the following body items have the same indexes on + # all Python versions + if sys.version_info < (3, 7): + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + del m.body[0] + for imp in m.body[0:2]: assert isinstance(imp, ast.Import) assert imp.lineno == 2 assert imp.col_offset == 0 - assert isinstance(m.body[3], ast.Assign) + assert isinstance(m.body[2], ast.Assign) s = """from __future__ import with_statement\nother_stuff""" m = rewrite(s) assert isinstance(m.body[0], ast.ImportFrom) @@ -80,16 +85,29 @@ class TestAssertionRewrite(object): assert imp.lineno == 2 assert imp.col_offset == 0 assert isinstance(m.body[3], ast.Expr) + s = """'doc string'\nfrom __future__ import with_statement""" + m = rewrite(s) + if sys.version_info < (3, 7): + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + del m.body[0] + assert isinstance(m.body[0], ast.ImportFrom) + for imp in m.body[1:3]: + assert isinstance(imp, ast.Import) + assert imp.lineno == 2 + assert imp.col_offset == 0 s = """'doc string'\nfrom __future__ import with_statement\nother""" m = rewrite(s) - assert isinstance(m.body[0], ast.Expr) - assert isinstance(m.body[0].value, ast.Str) - assert isinstance(m.body[1], ast.ImportFrom) - for imp in m.body[2:4]: + if sys.version_info < (3, 7): + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + del m.body[0] + assert isinstance(m.body[0], ast.ImportFrom) + for imp in m.body[1:3]: assert isinstance(imp, ast.Import) assert imp.lineno == 3 assert imp.col_offset == 0 - assert isinstance(m.body[4], ast.Expr) + assert isinstance(m.body[3], ast.Expr) s = """from . import relative\nother_stuff""" m = rewrite(s) for imp in m.body[0:2]: @@ -101,10 +119,14 @@ class TestAssertionRewrite(object): def test_dont_rewrite(self): s = """'PYTEST_DONT_REWRITE'\nassert 14""" m = rewrite(s) - assert len(m.body) == 2 - assert isinstance(m.body[0].value, ast.Str) - assert isinstance(m.body[1], ast.Assert) - assert m.body[1].msg is None + if sys.version_info < (3, 7): + assert len(m.body) == 2 + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + del m.body[0] + else: + assert len(m.body) == 1 + assert m.body[0].msg is None def test_name(self): def f():