merge Benjamin's assertion-rewrite branch: all assertion related code is now part of py.test core distribution - the builtin assertion plugin to be precise.

See doc/assert.txt for details on how what has been improved.
This commit is contained in:
holger krekel 2011-05-31 14:11:53 +02:00
commit 5690beab5a
21 changed files with 2246 additions and 121 deletions

View File

@ -1,6 +1,10 @@
Changes between 2.0.3 and DEV Changes between 2.0.3 and 2.1.0.DEV
---------------------------------------------- ----------------------------------------------
- merge Benjamin's assertionrewrite branch: now assertions
for test modules on python 2.6 and above are done by rewriting
the AST and saving the pyc file before the test module is imported.
see doc/assert.txt for more info.
- fix issue43: improve doctests with better traceback reporting on - fix issue43: improve doctests with better traceback reporting on
unexpected exceptions unexpected exceptions
- fix issue47: timing output in junitxml for test cases is now correct - fix issue47: timing output in junitxml for test cases is now correct

View File

@ -1,2 +1,2 @@
# #
__version__ = '2.1.0.dev1' __version__ = '2.1.0.dev2'

View File

@ -0,0 +1,128 @@
"""
support for presenting detailed information in failing assertions.
"""
import py
import imp
import marshal
import struct
import sys
import pytest
from _pytest.monkeypatch import monkeypatch
from _pytest.assertion import reinterpret, util
try:
from _pytest.assertion.rewrite import rewrite_asserts
except ImportError:
rewrite_asserts = None
else:
import ast
def pytest_addoption(parser):
group = parser.getgroup("debugconfig")
group.addoption('--assertmode', action="store", dest="assertmode",
choices=("on", "old", "off", "default"), default="default",
metavar="on|old|off",
help="""control assertion debugging tools.
'off' performs no assertion debugging.
'old' reinterprets the expressions in asserts to glean information.
'on' (the default) rewrites the assert statements in test modules to provide
sub-expression results.""")
group.addoption('--no-assert', action="store_true", default=False,
dest="noassert", help="DEPRECATED equivalent to --assertmode=off")
group.addoption('--nomagic', action="store_true", default=False,
dest="nomagic", help="DEPRECATED equivalent to --assertmode=off")
class AssertionState:
"""State for the assertion plugin."""
def __init__(self, config, mode):
self.mode = mode
self.trace = config.trace.root.get("assertion")
def pytest_configure(config):
warn_about_missing_assertion()
mode = config.getvalue("assertmode")
if config.getvalue("noassert") or config.getvalue("nomagic"):
if mode not in ("off", "default"):
raise pytest.UsageError("assertion options conflict")
mode = "off"
elif mode == "default":
mode = "on"
if mode != "off":
def callbinrepr(op, left, right):
hook_result = config.hook.pytest_assertrepr_compare(
config=config, op=op, left=left, right=right)
for new_expl in hook_result:
if new_expl:
return '\n~'.join(new_expl)
m = monkeypatch()
config._cleanup.append(m.undo)
m.setattr(py.builtin.builtins, 'AssertionError',
reinterpret.AssertionError)
m.setattr(util, '_reprcompare', callbinrepr)
if mode == "on" and rewrite_asserts is None:
mode = "old"
config._assertstate = AssertionState(config, mode)
config._assertstate.trace("configured with mode set to %r" % (mode,))
def _write_pyc(co, source_path):
if hasattr(imp, "cache_from_source"):
# Handle PEP 3147 pycs.
pyc = py.path(imp.cache_from_source(source_math))
pyc.dirname.ensure(dir=True)
else:
pyc = source_path + "c"
mtime = int(source_path.mtime())
fp = pyc.open("wb")
try:
fp.write(imp.get_magic())
fp.write(struct.pack("<l", mtime))
marshal.dump(co, fp)
finally:
fp.close()
return pyc
def before_module_import(mod):
if mod.config._assertstate.mode != "on":
return
# Some deep magic: load the source, rewrite the asserts, and write a
# fake pyc, so that it'll be loaded when the module is imported.
source = mod.fspath.read()
try:
tree = ast.parse(source)
except SyntaxError:
# Let this pop up again in the real import.
mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,))
return
rewrite_asserts(tree)
try:
co = compile(tree, str(mod.fspath), "exec")
except SyntaxError:
# It's possible that this error is from some bug in the assertion
# rewriting, but I don't know of a fast way to tell.
mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,))
return
mod._pyc = _write_pyc(co, mod.fspath)
mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,))
def after_module_import(mod):
if not hasattr(mod, "_pyc"):
return
state = mod.config._assertstate
try:
mod._pyc.remove()
except py.error.ENOENT:
state.trace("couldn't find pyc: %r" % (mod._pyc,))
else:
state.trace("removed pyc: %r" % (mod._pyc,))
def warn_about_missing_assertion():
try:
assert False
except AssertionError:
pass
else:
sys.stderr.write("WARNING: failing tests may report as passing because "
"assertions are turned off! (are you using python -O?)\n")
pytest_assertrepr_compare = util.assertrepr_compare

View File

@ -0,0 +1,339 @@
"""
Find intermediate evalutation results in assert statements through builtin AST.
This should replace oldinterpret.py eventually.
"""
import sys
import ast
import py
from _pytest.assertion import util
from _pytest.assertion.reinterpret import BuiltinAssertionError
if sys.platform.startswith("java") and sys.version_info < (2, 5, 2):
# See http://bugs.jython.org/issue1497
_exprs = ("BoolOp", "BinOp", "UnaryOp", "Lambda", "IfExp", "Dict",
"ListComp", "GeneratorExp", "Yield", "Compare", "Call",
"Repr", "Num", "Str", "Attribute", "Subscript", "Name",
"List", "Tuple")
_stmts = ("FunctionDef", "ClassDef", "Return", "Delete", "Assign",
"AugAssign", "Print", "For", "While", "If", "With", "Raise",
"TryExcept", "TryFinally", "Assert", "Import", "ImportFrom",
"Exec", "Global", "Expr", "Pass", "Break", "Continue")
_expr_nodes = set(getattr(ast, name) for name in _exprs)
_stmt_nodes = set(getattr(ast, name) for name in _stmts)
def _is_ast_expr(node):
return node.__class__ in _expr_nodes
def _is_ast_stmt(node):
return node.__class__ in _stmt_nodes
else:
def _is_ast_expr(node):
return isinstance(node, ast.expr)
def _is_ast_stmt(node):
return isinstance(node, ast.stmt)
class Failure(Exception):
"""Error found while interpreting AST."""
def __init__(self, explanation=""):
self.cause = sys.exc_info()
self.explanation = explanation
def interpret(source, frame, should_fail=False):
mod = ast.parse(source)
visitor = DebugInterpreter(frame)
try:
visitor.visit(mod)
except Failure:
failure = sys.exc_info()[1]
return getfailure(failure)
if should_fail:
return ("(assertion failed, but when it was re-run for "
"printing intermediate values, it did not fail. Suggestions: "
"compute assert expression before the assert or use --no-assert)")
def run(offending_line, frame=None):
if frame is None:
frame = py.code.Frame(sys._getframe(1))
return interpret(offending_line, frame)
def getfailure(failure):
explanation = util.format_explanation(failure.explanation)
value = failure.cause[1]
if str(value):
lines = explanation.splitlines()
if not lines:
lines.append("")
lines[0] += " << %s" % (value,)
explanation = "\n".join(lines)
text = "%s: %s" % (failure.cause[0].__name__, explanation)
if text.startswith("AssertionError: assert "):
text = text[16:]
return text
operator_map = {
ast.BitOr : "|",
ast.BitXor : "^",
ast.BitAnd : "&",
ast.LShift : "<<",
ast.RShift : ">>",
ast.Add : "+",
ast.Sub : "-",
ast.Mult : "*",
ast.Div : "/",
ast.FloorDiv : "//",
ast.Mod : "%",
ast.Eq : "==",
ast.NotEq : "!=",
ast.Lt : "<",
ast.LtE : "<=",
ast.Gt : ">",
ast.GtE : ">=",
ast.Pow : "**",
ast.Is : "is",
ast.IsNot : "is not",
ast.In : "in",
ast.NotIn : "not in"
}
unary_map = {
ast.Not : "not %s",
ast.Invert : "~%s",
ast.USub : "-%s",
ast.UAdd : "+%s"
}
class DebugInterpreter(ast.NodeVisitor):
"""Interpret AST nodes to gleam useful debugging information. """
def __init__(self, frame):
self.frame = frame
def generic_visit(self, node):
# Fallback when we don't have a special implementation.
if _is_ast_expr(node):
mod = ast.Expression(node)
co = self._compile(mod)
try:
result = self.frame.eval(co)
except Exception:
raise Failure()
explanation = self.frame.repr(result)
return explanation, result
elif _is_ast_stmt(node):
mod = ast.Module([node])
co = self._compile(mod, "exec")
try:
self.frame.exec_(co)
except Exception:
raise Failure()
return None, None
else:
raise AssertionError("can't handle %s" %(node,))
def _compile(self, source, mode="eval"):
return compile(source, "<assertion interpretation>", mode)
def visit_Expr(self, expr):
return self.visit(expr.value)
def visit_Module(self, mod):
for stmt in mod.body:
self.visit(stmt)
def visit_Name(self, name):
explanation, result = self.generic_visit(name)
# See if the name is local.
source = "%r in locals() is not globals()" % (name.id,)
co = self._compile(source)
try:
local = self.frame.eval(co)
except Exception:
# have to assume it isn't
local = False
if not local:
return name.id, result
return explanation, result
def visit_Compare(self, comp):
left = comp.left
left_explanation, left_result = self.visit(left)
for op, next_op in zip(comp.ops, comp.comparators):
next_explanation, next_result = self.visit(next_op)
op_symbol = operator_map[op.__class__]
explanation = "%s %s %s" % (left_explanation, op_symbol,
next_explanation)
source = "__exprinfo_left %s __exprinfo_right" % (op_symbol,)
co = self._compile(source)
try:
result = self.frame.eval(co, __exprinfo_left=left_result,
__exprinfo_right=next_result)
except Exception:
raise Failure(explanation)
try:
if not result:
break
except KeyboardInterrupt:
raise
except:
break
left_explanation, left_result = next_explanation, next_result
if util._reprcompare is not None:
res = util._reprcompare(op_symbol, left_result, next_result)
if res:
explanation = res
return explanation, result
def visit_BoolOp(self, boolop):
is_or = isinstance(boolop.op, ast.Or)
explanations = []
for operand in boolop.values:
explanation, result = self.visit(operand)
explanations.append(explanation)
if result == is_or:
break
name = is_or and " or " or " and "
explanation = "(" + name.join(explanations) + ")"
return explanation, result
def visit_UnaryOp(self, unary):
pattern = unary_map[unary.op.__class__]
operand_explanation, operand_result = self.visit(unary.operand)
explanation = pattern % (operand_explanation,)
co = self._compile(pattern % ("__exprinfo_expr",))
try:
result = self.frame.eval(co, __exprinfo_expr=operand_result)
except Exception:
raise Failure(explanation)
return explanation, result
def visit_BinOp(self, binop):
left_explanation, left_result = self.visit(binop.left)
right_explanation, right_result = self.visit(binop.right)
symbol = operator_map[binop.op.__class__]
explanation = "(%s %s %s)" % (left_explanation, symbol,
right_explanation)
source = "__exprinfo_left %s __exprinfo_right" % (symbol,)
co = self._compile(source)
try:
result = self.frame.eval(co, __exprinfo_left=left_result,
__exprinfo_right=right_result)
except Exception:
raise Failure(explanation)
return explanation, result
def visit_Call(self, call):
func_explanation, func = self.visit(call.func)
arg_explanations = []
ns = {"__exprinfo_func" : func}
arguments = []
for arg in call.args:
arg_explanation, arg_result = self.visit(arg)
arg_name = "__exprinfo_%s" % (len(ns),)
ns[arg_name] = arg_result
arguments.append(arg_name)
arg_explanations.append(arg_explanation)
for keyword in call.keywords:
arg_explanation, arg_result = self.visit(keyword.value)
arg_name = "__exprinfo_%s" % (len(ns),)
ns[arg_name] = arg_result
keyword_source = "%s=%%s" % (keyword.arg)
arguments.append(keyword_source % (arg_name,))
arg_explanations.append(keyword_source % (arg_explanation,))
if call.starargs:
arg_explanation, arg_result = self.visit(call.starargs)
arg_name = "__exprinfo_star"
ns[arg_name] = arg_result
arguments.append("*%s" % (arg_name,))
arg_explanations.append("*%s" % (arg_explanation,))
if call.kwargs:
arg_explanation, arg_result = self.visit(call.kwargs)
arg_name = "__exprinfo_kwds"
ns[arg_name] = arg_result
arguments.append("**%s" % (arg_name,))
arg_explanations.append("**%s" % (arg_explanation,))
args_explained = ", ".join(arg_explanations)
explanation = "%s(%s)" % (func_explanation, args_explained)
args = ", ".join(arguments)
source = "__exprinfo_func(%s)" % (args,)
co = self._compile(source)
try:
result = self.frame.eval(co, **ns)
except Exception:
raise Failure(explanation)
pattern = "%s\n{%s = %s\n}"
rep = self.frame.repr(result)
explanation = pattern % (rep, rep, explanation)
return explanation, result
def _is_builtin_name(self, name):
pattern = "%r not in globals() and %r not in locals()"
source = pattern % (name.id, name.id)
co = self._compile(source)
try:
return self.frame.eval(co)
except Exception:
return False
def visit_Attribute(self, attr):
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
source_explanation, source_result = self.visit(attr.value)
explanation = "%s.%s" % (source_explanation, attr.attr)
source = "__exprinfo_expr.%s" % (attr.attr,)
co = self._compile(source)
try:
result = self.frame.eval(co, __exprinfo_expr=source_result)
except Exception:
raise Failure(explanation)
explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result),
self.frame.repr(result),
source_explanation, attr.attr)
# Check if the attr is from an instance.
source = "%r in getattr(__exprinfo_expr, '__dict__', {})"
source = source % (attr.attr,)
co = self._compile(source)
try:
from_instance = self.frame.eval(co, __exprinfo_expr=source_result)
except Exception:
from_instance = True
if from_instance:
rep = self.frame.repr(result)
pattern = "%s\n{%s = %s\n}"
explanation = pattern % (rep, rep, explanation)
return explanation, result
def visit_Assert(self, assrt):
test_explanation, test_result = self.visit(assrt.test)
if test_explanation.startswith("False\n{False =") and \
test_explanation.endswith("\n"):
test_explanation = test_explanation[15:-2]
explanation = "assert %s" % (test_explanation,)
if not test_result:
try:
raise BuiltinAssertionError
except Exception:
raise Failure(explanation)
return explanation, test_result
def visit_Assign(self, assign):
value_explanation, value_result = self.visit(assign.value)
explanation = "... = %s" % (value_explanation,)
name = ast.Name("__exprinfo_expr", ast.Load(),
lineno=assign.value.lineno,
col_offset=assign.value.col_offset)
new_assign = ast.Assign(assign.targets, name, lineno=assign.lineno,
col_offset=assign.col_offset)
mod = ast.Module([new_assign])
co = self._compile(mod, "exec")
try:
self.frame.exec_(co, __exprinfo_expr=value_result)
except Exception:
raise Failure(explanation)
return explanation, value_result

