DRY in TestAssertionRewrite

This commit is contained in:
Miro Hrončok 2018-06-04 13:45:35 +02:00
parent 9a6fa33c69
commit 39ebdab1bc
1 changed files with 15 additions and 25 deletions

View File

@ -65,20 +65,27 @@ def getmsg(f, extra_ns=None, must_pass=False):
pytest.fail("function didn't raise at all") pytest.fail("function didn't raise at all")
def python_version_has_docstring_in_module_node(): def adjust_body_for_new_docstring_in_module_node(m):
"""Module docstrings in 3.8 are part of Module node. """Module docstrings in 3.8 are part of Module node.
This was briefly in 3.7 as well but got reverted in beta 5. This was briefly in 3.7 as well but got reverted in beta 5.
It's not in the body so we remove it so the following body items have
the same indexes on all Python versions:
TODO: TODO:
We have a complicated sys.version_info if in here to ease testing on We have a complicated sys.version_info if in here to ease testing on
various Python 3.7 versions, but we should remove the 3.7 check after various Python 3.7 versions, but we should remove the 3.7 check after
3.7 is released as stable to make this check more straightforward. 3.7 is released as stable to make this check more straightforward.
""" """
return ( if (
sys.version_info < (3, 8) sys.version_info < (3, 8)
and not ((3, 7) <= sys.version_info <= (3, 7, 0, "beta", 4)) and not ((3, 7) <= sys.version_info <= (3, 7, 0, "beta", 4))
) ):
assert len(m.body) > 1
assert isinstance(m.body[0], ast.Expr)
assert isinstance(m.body[0].value, ast.Str)
del m.body[0]
class TestAssertionRewrite(object): class TestAssertionRewrite(object):
@ -86,13 +93,7 @@ class TestAssertionRewrite(object):
def test_place_initial_imports(self): def test_place_initial_imports(self):
s = """'Doc string'\nother = stuff""" s = """'Doc string'\nother = stuff"""
m = rewrite(s) m = rewrite(s)
# Module docstrings in some new Python versions are part of Module node adjust_body_for_new_docstring_in_module_node(m)
# It's not in the body so we remove it so the following body items have
# the same indexes on all Python versions:
if python_version_has_docstring_in_module_node():
assert isinstance(m.body[0], ast.Expr)
assert isinstance(m.body[0].value, ast.Str)
del m.body[0]
for imp in m.body[0:2]: for imp in m.body[0:2]:
assert isinstance(imp, ast.Import) assert isinstance(imp, ast.Import)
assert imp.lineno == 2 assert imp.lineno == 2
@ -108,10 +109,7 @@ class TestAssertionRewrite(object):
assert isinstance(m.body[3], ast.Expr) assert isinstance(m.body[3], ast.Expr)
s = """'doc string'\nfrom __future__ import with_statement""" s = """'doc string'\nfrom __future__ import with_statement"""
m = rewrite(s) m = rewrite(s)
if python_version_has_docstring_in_module_node(): adjust_body_for_new_docstring_in_module_node(m)
assert isinstance(m.body[0], ast.Expr)
assert isinstance(m.body[0].value, ast.Str)
del m.body[0]
assert isinstance(m.body[0], ast.ImportFrom) assert isinstance(m.body[0], ast.ImportFrom)
for imp in m.body[1:3]: for imp in m.body[1:3]:
assert isinstance(imp, ast.Import) assert isinstance(imp, ast.Import)
@ -119,10 +117,7 @@ class TestAssertionRewrite(object):
assert imp.col_offset == 0 assert imp.col_offset == 0
s = """'doc string'\nfrom __future__ import with_statement\nother""" s = """'doc string'\nfrom __future__ import with_statement\nother"""
m = rewrite(s) m = rewrite(s)
if python_version_has_docstring_in_module_node(): adjust_body_for_new_docstring_in_module_node(m)
assert isinstance(m.body[0], ast.Expr)
assert isinstance(m.body[0].value, ast.Str)
del m.body[0]
assert isinstance(m.body[0], ast.ImportFrom) assert isinstance(m.body[0], ast.ImportFrom)
for imp in m.body[1:3]: for imp in m.body[1:3]:
assert isinstance(imp, ast.Import) assert isinstance(imp, ast.Import)
@ -140,13 +135,8 @@ class TestAssertionRewrite(object):
def test_dont_rewrite(self): def test_dont_rewrite(self):
s = """'PYTEST_DONT_REWRITE'\nassert 14""" s = """'PYTEST_DONT_REWRITE'\nassert 14"""
m = rewrite(s) m = rewrite(s)
if python_version_has_docstring_in_module_node(): adjust_body_for_new_docstring_in_module_node(m)
assert len(m.body) == 2 assert len(m.body) == 1
assert isinstance(m.body[0], ast.Expr)
assert isinstance(m.body[0].value, ast.Str)
del m.body[0]
else:
assert len(m.body) == 1
assert m.body[0].msg is None assert m.body[0].msg is None
def test_dont_rewrite_plugin(self, testdir): def test_dont_rewrite_plugin(self, testdir):