import assertion code from pylib

This commit is contained in:
Benjamin Peterson 2011-05-25 17:54:02 -05:00
parent 491c05cea7
commit f423ce9c01
8 changed files with 1344 additions and 22 deletions

View File

@ -31,15 +31,16 @@ def pytest_configure(config):
config._cleanup.append(m.undo) config._cleanup.append(m.undo)
warn_about_missing_assertion() warn_about_missing_assertion()
if not config.getvalue("noassert") and not config.getvalue("nomagic"): if not config.getvalue("noassert") and not config.getvalue("nomagic"):
from _pytest.assertion import reinterpret
def callbinrepr(op, left, right): def callbinrepr(op, left, right):
hook_result = config.hook.pytest_assertrepr_compare( hook_result = config.hook.pytest_assertrepr_compare(
config=config, op=op, left=left, right=right) config=config, op=op, left=left, right=right)
for new_expl in hook_result: for new_expl in hook_result:
if new_expl: if new_expl:
return '\n~'.join(new_expl) return '\n~'.join(new_expl)
m.setattr(py.builtin.builtins, m.setattr(py.builtin.builtins, 'AssertionError',
'AssertionError', py.code._AssertionError) reinterpret.AssertionError)
m.setattr(py.code, '_reprcompare', callbinrepr) m.setattr(sys.modules[__name__], '_reprcompare', callbinrepr)
else: else:
rewrite_asserts = None rewrite_asserts = None
@ -98,6 +99,53 @@ def warn_about_missing_assertion():
sys.stderr.write("WARNING: failing tests may report as passing because " sys.stderr.write("WARNING: failing tests may report as passing because "
"assertions are turned off! (are you using python -O?)\n") "assertions are turned off! (are you using python -O?)\n")
# if set, will be called by assert reinterp for comparison ops
_reprcompare = None
def _format_explanation(explanation):
"""This formats an explanation
Normally all embedded newlines are escaped, however there are
three exceptions: \n{, \n} and \n~. The first two are intended
cover nested explanations, see function and attribute explanations
for examples (.visit_Call(), visit_Attribute()). The last one is
for when one explanation needs to span multiple lines, e.g. when
displaying diffs.
"""
raw_lines = (explanation or '').split('\n')
# escape newlines not followed by {, } and ~
lines = [raw_lines[0]]
for l in raw_lines[1:]:
if l.startswith('{') or l.startswith('}') or l.startswith('~'):
lines.append(l)
else:
lines[-1] += '\\n' + l
result = lines[:1]
stack = [0]
stackcnt = [0]
for line in lines[1:]:
if line.startswith('{'):
if stackcnt[-1]:
s = 'and '
else:
s = 'where '
stack.append(len(result))
stackcnt[-1] += 1
stackcnt.append(0)
result.append(' +' + ' '*(len(stack)-1) + s + line[1:])
elif line.startswith('}'):
assert line.startswith('}')
stack.pop()
stackcnt.pop()
result[stack[-1]] += line[1:]
else:
assert line.startswith('~')
result.append(' '*len(stack) + line[1:])
assert len(stack) == 1
return '\n'.join(result)
# Provide basestring in python3 # Provide basestring in python3
try: try:
basestring = basestring basestring = basestring

View File