View File

@ -0,0 +1,556 @@
import py
import sys, inspect
from compiler import parse, ast, pycodegen
from _pytest.assertion.util import format_explanation
from _pytest.assertion.reinterpret import BuiltinAssertionError
passthroughex = py.builtin._sysex
class Failure:
def __init__(self, node):
self.exc, self.value, self.tb = sys.exc_info()
self.node = node
class View(object):
"""View base class.
If C is a subclass of View, then C(x) creates a proxy object around
the object x. The actual class of the proxy is not C in general,
but a *subclass* of C determined by the rules below. To avoid confusion
we call view class the class of the proxy (a subclass of C, so of View)
and object class the class of x.
Attributes and methods not found in the proxy are automatically read on x.
Other operations like setting attributes are performed on the proxy, as
determined by its view class. The object x is available from the proxy
as its __obj__ attribute.
The view class selection is determined by the __view__ tuples and the
optional __viewkey__ method. By default, the selected view class is the
most specific subclass of C whose __view__ mentions the class of x.
If no such subclass is found, the search proceeds with the parent
object classes. For example, C(True) will first look for a subclass
of C with __view__ = (..., bool, ...) and only if it doesn't find any
look for one with __view__ = (..., int, ...), and then ..., object,...
If everything fails the class C itself is considered to be the default.
Alternatively, the view class selection can be driven by another aspect
of the object x, instead of the class of x, by overriding __viewkey__.
See last example at the end of this module.
"""
_viewcache = {}
__view__ = ()
def __new__(rootclass, obj, *args, **kwds):
self = object.__new__(rootclass)
self.__obj__ = obj
self.__rootclass__ = rootclass
key = self.__viewkey__()
try:
self.__class__ = self._viewcache[key]
except KeyError:
self.__class__ = self._selectsubclass(key)
return self
def __getattr__(self, attr):
# attributes not found in the normal hierarchy rooted on View
# are looked up in the object's real class
return getattr(self.__obj__, attr)
def __viewkey__(self):
return self.__obj__.__class__
def __matchkey__(self, key, subclasses):
if inspect.isclass(key):
keys = inspect.getmro(key)
else:
keys = [key]
for key in keys:
result = [C for C in subclasses if key in C.__view__]
if result:
return result
return []
def _selectsubclass(self, key):
subclasses = list(enumsubclasses(self.__rootclass__))
for C in subclasses:
if not isinstance(C.__view__, tuple):
C.__view__ = (C.__view__,)
choices = self.__matchkey__(key, subclasses)
if not choices:
return self.__rootclass__
elif len(choices) == 1:
return choices[0]
else:
# combine the multiple choices
return type('?', tuple(choices), {})
def __repr__(self):
return '%s(%r)' % (self.__rootclass__.__name__, self.__obj__)
def enumsubclasses(cls):
for subcls in cls.__subclasses__():
for subsubclass in enumsubclasses(subcls):
yield subsubclass
yield cls
class Interpretable(View):
"""A parse tree node with a few extra methods."""
explanation = None
def is_builtin(self, frame):
return False
def eval(self, frame):
# fall-back for unknown expression nodes
try:
expr = ast.Expression(self.__obj__)
expr.filename = '<eval>'
self.__obj__.filename = '<eval>'
co = pycodegen.ExpressionCodeGenerator(expr).getCode()
result = frame.eval(co)
except passthroughex:
raise
except:
raise Failure(self)
self.result = result
self.explanation = self.explanation or frame.repr(self.result)
def run(self, frame):
# fall-back for unknown statement nodes
try:
expr = ast.Module(None, ast.Stmt([self.__obj__]))
expr.filename = '<run>'
co = pycodegen.ModuleCodeGenerator(expr).getCode()
frame.exec_(co)
except passthroughex:
raise
except:
raise Failure(self)
def nice_explanation(self):
return format_explanation(self.explanation)
class Name(Interpretable):
__view__ = ast.Name
def is_local(self, frame):
source = '%r in locals() is not globals()' % self.name
try:
return frame.is_true(frame.eval(source))
except passthroughex:
raise
except:
return False
def is_global(self, frame):
source = '%r in globals()' % self.name
try:
return frame.is_true(frame.eval(source))
except passthroughex:
raise
except:
return False
def is_builtin(self, frame):
source = '%r not in locals() and %r not in globals()' % (
self.name, self.name)
try:
return frame.is_true(frame.eval(source))
except passthroughex:
raise
except:
return False
def eval(self, frame):
super(Name, self).eval(frame)
if not self.is_local(frame):
self.explanation = self.name
class Compare(Interpretable):
__view__ = ast.Compare
def eval(self, frame):
expr = Interpretable(self.expr)
expr.eval(frame)
for operation, expr2 in self.ops:
if hasattr(self, 'result'):
# shortcutting in chained expressions
if not frame.is_true(self.result):
break
expr2 = Interpretable(expr2)
expr2.eval(frame)
self.explanation = "%s %s %s" % (
expr.explanation, operation, expr2.explanation)
source = "__exprinfo_left %s __exprinfo_right" % operation
try:
self.result = frame.eval(source,
__exprinfo_left=expr.result,
__exprinfo_right=expr2.result)
except passthroughex:
raise
except:
raise Failure(self)
expr = expr2
class And(Interpretable):
__view__ = ast.And
def eval(self, frame):
explanations = []
for expr in self.nodes:
expr = Interpretable(expr)
expr.eval(frame)
explanations.append(expr.explanation)
self.result = expr.result
if not frame.is_true(expr.result):
break
self.explanation = '(' + ' and '.join(explanations) + ')'
class Or(Interpretable):
__view__ = ast.Or
def eval(self, frame):
explanations = []
for expr in self.nodes:
expr = Interpretable(expr)
expr.eval(frame)
explanations.append(expr.explanation)
self.result = expr.result
if frame.is_true(expr.result):
break
self.explanation = '(' + ' or '.join(explanations) + ')'
# == Unary operations ==
keepalive = []
for astclass, astpattern in {
ast.Not : 'not __exprinfo_expr',
ast.Invert : '(~__exprinfo_expr)',
}.items():
class UnaryArith(Interpretable):
__view__ = astclass
def eval(self, frame, astpattern=astpattern):
expr = Interpretable(self.expr)
expr.eval(frame)
self.explanation = astpattern.replace('__exprinfo_expr',
expr.explanation)
try:
self.result = frame.eval(astpattern,
__exprinfo_expr=expr.result)
except passthroughex:
raise
except:
raise Failure(self)
keepalive.append(UnaryArith)
# == Binary operations ==
for astclass, astpattern in {
ast.Add : '(__exprinfo_left + __exprinfo_right)',
ast.Sub : '(__exprinfo_left - __exprinfo_right)',
ast.Mul : '(__exprinfo_left * __exprinfo_right)',
ast.Div : '(__exprinfo_left / __exprinfo_right)',
ast.Mod : '(__exprinfo_left % __exprinfo_right)',
ast.Power : '(__exprinfo_left ** __exprinfo_right)',
}.items():
class BinaryArith(Interpretable):
__view__ = astclass
def eval(self, frame, astpattern=astpattern):
left = Interpretable(self.left)
left.eval(frame)
right = Interpretable(self.right)
right.eval(frame)
self.explanation = (astpattern
.replace('__exprinfo_left', left .explanation)
.replace('__exprinfo_right', right.explanation))
try:
self.result = frame.eval(astpattern,
__exprinfo_left=left.result,
__exprinfo_right=right.result)
except passthroughex:
raise
except:
raise Failure(self)
keepalive.append(BinaryArith)
class CallFunc(Interpretable):
__view__ = ast.CallFunc
def is_bool(self, frame):
source = 'isinstance(__exprinfo_value, bool)'
try:
return frame.is_true(frame.eval(source,
__exprinfo_value=self.result))
except passthroughex:
raise
except:
return False
def eval(self, frame):
node = Interpretable(self.node)
node.eval(frame)
explanations = []
vars = {'__exprinfo_fn': node.result}
source = '__exprinfo_fn('
for a in self.args:
if isinstance(a, ast.Keyword):
keyword = a.name
a = a.expr
else:
keyword = None
a = Interpretable(a)
a.eval(frame)
argname = '__exprinfo_%d' % len(vars)
vars[argname] = a.result
if keyword is None:
source += argname + ','
explanations.append(a.explanation)
else:
source += '%s=%s,' % (keyword, argname)
explanations.append('%s=%s' % (keyword, a.explanation))
if self.star_args:
star_args = Interpretable(self.star_args)
star_args.eval(frame)
argname = '__exprinfo_star'
vars[argname] = star_args.result
source += '*' + argname + ','
explanations.append('*' + star_args.explanation)
if self.dstar_args:
dstar_args = Interpretable(self.dstar_args)
dstar_args.eval(frame)
argname = '__exprinfo_kwds'
vars[argname] = dstar_args.result
source += '**' + argname + ','
explanations.append('**' + dstar_args.explanation)
self.explanation = "%s(%s)" % (
node.explanation, ', '.join(explanations))
if source.endswith(','):
source = source[:-1]
source += ')'
try:
self.result = frame.eval(source, **vars)
except passthroughex:
raise
except:
raise Failure(self)
if not node.is_builtin(frame) or not self.is_bool(frame):
r = frame.repr(self.result)
self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation)
class Getattr(Interpretable):
__view__ = ast.Getattr
def eval(self, frame):
expr = Interpretable(self.expr)
expr.eval(frame)
source = '__exprinfo_expr.%s' % self.attrname
try:
self.result = frame.eval(source, __exprinfo_expr=expr.result)
except passthroughex:
raise
except:
raise Failure(self)
self.explanation = '%s.%s' % (expr.explanation, self.attrname)
# if the attribute comes from the instance, its value is interesting
source = ('hasattr(__exprinfo_expr, "__dict__") and '
'%r in __exprinfo_expr.__dict__' % self.attrname)
try:
from_instance = frame.is_true(
frame.eval(source, __exprinfo_expr=expr.result))
except passthroughex:
raise
except:
from_instance = True
if from_instance:
r = frame.repr(self.result)
self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation)
# == Re-interpretation of full statements ==
class Assert(Interpretable):
__view__ = ast.Assert
def run(self, frame):
test = Interpretable(self.test)
test.eval(frame)
# simplify 'assert False where False = ...'
if (test.explanation.startswith('False\n{False = ') and
test.explanation.endswith('\n}')):
test.explanation = test.explanation[15:-2]
# print the result as 'assert <explanation>'
self.result = test.result
self.explanation = 'assert ' + test.explanation
if not frame.is_true(test.result):
try:
raise BuiltinAssertionError
except passthroughex:
raise
except:
raise Failure(self)
class Assign(Interpretable):
__view__ = ast.Assign
def run(self, frame):
expr = Interpretable(self.expr)
expr.eval(frame)
self.result = expr.result
self.explanation = '... = ' + expr.explanation
# fall-back-run the rest of the assignment
ass = ast.Assign(self.nodes, ast.Name('__exprinfo_expr'))
mod = ast.Module(None, ast.Stmt([ass]))
mod.filename = '<run>'
co = pycodegen.ModuleCodeGenerator(mod).getCode()
try:
frame.exec_(co, __exprinfo_expr=expr.result)
except passthroughex:
raise
except:
raise Failure(self)
class Discard(Interpretable):
__view__ = ast.Discard
def run(self, frame):
expr = Interpretable(self.expr)
expr.eval(frame)
self.result = expr.result
self.explanation = expr.explanation
class Stmt(Interpretable):
__view__ = ast.Stmt
def run(self, frame):
for stmt in self.nodes:
stmt = Interpretable(stmt)
stmt.run(frame)
def report_failure(e):
explanation = e.node.nice_explanation()
if explanation:
explanation = ", in: " + explanation
else:
explanation = ""
sys.stdout.write("%s: %s%s\n" % (e.exc.__name__, e.value, explanation))
def check(s, frame=None):
if frame is None:
frame = sys._getframe(1)
frame = py.code.Frame(frame)
expr = parse(s, 'eval')
assert isinstance(expr, ast.Expression)
node = Interpretable(expr.node)
try:
node.eval(frame)
except passthroughex:
raise
except Failure:
e = sys.exc_info()[1]
report_failure(e)
else:
if not frame.is_true(node.result):
sys.stderr.write("assertion failed: %s\n" % node.nice_explanation())
###########################################################
# API / Entry points
# #########################################################
def interpret(source, frame, should_fail=False):
module = Interpretable(parse(source, 'exec').node)
#print "got module", module
if isinstance(frame, py.std.types.FrameType):
frame = py.code.Frame(frame)
try:
module.run(frame)
except Failure:
e = sys.exc_info()[1]
return getfailure(e)
except passthroughex:
raise
except:
import traceback
traceback.print_exc()
if should_fail:
return ("(assertion failed, but when it was re-run for "
"printing intermediate values, it did not fail. Suggestions: "
"compute assert expression before the assert or use --nomagic)")
else:
return None
def getmsg(excinfo):
if isinstance(excinfo, tuple):
excinfo = py.code.ExceptionInfo(excinfo)
#frame, line = gettbline(tb)
#frame = py.code.Frame(frame)
#return interpret(line, frame)
tb = excinfo.traceback[-1]
source = str(tb.statement).strip()
x = interpret(source, tb.frame, should_fail=True)
if not isinstance(x, str):
raise TypeError("interpret returned non-string %r" % (x,))
return x
def getfailure(e):
explanation = e.node.nice_explanation()
if str(e.value):
lines = explanation.split('\n')
lines[0] += " << %s" % (e.value,)
explanation = '\n'.join(lines)
text = "%s: %s" % (e.exc.__name__, explanation)
if text.startswith('AssertionError: assert '):
text = text[16:]
return text
def run(s, frame=None):
if frame is None:
frame = sys._getframe(1)
frame = py.code.Frame(frame)
module = Interpretable(parse(s, 'exec').node)
try:
module.run(frame)
except Failure:
e = sys.exc_info()[1]
report_failure(e)
if __name__ == '__main__':
# example:
def f():
return 5
def g():
return 3
def h(x):
return 'never'
check("f() * g() == 5")
check("not f()")
check("not (f() and g() or 0)")
check("f() == g()")
i = 4
check("i == f()")
check("len(f()) == 0")
check("isinstance(2+3+4, float)")
run("x = i")
check("x == 5")
run("assert not f(), 'oops'")
run("a, b, c = 1, 2")
run("a, b, c = f()")
check("max([f(),g()]) == 4")
check("'hello'[g()] == 'h'")
run("'guk%d' % h(f())")

