Remove astor and reproduce the original assertion expression
This commit is contained in:
parent
3c9b46f781
commit
7ee244476a
|
@ -1 +0,0 @@
|
||||||
pytest now also depends on the `astor <https://pypi.org/project/astor/>`__ package.
|
|
1
setup.py
1
setup.py
|
@ -13,7 +13,6 @@ INSTALL_REQUIRES = [
|
||||||
"pluggy>=0.12,<1.0",
|
"pluggy>=0.12,<1.0",
|
||||||
"importlib-metadata>=0.12",
|
"importlib-metadata>=0.12",
|
||||||
"wcwidth",
|
"wcwidth",
|
||||||
"astor",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
"""Rewrite assertion AST to produce nice error messages"""
|
"""Rewrite assertion AST to produce nice error messages"""
|
||||||
import ast
|
import ast
|
||||||
import errno
|
import errno
|
||||||
|
import functools
|
||||||
import importlib.machinery
|
import importlib.machinery
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import io
|
||||||
import itertools
|
import itertools
|
||||||
import marshal
|
import marshal
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
|
import tokenize
|
||||||
import types
|
import types
|
||||||
|
|
||||||
import astor
|
|
||||||
import atomicwrites
|
import atomicwrites
|
||||||
|
|
||||||
from _pytest._io.saferepr import saferepr
|
from _pytest._io.saferepr import saferepr
|
||||||
|
@ -285,7 +287,7 @@ def _rewrite_test(fn, config):
|
||||||
with open(fn, "rb") as f:
|
with open(fn, "rb") as f:
|
||||||
source = f.read()
|
source = f.read()
|
||||||
tree = ast.parse(source, filename=fn)
|
tree = ast.parse(source, filename=fn)
|
||||||
rewrite_asserts(tree, fn, config)
|
rewrite_asserts(tree, source, fn, config)
|
||||||
co = compile(tree, fn, "exec", dont_inherit=True)
|
co = compile(tree, fn, "exec", dont_inherit=True)
|
||||||
return stat, co
|
return stat, co
|
||||||
|
|
||||||
|
@ -327,9 +329,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
|
||||||
return co
|
return co
|
||||||
|
|
||||||
|
|
||||||
def rewrite_asserts(mod, module_path=None, config=None):
|
def rewrite_asserts(mod, source, module_path=None, config=None):
|
||||||
"""Rewrite the assert statements in mod."""
|
"""Rewrite the assert statements in mod."""
|
||||||
AssertionRewriter(module_path, config).run(mod)
|
AssertionRewriter(module_path, config, source).run(mod)
|
||||||
|
|
||||||
|
|
||||||
def _saferepr(obj):
|
def _saferepr(obj):
|
||||||
|
@ -457,6 +459,59 @@ def set_location(node, lineno, col_offset):
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def _get_assertion_exprs(src: bytes): # -> Dict[int, str]
|
||||||
|
"""Returns a mapping from {lineno: "assertion test expression"}"""
|
||||||
|
ret = {}
|
||||||
|
|
||||||
|
depth = 0
|
||||||
|
lines = []
|
||||||
|
assert_lineno = None
|
||||||
|
seen_lines = set()
|
||||||
|
|
||||||
|
def _write_and_reset() -> None:
|
||||||
|
nonlocal depth, lines, assert_lineno, seen_lines
|
||||||
|
ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
|
||||||
|
depth = 0
|
||||||
|
lines = []
|
||||||
|
assert_lineno = None
|
||||||
|
seen_lines = set()
|
||||||
|
|
||||||
|
tokens = tokenize.tokenize(io.BytesIO(src).readline)
|
||||||
|
for tp, src, (lineno, offset), _, line in tokens:
|
||||||
|
if tp == tokenize.NAME and src == "assert":
|
||||||
|
assert_lineno = lineno
|
||||||
|
elif assert_lineno is not None:
|
||||||
|
# keep track of depth for the assert-message `,` lookup
|
||||||
|
if tp == tokenize.OP and src in "([{":
|
||||||
|
depth += 1
|
||||||
|
elif tp == tokenize.OP and src in ")]}":
|
||||||
|
depth -= 1
|
||||||
|
|
||||||
|
if not lines:
|
||||||
|
lines.append(line[offset:])
|
||||||
|
seen_lines.add(lineno)
|
||||||
|
# a non-nested comma separates the expression from the message
|
||||||
|
elif depth == 0 and tp == tokenize.OP and src == ",":
|
||||||
|
# one line assert with message
|
||||||
|
if lineno in seen_lines and len(lines) == 1:
|
||||||
|
offset_in_trimmed = offset + len(lines[-1]) - len(line)
|
||||||
|
lines[-1] = lines[-1][:offset_in_trimmed]
|
||||||
|
# multi-line assert with message
|
||||||
|
elif lineno in seen_lines:
|
||||||
|
lines[-1] = lines[-1][:offset]
|
||||||
|
# multi line assert with escapd newline before message
|
||||||
|
else:
|
||||||
|
lines.append(line[:offset])
|
||||||
|
_write_and_reset()
|
||||||
|
elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
|
||||||
|
_write_and_reset()
|
||||||
|
elif lines and lineno not in seen_lines:
|
||||||
|
lines.append(line)
|
||||||
|
seen_lines.add(lineno)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class AssertionRewriter(ast.NodeVisitor):
|
class AssertionRewriter(ast.NodeVisitor):
|
||||||
"""Assertion rewriting implementation.
|
"""Assertion rewriting implementation.
|
||||||
|
|
||||||
|
@ -511,7 +566,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, module_path, config):
|
def __init__(self, module_path, config, source):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module_path = module_path
|
self.module_path = module_path
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -521,6 +576,11 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.enable_assertion_pass_hook = False
|
self.enable_assertion_pass_hook = False
|
||||||
|
self.source = source
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=1)
|
||||||
|
def _assert_expr_to_lineno(self):
|
||||||
|
return _get_assertion_exprs(self.source)
|
||||||
|
|
||||||
def run(self, mod):
|
def run(self, mod):
|
||||||
"""Find all assert statements in *mod* and rewrite them."""
|
"""Find all assert statements in *mod* and rewrite them."""
|
||||||
|
@ -738,7 +798,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
|
|
||||||
# Passed
|
# Passed
|
||||||
fmt_pass = self.helper("_format_explanation", msg)
|
fmt_pass = self.helper("_format_explanation", msg)
|
||||||
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
|
orig = self._assert_expr_to_lineno()[assert_.lineno]
|
||||||
hook_call_pass = ast.Expr(
|
hook_call_pass = ast.Expr(
|
||||||
self.helper(
|
self.helper(
|
||||||
"_call_assertion_pass",
|
"_call_assertion_pass",
|
||||||
|
|
|
@ -13,6 +13,7 @@ import py
|
||||||
import _pytest._code
|
import _pytest._code
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest.assertion import util
|
from _pytest.assertion import util
|
||||||
|
from _pytest.assertion.rewrite import _get_assertion_exprs
|
||||||
from _pytest.assertion.rewrite import AssertionRewritingHook
|
from _pytest.assertion.rewrite import AssertionRewritingHook
|
||||||
from _pytest.assertion.rewrite import PYTEST_TAG
|
from _pytest.assertion.rewrite import PYTEST_TAG
|
||||||
from _pytest.assertion.rewrite import rewrite_asserts
|
from _pytest.assertion.rewrite import rewrite_asserts
|
||||||
|
@ -31,7 +32,7 @@ def teardown_module(mod):
|
||||||
|
|
||||||
def rewrite(src):
|
def rewrite(src):
|
||||||
tree = ast.parse(src)
|
tree = ast.parse(src)
|
||||||
rewrite_asserts(tree)
|
rewrite_asserts(tree, src.encode())
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
|
|
||||||
|
@ -1292,10 +1293,10 @@ class TestEarlyRewriteBailout:
|
||||||
"""
|
"""
|
||||||
p = testdir.makepyfile(
|
p = testdir.makepyfile(
|
||||||
**{
|
**{
|
||||||
"tests/file.py": """
|
"tests/file.py": """\
|
||||||
def test_simple_failure():
|
def test_simple_failure():
|
||||||
assert 1 + 1 == 3
|
assert 1 + 1 == 3
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
testdir.syspathinsert(p.dirpath())
|
testdir.syspathinsert(p.dirpath())
|
||||||
|
@ -1315,19 +1316,19 @@ class TestEarlyRewriteBailout:
|
||||||
|
|
||||||
testdir.makepyfile(
|
testdir.makepyfile(
|
||||||
**{
|
**{
|
||||||
"test_setup_nonexisting_cwd.py": """
|
"test_setup_nonexisting_cwd.py": """\
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
d = tempfile.mkdtemp()
|
d = tempfile.mkdtemp()
|
||||||
os.chdir(d)
|
os.chdir(d)
|
||||||
shutil.rmtree(d)
|
shutil.rmtree(d)
|
||||||
""",
|
""",
|
||||||
"test_test.py": """
|
"test_test.py": """\
|
||||||
def test():
|
def test():
|
||||||
pass
|
pass
|
||||||
""",
|
""",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = testdir.runpytest()
|
result = testdir.runpytest()
|
||||||
|
@ -1339,23 +1340,22 @@ class TestAssertionPass:
|
||||||
config = testdir.parseconfig()
|
config = testdir.parseconfig()
|
||||||
assert config.getini("enable_assertion_pass_hook") is False
|
assert config.getini("enable_assertion_pass_hook") is False
|
||||||
|
|
||||||
def test_hook_call(self, testdir):
|
@pytest.fixture
|
||||||
|
def flag_on(self, testdir):
|
||||||
|
testdir.makeini("[pytest]\nenable_assertion_pass_hook = True\n")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def hook_on(self, testdir):
|
||||||
testdir.makeconftest(
|
testdir.makeconftest(
|
||||||
"""
|
"""\
|
||||||
def pytest_assertion_pass(item, lineno, orig, expl):
|
def pytest_assertion_pass(item, lineno, orig, expl):
|
||||||
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
|
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
testdir.makeini(
|
def test_hook_call(self, testdir, flag_on, hook_on):
|
||||||
"""
|
|
||||||
[pytest]
|
|
||||||
enable_assertion_pass_hook = True
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
testdir.makepyfile(
|
testdir.makepyfile(
|
||||||
"""
|
"""\
|
||||||
def test_simple():
|
def test_simple():
|
||||||
a=1
|
a=1
|
||||||
b=2
|
b=2
|
||||||
|
@ -1371,10 +1371,21 @@ class TestAssertionPass:
|
||||||
)
|
)
|
||||||
result = testdir.runpytest()
|
result = testdir.runpytest()
|
||||||
result.stdout.fnmatch_lines(
|
result.stdout.fnmatch_lines(
|
||||||
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
|
"*Assertion Passed: a+b == c+d (1 + 2) == (3 + 0) at line 7*"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch):
|
def test_hook_call_with_parens(self, testdir, flag_on, hook_on):
|
||||||
|
testdir.makepyfile(
|
||||||
|
"""\
|
||||||
|
def f(): return 1
|
||||||
|
def test():
|
||||||
|
assert f()
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = testdir.runpytest()
|
||||||
|
result.stdout.fnmatch_lines("*Assertion Passed: f() 1")
|
||||||
|
|
||||||
|
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch, flag_on):
|
||||||
"""Assertion pass should not be called (and hence formatting should
|
"""Assertion pass should not be called (and hence formatting should
|
||||||
not occur) if there is no hook declared for pytest_assertion_pass"""
|
not occur) if there is no hook declared for pytest_assertion_pass"""
|
||||||
|
|
||||||
|
@ -1385,15 +1396,8 @@ class TestAssertionPass:
|
||||||
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
|
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
|
||||||
)
|
)
|
||||||
|
|
||||||
testdir.makeini(
|
|
||||||
"""
|
|
||||||
[pytest]
|
|
||||||
enable_assertion_pass_hook = True
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
testdir.makepyfile(
|
testdir.makepyfile(
|
||||||
"""
|
"""\
|
||||||
def test_simple():
|
def test_simple():
|
||||||
a=1
|
a=1
|
||||||
b=2
|
b=2
|
||||||
|
@ -1418,21 +1422,14 @@ class TestAssertionPass:
|
||||||
)
|
)
|
||||||
|
|
||||||
testdir.makeconftest(
|
testdir.makeconftest(
|
||||||
"""
|
"""\
|
||||||
def pytest_assertion_pass(item, lineno, orig, expl):
|
def pytest_assertion_pass(item, lineno, orig, expl):
|
||||||
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
|
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
testdir.makeini(
|
|
||||||
"""
|
|
||||||
[pytest]
|
|
||||||
enable_assertion_pass_hook = False
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
testdir.makepyfile(
|
testdir.makepyfile(
|
||||||
"""
|
"""\
|
||||||
def test_simple():
|
def test_simple():
|
||||||
a=1
|
a=1
|
||||||
b=2
|
b=2
|
||||||
|
@ -1444,3 +1441,90 @@ class TestAssertionPass:
|
||||||
)
|
)
|
||||||
result = testdir.runpytest()
|
result = testdir.runpytest()
|
||||||
result.assert_outcomes(passed=1)
|
result.assert_outcomes(passed=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("src", "expected"),
|
||||||
|
(
|
||||||
|
# fmt: off
|
||||||
|
pytest.param(b"", {}, id="trivial"),
|
||||||
|
pytest.param(
|
||||||
|
b"def x(): assert 1\n",
|
||||||
|
{1: "1"},
|
||||||
|
id="assert statement not on own line",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
b"def x():\n"
|
||||||
|
b" assert 1\n"
|
||||||
|
b" assert 1+2\n",
|
||||||
|
{2: "1", 3: "1+2"},
|
||||||
|
id="multiple assertions",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
# changes in encoding cause the byte offsets to be different
|
||||||
|
"# -*- coding: latin1\n"
|
||||||
|
"def ÀÀÀÀÀ(): assert 1\n".encode("latin1"),
|
||||||
|
{2: "1"},
|
||||||
|
id="latin1 encoded on first line\n",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
# using the default utf-8 encoding
|
||||||
|
"def ÀÀÀÀÀ(): assert 1\n".encode(),
|
||||||
|
{1: "1"},
|
||||||
|
id="utf-8 encoded on first line",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
b"def x():\n"
|
||||||
|
b" assert (\n"
|
||||||
|
b" 1 + 2 # comment\n"
|
||||||
|
b" )\n",
|
||||||
|
{2: "(\n 1 + 2 # comment\n )"},
|
||||||
|
id="multi-line assertion",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
b"def x():\n"
|
||||||
|
b" assert y == [\n"
|
||||||
|
b" 1, 2, 3\n"
|
||||||
|
b" ]\n",
|
||||||
|
{2: "y == [\n 1, 2, 3\n ]"},
|
||||||
|
id="multi line assert with list continuation",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
b"def x():\n"
|
||||||
|
b" assert 1 + \\\n"
|
||||||
|
b" 2\n",
|
||||||
|
{2: "1 + \\\n 2"},
|
||||||
|
id="backslash continuation",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
b"def x():\n"
|
||||||
|
b" assert x, y\n",
|
||||||
|
{2: "x"},
|
||||||
|
id="assertion with message",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
b"def x():\n"
|
||||||
|
b" assert (\n"
|
||||||
|
b" f(1, 2, 3)\n"
|
||||||
|
b" ), 'f did not work!'\n",
|
||||||
|
{2: "(\n f(1, 2, 3)\n )"},
|
||||||
|
id="assertion with message, test spanning multiple lines",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
b"def x():\n"
|
||||||
|
b" assert \\\n"
|
||||||
|
b" x\\\n"
|
||||||
|
b" , 'failure message'\n",
|
||||||
|
{2: "x"},
|
||||||
|
id="escaped newlines plus message",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
b"def x(): assert 5",
|
||||||
|
{1: "5"},
|
||||||
|
id="no newline at end of file",
|
||||||
|
),
|
||||||
|
# fmt: on
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_get_assertion_exprs(src, expected):
|
||||||
|
assert _get_assertion_exprs(src) == expected
|
||||||
|
|
Loading…
Reference in New Issue