new assertion debugger which rewrites asserts before they are run
This commit is contained in:
parent
95ddd5059f
commit
c742e47de0
|
@ -0,0 +1,295 @@
|
||||||
|
"""Rewrite assertion AST to produce nice error messages"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import collections
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import py
|
||||||
|
|
||||||
|
|
||||||
|
def rewrite_asserts(mod):
|
||||||
|
"""Rewrite the assert statements in mod."""
|
||||||
|
AssertionRewriter().run(mod)
|
||||||
|
|
||||||
|
|
||||||
|
_saferepr = py.io.saferepr
|
||||||
|
_format_explanation = py.code._format_explanation
|
||||||
|
|
||||||
|
def _format_boolop(operands, explanations, is_or):
|
||||||
|
show_explanations = []
|
||||||
|
for operand, expl in zip(operands, explanations):
|
||||||
|
show_explanations.append(expl)
|
||||||
|
if operand == is_or:
|
||||||
|
break
|
||||||
|
return "(" + (is_or and " or " or " and ").join(show_explanations) + ")"
|
||||||
|
|
||||||
|
def _call_reprcompare(ops, results, expls, each_obj):
|
||||||
|
for i, res, expl in zip(range(len(ops)), results, expls):
|
||||||
|
if not res:
|
||||||
|
break
|
||||||
|
if py.code._reprcompare is not None:
|
||||||
|
custom = py.code._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
|
||||||
|
if custom is not None:
|
||||||
|
return custom
|
||||||
|
return expl
|
||||||
|
|
||||||
|
|
||||||
|
unary_map = {
|
||||||
|
ast.Not : "not %s",
|
||||||
|
ast.Invert : "~%s",
|
||||||
|
ast.USub : "-%s",
|
||||||
|
ast.UAdd : "+%s"
|
||||||
|
}
|
||||||
|
|
||||||
|
binop_map = {
|
||||||
|
ast.BitOr : "|",
|
||||||
|
ast.BitXor : "^",
|
||||||
|
ast.BitAnd : "&",
|
||||||
|
ast.LShift : "<<",
|
||||||
|
ast.RShift : ">>",
|
||||||
|
ast.Add : "+",
|
||||||
|
ast.Sub : "-",
|
||||||
|
ast.Mult : "*",
|
||||||
|
ast.Div : "/",
|
||||||
|
ast.FloorDiv : "//",
|
||||||
|
ast.Mod : "%",
|
||||||
|
ast.Eq : "==",
|
||||||
|
ast.NotEq : "!=",
|
||||||
|
ast.Lt : "<",
|
||||||
|
ast.LtE : "<=",
|
||||||
|
ast.Gt : ">",
|
||||||
|
ast.GtE : ">=",
|
||||||
|
ast.Pow : "**",
|
||||||
|
ast.Is : "is",
|
||||||
|
ast.IsNot : "is not",
|
||||||
|
ast.In : "in",
|
||||||
|
ast.NotIn : "not in"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AssertionRewriter(ast.NodeVisitor):
|
||||||
|
|
||||||
|
def run(self, mod):
|
||||||
|
"""Find all assert statements in *mod* and rewrite them."""
|
||||||
|
if not mod.body:
|
||||||
|
# Nothing to do.
|
||||||
|
return
|
||||||
|
# Insert some special imports at top but after any docstrings.
|
||||||
|
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
|
||||||
|
ast.alias("py", "@pylib"),
|
||||||
|
ast.alias("_pytest.assertrewrite", "@pytest_ar")]
|
||||||
|
imports = [ast.Import([alias], lineno=0, col_offset=0)
|
||||||
|
for alias in aliases]
|
||||||
|
pos = 0
|
||||||
|
if isinstance(mod.body[0], ast.Str):
|
||||||
|
pos = 1
|
||||||
|
mod.body[pos:pos] = imports
|
||||||
|
# Collect asserts.
|
||||||
|
asserts = []
|
||||||
|
nodes = collections.deque([mod])
|
||||||
|
while nodes:
|
||||||
|
node = nodes.popleft()
|
||||||
|
for name, field in ast.iter_fields(node):
|
||||||
|
if isinstance(field, list):
|
||||||
|
for i, child in enumerate(field):
|
||||||
|
if isinstance(child, ast.Assert):
|
||||||
|
asserts.append((field, i, child))
|
||||||
|
elif isinstance(child, ast.AST):
|
||||||
|
nodes.append(child)
|
||||||
|
elif (isinstance(field, ast.AST) and
|
||||||
|
# Don't recurse into expressions as they can't contain
|
||||||
|
# asserts.
|
||||||
|
not isinstance(field, ast.expr)):
|
||||||
|
nodes.append(field)
|
||||||
|
# Transform asserts.
|
||||||
|
for parent, pos, assert_ in asserts:
|
||||||
|
parent[pos:pos + 1] = self.visit(assert_)
|
||||||
|
|
||||||
|
def assign(self, expr):
|
||||||
|
"""Give *expr* a name."""
|
||||||
|
# Use a character invalid in python identifiers to avoid clashing.
|
||||||
|
name = "@py_assert" + str(next(self.variable_counter))
|
||||||
|
self.variables.add(name)
|
||||||
|
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
|
||||||
|
return ast.Name(name, ast.Load())
|
||||||
|
|
||||||
|
def display(self, expr):
|
||||||
|
"""Call py.io.saferepr on the expression."""
|
||||||
|
return self.helper("saferepr", expr)
|
||||||
|
|
||||||
|
def helper(self, name, *args):
|
||||||
|
"""Call a helper in this module."""
|
||||||
|
py_name = ast.Name("@pytest_ar", ast.Load())
|
||||||
|
attr = ast.Attribute(py_name, "_" + name, ast.Load())
|
||||||
|
return ast.Call(attr, list(args), [], None, None)
|
||||||
|
|
||||||
|
def builtin(self, name):
|
||||||
|
"""Return the builtin called *name*."""
|
||||||
|
builtin_name = ast.Name("@py_builtins", ast.Load())
|
||||||
|
return ast.Attribute(builtin_name, name, ast.Load())
|
||||||
|
|
||||||
|
def explanation_param(self, expr):
|
||||||
|
specifier = "py" + str(next(self.variable_counter))
|
||||||
|
self.explanation_specifiers[specifier] = expr
|
||||||
|
return "%(" + specifier + ")s"
|
||||||
|
|
||||||
|
def push_format_context(self):
|
||||||
|
self.explanation_specifiers = {}
|
||||||
|
self.stack.append(self.explanation_specifiers)
|
||||||
|
|
||||||
|
def pop_format_context(self, expl_expr):
|
||||||
|
current = self.stack.pop()
|
||||||
|
if self.stack:
|
||||||
|
self.explanation_specifiers = self.stack[-1]
|
||||||
|
keys = [ast.Str(key) for key in current.keys()]
|
||||||
|
format_dict = ast.Dict(keys, current.values())
|
||||||
|
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
|
||||||
|
name = "@py_format" + str(next(self.variable_counter))
|
||||||
|
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
||||||
|
return ast.Name(name, ast.Load())
|
||||||
|
|
||||||
|
def generic_visit(self, node):
|
||||||
|
"""Handle expressions we don't have custom code for."""
|
||||||
|
assert isinstance(node, ast.expr)
|
||||||
|
res = self.assign(node)
|
||||||
|
return res, self.explanation_param(self.display(res))
|
||||||
|
|
||||||
|
def visit_Assert(self, assert_):
|
||||||
|
if assert_.msg:
|
||||||
|
# There's already a message. Don't mess with it.
|
||||||
|
return [assert_]
|
||||||
|
self.statements = []
|
||||||
|
self.variables = set()
|
||||||
|
self.variable_counter = itertools.count()
|
||||||
|
self.stack = []
|
||||||
|
self.on_failure = []
|
||||||
|
self.push_format_context()
|
||||||
|
# Rewrite assert into a bunch of statements.
|
||||||
|
top_condition, explanation = self.visit(assert_.test)
|
||||||
|
# Create failure message.
|
||||||
|
body = self.on_failure
|
||||||
|
negation = ast.UnaryOp(ast.Not(), top_condition)
|
||||||
|
self.statements.append(ast.If(negation, body, []))
|
||||||
|
explanation = "assert " + explanation
|
||||||
|
template = ast.Str(explanation)
|
||||||
|
msg = self.pop_format_context(template)
|
||||||
|
fmt = self.helper("format_explanation", msg)
|
||||||
|
body.append(ast.Assert(top_condition, fmt))
|
||||||
|
# Delete temporary variables.
|
||||||
|
names = [ast.Name(name, ast.Del()) for name in self.variables]
|
||||||
|
if names:
|
||||||
|
delete = ast.Delete(names)
|
||||||
|
self.statements.append(delete)
|
||||||
|
# Fix line numbers.
|
||||||
|
for stmt in self.statements:
|
||||||
|
stmt.lineno = assert_.lineno
|
||||||
|
stmt.col_offset = assert_.col_offset
|
||||||
|
ast.fix_missing_locations(stmt)
|
||||||
|
return self.statements
|
||||||
|
|
||||||
|
def visit_Name(self, name):
|
||||||
|
# Check if the name is local or not.
|
||||||
|
locs = ast.Call(self.builtin("locals"), [], [], None, None)
|
||||||
|
globs = ast.Call(self.builtin("globals"), [], [], None, None)
|
||||||
|
ops = [ast.In(), ast.IsNot()]
|
||||||
|
test = ast.Compare(ast.Str(name.id), ops, [locs, globs])
|
||||||
|
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
|
||||||
|
return name, self.explanation_param(expr)
|
||||||
|
|
||||||
|
def visit_BoolOp(self, boolop):
|
||||||
|
operands = []
|
||||||
|
explanations = []
|
||||||
|
self.push_format_context()
|
||||||
|
for operand in boolop.values:
|
||||||
|
res, explanation = self.visit(operand)
|
||||||
|
operands.append(res)
|
||||||
|
explanations.append(explanation)
|
||||||
|
expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load())
|
||||||
|
is_or = ast.Num(isinstance(boolop.op, ast.Or))
|
||||||
|
expl_template = self.helper("format_boolop",
|
||||||
|
ast.Tuple(operands, ast.Load()), expls,
|
||||||
|
is_or)
|
||||||
|
expl = self.pop_format_context(expl_template)
|
||||||
|
res = self.assign(ast.BoolOp(boolop.op, operands))
|
||||||
|
return res, self.explanation_param(expl)
|
||||||
|
|
||||||
|
def visit_UnaryOp(self, unary):
|
||||||
|
pattern = unary_map[unary.op.__class__]
|
||||||
|
operand_res, operand_expl = self.visit(unary.operand)
|
||||||
|
res = self.assign(ast.UnaryOp(unary.op, operand_res))
|
||||||
|
return res, pattern % (operand_expl,)
|
||||||
|
|
||||||
|
def visit_BinOp(self, binop):
|
||||||
|
symbol = binop_map[binop.op.__class__]
|
||||||
|
left_expr, left_expl = self.visit(binop.left)
|
||||||
|
right_expr, right_expl = self.visit(binop.right)
|
||||||
|
explanation = "(%s %s %s)" % (left_expl, symbol, right_expl)
|
||||||
|
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
|
||||||
|
return res, explanation
|
||||||
|
|
||||||
|
def visit_Call(self, call):
|
||||||
|
new_func, func_expl = self.visit(call.func)
|
||||||
|
arg_expls = []
|
||||||
|
new_args = []
|
||||||
|
new_kwargs = []
|
||||||
|
new_star = new_kwarg = None
|
||||||
|
for arg in call.args:
|
||||||
|
res, expl = self.visit(arg)
|
||||||
|
new_args.append(res)
|
||||||
|
arg_expls.append(expl)
|
||||||
|
for keyword in call.keywords:
|
||||||
|
res, expl = self.visit(keyword.value)
|
||||||
|
new_kwargs.append(ast.keyword(keyword.arg, res))
|
||||||
|
arg_expls.append(keyword.arg + "=" + expl)
|
||||||
|
if call.starargs:
|
||||||
|
new_star, expl = self.visit(call.starargs)
|
||||||
|
arg_expls.append("*" + expl)
|
||||||
|
if call.kwargs:
|
||||||
|
new_kwarg, expl = self.visit(call.kwarg)
|
||||||
|
arg_expls.append("**" + expl)
|
||||||
|
expl = "%s(%s)" % (func_expl, ', '.join(arg_expls))
|
||||||
|
new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg)
|
||||||
|
res = self.assign(new_call)
|
||||||
|
res_expl = self.explanation_param(self.display(res))
|
||||||
|
outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
|
||||||
|
return res, outer_expl
|
||||||
|
|
||||||
|
def visit_Attribute(self, attr):
|
||||||
|
if not isinstance(attr.ctx, ast.Load):
|
||||||
|
return self.generic_visit(attr)
|
||||||
|
value, value_expl = self.visit(attr.value)
|
||||||
|
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
|
||||||
|
res_expl = self.explanation_param(self.display(res))
|
||||||
|
pat = "%s\n{%s = %s.%s\n}"
|
||||||
|
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
|
||||||
|
return res, expl
|
||||||
|
|
||||||
|
def visit_Compare(self, comp):
|
||||||
|
self.push_format_context()
|
||||||
|
left_res, left_expl = self.visit(comp.left)
|
||||||
|
res_variables = ["@py_assert" + str(next(self.variable_counter))
|
||||||
|
for i in range(len(comp.ops))]
|
||||||
|
load_names = [ast.Name(v, ast.Load()) for v in res_variables]
|
||||||
|
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
|
||||||
|
it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
|
||||||
|
expls = []
|
||||||
|
syms = []
|
||||||
|
results = [left_res]
|
||||||
|
for i, op, next_operand in it:
|
||||||
|
next_res, next_expl = self.visit(next_operand)
|
||||||
|
results.append(next_res)
|
||||||
|
sym = binop_map[op.__class__]
|
||||||
|
syms.append(ast.Str(sym))
|
||||||
|
expl = "%s %s %s" % (left_expl, sym, next_expl)
|
||||||
|
expls.append(ast.Str(expl))
|
||||||
|
res_expr = ast.Compare(left_res, [op], [next_res])
|
||||||
|
self.statements.append(ast.Assign([store_names[i]], res_expr))
|
||||||
|
left_res, left_expl = next_res, next_expl
|
||||||
|
# Use py.code._reprcompare if that's available.
|
||||||
|
expl_call = self.helper("call_reprcompare", ast.Tuple(syms, ast.Load()),
|
||||||
|
ast.Tuple(load_names, ast.Load()),
|
||||||
|
ast.Tuple(expls, ast.Load()),
|
||||||
|
ast.Tuple(results, ast.Load()))
|
||||||
|
args = [ast.List(load_names, ast.Load())]
|
||||||
|
res = ast.Call(self.builtin("all"), args, [], None, None)
|
||||||
|
return res, self.explanation_param(self.pop_format_context(expl_call))
|
|
@ -0,0 +1,195 @@
|
||||||
|
import sys
|
||||||
|
import py
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
ast = pytest.importorskip("ast")
|
||||||
|
|
||||||
|
from _pytest.assertrewrite import rewrite_asserts
|
||||||
|
|
||||||
|
|
||||||
|
def setup_module(mod):
|
||||||
|
mod._old_reprcompare = py.code._reprcompare
|
||||||
|
py.code._reprcompare = None
|
||||||
|
|
||||||
|
def teardown_module(mod):
|
||||||
|
py.code._reprcompare = mod._old_reprcompare
|
||||||
|
del mod._old_reprcompare
|
||||||
|
|
||||||
|
|
||||||
|
def getmsg(f, extra_ns=None, must_pass=False):
|
||||||
|
"""Rewrite the assertions in f, run it, and get the failure message."""
|
||||||
|
src = '\n'.join(py.code.Code(f).source().lines)
|
||||||
|
mod = ast.parse(src)
|
||||||
|
rewrite_asserts(mod)
|
||||||
|
code = compile(mod, "<test>", "exec")
|
||||||
|
ns = {}
|
||||||
|
if extra_ns is not None:
|
||||||
|
ns.update(extra_ns)
|
||||||
|
exec code in ns
|
||||||
|
func = ns[f.__name__]
|
||||||
|
try:
|
||||||
|
func()
|
||||||
|
except AssertionError:
|
||||||
|
if must_pass:
|
||||||
|
pytest.fail("shouldn't have raised")
|
||||||
|
s = str(sys.exc_info()[1])
|
||||||
|
if not s.startswith("assert"):
|
||||||
|
return "AssertionError: " + s
|
||||||
|
return s
|
||||||
|
else:
|
||||||
|
if not must_pass:
|
||||||
|
pytest.fail("function didn't raise at all")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssertionRewrite:
|
||||||
|
|
||||||
|
def test_name(self):
|
||||||
|
def f():
|
||||||
|
assert False
|
||||||
|
assert getmsg(f) == "assert False"
|
||||||
|
def f():
|
||||||
|
f = False
|
||||||
|
assert f
|
||||||
|
assert getmsg(f) == "assert False"
|
||||||
|
def f():
|
||||||
|
assert a_global
|
||||||
|
assert getmsg(f, {"a_global" : False}) == "assert a_global"
|
||||||
|
|
||||||
|
def test_assert_already_has_message(self):
|
||||||
|
def f():
|
||||||
|
assert False, "something bad!"
|
||||||
|
assert getmsg(f) == "AssertionError: something bad!"
|
||||||
|
|
||||||
|
def test_boolop(self):
|
||||||
|
def f():
|
||||||
|
f = g = False
|
||||||
|
assert f and g
|
||||||
|
assert getmsg(f) == "assert (False)"
|
||||||
|
def f():
|
||||||
|
f = True
|
||||||
|
g = False
|
||||||
|
assert f and g
|
||||||
|
assert getmsg(f) == "assert (True and False)"
|
||||||
|
def f():
|
||||||
|
f = False
|
||||||
|
g = True
|
||||||
|
assert f and g
|
||||||
|
assert getmsg(f) == "assert (False)"
|
||||||
|
def f():
|
||||||
|
f = g = False
|
||||||
|
assert f or g
|
||||||
|
assert getmsg(f) == "assert (False or False)"
|
||||||
|
def f():
|
||||||
|
f = True
|
||||||
|
g = False
|
||||||
|
assert f or g
|
||||||
|
getmsg(f, must_pass=True)
|
||||||
|
|
||||||
|
def test_short_circut_evaluation(self):
|
||||||
|
pytest.xfail("complicated fix; I'm not sure if it's important")
|
||||||
|
def f():
|
||||||
|
assert True or explode
|
||||||
|
getmsg(f, must_pass=True)
|
||||||
|
|
||||||
|
def test_unary_op(self):
|
||||||
|
def f():
|
||||||
|
x = True
|
||||||
|
assert not x
|
||||||
|
assert getmsg(f) == "assert not True"
|
||||||
|
def f():
|
||||||
|
x = 0
|
||||||
|
assert ~x + 1
|
||||||
|
assert getmsg(f) == "assert (~0 + 1)"
|
||||||
|
def f():
|
||||||
|
x = 3
|
||||||
|
assert -x + x
|
||||||
|
assert getmsg(f) == "assert (-3 + 3)"
|
||||||
|
def f():
|
||||||
|
x = 0
|
||||||
|
assert +x + x
|
||||||
|
assert getmsg(f) == "assert (+0 + 0)"
|
||||||
|
|
||||||
|
def test_binary_op(self):
|
||||||
|
def f():
|
||||||
|
x = 1
|
||||||
|
y = -1
|
||||||
|
assert x + y
|
||||||
|
assert getmsg(f) == "assert (1 + -1)"
|
||||||
|
|
||||||
|
def test_call(self):
|
||||||
|
def g(a=42, *args, **kwargs):
|
||||||
|
return False
|
||||||
|
ns = {"g" : g}
|
||||||
|
def f():
|
||||||
|
assert g()
|
||||||
|
assert getmsg(f, ns) == """assert False
|
||||||
|
+ where False = g()"""
|
||||||
|
def f():
|
||||||
|
assert g(1)
|
||||||
|
assert getmsg(f, ns) == """assert False
|
||||||
|
+ where False = g(1)"""
|
||||||
|
def f():
|
||||||
|
assert g(1, 2)
|
||||||
|
assert getmsg(f, ns) == """assert False
|
||||||
|
+ where False = g(1, 2)"""
|
||||||
|
def f():
|
||||||
|
assert g(1, g=42)
|
||||||
|
assert getmsg(f, ns) == """assert False
|
||||||
|
+ where False = g(1, g=42)"""
|
||||||
|
def f():
|
||||||
|
assert g(1, 3, g=23)
|
||||||
|
assert getmsg(f, ns) == """assert False
|
||||||
|
+ where False = g(1, 3, g=23)"""
|
||||||
|
|
||||||
|
def test_attribute(self):
|
||||||
|
class X(object):
|
||||||
|
g = 3
|
||||||
|
ns = {"X" : X, "x" : X()}
|
||||||
|
def f():
|
||||||
|
assert not x.g
|
||||||
|
assert getmsg(f, ns) == """assert not 3
|
||||||
|
+ where 3 = x.g"""
|
||||||
|
def f():
|
||||||
|
x.a = False
|
||||||
|
assert x.a
|
||||||
|
assert getmsg(f, ns) == """assert False
|
||||||
|
+ where False = x.a"""
|
||||||
|
|
||||||
|
def test_comparisons(self):
|
||||||
|
def f():
|
||||||
|
a, b = range(2)
|
||||||
|
assert b < a
|
||||||
|
assert getmsg(f) == """assert 1 < 0"""
|
||||||
|
def f():
|
||||||
|
a, b, c = range(3)
|
||||||
|
assert a > b > c
|
||||||
|
assert getmsg(f) == """assert 0 > 1"""
|
||||||
|
def f():
|
||||||
|
a, b, c = range(3)
|
||||||
|
assert a < b > c
|
||||||
|
assert getmsg(f) == """assert 1 > 2"""
|
||||||
|
def f():
|
||||||
|
a, b, c = range(3)
|
||||||
|
assert a < b <= c
|
||||||
|
getmsg(f, must_pass=True)
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
def f():
|
||||||
|
l = range(10)
|
||||||
|
assert len(l) == 11
|
||||||
|
assert getmsg(f).startswith("""assert 10 == 11
|
||||||
|
+ where 10 = len([""")
|
||||||
|
|
||||||
|
def test_custom_reprcompare(self, monkeypatch):
|
||||||
|
def my_reprcompare(op, left, right):
|
||||||
|
return "42"
|
||||||
|
monkeypatch.setattr(py.code, "_reprcompare", my_reprcompare)
|
||||||
|
def f():
|
||||||
|
assert 42 < 3
|
||||||
|
assert getmsg(f) == "assert 42"
|
||||||
|
def my_reprcompare(op, left, right):
|
||||||
|
return "%s %s %s" % (left, op, right)
|
||||||
|
monkeypatch.setattr(py.code, "_reprcompare", my_reprcompare)
|
||||||
|
def f():
|
||||||
|
assert 1 < 3 < 5 <= 4 < 7
|
||||||
|
assert getmsg(f) == "assert 5 <= 4"
|
Loading…
Reference in New Issue