View File

@ -0,0 +1,48 @@
import sys
import py
BuiltinAssertionError = py.builtin.builtins.AssertionError
class AssertionError(BuiltinAssertionError):
def __init__(self, *args):
BuiltinAssertionError.__init__(self, *args)
if args:
try:
self.msg = str(args[0])
except py.builtin._sysex:
raise
except:
self.msg = "<[broken __repr__] %s at %0xd>" %(
args[0].__class__, id(args[0]))
else:
f = py.code.Frame(sys._getframe(1))
try:
source = f.code.fullsource
if source is not None:
try:
source = source.getstatement(f.lineno, assertion=True)
except IndexError:
source = None
else:
source = str(source.deindent()).strip()
except py.error.ENOENT:
source = None
# this can also occur during reinterpretation, when the
# co_filename is set to "<run>".
if source:
self.msg = reinterpret(source, f, should_fail=True)
else:
self.msg = "<could not determine information>"
if not self.args:
self.args = (self.msg,)
if sys.version_info > (3, 0):
AssertionError.__module__ = "builtins"
reinterpret_old = "old reinterpretation not available for py3"
else:
from _pytest.assertion.oldinterpret import interpret as reinterpret_old
if sys.version_info >= (2, 6) or (sys.platform.startswith("java")):
from _pytest.assertion.newinterpret import interpret as reinterpret
else:
reinterpret = reinterpret_old

View File

