""" Find intermediate evalutation results in assert statements through builtin AST. This should replace oldinterpret.py eventually. """ import sys import ast import py from _pytest.assertion import util from _pytest.assertion.reinterpret import BuiltinAssertionError if sys.platform.startswith("java"): # See http://bugs.jython.org/issue1497 _exprs = ("BoolOp", "BinOp", "UnaryOp", "Lambda", "IfExp", "Dict", "ListComp", "GeneratorExp", "Yield", "Compare", "Call", "Repr", "Num", "Str", "Attribute", "Subscript", "Name", "List", "Tuple") _stmts = ("FunctionDef", "ClassDef", "Return", "Delete", "Assign", "AugAssign", "Print", "For", "While", "If", "With", "Raise", "TryExcept", "TryFinally", "Assert", "Import", "ImportFrom", "Exec", "Global", "Expr", "Pass", "Break", "Continue") _expr_nodes = set(getattr(ast, name) for name in _exprs) _stmt_nodes = set(getattr(ast, name) for name in _stmts) def _is_ast_expr(node): return node.__class__ in _expr_nodes def _is_ast_stmt(node): return node.__class__ in _stmt_nodes else: def _is_ast_expr(node): return isinstance(node, ast.expr) def _is_ast_stmt(node): return isinstance(node, ast.stmt) class Failure(Exception): """Error found while interpreting AST.""" def __init__(self, explanation=""): self.cause = sys.exc_info() self.explanation = explanation def interpret(source, frame, should_fail=False): mod = ast.parse(source) visitor = DebugInterpreter(frame) try: visitor.visit(mod) except Failure: failure = sys.exc_info()[1] return getfailure(failure) if should_fail: return ("(assertion failed, but when it was re-run for " "printing intermediate values, it did not fail. Suggestions: " "compute assert expression before the assert or use --assert=plain)") def run(offending_line, frame=None): if frame is None: frame = py.code.Frame(sys._getframe(1)) return interpret(offending_line, frame) def getfailure(e): explanation = util.format_explanation(e.explanation) value = e.cause[1] if str(value): lines = explanation.split('\n') lines[0] += " << %s" % (value,) explanation = '\n'.join(lines) text = "%s: %s" % (e.cause[0].__name__, explanation) if text.startswith('AssertionError: assert '): text = text[16:] return text operator_map = { ast.BitOr : "|", ast.BitXor : "^", ast.BitAnd : "&", ast.LShift : "<<", ast.RShift : ">>", ast.Add : "+", ast.Sub : "-", ast.Mult : "*", ast.Div : "/", ast.FloorDiv : "//", ast.Mod : "%", ast.Eq : "==", ast.NotEq : "!=", ast.Lt : "<", ast.LtE : "<=", ast.Gt : ">", ast.GtE : ">=", ast.Pow : "**", ast.Is : "is", ast.IsNot : "is not", ast.In : "in", ast.NotIn : "not in" } unary_map = { ast.Not : "not %s", ast.Invert : "~%s", ast.USub : "-%s", ast.UAdd : "+%s" } class DebugInterpreter(ast.NodeVisitor): """Interpret AST nodes to gleam useful debugging information. """ def __init__(self, frame): self.frame = frame def generic_visit(self, node): # Fallback when we don't have a special implementation. if _is_ast_expr(node): mod = ast.Expression(node) co = self._compile(mod) try: result = self.frame.eval(co) except Exception: raise Failure() explanation = self.frame.repr(result) return explanation, result elif _is_ast_stmt(node): mod = ast.Module([node]) co = self._compile(mod, "exec") try: self.frame.exec_(co) except Exception: raise Failure() return None, None else: raise AssertionError("can't handle %s" %(node,)) def _compile(self, source, mode="eval"): return compile(source, "", mode) def visit_Expr(self, expr): return self.visit(expr.value) def visit_Module(self, mod): for stmt in mod.body: self.visit(stmt) def visit_Name(self, name): explanation, result = self.generic_visit(name) # See if the name is local. source = "%r in locals() is not globals()" % (name.id,) co = self._compile(source) try: local = self.frame.eval(co) except Exception: # have to assume it isn't local = None if local is None or not self.frame.is_true(local): return name.id, result return explanation, result def visit_Compare(self, comp): left = comp.left left_explanation, left_result = self.visit(left) for op, next_op in zip(comp.ops, comp.comparators): next_explanation, next_result = self.visit(next_op) op_symbol = operator_map[op.__class__] explanation = "%s %s %s" % (left_explanation, op_symbol, next_explanation) source = "__exprinfo_left %s __exprinfo_right" % (op_symbol,) co = self._compile(source) try: result = self.frame.eval(co, __exprinfo_left=left_result, __exprinfo_right=next_result) except Exception: raise Failure(explanation) try: if not self.frame.is_true(result): break except KeyboardInterrupt: raise except: break left_explanation, left_result = next_explanation, next_result if util._reprcompare is not None: res = util._reprcompare(op_symbol, left_result, next_result) if res: explanation = res return explanation, result def visit_BoolOp(self, boolop): is_or = isinstance(boolop.op, ast.Or) explanations = [] for operand in boolop.values: explanation, result = self.visit(operand) explanations.append(explanation) if result == is_or: break name = is_or and " or " or " and " explanation = "(" + name.join(explanations) + ")" return explanation, result def visit_UnaryOp(self, unary): pattern = unary_map[unary.op.__class__] operand_explanation, operand_result = self.visit(unary.operand) explanation = pattern % (operand_explanation,) co = self._compile(pattern % ("__exprinfo_expr",)) try: result = self.frame.eval(co, __exprinfo_expr=operand_result) except Exception: raise Failure(explanation) return explanation, result def visit_BinOp(self, binop): left_explanation, left_result = self.visit(binop.left) right_explanation, right_result = self.visit(binop.right) symbol = operator_map[binop.op.__class__] explanation = "(%s %s %s)" % (left_explanation, symbol, right_explanation) source = "__exprinfo_left %s __exprinfo_right" % (symbol,) co = self._compile(source) try: result = self.frame.eval(co, __exprinfo_left=left_result, __exprinfo_right=right_result) except Exception: raise Failure(explanation) return explanation, result def visit_Call(self, call): func_explanation, func = self.visit(call.func) arg_explanations = [] ns = {"__exprinfo_func" : func} arguments = [] for arg in call.args: arg_explanation, arg_result = self.visit(arg) arg_name = "__exprinfo_%s" % (len(ns),) ns[arg_name] = arg_result arguments.append(arg_name) arg_explanations.append(arg_explanation) for keyword in call.keywords: arg_explanation, arg_result = self.visit(keyword.value) arg_name = "__exprinfo_%s" % (len(ns),) ns[arg_name] = arg_result keyword_source = "%s=%%s" % (keyword.arg) arguments.append(keyword_source % (arg_name,)) arg_explanations.append(keyword_source % (arg_explanation,)) if call.starargs: arg_explanation, arg_result = self.visit(call.starargs) arg_name = "__exprinfo_star" ns[arg_name] = arg_result arguments.append("*%s" % (arg_name,)) arg_explanations.append("*%s" % (arg_explanation,)) if call.kwargs: arg_explanation, arg_result = self.visit(call.kwargs) arg_name = "__exprinfo_kwds" ns[arg_name] = arg_result arguments.append("**%s" % (arg_name,)) arg_explanations.append("**%s" % (arg_explanation,)) args_explained = ", ".join(arg_explanations) explanation = "%s(%s)" % (func_explanation, args_explained) args = ", ".join(arguments) source = "__exprinfo_func(%s)" % (args,) co = self._compile(source) try: result = self.frame.eval(co, **ns) except Exception: raise Failure(explanation) pattern = "%s\n{%s = %s\n}" rep = self.frame.repr(result) explanation = pattern % (rep, rep, explanation) return explanation, result def _is_builtin_name(self, name): pattern = "%r not in globals() and %r not in locals()" source = pattern % (name.id, name.id) co = self._compile(source) try: return self.frame.eval(co) except Exception: return False def visit_Attribute(self, attr): if not isinstance(attr.ctx, ast.Load): return self.generic_visit(attr) source_explanation, source_result = self.visit(attr.value) explanation = "%s.%s" % (source_explanation, attr.attr) source = "__exprinfo_expr.%s" % (attr.attr,) co = self._compile(source) try: try: result = self.frame.eval(co, __exprinfo_expr=source_result) except AttributeError: # Maybe the attribute name needs to be mangled? if not attr.attr.startswith("__") or attr.attr.endswith("__"): raise source = "getattr(__exprinfo_expr.__class__, '__name__', '')" co = self._compile(source) class_name = self.frame.eval(co, __exprinfo_expr=source_result) mangled_attr = "_" + class_name + attr.attr source = "__exprinfo_expr.%s" % (mangled_attr,) co = self._compile(source) result = self.frame.eval(co, __exprinfo_expr=source_result) except Exception: raise Failure(explanation) explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result), self.frame.repr(result), source_explanation, attr.attr) # Check if the attr is from an instance. source = "%r in getattr(__exprinfo_expr, '__dict__', {})" source = source % (attr.attr,) co = self._compile(source) try: from_instance = self.frame.eval(co, __exprinfo_expr=source_result) except Exception: from_instance = None if from_instance is None or self.frame.is_true(from_instance): rep = self.frame.repr(result) pattern = "%s\n{%s = %s\n}" explanation = pattern % (rep, rep, explanation) return explanation, result def visit_Assert(self, assrt): test_explanation, test_result = self.visit(assrt.test) explanation = "assert %s" % (test_explanation,) if not self.frame.is_true(test_result): try: raise BuiltinAssertionError except Exception: raise Failure(explanation) return explanation, test_result def visit_Assign(self, assign): value_explanation, value_result = self.visit(assign.value) explanation = "... = %s" % (value_explanation,) name = ast.Name("__exprinfo_expr", ast.Load(), lineno=assign.value.lineno, col_offset=assign.value.col_offset) new_assign = ast.Assign(assign.targets, name, lineno=assign.lineno, col_offset=assign.col_offset) mod = ast.Module([new_assign]) co = self._compile(mod, "exec") try: self.frame.exec_(co, __exprinfo_expr=value_result) except Exception: raise Failure(explanation) return explanation, value_result