@ -0,0 +1,340 @@
"""
Find intermediate evalutation results in assert statements through builtin AST.
This should replace oldinterpret.py eventually.
"""
import sys
import ast
import py
from _pytest import assertion
from _pytest.assertion import _format_explanation
from _pytest.assertion.reinterpret import BuiltinAssertionError
if sys.platform.startswith("java") and sys.version_info < (2, 5, 2):
# See http://bugs.jython.org/issue1497
_exprs = ("BoolOp", "BinOp", "UnaryOp", "Lambda", "IfExp", "Dict",
"ListComp", "GeneratorExp", "Yield", "Compare", "Call",
"Repr", "Num", "Str", "Attribute", "Subscript", "Name",
"List", "Tuple")
_stmts = ("FunctionDef", "ClassDef", "Return", "Delete", "Assign",
"AugAssign", "Print", "For", "While", "If", "With", "Raise",
"TryExcept", "TryFinally", "Assert", "Import", "ImportFrom",
"Exec", "Global", "Expr", "Pass", "Break", "Continue")
_expr_nodes = set(getattr(ast, name) for name in _exprs)
_stmt_nodes = set(getattr(ast, name) for name in _stmts)
def _is_ast_expr(node):
return node.__class__ in _expr_nodes
def _is_ast_stmt(node):
return node.__class__ in _stmt_nodes
else:
def _is_ast_expr(node):
return isinstance(node, ast.expr)
def _is_ast_stmt(node):
return isinstance(node, ast.stmt)
class Failure(Exception):
"""Error found while interpreting AST."""
def __init__(self, explanation=""):
self.cause = sys.exc_info()
self.explanation = explanation
def interpret(source, frame, should_fail=False):
mod = ast.parse(source)
visitor = DebugInterpreter(frame)
try:
visitor.visit(mod)
except Failure:
failure = sys.exc_info()[1]
return getfailure(failure)
if should_fail:
return ("(assertion failed, but when it was re-run for "
"printing intermediate values, it did not fail. Suggestions: "
"compute assert expression before the assert or use --no-assert)")
def run(offending_line, frame=None):
if frame is None:
frame = py.code.Frame(sys._getframe(1))
return interpret(offending_line, frame)
def getfailure(failure):
explanation = _format_explanation(failure.explanation)
value = failure.cause[1]
if str(value):
lines = explanation.splitlines()
if not lines:
lines.append("")
lines[0] += " << %s" % (value,)
explanation = "\n".join(lines)
text = "%s: %s" % (failure.cause[0].__name__, explanation)
if text.startswith("AssertionError: assert "):
text = text[16:]
return text
operator_map = {
ast.BitOr : "|",
ast.BitXor : "^",
ast.BitAnd : "&",
ast.LShift : "<<",
ast.RShift : ">>",
ast.Add : "+",
ast.Sub : "-",
ast.Mult : "*",
ast.Div : "/",
ast.FloorDiv : "//",
ast.Mod : "%",
ast.Eq : "==",
ast.NotEq : "!=",
ast.Lt : "<",
ast.LtE : "<=",
ast.Gt : ">",
ast.GtE : ">=",
ast.Pow : "**",
ast.Is : "is",
ast.IsNot : "is not",
ast.In : "in",
ast.NotIn : "not in"
}
unary_map = {
ast.Not : "not %s",
ast.Invert : "~%s",
ast.USub : "-%s",
ast.UAdd : "+%s"
}
class DebugInterpreter(ast.NodeVisitor):
"""Interpret AST nodes to gleam useful debugging information. """
def __init__(self, frame):
self.frame = frame
def generic_visit(self, node):
# Fallback when we don't have a special implementation.
if _is_ast_expr(node):
mod = ast.Expression(node)
co = self._compile(mod)
try:
result = self.frame.eval(co)
except Exception:
raise Failure()
explanation = self.frame.repr(result)
return explanation, result
elif _is_ast_stmt(node):
mod = ast.Module([node])
co = self._compile(mod, "exec")
try:
self.frame.exec_(co)
except Exception:
raise Failure()
return None, None
else:
raise AssertionError("can't handle %s" %(node,))
def _compile(self, source, mode="eval"):
return compile(source, "<assertion interpretation>", mode)
def visit_Expr(self, expr):
return self.visit(expr.value)
def visit_Module(self, mod):
for stmt in mod.body:
self.visit(stmt)
def visit_Name(self, name):
explanation, result = self.generic_visit(name)
# See if the name is local.
source = "%r in locals() is not globals()" % (name.id,)
co = self._compile(source)
try:
local = self.frame.eval(co)
except Exception:
# have to assume it isn't
local = False
if not local:
return name.id, result
return explanation, result
def visit_Compare(self, comp):
left = comp.left
left_explanation, left_result = self.visit(left)
for op, next_op in zip(comp.ops, comp.comparators):
next_explanation, next_result = self.visit(next_op)
op_symbol = operator_map[op.__class__]
explanation = "%s %s %s" % (left_explanation, op_symbol,
next_explanation)
source = "__exprinfo_left %s __exprinfo_right" % (op_symbol,)
co = self._compile(source)
try:
result = self.frame.eval(co, __exprinfo_left=left_result,
__exprinfo_right=next_result)
except Exception:
raise Failure(explanation)
try:
if not result:
break
except KeyboardInterrupt:
raise
except:
break
left_explanation, left_result = next_explanation, next_result
if assertion._reprcompare is not None:
res = assertion._reprcompare(op_symbol, left_result, next_result)
if res:
explanation = res
return explanation, result
def visit_BoolOp(self, boolop):
is_or = isinstance(boolop.op, ast.Or)
explanations = []
for operand in boolop.values:
explanation, result = self.visit(operand)
explanations.append(explanation)
if result == is_or:
break
name = is_or and " or " or " and "
explanation = "(" + name.join(explanations) + ")"
return explanation, result
def visit_UnaryOp(self, unary):
pattern = unary_map[unary.op.__class__]
operand_explanation, operand_result = self.visit(unary.operand)
explanation = pattern % (operand_explanation,)
co = self._compile(pattern % ("__exprinfo_expr",))
try:
result = self.frame.eval(co, __exprinfo_expr=operand_result)
except Exception:
raise Failure(explanation)
return explanation, result
def visit_BinOp(self, binop):
left_explanation, left_result = self.visit(binop.left)
right_explanation, right_result = self.visit(binop.right)
symbol = operator_map[binop.op.__class__]
explanation = "(%s %s %s)" % (left_explanation, symbol,
right_explanation)
source = "__exprinfo_left %s __exprinfo_right" % (symbol,)
co = self._compile(source)
try:
result = self.frame.eval(co, __exprinfo_left=left_result,
__exprinfo_right=right_result)
except Exception:
raise Failure(explanation)
return explanation, result
def visit_Call(self, call):
func_explanation, func = self.visit(call.func)
arg_explanations = []
ns = {"__exprinfo_func" : func}
arguments = []
for arg in call.args:
arg_explanation, arg_result = self.visit(arg)
arg_name = "__exprinfo_%s" % (len(ns),)
ns[arg_name] = arg_result
arguments.append(arg_name)
arg_explanations.append(arg_explanation)
for keyword in call.keywords:
arg_explanation, arg_result = self.visit(keyword.value)
arg_name = "__exprinfo_%s" % (len(ns),)
ns[arg_name] = arg_result
keyword_source = "%s=%%s" % (keyword.arg)
arguments.append(keyword_source % (arg_name,))
arg_explanations.append(keyword_source % (arg_explanation,))
if call.starargs:
arg_explanation, arg_result = self.visit(call.starargs)
arg_name = "__exprinfo_star"
ns[arg_name] = arg_result
arguments.append("*%s" % (arg_name,))
arg_explanations.append("*%s" % (arg_explanation,))
if call.kwargs:
arg_explanation, arg_result = self.visit(call.kwargs)
arg_name = "__exprinfo_kwds"
ns[arg_name] = arg_result
arguments.append("**%s" % (arg_name,))
arg_explanations.append("**%s" % (arg_explanation,))
args_explained = ", ".join(arg_explanations)
explanation = "%s(%s)" % (func_explanation, args_explained)
args = ", ".join(arguments)
source = "__exprinfo_func(%s)" % (args,)
co = self._compile(source)
try:
result = self.frame.eval(co, **ns)
except Exception:
raise Failure(explanation)
pattern = "%s\n{%s = %s\n}"
rep = self.frame.repr(result)
explanation = pattern % (rep, rep, explanation)
return explanation, result
def _is_builtin_name(self, name):
pattern = "%r not in globals() and %r not in locals()"
source = pattern % (name.id, name.id)
co = self._compile(source)
try:
return self.frame.eval(co)
except Exception:
return False
def visit_Attribute(self, attr):
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
source_explanation, source_result = self.visit(attr.value)
explanation = "%s.%s" % (source_explanation, attr.attr)
source = "__exprinfo_expr.%s" % (attr.attr,)
co = self._compile(source)
try:
result = self.frame.eval(co, __exprinfo_expr=source_result)
except Exception:
raise Failure(explanation)
explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result),
self.frame.repr(result),
source_explanation, attr.attr)
# Check if the attr is from an instance.
source = "%r in getattr(__exprinfo_expr, '__dict__', {})"
source = source % (attr.attr,)
co = self._compile(source)
try:
from_instance = self.frame.eval(co, __exprinfo_expr=source_result)
except Exception:
from_instance = True
if from_instance:
rep = self.frame.repr(result)
pattern = "%s\n{%s = %s\n}"
explanation = pattern % (rep, rep, explanation)
return explanation, result
def visit_Assert(self, assrt):
test_explanation, test_result = self.visit(assrt.test)
if test_explanation.startswith("False\n{False =") and \
test_explanation.endswith("\n"):
test_explanation = test_explanation[15:-2]
explanation = "assert %s" % (test_explanation,)
if not test_result:
try:
raise BuiltinAssertionError
except Exception:
raise Failure(explanation)
return explanation, test_result
def visit_Assign(self, assign):
value_explanation, value_result = self.visit(assign.value)
explanation = "... = %s" % (value_explanation,)
name = ast.Name("__exprinfo_expr", ast.Load(),
lineno=assign.value.lineno,
col_offset=assign.value.col_offset)
new_assign = ast.Assign(assign.targets, name, lineno=assign.lineno,
col_offset=assign.col_offset)
mod = ast.Module([new_assign])
co = self._compile(mod, "exec")
try:
self.frame.exec_(co, __exprinfo_expr=value_result)
except Exception:
raise Failure(explanation)
return explanation, value_result