@ -0,0 +1,340 @@
"""Rewrite assertion AST to produce nice error messages"""
import ast
import collections
import itertools
import sys
import py
from _pytest.assertion import util
def rewrite_asserts(mod):
"""Rewrite the assert statements in mod."""
AssertionRewriter().run(mod)
_saferepr = py.io.saferepr
from _pytest.assertion.util import format_explanation as _format_explanation
def _format_boolop(operands, explanations, is_or):
show_explanations = []
for operand, expl in zip(operands, explanations):
show_explanations.append(expl)
if operand == is_or:
break
return "(" + (is_or and " or " or " and ").join(show_explanations) + ")"
def _call_reprcompare(ops, results, expls, each_obj):
for i, res, expl in zip(range(len(ops)), results, expls):
try:
done = not res
except Exception:
done = True
if done:
break
if util._reprcompare is not None:
custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
if custom is not None:
return custom
return expl
unary_map = {
ast.Not : "not %s",
ast.Invert : "~%s",
ast.USub : "-%s",
ast.UAdd : "+%s"
}
binop_map = {
ast.BitOr : "|",
ast.BitXor : "^",
ast.BitAnd : "&",
ast.LShift : "<<",
ast.RShift : ">>",
ast.Add : "+",
ast.Sub : "-",
ast.Mult : "*",
ast.Div : "/",
ast.FloorDiv : "//",
ast.Mod : "%",
ast.Eq : "==",
ast.NotEq : "!=",
ast.Lt : "<",
ast.LtE : "<=",
ast.Gt : ">",
ast.GtE : ">=",
ast.Pow : "**",
ast.Is : "is",
ast.IsNot : "is not",
ast.In : "in",
ast.NotIn : "not in"
}
def set_location(node, lineno, col_offset):
"""Set node location information recursively."""
def _fix(node, lineno, col_offset):
if "lineno" in node._attributes:
node.lineno = lineno
if "col_offset" in node._attributes:
node.col_offset = col_offset
for child in ast.iter_child_nodes(node):
_fix(child, lineno, col_offset)
_fix(node, lineno, col_offset)
return node
class AssertionRewriter(ast.NodeVisitor):
def run(self, mod):
"""Find all assert statements in *mod* and rewrite them."""
if not mod.body:
# Nothing to do.
return
# Insert some special imports at the top of the module but after any
# docstrings and __future__ imports.
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
expect_docstring = True
pos = 0
lineno = 0
for item in mod.body:
if (expect_docstring and isinstance(item, ast.Expr) and
isinstance(item.value, ast.Str)):
doc = item.value.s
if "PYTEST_DONT_REWRITE" in doc:
# The module has disabled assertion rewriting.
return
lineno += len(doc) - 1
expect_docstring = False
elif (not isinstance(item, ast.ImportFrom) or item.level > 0 and
item.identifier != "__future__"):
lineno = item.lineno
break
pos += 1
imports = [ast.Import([alias], lineno=lineno, col_offset=0)
for alias in aliases]
mod.body[pos:pos] = imports
# Collect asserts.
nodes = collections.deque([mod])
while nodes:
node = nodes.popleft()
for name, field in ast.iter_fields(node):
if isinstance(field, list):
new = []
for i, child in enumerate(field):
if isinstance(child, ast.Assert):
# Transform assert.
new.extend(self.visit(child))
else:
new.append(child)
if isinstance(child, ast.AST):
nodes.append(child)
setattr(node, name, new)
elif (isinstance(field, ast.AST) and
# Don't recurse into expressions as they can't contain
# asserts.
not isinstance(field, ast.expr)):
nodes.append(field)
def variable(self):
"""Get a new variable."""
# Use a character invalid in python identifiers to avoid clashing.
name = "@py_assert" + str(next(self.variable_counter))
self.variables.add(name)
return name
def assign(self, expr):
"""Give *expr* a name."""
name = self.variable()
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
return ast.Name(name, ast.Load())
def display(self, expr):
"""Call py.io.saferepr on the expression."""
return self.helper("saferepr", expr)
def helper(self, name, *args):
"""Call a helper in this module."""
py_name = ast.Name("@pytest_ar", ast.Load())
attr = ast.Attribute(py_name, "_" + name, ast.Load())
return ast.Call(attr, list(args), [], None, None)
def builtin(self, name):
"""Return the builtin called *name*."""
builtin_name = ast.Name("@py_builtins", ast.Load())
return ast.Attribute(builtin_name, name, ast.Load())
def explanation_param(self, expr):
specifier = "py" + str(next(self.variable_counter))
self.explanation_specifiers[specifier] = expr
return "%(" + specifier + ")s"
def push_format_context(self):
self.explanation_specifiers = {}
self.stack.append(self.explanation_specifiers)
def pop_format_context(self, expl_expr):
current = self.stack.pop()
if self.stack:
self.explanation_specifiers = self.stack[-1]
keys = [ast.Str(key) for key in current.keys()]
format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter))
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load())
def generic_visit(self, node):
"""Handle expressions we don't have custom code for."""
assert isinstance(node, ast.expr)
res = self.assign(node)
return res, self.explanation_param(self.display(res))
def visit_Assert(self, assert_):
if assert_.msg:
# There's already a message. Don't mess with it.
return [assert_]
self.statements = []
self.variables = set()
self.variable_counter = itertools.count()
self.stack = []
self.on_failure = []
self.push_format_context()
# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)
# Create failure message.
body = self.on_failure
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))
explanation = "assert " + explanation
template = ast.Str(explanation)
msg = self.pop_format_context(template)
fmt = self.helper("format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
exc = ast.Call(err_name, [fmt], [], None, None)
if sys.version_info[0] >= 3:
raise_ = ast.Raise(exc, None)
else:
raise_ = ast.Raise(exc, None, None)
body.append(raise_)
# Delete temporary variables.
names = [ast.Name(name, ast.Del()) for name in self.variables]
if names:
delete = ast.Delete(names)
self.statements.append(delete)
# Fix line numbers.
for stmt in self.statements:
set_location(stmt, assert_.lineno, assert_.col_offset)
return self.statements
def visit_Name(self, name):
# Check if the name is local or not.
locs = ast.Call(self.builtin("locals"), [], [], None, None)
globs = ast.Call(self.builtin("globals"), [], [], None, None)
ops = [ast.In(), ast.IsNot()]
test = ast.Compare(ast.Str(name.id), ops, [locs, globs])
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
return name, self.explanation_param(expr)
def visit_BoolOp(self, boolop):
operands = []
explanations = []
self.push_format_context()
for operand in boolop.values:
res, explanation = self.visit(operand)
operands.append(res)
explanations.append(explanation)
expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load())
is_or = ast.Num(isinstance(boolop.op, ast.Or))
expl_template = self.helper("format_boolop",
ast.Tuple(operands, ast.Load()), expls,
is_or)
expl = self.pop_format_context(expl_template)
res = self.assign(ast.BoolOp(boolop.op, operands))
return res, self.explanation_param(expl)
def visit_UnaryOp(self, unary):
pattern = unary_map[unary.op.__class__]
operand_res, operand_expl = self.visit(unary.operand)
res = self.assign(ast.UnaryOp(unary.op, operand_res))
return res, pattern % (operand_expl,)
def visit_BinOp(self, binop):
symbol = binop_map[binop.op.__class__]
left_expr, left_expl = self.visit(binop.left)
right_expr, right_expl = self.visit(binop.right)
explanation = "(%s %s %s)" % (left_expl, symbol, right_expl)
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
return res, explanation
def visit_Call(self, call):
new_func, func_expl = self.visit(call.func)
arg_expls = []
new_args = []
new_kwargs = []
new_star = new_kwarg = None
for arg in call.args:
res, expl = self.visit(arg)
new_args.append(res)
arg_expls.append(expl)
for keyword in call.keywords:
res, expl = self.visit(keyword.value)
new_kwargs.append(ast.keyword(keyword.arg, res))
arg_expls.append(keyword.arg + "=" + expl)
if call.starargs:
new_star, expl = self.visit(call.starargs)
arg_expls.append("*" + expl)
if call.kwargs:
new_kwarg, expl = self.visit(call.kwarg)
arg_expls.append("**" + expl)
expl = "%s(%s)" % (func_expl, ', '.join(arg_expls))
new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg)
res = self.assign(new_call)
res_expl = self.explanation_param(self.display(res))
outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
return res, outer_expl
def visit_Attribute(self, attr):
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
value, value_expl = self.visit(attr.value)
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
res_expl = self.explanation_param(self.display(res))
pat = "%s\n{%s = %s.%s\n}"
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
return res, expl
def visit_Compare(self, comp):
self.push_format_context()
left_res, left_expl = self.visit(comp.left)
res_variables = [self.variable() for i in range(len(comp.ops))]
load_names = [ast.Name(v, ast.Load()) for v in res_variables]
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
expls = []
syms = []
results = [left_res]
for i, op, next_operand in it:
next_res, next_expl = self.visit(next_operand)
results.append(next_res)
sym = binop_map[op.__class__]
syms.append(ast.Str(sym))
expl = "%s %s %s" % (left_expl, sym, next_expl)
expls.append(ast.Str(expl))
res_expr = ast.Compare(left_res, [op], [next_res])
self.statements.append(ast.Assign([store_names[i]], res_expr))
left_res, left_expl = next_res, next_expl
# Use py.code._reprcompare if that's available.
expl_call = self.helper("call_reprcompare",
ast.Tuple(syms, ast.Load()),
ast.Tuple(load_names, ast.Load()),
ast.Tuple(expls, ast.Load()),
ast.Tuple(results, ast.Load()))
if len(comp.ops) > 1:
res = ast.BoolOp(ast.And(), load_names)
else:
res = load_names[0]
return res, self.explanation_param(self.pop_format_context(expl_call))

View File

