Remove astor and reproduce the original assertion expression

This commit is contained in:
Anthony Sottile 2019-06-27 19:11:20 -07:00
parent 3c9b46f781
commit 7ee244476a
4 changed files with 196 additions and 54 deletions

View File

@ -1 +0,0 @@
pytest now also depends on the `astor <https://pypi.org/project/astor/>`__ package.

View File

@ -13,7 +13,6 @@ INSTALL_REQUIRES = [
"pluggy>=0.12,<1.0",
"importlib-metadata>=0.12",
"wcwidth",
"astor",
]

View File

@ -1,16 +1,18 @@
"""Rewrite assertion AST to produce nice error messages"""
import ast
import errno
import functools
import importlib.machinery
import importlib.util
import io
import itertools
import marshal
import os
import struct
import sys
import tokenize
import types
import astor
import atomicwrites
from _pytest._io.saferepr import saferepr
@ -285,7 +287,7 @@ def _rewrite_test(fn, config):
with open(fn, "rb") as f:
source = f.read()
tree = ast.parse(source, filename=fn)
rewrite_asserts(tree, fn, config)
rewrite_asserts(tree, source, fn, config)
co = compile(tree, fn, "exec", dont_inherit=True)
return stat, co
@ -327,9 +329,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co
def rewrite_asserts(mod, module_path=None, config=None):
def rewrite_asserts(mod, source, module_path=None, config=None):
"""Rewrite the assert statements in mod."""
AssertionRewriter(module_path, config).run(mod)
AssertionRewriter(module_path, config, source).run(mod)
def _saferepr(obj):
@ -457,6 +459,59 @@ def set_location(node, lineno, col_offset):
return node
def _get_assertion_exprs(src: bytes): # -> Dict[int, str]
"""Returns a mapping from {lineno: "assertion test expression"}"""
ret = {}
depth = 0
lines = []
assert_lineno = None
seen_lines = set()
def _write_and_reset() -> None:
nonlocal depth, lines, assert_lineno, seen_lines
ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
depth = 0
lines = []
assert_lineno = None
seen_lines = set()
tokens = tokenize.tokenize(io.BytesIO(src).readline)
for tp, src, (lineno, offset), _, line in tokens:
if tp == tokenize.NAME and src == "assert":
assert_lineno = lineno
elif assert_lineno is not None:
# keep track of depth for the assert-message `,` lookup
if tp == tokenize.OP and src in "([{":
depth += 1
elif tp == tokenize.OP and src in ")]}":
depth -= 1
if not lines:
lines.append(line[offset:])
seen_lines.add(lineno)
# a non-nested comma separates the expression from the message
elif depth == 0 and tp == tokenize.OP and src == ",":
# one line assert with message
if lineno in seen_lines and len(lines) == 1:
offset_in_trimmed = offset + len(lines[-1]) - len(line)
lines[-1] = lines[-1][:offset_in_trimmed]
# multi-line assert with message
elif lineno in seen_lines:
lines[-1] = lines[-1][:offset]
# multi line assert with escapd newline before message
else:
lines.append(line[:offset])
_write_and_reset()
elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
_write_and_reset()
elif lines and lineno not in seen_lines:
lines.append(line)
seen_lines.add(lineno)
return ret
class AssertionRewriter(ast.NodeVisitor):
"""Assertion rewriting implementation.
@ -511,7 +566,7 @@ class AssertionRewriter(ast.NodeVisitor):
"""
def __init__(self, module_path, config):
def __init__(self, module_path, config, source):
super().__init__()
self.module_path = module_path
self.config = config
@ -521,6 +576,11 @@ class AssertionRewriter(ast.NodeVisitor):
)
else:
self.enable_assertion_pass_hook = False
self.source = source
@functools.lru_cache(maxsize=1)
def _assert_expr_to_lineno(self):
return _get_assertion_exprs(self.source)
def run(self, mod):
"""Find all assert statements in *mod* and rewrite them."""
@ -738,7 +798,7 @@ class AssertionRewriter(ast.NodeVisitor):
# Passed
fmt_pass = self.helper("_format_explanation", msg)
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
orig = self._assert_expr_to_lineno()[assert_.lineno]
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass",

View File

