implement assert debugging with builtin AST

--HG--
branch : trunk
This commit is contained in:
Benjamin Peterson 2009-08-28 18:44:20 -05:00
parent 3bdbb29c6f
commit e0e9953be2
2 changed files with 349 additions and 33 deletions

View File

@ -97,6 +97,42 @@ def enumsubclasses(cls):
yield cls
def _format_explanation(explanation):
# uck! See CallFunc for where \n{ and \n} escape sequences are used
raw_lines = (explanation or '').split('\n')
# escape newlines not followed by { and }
lines = [raw_lines[0]]
for l in raw_lines[1:]:
if l.startswith('{') or l.startswith('}'):
lines.append(l)
else:
lines[-1] += '\\n' + l
result = lines[:1]
stack = [0]
stackcnt = [0]
for line in lines[1:]:
if line.startswith('{'):
if stackcnt[-1]:
s = 'and '
else:
s = 'where '
stack.append(len(result))
stackcnt[-1] += 1
stackcnt.append(0)
result.append(' +' + ' '*(len(stack)-1) + s + line[1:])
else:
assert line.startswith('}')
stack.pop()
stackcnt.pop()
result[stack[-1]] += line[1:]
assert len(stack) == 1
return '\n'.join(result)
class Interpretable(View):
"""A parse tree node with a few extra methods."""
explanation = None
@ -132,36 +168,8 @@ class Interpretable(View):
raise Failure(self)
def nice_explanation(self):
# uck! See CallFunc for where \n{ and \n} escape sequences are used
raw_lines = (self.explanation or '').split('\n')
# escape newlines not followed by { and }
lines = [raw_lines[0]]
for l in raw_lines[1:]:
if l.startswith('{') or l.startswith('}'):
lines.append(l)
else:
lines[-1] += '\\n' + l
return _format_explanation(self.explanation)
result = lines[:1]
stack = [0]
stackcnt = [0]
for line in lines[1:]:
if line.startswith('{'):
if stackcnt[-1]:
s = 'and '
else:
s = 'where '
stack.append(len(result))
stackcnt[-1] += 1
stackcnt.append(0)
result.append(' +' + ' '*(len(stack)-1) + s + line[1:])
else:
assert line.startswith('}')
stack.pop()
stackcnt.pop()
result[stack[-1]] += line[1:]
assert len(stack) == 1
return '\n'.join(result)
class Name(Interpretable):
__view__ = ast.Name
@ -571,16 +579,20 @@ class AssertionError(BuiltinAssertionError):
args[0].__class__, id(args[0]))
else:
f = sys._getframe(1)
f = py.code.Frame(sys._getframe(1))
try:
source = py.code.Frame(f).statement
source = f.statement
source = str(source.deindent()).strip()
except py.error.ENOENT:
source = None
# this can also occur during reinterpretation, when the
# co_filename is set to "<run>".
if source:
self.msg = interpret(source, f, should_fail=True)
if sys.version_info >= (2, 6):
from py.__.code._assertionnew import interpret as do_interp
else:
do_interp = interpret
self.msg = do_interp(source, f, should_fail=True)
if not self.args:
self.args = (self.msg,)
else:

304
py/code/_assertionnew.py Normal file
View File