@ -1,43 +1,57 @@
""" """Utilities for assertion debugging"""
support for presented detailed information in failing assertions.
"""
import py import py
import sys
from _pytest.monkeypatch import monkeypatch
def pytest_addoption(parser):
group = parser.getgroup("debugconfig")
group._addoption('--no-assert', action="store_true", default=False,
dest="noassert",
help="disable python assert expression reinterpretation."),
def pytest_configure(config): # The _reprcompare attribute on the util module is used by the new assertion
# The _reprcompare attribute on the py.code module is used by # interpretation code and assertion rewriter to detect this plugin was
# py._code._assertionnew to detect this plugin was loaded and in # loaded and in turn call the hooks defined here as part of the
# turn call the hooks defined here as part of the
# DebugInterpreter. # DebugInterpreter.
m = monkeypatch() _reprcompare = None
config._cleanup.append(m.undo)
warn_about_missing_assertion()
if not config.getvalue("noassert") and not config.getvalue("nomagic"):
def callbinrepr(op, left, right):
hook_result = config.hook.pytest_assertrepr_compare(
config=config, op=op, left=left, right=right)
for new_expl in hook_result:
if new_expl:
return '\n~'.join(new_expl)
m.setattr(py.builtin.builtins,
'AssertionError', py.code._AssertionError)
m.setattr(py.code, '_reprcompare', callbinrepr)
def warn_about_missing_assertion(): def format_explanation(explanation):
try: """This formats an explanation
assert False
except AssertionError: Normally all embedded newlines are escaped, however there are
pass three exceptions: \n{, \n} and \n~. The first two are intended
cover nested explanations, see function and attribute explanations
for examples (.visit_Call(), visit_Attribute()). The last one is
for when one explanation needs to span multiple lines, e.g. when
displaying diffs.
"""
raw_lines = (explanation or '').split('\n')
# escape newlines not followed by {, } and ~
lines = [raw_lines[0]]
for l in raw_lines[1:]:
if l.startswith('{') or l.startswith('}') or l.startswith('~'):
lines.append(l)
else: else:
sys.stderr.write("WARNING: failing tests may report as passing because " lines[-1] += '\\n' + l
"assertions are turned off! (are you using python -O?)\n")
result = lines[:1]
stack = [0]
stackcnt = [0]
for line in lines[1:]:
if line.startswith('{'):
if stackcnt[-1]:
s = 'and '
else:
s = 'where '
stack.append(len(result))
stackcnt[-1] += 1
stackcnt.append(0)
result.append(' +' + ' '*(len(stack)-1) + s + line[1:])
elif line.startswith('}'):
assert line.startswith('}')
stack.pop()
stackcnt.pop()
result[stack[-1]] += line[1:]
else:
assert line.startswith('~')
result.append(' '*len(stack) + line[1:])
assert len(stack) == 1
return '\n'.join(result)
# Provide basestring in python3 # Provide basestring in python3
try: try:
@ -46,7 +60,7 @@ except NameError:
basestring = str basestring = str
def pytest_assertrepr_compare(op, left, right): def assertrepr_compare(op, left, right):
"""return specialised explanations for some operators/operands""" """return specialised explanations for some operators/operands"""
width = 80 - 15 - len(op) - 2 # 15 chars indentation, 1 space around op width = 80 - 15 - len(op) - 2 # 15 chars indentation, 1 space around op
left_repr = py.io.saferepr(left, maxsize=int(width/2)) left_repr = py.io.saferepr(left, maxsize=int(width/2))

View File

@ -16,9 +16,6 @@ def pytest_addoption(parser):
group.addoption('--traceconfig', group.addoption('--traceconfig',
action="store_true", dest="traceconfig", default=False, action="store_true", dest="traceconfig", default=False,
help="trace considerations of conftest.py files."), help="trace considerations of conftest.py files."),
group._addoption('--nomagic',
action="store_true", dest="nomagic", default=False,
help="don't reinterpret asserts, no traceback cutting. ")
group.addoption('--debug', group.addoption('--debug',
action="store_true", dest="debug", default=False, action="store_true", dest="debug", default=False,
help="generate and show internal debugging information.") help="generate and show internal debugging information.")

View File

@ -46,23 +46,22 @@ def pytest_addoption(parser):
def pytest_namespace(): def pytest_namespace():
return dict(collect=dict(Item=Item, Collector=Collector, File=File)) collect = dict(Item=Item, Collector=Collector, File=File, Session=Session)
return dict(collect=collect)
def pytest_configure(config): def pytest_configure(config):
py.test.config = config # compatibiltiy py.test.config = config # compatibiltiy
if config.option.exitfirst: if config.option.exitfirst:
config.option.maxfail = 1 config.option.maxfail = 1
def pytest_cmdline_main(config): def wrap_session(config, doit):
""" default command line protocol for initialization, session, """Skeleton command line program"""
running tests and reporting. """
session = Session(config) session = Session(config)
session.exitstatus = EXIT_OK session.exitstatus = EXIT_OK
try: try:
config.pluginmanager.do_configure(config) config.pluginmanager.do_configure(config)
config.hook.pytest_sessionstart(session=session) config.hook.pytest_sessionstart(session=session)
config.hook.pytest_collection(session=session) doit(config, session)
config.hook.pytest_runtestloop(session=session)
except pytest.UsageError: except pytest.UsageError:
raise raise
except KeyboardInterrupt: except KeyboardInterrupt:
@ -82,13 +81,17 @@ def pytest_cmdline_main(config):
config.pluginmanager.do_unconfigure(config) config.pluginmanager.do_unconfigure(config)
return session.exitstatus return session.exitstatus
def pytest_cmdline_main(config):
return wrap_session(config, _main)
def _main(config, session):
""" default command line protocol for initialization, session,
running tests and reporting. """
config.hook.pytest_collection(session=session)
config.hook.pytest_runtestloop(session=session)
def pytest_collection(session): def pytest_collection(session):
session.perform_collect() return session.perform_collect()
hook = session.config.hook
hook.pytest_collection_modifyitems(session=session,
config=session.config, items=session.items)
hook.pytest_collection_finish(session=session)
return True
def pytest_runtestloop(session): def pytest_runtestloop(session):
if session.config.option.collectonly: if session.config.option.collectonly:
@ -374,6 +377,16 @@ class Session(FSCollector):
return HookProxy(fspath, self.config) return HookProxy(fspath, self.config)
def perform_collect(self, args=None, genitems=True): def perform_collect(self, args=None, genitems=True):
hook = self.config.hook
try:
items = self._perform_collect(args, genitems)
hook.pytest_collection_modifyitems(session=self,
config=self.config, items=items)
finally:
hook.pytest_collection_finish(session=self)
return items
def _perform_collect(self, args, genitems):
if args is None: if args is None:
args = self.config.args args = self.config.args
self.trace("perform_collect", self, args) self.trace("perform_collect", self, args)

View File

@ -6,7 +6,7 @@ import re
import inspect import inspect
import time import time
from fnmatch import fnmatch from fnmatch import fnmatch
from _pytest.main import Session from _pytest.main import Session, EXIT_OK
from py.builtin import print_ from py.builtin import print_
from _pytest.core import HookRelay from _pytest.core import HookRelay
@ -292,13 +292,19 @@ class TmpTestdir:
assert '::' not in str(arg) assert '::' not in str(arg)
p = py.path.local(arg) p = py.path.local(arg)
x = session.fspath.bestrelpath(p) x = session.fspath.bestrelpath(p)
return session.perform_collect([x], genitems=False)[0] config.hook.pytest_sessionstart(session=session)
res = session.perform_collect([x], genitems=False)[0]
config.hook.pytest_sessionfinish(session=session, exitstatus=EXIT_OK)
return res
def getpathnode(self, path): def getpathnode(self, path):
config = self.parseconfig(path) config = self.parseconfigure(path)
session = Session(config) session = Session(config)
x = session.fspath.bestrelpath(path) x = session.fspath.bestrelpath(path)
return session.perform_collect([x], genitems=False)[0] config.hook.pytest_sessionstart(session=session)
res = session.perform_collect([x], genitems=False)[0]
config.hook.pytest_sessionfinish(session=session, exitstatus=EXIT_OK)
return res
def genitems(self, colitems): def genitems(self, colitems):
session = colitems[0].session session = colitems[0].session
@ -312,7 +318,9 @@ class TmpTestdir:
config = self.parseconfigure(*args) config = self.parseconfigure(*args)
rec = self.getreportrecorder(config) rec = self.getreportrecorder(config)
session = Session(config) session = Session(config)
config.hook.pytest_sessionstart(session=session)
session.perform_collect() session.perform_collect()
config.hook.pytest_sessionfinish(session=session, exitstatus=EXIT_OK)
return session.items, rec return session.items, rec
def runitem(self, source): def runitem(self, source):
@ -382,6 +390,8 @@ class TmpTestdir:
c.basetemp = py.path.local.make_numbered_dir(prefix="reparse", c.basetemp = py.path.local.make_numbered_dir(prefix="reparse",
keep=0, rootdir=self.tmpdir, lock_timeout=None) keep=0, rootdir=self.tmpdir, lock_timeout=None)
c.parse(args) c.parse(args)
c.pluginmanager.do_configure(c)
self.request.addfinalizer(lambda: c.pluginmanager.do_unconfigure(c))
return c return c
finally: finally:
py.test.config = oldconfig py.test.config = oldconfig

View File

@ -4,6 +4,7 @@ import inspect
import sys import sys
import pytest import pytest
from py._code.code import TerminalRepr from py._code.code import TerminalRepr
from _pytest import assertion
import _pytest import _pytest
cutdir = py.path.local(_pytest.__file__).dirpath() cutdir = py.path.local(_pytest.__file__).dirpath()
@ -226,8 +227,12 @@ class Module(pytest.File, PyCollectorMixin):
def _importtestmodule(self): def _importtestmodule(self):
# we assume we are only called once per module # we assume we are only called once per module
assertion.before_module_import(self)
try:
try: try:
mod = self.fspath.pyimport(ensuresyspath=True) mod = self.fspath.pyimport(ensuresyspath=True)
finally:
assertion.after_module_import(self)
except SyntaxError: except SyntaxError:
excinfo = py.code.ExceptionInfo() excinfo = py.code.ExceptionInfo()
raise self.CollectError(excinfo.getrepr(style="short")) raise self.CollectError(excinfo.getrepr(style="short"))
@ -374,7 +379,7 @@ class Generator(FunctionMixin, PyCollectorMixin, pytest.Collector):
# test generators are seen as collectors but they also # test generators are seen as collectors but they also
# invoke setup/teardown on popular request # invoke setup/teardown on popular request
# (induced by the common "test_*" naming shared with normal tests) # (induced by the common "test_*" naming shared with normal tests)
self.config._setupstate.prepare(self) self.session._setupstate.prepare(self)
# see FunctionMixin.setup and test_setupstate_is_preserved_134 # see FunctionMixin.setup and test_setupstate_is_preserved_134
self._preservedparent = self.parent.obj self._preservedparent = self.parent.obj
l = [] l = []
@ -726,7 +731,7 @@ class FuncargRequest:
def _addfinalizer(self, finalizer, scope): def _addfinalizer(self, finalizer, scope):
colitem = self._getscopeitem(scope) colitem = self._getscopeitem(scope)
self.config._setupstate.addfinalizer( self._pyfuncitem.session._setupstate.addfinalizer(
finalizer=finalizer, colitem=colitem) finalizer=finalizer, colitem=colitem)
def __repr__(self): def __repr__(self):
@ -747,8 +752,10 @@ class FuncargRequest:
raise self.LookupError(msg) raise self.LookupError(msg)
def showfuncargs(config): def showfuncargs(config):
from _pytest.main import Session from _pytest.main import wrap_session
session = Session(config) return wrap_session(config, _showfuncargs_main)
def _showfuncargs_main(config, session):
session.perform_collect() session.perform_collect()
if session.items: if session.items:
plugins = session.items[0].getplugins() plugins = session.items[0].getplugins()

View File

