Merge pull request #9163 from bluetech/rewrite-end-lineno

rewrite: fixup end_lineno, end_col_offset of rewritten asserts
This commit is contained in:
Ran Benita 2021-10-05 16:46:29 +03:00 committed by GitHub
commit 54811b24e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 15 deletions

View File

@ -0,0 +1 @@
The end line number and end column offset are now properly set for rewritten assert statements.

View File

@ -19,6 +19,7 @@ from typing import Callable
from typing import Dict
from typing import IO
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
@ -539,19 +540,11 @@ BINOP_MAP = {
}
def set_location(node, lineno, col_offset):
"""Set node location information recursively."""
def _fix(node, lineno, col_offset):
if "lineno" in node._attributes:
node.lineno = lineno
if "col_offset" in node._attributes:
node.col_offset = col_offset
for child in ast.iter_child_nodes(node):
_fix(child, lineno, col_offset)
_fix(node, lineno, col_offset)
return node
def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
"""Recursively yield node and all its children in depth-first order."""
yield node
for child in ast.iter_child_nodes(node):
yield from traverse_node(child)
@functools.lru_cache(maxsize=1)
@ -954,9 +947,10 @@ class AssertionRewriter(ast.NodeVisitor):
variables = [ast.Name(name, ast.Store()) for name in self.variables]
clear = ast.Assign(variables, ast.NameConstant(None))
self.statements.append(clear)
# Fix line numbers.
# Fix locations (line numbers/column offsets).
for stmt in self.statements:
set_location(stmt, assert_.lineno, assert_.col_offset)
for node in traverse_node(stmt):
ast.copy_location(node, assert_)
return self.statements
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:

View File

@ -111,6 +111,28 @@ class TestAssertionRewrite:
assert imp.col_offset == 0
assert isinstance(m.body[3], ast.Expr)
def test_location_is_set(self) -> None:
s = textwrap.dedent(
"""
assert False, (
"Ouch"
)
"""
)
m = rewrite(s)
for node in m.body:
if isinstance(node, ast.Import):
continue
for n in [node, *ast.iter_child_nodes(node)]:
assert n.lineno == 3
assert n.col_offset == 0
if sys.version_info >= (3, 8):
assert n.end_lineno == 6
assert n.end_col_offset == 3
def test_dont_rewrite(self) -> None:
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
m = rewrite(s)