Remove astor and reproduce the original assertion expression (#5512)
Remove astor and reproduce the original assertion expression
This commit is contained in:
commit
73d918db55
|
@ -1 +0,0 @@
|
|||
pytest now also depends on the `astor <https://pypi.org/project/astor/>`__ package.
|
1
setup.py
1
setup.py
|
@ -13,7 +13,6 @@ INSTALL_REQUIRES = [
|
|||
"pluggy>=0.12,<1.0",
|
||||
"importlib-metadata>=0.12",
|
||||
"wcwidth",
|
||||
"astor",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
"""Rewrite assertion AST to produce nice error messages"""
|
||||
import ast
|
||||
import errno
|
||||
import functools
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
import io
|
||||
import itertools
|
||||
import marshal
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import tokenize
|
||||
import types
|
||||
|
||||
import astor
|
||||
import atomicwrites
|
||||
|
||||
from _pytest._io.saferepr import saferepr
|
||||
|
@ -285,7 +287,7 @@ def _rewrite_test(fn, config):
|
|||
with open(fn, "rb") as f:
|
||||
source = f.read()
|
||||
tree = ast.parse(source, filename=fn)
|
||||
rewrite_asserts(tree, fn, config)
|
||||
rewrite_asserts(tree, source, fn, config)
|
||||
co = compile(tree, fn, "exec", dont_inherit=True)
|
||||
return stat, co
|
||||
|
||||
|
@ -327,9 +329,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
|
|||
return co
|
||||
|
||||
|
||||
def rewrite_asserts(mod, module_path=None, config=None):
|
||||
def rewrite_asserts(mod, source, module_path=None, config=None):
|
||||
"""Rewrite the assert statements in mod."""
|
||||
AssertionRewriter(module_path, config).run(mod)
|
||||
AssertionRewriter(module_path, config, source).run(mod)
|
||||
|
||||
|
||||
def _saferepr(obj):
|
||||
|
@ -457,6 +459,59 @@ def set_location(node, lineno, col_offset):
|
|||
return node
|
||||
|
||||
|
||||
def _get_assertion_exprs(src: bytes): # -> Dict[int, str]
|
||||
"""Returns a mapping from {lineno: "assertion test expression"}"""
|
||||
ret = {}
|
||||
|
||||
depth = 0
|
||||
lines = []
|
||||
assert_lineno = None
|
||||
seen_lines = set()
|
||||
|
||||
def _write_and_reset() -> None:
|
||||
nonlocal depth, lines, assert_lineno, seen_lines
|
||||
ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
|
||||
depth = 0
|
||||
lines = []
|
||||
assert_lineno = None
|
||||
seen_lines = set()
|
||||
|
||||
tokens = tokenize.tokenize(io.BytesIO(src).readline)
|
||||
for tp, src, (lineno, offset), _, line in tokens:
|
||||
if tp == tokenize.NAME and src == "assert":
|
||||
assert_lineno = lineno
|
||||
elif assert_lineno is not None:
|
||||
# keep track of depth for the assert-message `,` lookup
|
||||
if tp == tokenize.OP and src in "([{":
|
||||
depth += 1
|
||||
elif tp == tokenize.OP and src in ")]}":
|
||||
depth -= 1
|
||||
|
||||
if not lines:
|
||||
lines.append(line[offset:])
|
||||
seen_lines.add(lineno)
|
||||
# a non-nested comma separates the expression from the message
|
||||
elif depth == 0 and tp == tokenize.OP and src == ",":
|
||||
# one line assert with message
|
||||
if lineno in seen_lines and len(lines) == 1:
|
||||
offset_in_trimmed = offset + len(lines[-1]) - len(line)
|
||||
lines[-1] = lines[-1][:offset_in_trimmed]
|
||||
# multi-line assert with message
|
||||
elif lineno in seen_lines:
|
||||
lines[-1] = lines[-1][:offset]
|
||||
# multi line assert with escapd newline before message
|
||||
else:
|
||||
lines.append(line[:offset])
|
||||
_write_and_reset()
|
||||
elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
|
||||
_write_and_reset()
|
||||
elif lines and lineno not in seen_lines:
|
||||
lines.append(line)
|
||||
seen_lines.add(lineno)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class AssertionRewriter(ast.NodeVisitor):
|
||||
"""Assertion rewriting implementation.
|
||||
|
||||
|
@ -511,7 +566,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, module_path, config):
|
||||
def __init__(self, module_path, config, source):
|
||||
super().__init__()
|
||||
self.module_path = module_path
|
||||
self.config = config
|
||||
|
@ -521,6 +576,11 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
)
|
||||
else:
|
||||
self.enable_assertion_pass_hook = False
|
||||
self.source = source
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _assert_expr_to_lineno(self):
|
||||
return _get_assertion_exprs(self.source)
|
||||
|
||||
def run(self, mod):
|
||||
"""Find all assert statements in *mod* and rewrite them."""
|
||||
|
@ -738,7 +798,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
|
||||
# Passed
|
||||
fmt_pass = self.helper("_format_explanation", msg)
|
||||
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
|
||||
orig = self._assert_expr_to_lineno()[assert_.lineno]
|
||||
hook_call_pass = ast.Expr(
|
||||
self.helper(
|
||||
"_call_assertion_pass",
|
||||
|
|
|
@ -13,6 +13,7 @@ import py
|
|||
import _pytest._code
|
||||
import pytest
|
||||
from _pytest.assertion import util
|
||||
from _pytest.assertion.rewrite import _get_assertion_exprs
|
||||
from _pytest.assertion.rewrite import AssertionRewritingHook
|
||||
from _pytest.assertion.rewrite import PYTEST_TAG
|
||||
from _pytest.assertion.rewrite import rewrite_asserts
|
||||
|
@ -31,7 +32,7 @@ def teardown_module(mod):
|
|||
|
||||
def rewrite(src):
|
||||
tree = ast.parse(src)
|
||||
rewrite_asserts(tree)
|
||||
rewrite_asserts(tree, src.encode())
|
||||
return tree
|
||||
|
||||
|
||||
|
@ -1292,10 +1293,10 @@ class TestEarlyRewriteBailout:
|
|||
"""
|
||||
p = testdir.makepyfile(
|
||||
**{
|
||||
"tests/file.py": """
|
||||
def test_simple_failure():
|
||||
assert 1 + 1 == 3
|
||||
"""
|
||||
"tests/file.py": """\
|
||||
def test_simple_failure():
|
||||
assert 1 + 1 == 3
|
||||
"""
|
||||
}
|
||||
)
|
||||
testdir.syspathinsert(p.dirpath())
|
||||
|
@ -1315,19 +1316,19 @@ class TestEarlyRewriteBailout:
|
|||
|
||||
testdir.makepyfile(
|
||||
**{
|
||||
"test_setup_nonexisting_cwd.py": """
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
"test_setup_nonexisting_cwd.py": """\
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
d = tempfile.mkdtemp()
|
||||
os.chdir(d)
|
||||
shutil.rmtree(d)
|
||||
""",
|
||||
"test_test.py": """
|
||||
def test():
|
||||
pass
|
||||
""",
|
||||
d = tempfile.mkdtemp()
|
||||
os.chdir(d)
|
||||
shutil.rmtree(d)
|
||||
""",
|
||||
"test_test.py": """\
|
||||
def test():
|
||||
pass
|
||||
""",
|
||||
}
|
||||
)
|
||||
result = testdir.runpytest()
|
||||
|
@ -1339,23 +1340,22 @@ class TestAssertionPass:
|
|||
config = testdir.parseconfig()
|
||||
assert config.getini("enable_assertion_pass_hook") is False
|
||||
|
||||
def test_hook_call(self, testdir):
|
||||
@pytest.fixture
|
||||
def flag_on(self, testdir):
|
||||
testdir.makeini("[pytest]\nenable_assertion_pass_hook = True\n")
|
||||
|
||||
@pytest.fixture
|
||||
def hook_on(self, testdir):
|
||||
testdir.makeconftest(
|
||||
"""
|
||||
"""\
|
||||
def pytest_assertion_pass(item, lineno, orig, expl):
|
||||
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
|
||||
"""
|
||||
)
|
||||
|
||||
testdir.makeini(
|
||||
"""
|
||||
[pytest]
|
||||
enable_assertion_pass_hook = True
|
||||
"""
|
||||
)
|
||||
|
||||
def test_hook_call(self, testdir, flag_on, hook_on):
|
||||
testdir.makepyfile(
|
||||
"""
|
||||
"""\
|
||||
def test_simple():
|
||||
a=1
|
||||
b=2
|
||||
|
@ -1371,10 +1371,21 @@ class TestAssertionPass:
|
|||
)
|
||||
result = testdir.runpytest()
|
||||
result.stdout.fnmatch_lines(
|
||||
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
|
||||
"*Assertion Passed: a+b == c+d (1 + 2) == (3 + 0) at line 7*"
|
||||
)
|
||||
|
||||
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch):
|
||||
def test_hook_call_with_parens(self, testdir, flag_on, hook_on):
|
||||
testdir.makepyfile(
|
||||
"""\
|
||||
def f(): return 1
|
||||
def test():
|
||||
assert f()
|
||||
"""
|
||||
)
|
||||
result = testdir.runpytest()
|
||||
result.stdout.fnmatch_lines("*Assertion Passed: f() 1")
|
||||
|
||||
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch, flag_on):
|
||||
"""Assertion pass should not be called (and hence formatting should
|
||||
not occur) if there is no hook declared for pytest_assertion_pass"""
|
||||
|
||||
|
@ -1385,15 +1396,8 @@ class TestAssertionPass:
|
|||
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
|
||||
)
|
||||
|
||||
testdir.makeini(
|
||||
"""
|
||||
[pytest]
|
||||
enable_assertion_pass_hook = True
|
||||
"""
|
||||
)
|
||||
|
||||
testdir.makepyfile(
|
||||
"""
|
||||
"""\
|
||||
def test_simple():
|
||||
a=1
|
||||
b=2
|
||||
|
@ -1418,21 +1422,14 @@ class TestAssertionPass:
|
|||
)
|
||||
|
||||
testdir.makeconftest(
|
||||
"""
|
||||
"""\
|
||||
def pytest_assertion_pass(item, lineno, orig, expl):
|
||||
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
|
||||
"""
|
||||
)
|
||||
|
||||
testdir.makeini(
|
||||
"""
|
||||
[pytest]
|
||||
enable_assertion_pass_hook = False
|
||||
"""
|
||||
)
|
||||
|
||||
testdir.makepyfile(
|
||||
"""
|
||||
"""\
|
||||
def test_simple():
|
||||
a=1
|
||||
b=2
|
||||
|
@ -1444,3 +1441,90 @@ class TestAssertionPass:
|
|||
)
|
||||
result = testdir.runpytest()
|
||||
result.assert_outcomes(passed=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("src", "expected"),
|
||||
(
|
||||
# fmt: off
|
||||
pytest.param(b"", {}, id="trivial"),
|
||||
pytest.param(
|
||||
b"def x(): assert 1\n",
|
||||
{1: "1"},
|
||||
id="assert statement not on own line",
|
||||
),
|
||||
pytest.param(
|
||||
b"def x():\n"
|
||||
b" assert 1\n"
|
||||
b" assert 1+2\n",
|
||||
{2: "1", 3: "1+2"},
|
||||
id="multiple assertions",
|
||||
),
|
||||
pytest.param(
|
||||
# changes in encoding cause the byte offsets to be different
|
||||
"# -*- coding: latin1\n"
|
||||
"def ÀÀÀÀÀ(): assert 1\n".encode("latin1"),
|
||||
{2: "1"},
|
||||
id="latin1 encoded on first line\n",
|
||||
),
|
||||
pytest.param(
|
||||
# using the default utf-8 encoding
|
||||
"def ÀÀÀÀÀ(): assert 1\n".encode(),
|
||||
{1: "1"},
|
||||
id="utf-8 encoded on first line",
|
||||
),
|
||||
pytest.param(
|
||||
b"def x():\n"
|
||||
b" assert (\n"
|
||||
b" 1 + 2 # comment\n"
|
||||
b" )\n",
|
||||
{2: "(\n 1 + 2 # comment\n )"},
|
||||
id="multi-line assertion",
|
||||
),
|
||||
pytest.param(
|
||||
b"def x():\n"
|
||||
b" assert y == [\n"
|
||||
b" 1, 2, 3\n"
|
||||
b" ]\n",
|
||||
{2: "y == [\n 1, 2, 3\n ]"},
|
||||
id="multi line assert with list continuation",
|
||||
),
|
||||
pytest.param(
|
||||
b"def x():\n"
|
||||
b" assert 1 + \\\n"
|
||||
b" 2\n",
|
||||
{2: "1 + \\\n 2"},
|
||||
id="backslash continuation",
|
||||
),
|
||||
pytest.param(
|
||||
b"def x():\n"
|
||||
b" assert x, y\n",
|
||||
{2: "x"},
|
||||
id="assertion with message",
|
||||
),
|
||||
pytest.param(
|
||||
b"def x():\n"
|
||||
b" assert (\n"
|
||||
b" f(1, 2, 3)\n"
|
||||
b" ), 'f did not work!'\n",
|
||||
{2: "(\n f(1, 2, 3)\n )"},
|
||||
id="assertion with message, test spanning multiple lines",
|
||||
),
|
||||
pytest.param(
|
||||
b"def x():\n"
|
||||
b" assert \\\n"
|
||||
b" x\\\n"
|
||||
b" , 'failure message'\n",
|
||||
{2: "x"},
|
||||
id="escaped newlines plus message",
|
||||
),
|
||||
pytest.param(
|
||||
b"def x(): assert 5",
|
||||
{1: "5"},
|
||||
id="no newline at end of file",
|
||||
),
|
||||
# fmt: on
|
||||
),
|
||||
)
|
||||
def test_get_assertion_exprs(src, expected):
|
||||
assert _get_assertion_exprs(src) == expected
|
||||
|
|
Loading…
Reference in New Issue