@ -14,12 +14,10 @@ def pytest_namespace():
# #
# pytest plugin hooks # pytest plugin hooks
# XXX move to pytest_sessionstart and fix py.test owns tests def pytest_sessionstart(session):
def pytest_configure(config): session._setupstate = SetupState()
config._setupstate = SetupState()
def pytest_sessionfinish(session, exitstatus): def pytest_sessionfinish(session, exitstatus):
if hasattr(session.config, '_setupstate'):
hook = session.config.hook hook = session.config.hook
rep = hook.pytest__teardown_final(session=session) rep = hook.pytest__teardown_final(session=session)
if rep: if rep:
@ -46,16 +44,16 @@ def runtestprotocol(item, log=True):
return reports return reports
def pytest_runtest_setup(item): def pytest_runtest_setup(item):
item.config._setupstate.prepare(item) item.session._setupstate.prepare(item)
def pytest_runtest_call(item): def pytest_runtest_call(item):
item.runtest() item.runtest()
def pytest_runtest_teardown(item): def pytest_runtest_teardown(item):
item.config._setupstate.teardown_exact(item) item.session._setupstate.teardown_exact(item)
def pytest__teardown_final(session): def pytest__teardown_final(session):
call = CallInfo(session.config._setupstate.teardown_all, when="teardown") call = CallInfo(session._setupstate.teardown_all, when="teardown")
if call.excinfo: if call.excinfo:
ntraceback = call.excinfo.traceback .cut(excludepath=py._pydir) ntraceback = call.excinfo.traceback .cut(excludepath=py._pydir)
call.excinfo.traceback = ntraceback.filter() call.excinfo.traceback = ntraceback.filter()

View File

@ -18,8 +18,8 @@ following::
def test_function(): def test_function():
assert f() == 4 assert f() == 4
to assert that your object returns a certain value. If this to assert that your function returns a certain value. If this assertion fails
assertion fails you will see the value of ``x``:: you will see the value of ``x``::
$ py.test test_assert1.py $ py.test test_assert1.py
=========================== test session starts ============================ =========================== test session starts ============================
@ -39,22 +39,12 @@ assertion fails you will see the value of ``x``::
test_assert1.py:5: AssertionError test_assert1.py:5: AssertionError
========================= 1 failed in 0.02 seconds ========================= ========================= 1 failed in 0.02 seconds =========================
Reporting details about the failing assertion is achieved by re-evaluating py.test has support for showing the values of the most common subexpressions
the assert expression and recording the intermediate values. including calls, attributes, comparisons, and binary and unary operators. This
allows you to use the idiomatic python constructs without boilerplate code while
not losing introspection information.
Note: If evaluating the assert expression has side effects you may get a See :ref:`assert-details` for more information on assertion introspection.
warning that the intermediate values could not be determined safely. A
common example of this issue is an assertion which reads from a file::
assert f.read() != '...'
If this assertion fails then the re-evaluation will probably succeed!
This is because ``f.read()`` will return an empty string when it is
called the second time during the re-evaluation. However, it is
easy to rewrite the assertion and avoid any trouble::
content = f.read()
assert content != '...'
assertions about expected exceptions assertions about expected exceptions
@ -137,6 +127,66 @@ Special comparisons are done for a number of cases:
See the :ref:`reporting demo <tbreportdemo>` for many more examples. See the :ref:`reporting demo <tbreportdemo>` for many more examples.
.. _assert-details:
Assertion introspection details
-------------------------------
Reporting details about the failing assertion is achieved either by rewriting
assert statements before they are run or re-evaluating the assert expression and
recording the intermediate values. Which technique is used depends on the
location of the assert, py.test's configuration, and Python version being used
to run py.test.
By default, if the Python version is greater than or equal to 2.6, py.test
rewrites assert statements in test modules. Rewritten assert statements put
introspection information into the assertion failure message. Note py.test only
rewrites test modules directly discovered by its test collection process, so
asserts in supporting modules will not be rewritten.
.. note::
py.test rewrites test modules as it collects tests from them. It does this by
writing a new pyc file which Python loads when the test module is
imported. If the module has already been loaded (it is in sys.modules),
though, Python will not load the rewritten module. This means if a test
module imports another test module which has not already been rewritten, then
py.test will not be able to rewrite the second module.
If an assert statement has not been rewritten or the Python version is less than
2.6, py.test falls back on assert reinterpretation. In assert reinterpretation,
py.test walks the frame of the function containing the assert statement to
discover sub-expression results of the failing assert statement. You can force
py.test to always use assertion reinterpretation by passing the
``--assertmode=old`` option.
Assert reinterpretation has a caveat not present with assert rewriting: If
evaluating the assert expression has side effects you may get a warning that the
intermediate values could not be determined safely. A common example of this
issue is an assertion which reads from a file::
assert f.read() != '...'
If this assertion fails then the re-evaluation will probably succeed!
This is because ``f.read()`` will return an empty string when it is
called the second time during the re-evaluation. However, it is
easy to rewrite the assertion and avoid any trouble::
content = f.read()
assert content != '...'
All assert introspection can be turned off by passing ``--assertmode=off``.
.. versionadded:: 2.1
Add assert rewriting as an alternate introspection technique.
.. versionchanged:: 2.1
Introduce the ``--assertmode`` option. Deprecate ``--no-assert`` and
``--nomagic``.
.. ..
Defining your own comparison Defining your own comparison
---------------------------------------------- ----------------------------------------------

View File

@ -47,13 +47,17 @@ customizable testing frameworks for Python. However,
``py.test`` still uses many metaprogramming techniques and ``py.test`` still uses many metaprogramming techniques and
reading its source is thus likely not something for Python beginners. reading its source is thus likely not something for Python beginners.
A second "magic" issue arguably the assert statement re-intepreation: A second "magic" issue arguably the assert statement debugging feature. When
When an ``assert`` statement fails, py.test re-interprets the expression loading test modules py.test rewrites the source code of assert statements. When
to show intermediate values if a test fails. If your expression a rewritten assert statement fails, its error message has more information than
has side effects (better to avoid them anyway!) the intermediate values the original. py.test also has a second assert debugging technique. When an
may not be the same, obfuscating the initial error (this is also ``assert`` statement that was missed by the rewriter fails, py.test
explained at the command line if it happens). re-interprets the expression to show intermediate values if a test fails. This
``py.test --no-assert`` turns off assert re-interpretation. second technique suffers from caveat that the rewriting does not: If your
expression has side effects (better to avoid them anyway!) the intermediate
values may not be the same, confusing the reinterpreter and obfuscating the
initial error (this is also explained at the command line if it happens).
You can turn off all assertion debugging with ``py.test --assertmode=off``.
.. _`py namespaces`: index.html .. _`py namespaces`: index.html
.. _`py/__init__.py`: http://bitbucket.org/hpk42/py-trunk/src/trunk/py/__init__.py .. _`py/__init__.py`: http://bitbucket.org/hpk42/py-trunk/src/trunk/py/__init__.py

View File