View File

@ -0,0 +1,556 @@
import py
import sys, inspect
from compiler import parse, ast, pycodegen
from _pytest.assertion import _format_explanation
from _pytest.assertion.reinterpret import BuiltinAssertionError
passthroughex = py.builtin._sysex
class Failure:
def __init__(self, node):
self.exc, self.value, self.tb = sys.exc_info()
self.node = node
class View(object):
"""View base class.
If C is a subclass of View, then C(x) creates a proxy object around
the object x. The actual class of the proxy is not C in general,
but a *subclass* of C determined by the rules below. To avoid confusion
we call view class the class of the proxy (a subclass of C, so of View)
and object class the class of x.
Attributes and methods not found in the proxy are automatically read on x.
Other operations like setting attributes are performed on the proxy, as
determined by its view class. The object x is available from the proxy
as its __obj__ attribute.
The view class selection is determined by the __view__ tuples and the
optional __viewkey__ method. By default, the selected view class is the
most specific subclass of C whose __view__ mentions the class of x.
If no such subclass is found, the search proceeds with the parent
object classes. For example, C(True) will first look for a subclass
of C with __view__ = (..., bool, ...) and only if it doesn't find any
look for one with __view__ = (..., int, ...), and then ..., object,...
If everything fails the class C itself is considered to be the default.
Alternatively, the view class selection can be driven by another aspect
of the object x, instead of the class of x, by overriding __viewkey__.
See last example at the end of this module.
"""
_viewcache = {}
__view__ = ()
def __new__(rootclass, obj, *args, **kwds):
self = object.__new__(rootclass)
self.__obj__ = obj
self.__rootclass__ = rootclass
key = self.__viewkey__()
try:
self.__class__ = self._viewcache[key]
except KeyError:
self.__class__ = self._selectsubclass(key)
return self
def __getattr__(self, attr):
# attributes not found in the normal hierarchy rooted on View
# are looked up in the object's real class
return getattr(self.__obj__, attr)
def __viewkey__(self):
return self.__obj__.__class__
def __matchkey__(self, key, subclasses):
if inspect.isclass(key):
keys = inspect.getmro(key)
else:
keys = [key]
for key in keys:
result = [C for C in subclasses if key in C.__view__]
if result:
return result
return []
def _selectsubclass(self, key):
subclasses = list(enumsubclasses(self.__rootclass__))
for C in subclasses:
if not isinstance(C.__view__, tuple):
C.__view__ = (C.__view__,)
choices = self.__matchkey__(key, subclasses)
if not choices:
return self.__rootclass__
elif len(choices) == 1:
return choices[0]
else:
# combine the multiple choices
return type('?', tuple(choices), {})
def __repr__(self):
return '%s(%r)' % (self.__rootclass__.__name__, self.__obj__)
def enumsubclasses(cls):
for subcls in cls.__subclasses__():
for subsubclass in enumsubclasses(subcls):
yield subsubclass
yield cls
class Interpretable(View):
"""A parse tree node with a few extra methods."""
explanation = None
def is_builtin(self, frame):
return False
def eval(self, frame):
# fall-back for unknown expression nodes
try:
expr = ast.Expression(self.__obj__)
expr.filename = '<eval>'
self.__obj__.filename = '<eval>'
co = pycodegen.ExpressionCodeGenerator(expr).getCode()
result = frame.eval(co)
except passthroughex:
raise
except:
raise Failure(self)
self.result = result
self.explanation = self.explanation or frame.repr(self.result)
def run(self, frame):
# fall-back for unknown statement nodes
try:
expr = ast.Module(None, ast.Stmt([self.__obj__]))
expr.filename = '<run>'
co = pycodegen.ModuleCodeGenerator(expr).getCode()
frame.exec_(co)
except passthroughex:
raise
except:
raise Failure(self)
def nice_explanation(self):
return _format_explanation(self.explanation)
class Name(Interpretable):
__view__ = ast.Name
def is_local(self, frame):
source = '%r in locals() is not globals()' % self.name
try:
return frame.is_true(frame.eval(source))
except passthroughex:
raise
except:
return False
def is_global(self, frame):
source = '%r in globals()' % self.name
try:
return frame.is_true(frame.eval(source))
except passthroughex:
raise
except:
return False
def is_builtin(self, frame):
source = '%r not in locals() and %r not in globals()' % (
self.name, self.name)
try:
return frame.is_true(frame.eval(source))
except passthroughex:
raise
except:
return False
def eval(self, frame):
super(Name, self).eval(frame)
if not self.is_local(frame):
self.explanation = self.name
class Compare(Interpretable):
__view__ = ast.Compare
def eval(self, frame):
expr = Interpretable(self.expr)
expr.eval(frame)
for operation, expr2 in self.ops:
if hasattr(self, 'result'):
# shortcutting in chained expressions
if not frame.is_true(self.result):
break
expr2 = Interpretable(expr2)
expr2.eval(frame)
self.explanation = "%s %s %s" % (
expr.explanation, operation, expr2.explanation)
source = "__exprinfo_left %s __exprinfo_right" % operation
try:
self.result = frame.eval(source,
__exprinfo_left=expr.result,
__exprinfo_right=expr2.result)
except passthroughex:
raise
except:
raise Failure(self)
expr = expr2
class And(Interpretable):
__view__ = ast.And
def eval(self, frame):
explanations = []
for expr in self.nodes:
expr = Interpretable(expr)
expr.eval(frame)
explanations.append(expr.explanation)
self.result = expr.result
if not frame.is_true(expr.result):
break
self.explanation = '(' + ' and '.join(explanations) + ')'
class Or(Interpretable):
__view__ = ast.Or
def eval(self, frame):
explanations = []
for expr in self.nodes:
expr = Interpretable(expr)
expr.eval(frame)
explanations.append(expr.explanation)
self.result = expr.result
if frame.is_true(expr.result):
break
self.explanation = '(' + ' or '.join(explanations) + ')'
# == Unary operations ==
keepalive = []
for astclass, astpattern in {
ast.Not : 'not __exprinfo_expr',
ast.Invert : '(~__exprinfo_expr)',
}.items():
class UnaryArith(Interpretable):
__view__ = astclass
def eval(self, frame, astpattern=astpattern):
expr = Interpretable(self.expr)
expr.eval(frame)
self.explanation = astpattern.replace('__exprinfo_expr',
expr.explanation)
try:
self.result = frame.eval(astpattern,
__exprinfo_expr=expr.result)
except passthroughex:
raise
except:
raise Failure(self)
keepalive.append(UnaryArith)
# == Binary operations ==
for astclass, astpattern in {
ast.Add : '(__exprinfo_left + __exprinfo_right)',
ast.Sub : '(__exprinfo_left - __exprinfo_right)',
ast.Mul : '(__exprinfo_left * __exprinfo_right)',
ast.Div : '(__exprinfo_left / __exprinfo_right)',
ast.Mod : '(__exprinfo_left % __exprinfo_right)',
ast.Power : '(__exprinfo_left ** __exprinfo_right)',
}.items():
class BinaryArith(Interpretable):
__view__ = astclass
def eval(self, frame, astpattern=astpattern):
left = Interpretable(self.left)
left.eval(frame)
right = Interpretable(self.right)
right.eval(frame)
self.explanation = (astpattern
.replace('__exprinfo_left', left .explanation)
.replace('__exprinfo_right', right.explanation))
try:
self.result = frame.eval(astpattern,
__exprinfo_left=left.result,
__exprinfo_right=right.result)
except passthroughex:
raise
except:
raise Failure(self)
keepalive.append(BinaryArith)
class CallFunc(Interpretable):
__view__ = ast.CallFunc
def is_bool(self, frame):
source = 'isinstance(__exprinfo_value, bool)'
try:
return frame.is_true(frame.eval(source,
__exprinfo_value=self.result))
except passthroughex:
raise
except:
return False
def eval(self, frame):
node = Interpretable(self.node)
node.eval(frame)
explanations = []
vars = {'__exprinfo_fn': node.result}
source = '__exprinfo_fn('
for a in self.args:
if isinstance(a, ast.Keyword):
keyword = a.name
a = a.expr
else:
keyword = None
a = Interpretable(a)
a.eval(frame)
argname = '__exprinfo_%d' % len(vars)
vars[argname] = a.result
if keyword is None:
source += argname + ','
explanations.append(a.explanation)
else:
source += '%s=%s,' % (keyword, argname)
explanations.append('%s=%s' % (keyword, a.explanation))
if self.star_args:
star_args = Interpretable(self.star_args)
star_args.eval(frame)
argname = '__exprinfo_star'
vars[argname] = star_args.result
source += '*' + argname + ','
explanations.append('*' + star_args.explanation)
if self.dstar_args:
dstar_args = Interpretable(self.dstar_args)
dstar_args.eval(frame)
argname = '__exprinfo_kwds'
vars[argname] = dstar_args.result
source += '**' + argname + ','
explanations.append('**' + dstar_args.explanation)
self.explanation = "%s(%s)" % (
node.explanation, ', '.join(explanations))
if source.endswith(','):
source = source[:-1]
source += ')'
try:
self.result = frame.eval(source, **vars)
except passthroughex:
raise
except:
raise Failure(self)
if not node.is_builtin(frame) or not self.is_bool(frame):
r = frame.repr(self.result)
self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation)
class Getattr(Interpretable):
__view__ = ast.Getattr
def eval(self, frame):
expr = Interpretable(self.expr)
expr.eval(frame)
source = '__exprinfo_expr.%s' % self.attrname
try:
self.result = frame.eval(source, __exprinfo_expr=expr.result)
except passthroughex:
raise
except:
raise Failure(self)
self.explanation = '%s.%s' % (expr.explanation, self.attrname)
# if the attribute comes from the instance, its value is interesting
source = ('hasattr(__exprinfo_expr, "__dict__") and '
'%r in __exprinfo_expr.__dict__' % self.attrname)
try:
from_instance = frame.is_true(
frame.eval(source, __exprinfo_expr=expr.result))
except passthroughex:
raise
except:
from_instance = True
if from_instance:
r = frame.repr(self.result)
self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation)
# == Re-interpretation of full statements ==
class Assert(Interpretable):
__view__ = ast.Assert
def run(self, frame):
test = Interpretable(self.test)
test.eval(frame)
# simplify 'assert False where False = ...'
if (test.explanation.startswith('False\n{False = ') and
test.explanation.endswith('\n}')):
test.explanation = test.explanation[15:-2]
# print the result as 'assert <explanation>'
self.result = test.result
self.explanation = 'assert ' + test.explanation
if not frame.is_true(test.result):
try:
raise BuiltinAssertionError
except passthroughex:
raise
except:
raise Failure(self)
class Assign(Interpretable):
__view__ = ast.Assign
def run(self, frame):
expr = Interpretable(self.expr)
expr.eval(frame)
self.result = expr.result
self.explanation = '... = ' + expr.explanation
# fall-back-run the rest of the assignment
ass = ast.Assign(self.nodes, ast.Name('__exprinfo_expr'))
mod = ast.Module(None, ast.Stmt([ass]))
mod.filename = '<run>'
co = pycodegen.ModuleCodeGenerator(mod).getCode()
try:
frame.exec_(co, __exprinfo_expr=expr.result)
except passthroughex:
raise
except:
raise Failure(self)
class Discard(Interpretable):
__view__ = ast.Discard
def run(self, frame):
expr = Interpretable(self.expr)
expr.eval(frame)
self.result = expr.result
self.explanation = expr.explanation
class Stmt(Interpretable):
__view__ = ast.Stmt
def run(self, frame):
for stmt in self.nodes:
stmt = Interpretable(stmt)
stmt.run(frame)
def report_failure(e):
explanation = e.node.nice_explanation()
if explanation:
explanation = ", in: " + explanation
else:
explanation = ""
sys.stdout.write("%s: %s%s\n" % (e.exc.__name__, e.value, explanation))
def check(s, frame=None):
if frame is None:
frame = sys._getframe(1)
frame = py.code.Frame(frame)
expr = parse(s, 'eval')
assert isinstance(expr, ast.Expression)
node = Interpretable(expr.node)
try:
node.eval(frame)
except passthroughex:
raise
except Failure:
e = sys.exc_info()[1]
report_failure(e)
else:
if not frame.is_true(node.result):
sys.stderr.write("assertion failed: %s\n" % node.nice_explanation())
###########################################################
# API / Entry points
# #########################################################
def interpret(source, frame, should_fail=False):
module = Interpretable(parse(source, 'exec').node)
#print "got module", module
if isinstance(frame, py.std.types.FrameType):
frame = py.code.Frame(frame)
try:
module.run(frame)
except Failure:
e = sys.exc_info()[1]
return getfailure(e)
except passthroughex:
raise
except:
import traceback
traceback.print_exc()
if should_fail:
return ("(assertion failed, but when it was re-run for "
"printing intermediate values, it did not fail. Suggestions: "
"compute assert expression before the assert or use --nomagic)")
else:
return None
def getmsg(excinfo):
if isinstance(excinfo, tuple):
excinfo = py.code.ExceptionInfo(excinfo)
#frame, line = gettbline(tb)
#frame = py.code.Frame(frame)
#return interpret(line, frame)
tb = excinfo.traceback[-1]
source = str(tb.statement).strip()
x = interpret(source, tb.frame, should_fail=True)
if not isinstance(x, str):
raise TypeError("interpret returned non-string %r" % (x,))
return x
def getfailure(e):
explanation = e.node.nice_explanation()
if str(e.value):
lines = explanation.split('\n')
lines[0] += " << %s" % (e.value,)
explanation = '\n'.join(lines)
text = "%s: %s" % (e.exc.__name__, explanation)
if text.startswith('AssertionError: assert '):
text = text[16:]
return text
def run(s, frame=None):
if frame is None:
frame = sys._getframe(1)
frame = py.code.Frame(frame)
module = Interpretable(parse(s, 'exec').node)
try:
module.run(frame)
except Failure:
e = sys.exc_info()[1]
report_failure(e)
if __name__ == '__main__':
# example:
def f():
return 5
def g():
return 3
def h(x):
return 'never'
check("f() * g() == 5")
check("not f()")
check("not (f() and g() or 0)")
check("f() == g()")
i = 4
check("i == f()")
check("len(f()) == 0")
check("isinstance(2+3+4, float)")
run("x = i")
check("x == 5")
run("assert not f(), 'oops'")
run("a, b, c = 1, 2")
run("a, b, c = f()")
check("max([f(),g()]) == 4")
check("'hello'[g()] == 'h'")
run("'guk%d' % h(f())")