@ -13,6 +13,7 @@ import py
import _pytest._code
import pytest
from _pytest.assertion import util
from _pytest.assertion.rewrite import _get_assertion_exprs
from _pytest.assertion.rewrite import AssertionRewritingHook
from _pytest.assertion.rewrite import PYTEST_TAG
from _pytest.assertion.rewrite import rewrite_asserts
@ -31,7 +32,7 @@ def teardown_module(mod):
def rewrite(src):
tree = ast.parse(src)
rewrite_asserts(tree)
rewrite_asserts(tree, src.encode())
return tree
@ -1292,10 +1293,10 @@ class TestEarlyRewriteBailout:
"""
p = testdir.makepyfile(
**{
"tests/file.py": """
def test_simple_failure():
assert 1 + 1 == 3
"""
"tests/file.py": """\
def test_simple_failure():
assert 1 + 1 == 3
"""
}
)
testdir.syspathinsert(p.dirpath())
@ -1315,19 +1316,19 @@ class TestEarlyRewriteBailout:
testdir.makepyfile(
**{
"test_setup_nonexisting_cwd.py": """
import os
import shutil
import tempfile
"test_setup_nonexisting_cwd.py": """\
import os
import shutil
import tempfile
d = tempfile.mkdtemp()
os.chdir(d)
shutil.rmtree(d)
""",
"test_test.py": """
def test():
pass
""",
d = tempfile.mkdtemp()
os.chdir(d)
shutil.rmtree(d)
""",
"test_test.py": """\
def test():
pass
""",
}
)
result = testdir.runpytest()
@ -1339,23 +1340,22 @@ class TestAssertionPass:
config = testdir.parseconfig()
assert config.getini("enable_assertion_pass_hook") is False
def test_hook_call(self, testdir):
@pytest.fixture
def flag_on(self, testdir):
testdir.makeini("[pytest]\nenable_assertion_pass_hook = True\n")
@pytest.fixture
def hook_on(self, testdir):
testdir.makeconftest(
"""
"""\
def pytest_assertion_pass(item, lineno, orig, expl):
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
"""
)
testdir.makeini(
"""
[pytest]
enable_assertion_pass_hook = True
"""
)
def test_hook_call(self, testdir, flag_on, hook_on):
testdir.makepyfile(
"""
"""\
def test_simple():
a=1
b=2
@ -1371,10 +1371,21 @@ class TestAssertionPass:
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
"*Assertion Passed: a+b == c+d (1 + 2) == (3 + 0) at line 7*"
)
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch):
def test_hook_call_with_parens(self, testdir, flag_on, hook_on):
testdir.makepyfile(
"""\
def f(): return 1
def test():
assert f()
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines("*Assertion Passed: f() 1")
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch, flag_on):
"""Assertion pass should not be called (and hence formatting should
not occur) if there is no hook declared for pytest_assertion_pass"""
@ -1385,15 +1396,8 @@ class TestAssertionPass:
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
)
testdir.makeini(
"""
[pytest]
enable_assertion_pass_hook = True
"""
)
testdir.makepyfile(
"""
"""\
def test_simple():
a=1
b=2
@ -1418,21 +1422,14 @@ class TestAssertionPass:
)
testdir.makeconftest(
"""
"""\
def pytest_assertion_pass(item, lineno, orig, expl):
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
"""
)
testdir.makeini(
"""
[pytest]
enable_assertion_pass_hook = False
"""
)
testdir.makepyfile(
"""
"""\
def test_simple():
a=1
b=2
@ -1444,3 +1441,90 @@ class TestAssertionPass:
)
result = testdir.runpytest()
result.assert_outcomes(passed=1)
@pytest.mark.parametrize(
("src", "expected"),
(
# fmt: off
pytest.param(b"", {}, id="trivial"),
pytest.param(
b"def x(): assert 1\n",
{1: "1"},
id="assert statement not on own line",
),
pytest.param(
b"def x():\n"
b" assert 1\n"
b" assert 1+2\n",
{2: "1", 3: "1+2"},
id="multiple assertions",
),
pytest.param(
# changes in encoding cause the byte offsets to be different
"# -*- coding: latin1\n"
"def ÀÀÀÀÀ(): assert 1\n".encode("latin1"),
{2: "1"},
id="latin1 encoded on first line\n",
),
pytest.param(
# using the default utf-8 encoding
"def ÀÀÀÀÀ(): assert 1\n".encode(),
{1: "1"},
id="utf-8 encoded on first line",
),
pytest.param(
b"def x():\n"
b" assert (\n"
b" 1 + 2 # comment\n"
b" )\n",
{2: "(\n 1 + 2 # comment\n )"},
id="multi-line assertion",
),
pytest.param(
b"def x():\n"
b" assert y == [\n"
b" 1, 2, 3\n"
b" ]\n",
{2: "y == [\n 1, 2, 3\n ]"},
id="multi line assert with list continuation",
),
pytest.param(
b"def x():\n"
b" assert 1 + \\\n"
b" 2\n",
{2: "1 + \\\n 2"},
id="backslash continuation",
),
pytest.param(
b"def x():\n"
b" assert x, y\n",
{2: "x"},
id="assertion with message",
),
pytest.param(
b"def x():\n"
b" assert (\n"
b" f(1, 2, 3)\n"
b" ), 'f did not work!'\n",
{2: "(\n f(1, 2, 3)\n )"},
id="assertion with message, test spanning multiple lines",
),
pytest.param(
b"def x():\n"
b" assert \\\n"
b" x\\\n"
b" , 'failure message'\n",
{2: "x"},
id="escaped newlines plus message",
),
pytest.param(
b"def x(): assert 5",
{1: "5"},
id="no newline at end of file",
),
# fmt: on
),
)
def test_get_assertion_exprs(src, expected):
assert _get_assertion_exprs(src) == expected