@ -22,14 +22,14 @@ def main():
name='pytest', name='pytest',
description='py.test: simple powerful testing with Python', description='py.test: simple powerful testing with Python',
long_description = long_description, long_description = long_description,
version='2.1.0.dev1', version='2.1.0.dev2',
url='http://pytest.org', url='http://pytest.org',
license='MIT license', license='MIT license',
platforms=['unix', 'linux', 'osx', 'cygwin', 'win32'], platforms=['unix', 'linux', 'osx', 'cygwin', 'win32'],
author='holger krekel, Guido Wesdorp, Carl Friedrich Bolz, Armin Rigo, Maciej Fijalkowski & others', author='holger krekel, Guido Wesdorp, Carl Friedrich Bolz, Armin Rigo, Maciej Fijalkowski & others',
author_email='holger at merlinux.eu', author_email='holger at merlinux.eu',
entry_points= make_entry_points(), entry_points= make_entry_points(),
install_requires=['py>1.4.1'], install_requires=['py>1.4.3'],
classifiers=['Development Status :: 5 - Production/Stable', classifiers=['Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers', 'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License', 'License :: OSI Approved :: MIT License',

View File

@ -0,0 +1,327 @@
"PYTEST_DONT_REWRITE"
import pytest, py
from _pytest.assertion import util
def exvalue():
return py.std.sys.exc_info()[1]
def f():
return 2
def test_not_being_rewritten():
assert "@py_builtins" not in globals()
def test_assert():
try:
assert f() == 3
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith('assert 2 == 3\n')
def test_assert_with_explicit_message():
try:
assert f() == 3, "hello"
except AssertionError:
e = exvalue()
assert e.msg == 'hello'
def test_assert_within_finally():
class A:
def f():
pass
excinfo = py.test.raises(TypeError, """
try:
A().f()
finally:
i = 42
""")
s = excinfo.exconly()
assert s.find("takes no argument") != -1
#def g():
# A.f()
#excinfo = getexcinfo(TypeError, g)
#msg = getmsg(excinfo)
#assert msg.find("must be called with A") != -1
def test_assert_multiline_1():
try:
assert (f() ==
3)
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith('assert 2 == 3\n')
def test_assert_multiline_2():
try:
assert (f() == (4,
3)[-1])
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith('assert 2 ==')
def test_in():
try:
assert "hi" in [1, 2]
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith("assert 'hi' in")
def test_is():
try:
assert 1 is 2
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith("assert 1 is 2")
@py.test.mark.skipif("sys.version_info < (2,6)")
def test_attrib():
class Foo(object):
b = 1
i = Foo()
try:
assert i.b == 2
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith("assert 1 == 2")
@py.test.mark.skipif("sys.version_info < (2,6)")
def test_attrib_inst():
class Foo(object):
b = 1
try:
assert Foo().b == 2
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith("assert 1 == 2")
def test_len():
l = list(range(42))
try:
assert len(l) == 100
except AssertionError:
e = exvalue()
s = str(e)
assert s.startswith("assert 42 == 100")
assert "where 42 = len([" in s
def test_assert_non_string_message():
class A:
def __str__(self):
return "hello"
try:
assert 0 == 1, A()
except AssertionError:
e = exvalue()
assert e.msg == "hello"
def test_assert_keyword_arg():
def f(x=3):
return False
try:
assert f(x=5)
except AssertionError:
e = exvalue()
assert "x=5" in e.msg
# These tests should both fail, but should fail nicely...
class WeirdRepr:
def __repr__(self):
return '<WeirdRepr\nsecond line>'
def bug_test_assert_repr():
v = WeirdRepr()
try:
assert v == 1
except AssertionError:
e = exvalue()
assert e.msg.find('WeirdRepr') != -1
assert e.msg.find('second line') != -1
assert 0
def test_assert_non_string():
try:
assert 0, ['list']
except AssertionError:
e = exvalue()
assert e.msg.find("list") != -1
def test_assert_implicit_multiline():
try:
x = [1,2,3]
assert x != [1,
2, 3]
except AssertionError:
e = exvalue()
assert e.msg.find('assert [1, 2, 3] !=') != -1
def test_assert_with_brokenrepr_arg():
class BrokenRepr:
def __repr__(self): 0 / 0
e = AssertionError(BrokenRepr())
if e.msg.find("broken __repr__") == -1:
py.test.fail("broken __repr__ not handle correctly")
def test_multiple_statements_per_line():
try:
a = 1; assert a == 2
except AssertionError:
e = exvalue()
assert "assert 1 == 2" in e.msg
def test_power():
try:
assert 2**3 == 7
except AssertionError:
e = exvalue()
assert "assert (2 ** 3) == 7" in e.msg
class TestView:
def setup_class(cls):
cls.View = pytest.importorskip("_pytest.assertion.oldinterpret").View
def test_class_dispatch(self):
### Use a custom class hierarchy with existing instances
class Picklable(self.View):
pass
class Simple(Picklable):
__view__ = object
def pickle(self):
return repr(self.__obj__)
class Seq(Picklable):
__view__ = list, tuple, dict
def pickle(self):
return ';'.join(
[Picklable(item).pickle() for item in self.__obj__])
class Dict(Seq):
__view__ = dict
def pickle(self):
return Seq.pickle(self) + '!' + Seq(self.values()).pickle()
assert Picklable(123).pickle() == '123'
assert Picklable([1,[2,3],4]).pickle() == '1;2;3;4'
assert Picklable({1:2}).pickle() == '1!2'
def test_viewtype_class_hierarchy(self):
# Use a custom class hierarchy based on attributes of existing instances
class Operation:
"Existing class that I don't want to change."
def __init__(self, opname, *args):
self.opname = opname
self.args = args
existing = [Operation('+', 4, 5),
Operation('getitem', '', 'join'),
Operation('setattr', 'x', 'y', 3),
Operation('-', 12, 1)]
class PyOp(self.View):
def __viewkey__(self):
return self.opname
def generate(self):
return '%s(%s)' % (self.opname, ', '.join(map(repr, self.args)))
class PyBinaryOp(PyOp):
__view__ = ('+', '-', '*', '/')
def generate(self):
return '%s %s %s' % (self.args[0], self.opname, self.args[1])
codelines = [PyOp(op).generate() for op in existing]
assert codelines == ["4 + 5", "getitem('', 'join')",
"setattr('x', 'y', 3)", "12 - 1"]
@py.test.mark.skipif("sys.version_info < (2,6)")
def test_assert_customizable_reprcompare(monkeypatch):
monkeypatch.setattr(util, '_reprcompare', lambda *args: 'hello')
try:
assert 3 == 4
except AssertionError:
e = exvalue()
s = str(e)
assert "hello" in s
def test_assert_long_source_1():
try:
assert len == [
(None, ['somet text', 'more text']),
]
except AssertionError:
e = exvalue()
s = str(e)
assert 're-run' not in s
assert 'somet text' in s
def test_assert_long_source_2():
try:
assert(len == [
(None, ['somet text', 'more text']),
])
except AssertionError:
e = exvalue()
s = str(e)
assert 're-run' not in s
assert 'somet text' in s
def test_assert_raise_alias(testdir):
testdir.makepyfile("""
"PYTEST_DONT_REWRITE"
import sys
EX = AssertionError
def test_hello():
raise EX("hello"
"multi"
"line")
""")
result = testdir.runpytest()
result.stdout.fnmatch_lines([
"*def test_hello*",
"*raise EX*",
"*1 failed*",
])
@pytest.mark.skipif("sys.version_info < (2,5)")
def test_assert_raise_subclass():
class SomeEx(AssertionError):
def __init__(self, *args):
super(SomeEx, self).__init__()
try:
raise SomeEx("hello")
except AssertionError:
s = str(exvalue())
assert 're-run' not in s
assert 'could not determine' in s
def test_assert_raises_in_nonzero_of_object_pytest_issue10():
class A(object):
def __nonzero__(self):
raise ValueError(42)
def __lt__(self, other):
return A()
def __repr__(self):
return "<MY42 object>"
def myany(x):
return True
try:
assert not(myany(A() < 0))
except AssertionError:
e = exvalue()
s = str(e)
assert "<MY42 object> < 0" in s

View File

@ -2,11 +2,12 @@ import sys
import py, pytest import py, pytest
import _pytest.assertion as plugin import _pytest.assertion as plugin
from _pytest.assertion import reinterpret, util
needsnewassert = pytest.mark.skipif("sys.version_info < (2,6)") needsnewassert = pytest.mark.skipif("sys.version_info < (2,6)")
def interpret(expr): def interpret(expr):
return py.code._reinterpret(expr, py.code.Frame(sys._getframe(1))) return reinterpret.reinterpret(expr, py.code.Frame(sys._getframe(1)))
class TestBinReprIntegration: class TestBinReprIntegration:
pytestmark = needsnewassert pytestmark = needsnewassert
@ -25,7 +26,7 @@ class TestBinReprIntegration:
self.right = right self.right = right
mockhook = MockHook() mockhook = MockHook()
monkeypatch = request.getfuncargvalue("monkeypatch") monkeypatch = request.getfuncargvalue("monkeypatch")
monkeypatch.setattr(py.code, '_reprcompare', mockhook) monkeypatch.setattr(util, '_reprcompare', mockhook)
return mockhook return mockhook
def test_pytest_assertrepr_compare_called(self, hook): def test_pytest_assertrepr_compare_called(self, hook):
@ -40,13 +41,13 @@ class TestBinReprIntegration:
assert hook.right == [0, 2] assert hook.right == [0, 2]
def test_configure_unconfigure(self, testdir, hook): def test_configure_unconfigure(self, testdir, hook):
assert hook == py.code._reprcompare assert hook == util._reprcompare
config = testdir.parseconfig() config = testdir.parseconfig()
plugin.pytest_configure(config) plugin.pytest_configure(config)
assert hook != py.code._reprcompare assert hook != util._reprcompare
from _pytest.config import pytest_unconfigure from _pytest.config import pytest_unconfigure
pytest_unconfigure(config) pytest_unconfigure(config)
assert hook == py.code._reprcompare assert hook == util._reprcompare
def callequal(left, right): def callequal(left, right):
return plugin.pytest_assertrepr_compare('==', left, right) return plugin.pytest_assertrepr_compare('==', left, right)
@ -119,6 +120,10 @@ class TestAssert_reprcompare:
expl = ' '.join(callequal('foo', 'bar')) expl = ' '.join(callequal('foo', 'bar'))
assert 'raised in repr()' not in expl assert 'raised in repr()' not in expl
@pytest.mark.skipif("config._assertstate.mode != 'on'")
def test_rewritten():
assert "@py_builtins" in globals()
def test_reprcompare_notin(): def test_reprcompare_notin():
detail = plugin.pytest_assertrepr_compare('not in', 'foo', 'aaafoobbb')[1:] detail = plugin.pytest_assertrepr_compare('not in', 'foo', 'aaafoobbb')[1:]
assert detail == ["'foo' is contained here:", ' aaafoobbb', '? +++'] assert detail == ["'foo' is contained here:", ' aaafoobbb', '? +++']
@ -159,7 +164,7 @@ def test_sequence_comparison_uses_repr(testdir):
]) ])
def test_functional(testdir): def test_assertion_options(testdir):
testdir.makepyfile(""" testdir.makepyfile("""
def test_hello(): def test_hello():
x = 3 x = 3
@ -167,8 +172,30 @@ def test_functional(testdir):
""") """)
result = testdir.runpytest() result = testdir.runpytest()
assert "3 == 4" in result.stdout.str() assert "3 == 4" in result.stdout.str()
result = testdir.runpytest("--no-assert") off_options = (("--no-assert",),
("--nomagic",),
("--no-assert", "--nomagic"),
("--assertmode=off",),
("--assertmode=off", "--no-assert"),
("--assertmode=off", "--nomagic"),
("--assertmode=off," "--no-assert", "--nomagic"))
for opt in off_options:
result = testdir.runpytest(*opt)
assert "3 == 4" not in result.stdout.str() assert "3 == 4" not in result.stdout.str()
for mode in "on", "old":
for other_opt in off_options[:3]:
opt = ("--assertmode=" + mode,) + other_opt
result = testdir.runpytest(*opt)
assert result.ret == 3
assert "assertion options conflict" in result.stderr.str()
def test_old_assert_mode(testdir):
testdir.makepyfile("""
def test_in_old_mode():
assert "@py_builtins" not in globals()
""")
result = testdir.runpytest("--assertmode=old")
assert result.ret == 0
def test_triple_quoted_string_issue113(testdir): def test_triple_quoted_string_issue113(testdir):
testdir.makepyfile(""" testdir.makepyfile("""
@ -221,3 +248,10 @@ def test_warn_missing(testdir):
result.stderr.fnmatch_lines([ result.stderr.fnmatch_lines([
"*WARNING*assertion*", "*WARNING*assertion*",
]) ])
def test_load_fake_pyc(testdir):
path = testdir.makepyfile("x = 'hello'")
co = compile("x = 'bye'", str(path), "exec")
plugin._write_pyc(co, path)
mod = path.pyimport()
assert mod.x == "bye"

View File

@ -0,0 +1,256 @@
import sys
import py
import pytest
ast = pytest.importorskip("ast")
from _pytest.assertion import util
from _pytest.assertion.rewrite import rewrite_asserts
def setup_module(mod):
mod._old_reprcompare = util._reprcompare
py.code._reprcompare = None
def teardown_module(mod):
util._reprcompare = mod._old_reprcompare
del mod._old_reprcompare
def rewrite(src):
tree = ast.parse(src)
rewrite_asserts(tree)
return tree
def getmsg(f, extra_ns=None, must_pass=False):
"""Rewrite the assertions in f, run it, and get the failure message."""
src = '\n'.join(py.code.Code(f).source().lines)
mod = rewrite(src)
code = compile(mod, "<test>", "exec")
ns = {}
if extra_ns is not None:
ns.update(extra_ns)
py.builtin.exec_(code, ns)
func = ns[f.__name__]
try:
func()
except AssertionError:
if must_pass:
pytest.fail("shouldn't have raised")
s = str(sys.exc_info()[1])
if not s.startswith("assert"):
return "AssertionError: " + s
return s
else:
if not must_pass:
pytest.fail("function didn't raise at all")
class TestAssertionRewrite:
def test_place_initial_imports(self):
s = """'Doc string'\nother = stuff"""
m = rewrite(s)
assert isinstance(m.body[0], ast.Expr)
assert isinstance(m.body[0].value, ast.Str)
for imp in m.body[1:3]:
assert isinstance(imp, ast.Import)
assert imp.lineno == 2
assert imp.col_offset == 0
assert isinstance(m.body[3], ast.Assign)
s = """from __future__ import with_statement\nother_stuff"""
m = rewrite(s)
assert isinstance(m.body[0], ast.ImportFrom)
for imp in m.body[1:3]:
assert isinstance(imp, ast.Import)
assert imp.lineno == 2
assert imp.col_offset == 0
assert isinstance(m.body[3], ast.Expr)
s = """'doc string'\nfrom __future__ import with_statement\nother"""
m = rewrite(s)
assert isinstance(m.body[0], ast.Expr)
assert isinstance(m.body[0].value, ast.Str)
assert isinstance(m.body[1], ast.ImportFrom)
for imp in m.body[2:4]:
assert isinstance(imp, ast.Import)
assert imp.lineno == 3
assert imp.col_offset == 0
assert isinstance(m.body[4], ast.Expr)
def test_dont_rewrite(self):
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
m = rewrite(s)
assert len(m.body) == 2
assert isinstance(m.body[0].value, ast.Str)
assert isinstance(m.body[1], ast.Assert)
assert m.body[1].msg is None
def test_name(self):
def f():
assert False
assert getmsg(f) == "assert False"
def f():
f = False
assert f
assert getmsg(f) == "assert False"
def f():
assert a_global
assert getmsg(f, {"a_global" : False}) == "assert a_global"
def test_assert_already_has_message(self):
def f():
assert False, "something bad!"
assert getmsg(f) == "AssertionError: something bad!"
def test_boolop(self):
def f():
f = g = False
assert f and g
assert getmsg(f) == "assert (False)"
def f():
f = True
g = False
assert f and g
assert getmsg(f) == "assert (True and False)"
def f():
f = False
g = True
assert f and g
assert getmsg(f) == "assert (False)"
def f():
f = g = False
assert f or g
assert getmsg(f) == "assert (False or False)"
def f():
f = True
g = False
assert f or g
getmsg(f, must_pass=True)
def test_short_circut_evaluation(self):
pytest.xfail("complicated fix; I'm not sure if it's important")
def f():
assert True or explode
getmsg(f, must_pass=True)
def test_unary_op(self):
def f():
x = True
assert not x
assert getmsg(f) == "assert not True"
def f():
x = 0
assert ~x + 1
assert getmsg(f) == "assert (~0 + 1)"
def f():
x = 3
assert -x + x
assert getmsg(f) == "assert (-3 + 3)"
def f():
x = 0
assert +x + x
assert getmsg(f) == "assert (+0 + 0)"
def test_binary_op(self):
def f():
x = 1
y = -1
assert x + y
assert getmsg(f) == "assert (1 + -1)"
def test_call(self):
def g(a=42, *args, **kwargs):
return False
ns = {"g" : g}
def f():
assert g()
assert getmsg(f, ns) == """assert False
+ where False = g()"""
def f():
assert g(1)
assert getmsg(f, ns) == """assert False
+ where False = g(1)"""
def f():
assert g(1, 2)
assert getmsg(f, ns) == """assert False
+ where False = g(1, 2)"""
def f():
assert g(1, g=42)
assert getmsg(f, ns) == """assert False
+ where False = g(1, g=42)"""
def f():
assert g(1, 3, g=23)
assert getmsg(f, ns) == """assert False
+ where False = g(1, 3, g=23)"""
def test_attribute(self):
class X(object):
g = 3
ns = {"X" : X, "x" : X()}
def f():
assert not x.g
assert getmsg(f, ns) == """assert not 3
+ where 3 = x.g"""
def f():
x.a = False
assert x.a
assert getmsg(f, ns) == """assert False
+ where False = x.a"""
def test_comparisons(self):
def f():
a, b = range(2)
assert b < a
assert getmsg(f) == """assert 1 < 0"""
def f():
a, b, c = range(3)
assert a > b > c
assert getmsg(f) == """assert 0 > 1"""
def f():
a, b, c = range(3)
assert a < b > c
assert getmsg(f) == """assert 1 > 2"""
def f():
a, b, c = range(3)
assert a < b <= c
getmsg(f, must_pass=True)
def f():
a, b, c = range(3)
assert a < b
assert b < c
getmsg(f, must_pass=True)
def test_len(self):
def f():
l = list(range(10))
assert len(l) == 11
assert getmsg(f).startswith("""assert 10 == 11
+ where 10 = len([""")
def test_custom_reprcompare(self, monkeypatch):
def my_reprcompare(op, left, right):
return "42"
monkeypatch.setattr(util, "_reprcompare", my_reprcompare)
def f():
assert 42 < 3
assert getmsg(f) == "assert 42"
def my_reprcompare(op, left, right):
return "%s %s %s" % (left, op, right)
monkeypatch.setattr(util, "_reprcompare", my_reprcompare)
def f():
assert 1 < 3 < 5 <= 4 < 7
assert getmsg(f) == "assert 5 <= 4"
def test_assert_raising_nonzero_in_comparison(self):
def f():
class A(object):
def __nonzero__(self):
raise ValueError(42)
def __lt__(self, other):
return A()
def __repr__(self):
return "<MY42 object>"
def myany(x):
return False
assert myany(A() < 0)
assert "<MY42 object> < 0" in getmsg(f)

View File

@ -313,7 +313,7 @@ class TestSession:
def test_collect_topdir(self, testdir): def test_collect_topdir(self, testdir):
p = testdir.makepyfile("def test_func(): pass") p = testdir.makepyfile("def test_func(): pass")
id = "::".join([p.basename, "test_func"]) id = "::".join([p.basename, "test_func"])
config = testdir.parseconfig(id) config = testdir.parseconfigure(id)
topdir = testdir.tmpdir topdir = testdir.tmpdir
rcol = Session(config) rcol = Session(config)
assert topdir == rcol.fspath assert topdir == rcol.fspath
@ -328,7 +328,7 @@ class TestSession:
def test_collect_protocol_single_function(self, testdir): def test_collect_protocol_single_function(self, testdir):
p = testdir.makepyfile("def test_func(): pass") p = testdir.makepyfile("def test_func(): pass")
id = "::".join([p.basename, "test_func"]) id = "::".join([p.basename, "test_func"])
config = testdir.parseconfig(id) config = testdir.parseconfigure(id)
topdir = testdir.tmpdir topdir = testdir.tmpdir
rcol = Session(config) rcol = Session(config)
assert topdir == rcol.fspath assert topdir == rcol.fspath
@ -363,7 +363,7 @@ class TestSession:
p.basename + "::TestClass::()", p.basename + "::TestClass::()",
normid, normid,
]: ]:
config = testdir.parseconfig(id) config = testdir.parseconfigure(id)
rcol = Session(config=config) rcol = Session(config=config)
rcol.perform_collect() rcol.perform_collect()
items = rcol.items items = rcol.items
@ -388,7 +388,7 @@ class TestSession:
""" % p.basename) """ % p.basename)
id = p.basename id = p.basename
config = testdir.parseconfig(id) config = testdir.parseconfigure(id)
rcol = Session(config) rcol = Session(config)
hookrec = testdir.getreportrecorder(config) hookrec = testdir.getreportrecorder(config)
rcol.perform_collect() rcol.perform_collect()
@ -413,7 +413,7 @@ class TestSession:
aaa = testdir.mkpydir("aaa") aaa = testdir.mkpydir("aaa")
test_aaa = aaa.join("test_aaa.py") test_aaa = aaa.join("test_aaa.py")
p.move(test_aaa) p.move(test_aaa)
config = testdir.parseconfig() config = testdir.parseconfigure()
rcol = Session(config) rcol = Session(config)
hookrec = testdir.getreportrecorder(config) hookrec = testdir.getreportrecorder(config)
rcol.perform_collect() rcol.perform_collect()
@ -437,7 +437,7 @@ class TestSession:
p.move(test_bbb) p.move(test_bbb)
id = "." id = "."
config = testdir.parseconfig(id) config = testdir.parseconfigure(id)
rcol = Session(config) rcol = Session(config)
hookrec = testdir.getreportrecorder(config) hookrec = testdir.getreportrecorder(config)
rcol.perform_collect() rcol.perform_collect()
@ -455,7 +455,7 @@ class TestSession:
def test_serialization_byid(self, testdir): def test_serialization_byid(self, testdir):
p = testdir.makepyfile("def test_func(): pass") p = testdir.makepyfile("def test_func(): pass")
config = testdir.parseconfig() config = testdir.parseconfigure()
rcol = Session(config) rcol = Session(config)
rcol.perform_collect() rcol.perform_collect()
items = rcol.items items = rcol.items
@ -476,7 +476,7 @@ class TestSession:
pass pass
""") """)
arg = p.basename + ("::TestClass::test_method") arg = p.basename + ("::TestClass::test_method")
config = testdir.parseconfig(arg) config = testdir.parseconfigure(arg)
rcol = Session(config) rcol = Session(config)
rcol.perform_collect() rcol.perform_collect()
items = rcol.items items = rcol.items