View File

@ -0,0 +1,48 @@
import sys
import py
BuiltinAssertionError = py.builtin.builtins.AssertionError
class AssertionError(BuiltinAssertionError):
def __init__(self, *args):
BuiltinAssertionError.__init__(self, *args)
if args:
try:
self.msg = str(args[0])
except py.builtin._sysex:
raise
except:
self.msg = "<[broken __repr__] %s at %0xd>" %(
args[0].__class__, id(args[0]))
else:
f = py.code.Frame(sys._getframe(1))
try:
source = f.code.fullsource
if source is not None:
try:
source = source.getstatement(f.lineno, assertion=True)
except IndexError:
source = None
else:
source = str(source.deindent()).strip()
except py.error.ENOENT:
source = None
# this can also occur during reinterpretation, when the
# co_filename is set to "<run>".
if source:
self.msg = reinterpret(source, f, should_fail=True)
else:
self.msg = "<could not determine information>"
if not self.args:
self.args = (self.msg,)
if sys.version_info > (3, 0):
AssertionError.__module__ = "builtins"
reinterpret_old = "old reinterpretation not available for py3"
else:
from _pytest.assertion.oldinterpret import interpret as reinterpret_old
if sys.version_info >= (2, 6) or (sys.platform.startswith("java")):
from _pytest.assertion.newinterpret import interpret as reinterpret
else:
reinterpret = reinterpret_old

