Remove astor and reproduce the original assertion expression

This commit is contained in:
Anthony Sottile 2019-06-27 19:11:20 -07:00
parent 3c9b46f781
commit 7ee244476a
4 changed files with 196 additions and 54 deletions

View File

@ -1 +0,0 @@
pytest now also depends on the `astor <https://pypi.org/project/astor/>`__ package.

View File

@ -13,7 +13,6 @@ INSTALL_REQUIRES = [
"pluggy>=0.12,<1.0", "pluggy>=0.12,<1.0",
"importlib-metadata>=0.12", "importlib-metadata>=0.12",
"wcwidth", "wcwidth",
"astor",
] ]

View File

@ -1,16 +1,18 @@
"""Rewrite assertion AST to produce nice error messages""" """Rewrite assertion AST to produce nice error messages"""
import ast import ast
import errno import errno
import functools
import importlib.machinery import importlib.machinery
import importlib.util import importlib.util
import io
import itertools import itertools
import marshal import marshal
import os import os
import struct import struct
import sys import sys
import tokenize
import types import types
import astor
import atomicwrites import atomicwrites
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
@ -285,7 +287,7 @@ def _rewrite_test(fn, config):
with open(fn, "rb") as f: with open(fn, "rb") as f:
source = f.read() source = f.read()
tree = ast.parse(source, filename=fn) tree = ast.parse(source, filename=fn)
rewrite_asserts(tree, fn, config) rewrite_asserts(tree, source, fn, config)
co = compile(tree, fn, "exec", dont_inherit=True) co = compile(tree, fn, "exec", dont_inherit=True)
return stat, co return stat, co
@ -327,9 +329,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co return co
def rewrite_asserts(mod, module_path=None, config=None): def rewrite_asserts(mod, source, module_path=None, config=None):
"""Rewrite the assert statements in mod.""" """Rewrite the assert statements in mod."""
AssertionRewriter(module_path, config).run(mod) AssertionRewriter(module_path, config, source).run(mod)
def _saferepr(obj): def _saferepr(obj):
@ -457,6 +459,59 @@ def set_location(node, lineno, col_offset):
return node return node
def _get_assertion_exprs(src: bytes): # -> Dict[int, str]
"""Returns a mapping from {lineno: "assertion test expression"}"""
ret = {}
depth = 0
lines = []
assert_lineno = None
seen_lines = set()
def _write_and_reset() -> None:
nonlocal depth, lines, assert_lineno, seen_lines
ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
depth = 0
lines = []
assert_lineno = None
seen_lines = set()
tokens = tokenize.tokenize(io.BytesIO(src).readline)
for tp, src, (lineno, offset), _, line in tokens:
if tp == tokenize.NAME and src == "assert":
assert_lineno = lineno
elif assert_lineno is not None:
# keep track of depth for the assert-message `,` lookup
if tp == tokenize.OP and src in "([{":
depth += 1
elif tp == tokenize.OP and src in ")]}":
depth -= 1
if not lines:
lines.append(line[offset:])
seen_lines.add(lineno)
# a non-nested comma separates the expression from the message
elif depth == 0 and tp == tokenize.OP and src == ",":
# one line assert with message
if lineno in seen_lines and len(lines) == 1:
offset_in_trimmed = offset + len(lines[-1]) - len(line)
lines[-1] = lines[-1][:offset_in_trimmed]
# multi-line assert with message
elif lineno in seen_lines:
lines[-1] = lines[-1][:offset]
# multi line assert with escapd newline before message
else:
lines.append(line[:offset])
_write_and_reset()
elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
_write_and_reset()
elif lines and lineno not in seen_lines:
lines.append(line)
seen_lines.add(lineno)
return ret
class AssertionRewriter(ast.NodeVisitor): class AssertionRewriter(ast.NodeVisitor):
"""Assertion rewriting implementation. """Assertion rewriting implementation.
@ -511,7 +566,7 @@ class AssertionRewriter(ast.NodeVisitor):
""" """
def __init__(self, module_path, config): def __init__(self, module_path, config, source):
super().__init__() super().__init__()
self.module_path = module_path self.module_path = module_path
self.config = config self.config = config
@ -521,6 +576,11 @@ class AssertionRewriter(ast.NodeVisitor):
) )
else: else:
self.enable_assertion_pass_hook = False self.enable_assertion_pass_hook = False
self.source = source
@functools.lru_cache(maxsize=1)
def _assert_expr_to_lineno(self):
return _get_assertion_exprs(self.source)
def run(self, mod): def run(self, mod):
"""Find all assert statements in *mod* and rewrite them.""" """Find all assert statements in *mod* and rewrite them."""
@ -738,7 +798,7 @@ class AssertionRewriter(ast.NodeVisitor):
# Passed # Passed
fmt_pass = self.helper("_format_explanation", msg) fmt_pass = self.helper("_format_explanation", msg)
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")") orig = self._assert_expr_to_lineno()[assert_.lineno]
hook_call_pass = ast.Expr( hook_call_pass = ast.Expr(
self.helper( self.helper(
"_call_assertion_pass", "_call_assertion_pass",

View File

@ -13,6 +13,7 @@ import py
import _pytest._code import _pytest._code
import pytest import pytest
from _pytest.assertion import util from _pytest.assertion import util
from _pytest.assertion.rewrite import _get_assertion_exprs
from _pytest.assertion.rewrite import AssertionRewritingHook from _pytest.assertion.rewrite import AssertionRewritingHook
from _pytest.assertion.rewrite import PYTEST_TAG from _pytest.assertion.rewrite import PYTEST_TAG
from _pytest.assertion.rewrite import rewrite_asserts from _pytest.assertion.rewrite import rewrite_asserts
@ -31,7 +32,7 @@ def teardown_module(mod):
def rewrite(src): def rewrite(src):
tree = ast.parse(src) tree = ast.parse(src)
rewrite_asserts(tree) rewrite_asserts(tree, src.encode())
return tree return tree
@ -1292,10 +1293,10 @@ class TestEarlyRewriteBailout:
""" """
p = testdir.makepyfile( p = testdir.makepyfile(
**{ **{
"tests/file.py": """ "tests/file.py": """\
def test_simple_failure(): def test_simple_failure():
assert 1 + 1 == 3 assert 1 + 1 == 3
""" """
} }
) )
testdir.syspathinsert(p.dirpath()) testdir.syspathinsert(p.dirpath())
@ -1315,19 +1316,19 @@ class TestEarlyRewriteBailout:
testdir.makepyfile( testdir.makepyfile(
**{ **{
"test_setup_nonexisting_cwd.py": """ "test_setup_nonexisting_cwd.py": """\
import os import os
import shutil import shutil
import tempfile import tempfile
d = tempfile.mkdtemp() d = tempfile.mkdtemp()
os.chdir(d) os.chdir(d)
shutil.rmtree(d) shutil.rmtree(d)
""", """,
"test_test.py": """ "test_test.py": """\
def test(): def test():
pass pass
""", """,
} }
) )
result = testdir.runpytest() result = testdir.runpytest()
@ -1339,23 +1340,22 @@ class TestAssertionPass:
config = testdir.parseconfig() config = testdir.parseconfig()
assert config.getini("enable_assertion_pass_hook") is False assert config.getini("enable_assertion_pass_hook") is False
def test_hook_call(self, testdir): @pytest.fixture
def flag_on(self, testdir):
testdir.makeini("[pytest]\nenable_assertion_pass_hook = True\n")
@pytest.fixture
def hook_on(self, testdir):
testdir.makeconftest( testdir.makeconftest(
""" """\
def pytest_assertion_pass(item, lineno, orig, expl): def pytest_assertion_pass(item, lineno, orig, expl):
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno)) raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
""" """
) )
testdir.makeini( def test_hook_call(self, testdir, flag_on, hook_on):
"""
[pytest]
enable_assertion_pass_hook = True
"""
)
testdir.makepyfile( testdir.makepyfile(
""" """\
def test_simple(): def test_simple():
a=1 a=1
b=2 b=2
@ -1371,10 +1371,21 @@ class TestAssertionPass:
) )
result = testdir.runpytest() result = testdir.runpytest()
result.stdout.fnmatch_lines( result.stdout.fnmatch_lines(
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*" "*Assertion Passed: a+b == c+d (1 + 2) == (3 + 0) at line 7*"
) )
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch): def test_hook_call_with_parens(self, testdir, flag_on, hook_on):
testdir.makepyfile(
"""\
def f(): return 1
def test():
assert f()
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines("*Assertion Passed: f() 1")
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch, flag_on):
"""Assertion pass should not be called (and hence formatting should """Assertion pass should not be called (and hence formatting should
not occur) if there is no hook declared for pytest_assertion_pass""" not occur) if there is no hook declared for pytest_assertion_pass"""
@ -1385,15 +1396,8 @@ class TestAssertionPass:
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass _pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
) )
testdir.makeini(
"""
[pytest]
enable_assertion_pass_hook = True
"""
)
testdir.makepyfile( testdir.makepyfile(
""" """\
def test_simple(): def test_simple():
a=1 a=1
b=2 b=2
@ -1418,21 +1422,14 @@ class TestAssertionPass:
) )
testdir.makeconftest( testdir.makeconftest(
""" """\
def pytest_assertion_pass(item, lineno, orig, expl): def pytest_assertion_pass(item, lineno, orig, expl):
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno)) raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
""" """
) )
testdir.makeini(
"""
[pytest]
enable_assertion_pass_hook = False
"""
)
testdir.makepyfile( testdir.makepyfile(
""" """\
def test_simple(): def test_simple():
a=1 a=1
b=2 b=2
@ -1444,3 +1441,90 @@ class TestAssertionPass:
) )
result = testdir.runpytest() result = testdir.runpytest()
result.assert_outcomes(passed=1) result.assert_outcomes(passed=1)
@pytest.mark.parametrize(
("src", "expected"),
(
# fmt: off
pytest.param(b"", {}, id="trivial"),
pytest.param(
b"def x(): assert 1\n",
{1: "1"},
id="assert statement not on own line",
),
pytest.param(
b"def x():\n"
b" assert 1\n"
b" assert 1+2\n",
{2: "1", 3: "1+2"},
id="multiple assertions",
),
pytest.param(
# changes in encoding cause the byte offsets to be different
"# -*- coding: latin1\n"
"def ÀÀÀÀÀ(): assert 1\n".encode("latin1"),
{2: "1"},
id="latin1 encoded on first line\n",
),
pytest.param(
# using the default utf-8 encoding
"def ÀÀÀÀÀ(): assert 1\n".encode(),
{1: "1"},
id="utf-8 encoded on first line",
),
pytest.param(
b"def x():\n"
b" assert (\n"
b" 1 + 2 # comment\n"
b" )\n",
{2: "(\n 1 + 2 # comment\n )"},
id="multi-line assertion",
),
pytest.param(
b"def x():\n"
b" assert y == [\n"
b" 1, 2, 3\n"
b" ]\n",
{2: "y == [\n 1, 2, 3\n ]"},
id="multi line assert with list continuation",
),
pytest.param(
b"def x():\n"
b" assert 1 + \\\n"
b" 2\n",
{2: "1 + \\\n 2"},
id="backslash continuation",
),
pytest.param(
b"def x():\n"
b" assert x, y\n",
{2: "x"},
id="assertion with message",
),
pytest.param(
b"def x():\n"
b" assert (\n"
b" f(1, 2, 3)\n"
b" ), 'f did not work!'\n",
{2: "(\n f(1, 2, 3)\n )"},
id="assertion with message, test spanning multiple lines",
),
pytest.param(
b"def x():\n"
b" assert \\\n"
b" x\\\n"
b" , 'failure message'\n",
{2: "x"},
id="escaped newlines plus message",
),
pytest.param(
b"def x(): assert 5",
{1: "5"},
id="no newline at end of file",
),
# fmt: on
),
)
def test_get_assertion_exprs(src, expected):
assert _get_assertion_exprs(src) == expected