View File

@ -705,11 +705,11 @@ class TestRequest:
def test_func(something): pass def test_func(something): pass
""") """)
req = funcargs.FuncargRequest(item) req = funcargs.FuncargRequest(item)
req.config._setupstate.prepare(item) # XXX req._pyfuncitem.session._setupstate.prepare(item) # XXX
req._fillfuncargs() req._fillfuncargs()
# successively check finalization calls # successively check finalization calls
teardownlist = item.getparent(pytest.Module).obj.teardownlist teardownlist = item.getparent(pytest.Module).obj.teardownlist
ss = item.config._setupstate ss = item.session._setupstate
assert not teardownlist assert not teardownlist
ss.teardown_exact(item) ss.teardown_exact(item)
print(ss.stack) print(ss.stack)
@ -834,11 +834,11 @@ class TestRequestCachedSetup:
ret1 = req1.cached_setup(setup, teardown, scope="function") ret1 = req1.cached_setup(setup, teardown, scope="function")
assert l == ['setup'] assert l == ['setup']
# artificial call of finalizer # artificial call of finalizer
req1.config._setupstate._callfinalizers(item1) req1._pyfuncitem.session._setupstate._callfinalizers(item1)
assert l == ["setup", "teardown"] assert l == ["setup", "teardown"]
ret2 = req1.cached_setup(setup, teardown, scope="function") ret2 = req1.cached_setup(setup, teardown, scope="function")
assert l == ["setup", "teardown", "setup"] assert l == ["setup", "teardown", "setup"]
req1.config._setupstate._callfinalizers(item1) req1._pyfuncitem.session._setupstate._callfinalizers(item1)
assert l == ["setup", "teardown", "setup", "teardown"] assert l == ["setup", "teardown", "setup", "teardown"]
def test_request_cached_setup_two_args(self, testdir): def test_request_cached_setup_two_args(self, testdir):