View File

@ -13,7 +13,6 @@ def rewrite_asserts(mod):
_saferepr = py.io.saferepr _saferepr = py.io.saferepr
_format_explanation = py.code._format_explanation
def _format_boolop(operands, explanations, is_or): def _format_boolop(operands, explanations, is_or):
show_explanations = [] show_explanations = []
@ -31,8 +30,9 @@ def _call_reprcompare(ops, results, expls, each_obj):
done = True done = True
if done: if done:
break break
if py.code._reprcompare is not None: from _pytest.assertion import _reprcompare
custom = py.code._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) if _reprcompare is not None:
custom = _reprcompare(ops[i], each_obj[i], each_obj[i + 1])
if custom is not None: if custom is not None:
return custom return custom
return expl return expl
@ -94,7 +94,7 @@ class AssertionRewriter(ast.NodeVisitor):
# Insert some special imports at the top of the module but after any # Insert some special imports at the top of the module but after any
# docstrings and __future__ imports. # docstrings and __future__ imports.
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
ast.alias("py", "@pylib"), ast.alias("_pytest.assertion", "@pytest_a"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar")] ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
expect_docstring = True expect_docstring = True
pos = 0 pos = 0
@ -153,11 +153,11 @@ class AssertionRewriter(ast.NodeVisitor):
def display(self, expr): def display(self, expr):
"""Call py.io.saferepr on the expression.""" """Call py.io.saferepr on the expression."""
return self.helper("saferepr", expr) return self.helper("ar", "saferepr", expr)
def helper(self, name, *args): def helper(self, mod, name, *args):
"""Call a helper in this module.""" """Call a helper in this module."""
py_name = ast.Name("@pytest_ar", ast.Load()) py_name = ast.Name("@pytest_" + mod, ast.Load())
attr = ast.Attribute(py_name, "_" + name, ast.Load()) attr = ast.Attribute(py_name, "_" + name, ast.Load())
return ast.Call(attr, list(args), [], None, None) return ast.Call(attr, list(args), [], None, None)
@ -211,7 +211,7 @@ class AssertionRewriter(ast.NodeVisitor):
explanation = "assert " + explanation explanation = "assert " + explanation
template = ast.Str(explanation) template = ast.Str(explanation)
msg = self.pop_format_context(template) msg = self.pop_format_context(template)
fmt = self.helper("format_explanation", msg) fmt = self.helper("a", "format_explanation", msg)
body.append(ast.Assert(top_condition, fmt)) body.append(ast.Assert(top_condition, fmt))
# Delete temporary variables. # Delete temporary variables.
names = [ast.Name(name, ast.Del()) for name in self.variables] names = [ast.Name(name, ast.Del()) for name in self.variables]
@ -242,7 +242,7 @@ class AssertionRewriter(ast.NodeVisitor):
explanations.append(explanation) explanations.append(explanation)
expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load()) expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load())
is_or = ast.Num(isinstance(boolop.op, ast.Or)) is_or = ast.Num(isinstance(boolop.op, ast.Or))
expl_template = self.helper("format_boolop", expl_template = self.helper("ar", "format_boolop",
ast.Tuple(operands, ast.Load()), expls, ast.Tuple(operands, ast.Load()), expls,
is_or) is_or)
expl = self.pop_format_context(expl_template) expl = self.pop_format_context(expl_template)
@ -321,7 +321,8 @@ class AssertionRewriter(ast.NodeVisitor):
self.statements.append(ast.Assign([store_names[i]], res_expr)) self.statements.append(ast.Assign([store_names[i]], res_expr))
left_res, left_expl = next_res, next_expl left_res, left_expl = next_res, next_expl
# Use py.code._reprcompare if that's available. # Use py.code._reprcompare if that's available.
expl_call = self.helper("call_reprcompare", ast.Tuple(syms, ast.Load()), expl_call = self.helper("ar", "call_reprcompare",
ast.Tuple(syms, ast.Load()),
ast.Tuple(load_names, ast.Load()), ast.Tuple(load_names, ast.Load()),
ast.Tuple(expls, ast.Load()), ast.Tuple(expls, ast.Load()),
ast.Tuple(results, ast.Load())) ast.Tuple(results, ast.Load()))