@ -0,0 +1,304 @@
"""
Like _assertion.py but using builtin AST. It should replace _assertion.py
eventually.
"""
import sys
import ast
import py
from py.__.code._assertion import _format_explanation, BuiltinAssertionError
class Failure(Exception):
"""Error found while interpreting AST."""
def __init__(self, explanation=""):
self.cause = sys.exc_info()
self.explanation = explanation
def interpret(source, frame, should_fail=False):
mod = ast.parse(source)
visitor = DebugInterpreter(frame)
try:
visitor.visit(mod)
except Failure as failure:
return getfailure(failure)
if should_fail:
return ("(assertion failed, but when it was re-run for "
"printing intermediate values, it did not fail. Suggestions: "
"compute assert expression before the assert or use --nomagic)")
def run(offending_line, frame=None):
if frame is None:
frame = py.code.Frame(sys._getframe(1))
return interpret(offending_line, frame)
def getfailure(failure):
explanation = _format_explanation(failure.explanation)
value = failure.cause[1]
if str(value):
lines = explanation.splitlines()
lines[0] += " << {0}".format(value)
explanation = "\n".join(lines)
text = "{0}: {1}".format(failure.cause[0].__name__, explanation)
if text.startswith("AssertionError: assert "):
text = text[16:]
return text
operator_map = {
ast.BitOr : "|",
ast.BitXor : "^",
ast.BitAnd : "&",
ast.LShift : "<<",
ast.RShift : ">>",
ast.Add : "+",
ast.Sub : "-",
ast.Mult : "*",
ast.Div : "/",
ast.FloorDiv : "//",
ast.Mod : "%",
ast.Eq : "==",
ast.NotEq : "!=",
ast.Lt : "<",
ast.LtE : "<=",
ast.Gt : ">",
ast.GtE : ">=",
}
unary_map = {
ast.Not : "not {0}",
ast.Invert : "~{0}",
ast.USub : "-{0}",
ast.UAdd : "+{0}"
}
class DebugInterpreter(ast.NodeVisitor):
"""Interpret AST nodes to gleam useful debugging information."""
def __init__(self, frame):
self.frame = frame
def generic_visit(self, node):
# Fallback when we don't have a special implementation.
if isinstance(node, ast.expr):
mod = ast.Expression(node)
co = self._compile(mod)
try:
result = self.frame.eval(co)
except Exception:
raise Failure()
explanation = self.frame.repr(result)
return explanation, result
elif isinstance(node, ast.stmt):
mod = ast.Module([node])
co = self._compile(mod, "exec")
try:
frame.exec_(co)
except Exception:
raise Failure()
return None, None
else:
raise AssertionError("can't handle {0}".format(node))
def _compile(self, source, mode="eval"):
return compile(source, "<assertion interpretation>", mode)
def visit_Module(self, mod):
for stmt in mod.body:
self.visit(stmt)
def visit_Name(self, name):
explanation, result = self.generic_visit(name)
# See if the name is local.
source = "{0!r} in locals() is not globals()".format(name.id)
co = self._compile(source)
try:
local = self.frame.eval(co)
except Exception, e:
# have to assume it isn't
local = False
if not local:
return name.id, result
return explanation, result
def visit_Compare(self, comp):
left = comp.left
left_explanation, left_result = self.visit(left)
got_result = False
for op, next_op in zip(comp.ops, comp.comparators):
if got_result and not result:
break
next_explanation, next_result = self.visit(next_op)
op_symbol = operator_map[op.__class__]
explanation = "{0} {1} {2}".format(left_explanation, op_symbol,
next_explanation)
source = "__exprinfo_left {0} __exprinfo_right".format(op_symbol)
co = self._compile(source)
try:
result = self.frame.eval(co, __exprinfo_left=left_result,
__exprinfo_right=next_result)
except Exception:
raise Failure(explanation)
else:
got_result = True
left_explanation, left_result = next_explanation, next_result
return explanation, result
def visit_BoolOp(self, boolop):
is_or = isinstance(boolop.op, ast.Or)
explanations = []
for operand in boolop.values:
explanation, result = self.visit(operand)
explanations.append(explanation)
if result == is_or:
break
name = " or " if is_or else " and "
explanation = "(" + name.join(explanations) + ")"
return explanation, result
def visit_UnaryOp(self, unary):
pattern = unary_map[unary.op.__class__]
operand_explanation, operand_result = self.visit(unary.operand)
explanation = pattern.format(operand_explanation)
co = self._compile(pattern.format("__exprinfo_expr"))
try:
result = self.frame.eval(co, __exprinfo_expr=operand_result)
except Exception:
raise Failure(explanation)
return explanation, result
def visit_BinOp(self, binop):
left_explanation, left_result = self.visit(binop.left)
right_explanation, right_result = self.visit(binop.right)
symbol = operator_map[binop.op.__class__]
explanation = "{0} {1} {2}".format(left_explanation, symbol,
right_explanation)
source = "__exprinfo_left {0} __exprinfo_right".format(symbol)
co = self._compile(source)
try:
result = self.frame.eval(co, __exprinfo_left=left_result,
__exprinfo_right=right_result)
except Exception:
raise Failure(explanation)
return explanation, result
def visit_Call(self, call):
func_explanation, func = self.visit(call.func)
arg_explanations = []
ns = {"__exprinfo_func" : func}
arguments = []
for arg in call.args:
arg_explanation, arg_result = self.visit(arg)
arg_name = "__exprinfo_{0}".format(len(ns))
ns[arg_name] = arg_result
arguments.append(arg_name)
arg_explanations.append(arg_explanation)
for keyword in call.keywords:
arg_explanation, arg_result = self.visit(keyword.value)
arg_name = "__exprinfo_{0}".format(len(ns))
ns[arg_name] = arg_result
keyword_source = "{0}={{0}}".format(keyword.id)
arguments.append(keyword_source.format(arg_name))
arg_explanations.append(keyword_source.format(arg_explanation))
if call.starargs:
arg_explanation, arg_result = self.visit(call.starargs)
arg_name = "__exprinfo_star"
ns[arg_name] = arg_result
arguments.append("*{0}".format(arg_name))
arg_explanations.append("*{0}".format(arg_explanation))
if call.kwargs:
arg_explanation, arg_result = self.visit(call.kwargs)
arg_name = "__exprinfo_kwds"
ns[arg_name] = arg_result
arguments.append("**{0}".format(arg_name))
arg_explanations.append("**{0}".format(arg_explanation))
args_explained = ", ".join(arg_explanations)
explanation = "{0}({1})".format(func_explanation, args_explained)
args = ", ".join(arguments)
source = "__exprinfo_func({0})".format(args)
co = self._compile(source)
try:
result = self.frame.eval(co, **ns)
except Exception:
raise Failure(explanation)
# Only show result explanation if it's not a builtin call or returns a
# bool.
if not isinstance(call.func, ast.Name) or \
not self._is_builtin_name(call.func):
source = "isinstance(__exprinfo_value, bool)"
co = self._compile(source)
try:
is_bool = self.frame.eval(co, __exprinfo_value=result)
except Exception:
is_bool = False
if not is_bool:
pattern = "{0}\n{{{0} = {1}\n}}"
rep = self.frame.repr(result)
explanation = pattern.format(rep, explanation)
return explanation, result
def _is_builtin_name(self, name):
pattern = "{0!r} not in globals() and {0!r} not in locals()"
source = pattern.format(name.id)
co = self._compile(source)
try:
return self.frame.eval(co)
except Exception:
return False
def visit_Attribute(self, attr):
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
source_explanation, source_result = self.visit(attr.value)
explanation = "{0}.{1}".format(source_explanation, attr.attr)
source = "__exprinfo_expr.{0}".format(attr.attr)
co = self._compile(source)
try:
result = self.frame.eval(co, __exprinfo_expr=source_result)
except Exception:
raise Failure(explanation)
# Check if the attr is from an instance.
source = "{0!r} in getattr(__exprinfo_expr, '__dict__', {{}})"
source = source.format(attr.attr)
co = self._compile(source)
try:
from_instance = self.frame.eval(co, __exprinfo_expr=source_result)
except Exception:
from_instance = True
if from_instance:
rep = self.frame.repr(result)
pattern = "{0}\n{{{0} = {1}\n}}"
explanation = pattern.format(rep, explanation)
return explanation, result
def visit_Assert(self, assrt):
test_explanation, test_result = self.visit(assrt.test)
if test_explanation.startswith("False\n{False =") and \
test_explanation.endswith("\n"):
test_explanation = test_explanation[15:-2]
explanation = "assert {0}".format(test_explanation)
if not test_result:
try:
raise BuiltinAssertionError
except Exception:
raise Failure(explanation)
return explanation, test_result
def visit_Assign(self, assign):
value_explanation, value_result = self.visit(assign.value)
explanation = "... = {0}".format(value_explanation)
name = ast.Name("__exprinfo_expr", ast.Load(), assign.value.lineno,
assign.value.col_offset)
new_assign = ast.Assign(assign.targets, name, assign.lineno,
assign.col_offset)
mod = ast.Module([new_assign])
co = self._compile(mod, "exec")
try:
self.frame.exec_(co, __exprinfo_expr=value_result)
except Exception:
raise Failure(explanation)
return explanation, value_result