View File

@ -0,0 +1,327 @@
"PYTEST_DONT_REWRITE"
import pytest, py
from _pytest import assertion
def exvalue():
return py.std.sys.exc_info()[1]
def f():
return 2
def test_not_being_rewritten():
assert "@py_builtins" not in globals()
def test_assert():
try:
assert f() == 3
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith('assert 2 == 3\n')
def test_assert_with_explicit_message():
try:
assert f() == 3, "hello"
except AssertionError:
e = exvalue()
assert e.msg == 'hello'
def test_assert_within_finally():
class A:
def f():
pass
excinfo = py.test.raises(TypeError, """
try:
A().f()
finally:
i = 42
""")
s = excinfo.exconly()
assert s.find("takes no argument") != -1
#def g():
# A.f()
#excinfo = getexcinfo(TypeError, g)
#msg = getmsg(excinfo)
#assert msg.find("must be called with A") != -1
def test_assert_multiline_1():
try:
assert (f() ==
3)
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith('assert 2 == 3\n')
def test_assert_multiline_2():
try:
assert (f() == (4,
3)[-1])
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith('assert 2 ==')
def test_in():
try:
assert "hi" in [1, 2]
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith("assert 'hi' in")
def test_is():
try:
assert 1 is 2
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith("assert 1 is 2")
@py.test.mark.skipif("sys.version_info < (2,6)")
def test_attrib():
class Foo(object):
b = 1
i = Foo()
try:
assert i.b == 2
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith("assert 1 == 2")
@py.test.mark.skipif("sys.version_info < (2,6)")
def test_attrib_inst():
class Foo(object):
b = 1
try:
assert Foo().b == 2
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith("assert 1 == 2")
def test_len():
l = list(range(42))
try:
assert len(l) == 100
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith("assert 42 == 100")
assert "where 42 = len([" in s
def test_assert_non_string_message():
class A:
def __str__(self):
return "hello"
try:
assert 0 == 1, A()
except AssertionError:
e = exvalue()
assert e.msg == "hello"
def test_assert_keyword_arg():
def f(x=3):
return False
try:
assert f(x=5)
except AssertionError:
e = exvalue()
assert "x=5" in e.msg
# These tests should both fail, but should fail nicely...
class WeirdRepr:
def __repr__(self):
return '<WeirdRepr\nsecond line>'
def bug_test_assert_repr():
v = WeirdRepr()
try:
assert v == 1
except AssertionError:
e = exvalue()
assert e.msg.find('WeirdRepr') != -1
assert e.msg.find('second line') != -1
assert 0
def test_assert_non_string():
try:
assert 0, ['list']
except AssertionError:
e = exvalue()
assert e.msg.find("list") != -1
def test_assert_implicit_multiline():
try:
x = [1,2,3]
assert x != [1,
2, 3]
except AssertionError:
e = exvalue()
assert e.msg.find('assert [1, 2, 3] !=') != -1
def test_assert_with_brokenrepr_arg():
class BrokenRepr:
def __repr__(self): 0 / 0
e = AssertionError(BrokenRepr())
if e.msg.find("broken __repr__") == -1:
py.test.fail("broken __repr__ not handle correctly")
def test_multiple_statements_per_line():
try:
a = 1; assert a == 2
except AssertionError:
e = exvalue()
assert "assert 1 == 2" in e.msg
def test_power():
try:
assert 2**3 == 7
except AssertionError:
e = exvalue()
assert "assert (2 ** 3) == 7" in e.msg
class TestView:
def setup_class(cls):
cls.View = pytest.importorskip("_pytest.assertion.oldinterpret").View
def test_class_dispatch(self):
### Use a custom class hierarchy with existing instances
class Picklable(self.View):
pass
class Simple(Picklable):
__view__ = object
def pickle(self):
return repr(self.__obj__)
class Seq(Picklable):
__view__ = list, tuple, dict
def pickle(self):
return ';'.join(
[Picklable(item).pickle() for item in self.__obj__])
class Dict(Seq):
__view__ = dict
def pickle(self):
return Seq.pickle(self) + '!' + Seq(self.values()).pickle()
assert Picklable(123).pickle() == '123'
assert Picklable([1,[2,3],4]).pickle() == '1;2;3;4'
assert Picklable({1:2}).pickle() == '1!2'
def test_viewtype_class_hierarchy(self):
# Use a custom class hierarchy based on attributes of existing instances
class Operation:
"Existing class that I don't want to change."
def __init__(self, opname, *args):
self.opname = opname
self.args = args
existing = [Operation('+', 4, 5),
Operation('getitem', '', 'join'),
Operation('setattr', 'x', 'y', 3),
Operation('-', 12, 1)]
class PyOp(self.View):
def __viewkey__(self):
return self.opname
def generate(self):
return '%s(%s)' % (self.opname, ', '.join(map(repr, self.args)))
class PyBinaryOp(PyOp):
__view__ = ('+', '-', '*', '/')
def generate(self):
return '%s %s %s' % (self.args[0], self.opname, self.args[1])
codelines = [PyOp(op).generate() for op in existing]
assert codelines == ["4 + 5", "getitem('', 'join')",
"setattr('x', 'y', 3)", "12 - 1"]
@py.test.mark.skipif("sys.version_info < (2,6)")
def test_assert_customizable_reprcompare(monkeypatch):
monkeypatch.setattr(assertion, '_reprcompare', lambda *args: 'hello')
try:
assert 3 == 4
except AssertionError:
e = exvalue()
s = str(e)
assert "hello" in s
def test_assert_long_source_1():
try:
assert len == [
(None, ['somet text', 'more text']),
]
except AssertionError:
e = exvalue()
s = str(e)
assert 're-run' not in s
assert 'somet text' in s
def test_assert_long_source_2():
try:
assert(len == [
(None, ['somet text', 'more text']),
])
except AssertionError:
e = exvalue()
s = str(e)
assert 're-run' not in s
assert 'somet text' in s
def test_assert_raise_alias(testdir):
testdir.makepyfile("""
"PYTEST_DONT_REWRITE"
import sys
EX = AssertionError
def test_hello():
raise EX("hello"
"multi"
"line")
""")
result = testdir.runpytest()
result.stdout.fnmatch_lines([
"*def test_hello*",
"*raise EX*",
"*1 failed*",
])
@pytest.mark.skipif("sys.version_info < (2,5)")
def test_assert_raise_subclass():
class SomeEx(AssertionError):
def __init__(self, *args):
super(SomeEx, self).__init__()
try:
raise SomeEx("hello")
except AssertionError:
s = str(exvalue())
assert 're-run' not in s
assert 'could not determine' in s
def test_assert_raises_in_nonzero_of_object_pytest_issue10():
class A(object):
def __nonzero__(self):
raise ValueError(42)
def __lt__(self, other):
return A()
def __repr__(self):
return "<MY42 object>"
def myany(x):
return True
try:
assert not(myany(A() < 0))
except AssertionError:
e = exvalue()
s = str(e)
assert "<MY42 object> < 0" in s

View File

@ -2,11 +2,12 @@ import sys
import py, pytest import py, pytest
import _pytest.assertion as plugin import _pytest.assertion as plugin
from _pytest.assertion import reinterpret
needsnewassert = pytest.mark.skipif("sys.version_info < (2,6)") needsnewassert = pytest.mark.skipif("sys.version_info < (2,6)")
def interpret(expr): def interpret(expr):
return py.code._reinterpret(expr, py.code.Frame(sys._getframe(1))) return reinterpret.reinterpret(expr, py.code.Frame(sys._getframe(1)))
class TestBinReprIntegration: class TestBinReprIntegration:
pytestmark = needsnewassert pytestmark = needsnewassert
@ -25,7 +26,7 @@ class TestBinReprIntegration:
self.right = right self.right = right
mockhook = MockHook() mockhook = MockHook()
monkeypatch = request.getfuncargvalue("monkeypatch") monkeypatch = request.getfuncargvalue("monkeypatch")
monkeypatch.setattr(py.code, '_reprcompare', mockhook) monkeypatch.setattr(plugin, '_reprcompare', mockhook)
return mockhook return mockhook
def test_pytest_assertrepr_compare_called(self, hook): def test_pytest_assertrepr_compare_called(self, hook):
@ -40,13 +41,13 @@ class TestBinReprIntegration:
assert hook.right == [0, 2] assert hook.right == [0, 2]
def test_configure_unconfigure(self, testdir, hook): def test_configure_unconfigure(self, testdir, hook):
assert hook == py.code._reprcompare assert hook == plugin._reprcompare
config = testdir.parseconfig() config = testdir.parseconfig()
plugin.pytest_configure(config) plugin.pytest_configure(config)
assert hook != py.code._reprcompare assert hook != plugin._reprcompare
from _pytest.config import pytest_unconfigure from _pytest.config import pytest_unconfigure
pytest_unconfigure(config) pytest_unconfigure(config)
assert hook == py.code._reprcompare assert hook == plugin._reprcompare
def callequal(left, right): def callequal(left, right):
return plugin.pytest_assertrepr_compare('==', left, right) return plugin.pytest_assertrepr_compare('==', left, right)

View File

@ -4,15 +4,16 @@ import pytest
ast = pytest.importorskip("ast") ast = pytest.importorskip("ast")
from _pytest import assertion
from _pytest.assertion.rewrite import rewrite_asserts from _pytest.assertion.rewrite import rewrite_asserts
def setup_module(mod): def setup_module(mod):
mod._old_reprcompare = py.code._reprcompare mod._old_reprcompare = assertion._reprcompare
py.code._reprcompare = None py.code._reprcompare = None
def teardown_module(mod): def teardown_module(mod):
py.code._reprcompare = mod._old_reprcompare assertion._reprcompare = mod._old_reprcompare
del mod._old_reprcompare del mod._old_reprcompare
@ -229,13 +230,13 @@ class TestAssertionRewrite:
def test_custom_reprcompare(self, monkeypatch): def test_custom_reprcompare(self, monkeypatch):
def my_reprcompare(op, left, right): def my_reprcompare(op, left, right):
return "42" return "42"
monkeypatch.setattr(py.code, "_reprcompare", my_reprcompare) monkeypatch.setattr(assertion, "_reprcompare", my_reprcompare)
def f(): def f():
assert 42 < 3 assert 42 < 3
assert getmsg(f) == "assert 42" assert getmsg(f) == "assert 42"
def my_reprcompare(op, left, right): def my_reprcompare(op, left, right):
return "%s %s %s" % (left, op, right) return "%s %s %s" % (left, op, right)
monkeypatch.setattr(py.code, "_reprcompare", my_reprcompare) monkeypatch.setattr(assertion, "_reprcompare", my_reprcompare)
def f(): def f():
assert 1 < 3 < 5 <= 4 < 7 assert 1 < 3 < 5 <= 4 < 7
assert getmsg(f) == "assert 5 <= 4" assert getmsg(f) == "assert 5 <= 4"