From c742e47de03965455d1477eb2fc0cd88594ac3ba Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Wed, 18 May 2011 15:31:10 -0500 Subject: [PATCH 01/59] new assertion debugger which rewrites asserts before they are run --- _pytest/assertrewrite.py | 295 ++++++++++++++++++++++++++++++++++ testing/test_assertrewrite.py | 195 ++++++++++++++++++++++ 2 files changed, 490 insertions(+) create mode 100644 _pytest/assertrewrite.py create mode 100644 testing/test_assertrewrite.py diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py new file mode 100644 index 000000000..4aca62d5c --- /dev/null +++ b/_pytest/assertrewrite.py @@ -0,0 +1,295 @@ +"""Rewrite assertion AST to produce nice error messages""" + +import ast +import collections +import itertools + +import py + + +def rewrite_asserts(mod): + """Rewrite the assert statements in mod.""" + AssertionRewriter().run(mod) + + +_saferepr = py.io.saferepr +_format_explanation = py.code._format_explanation + +def _format_boolop(operands, explanations, is_or): + show_explanations = [] + for operand, expl in zip(operands, explanations): + show_explanations.append(expl) + if operand == is_or: + break + return "(" + (is_or and " or " or " and ").join(show_explanations) + ")" + +def _call_reprcompare(ops, results, expls, each_obj): + for i, res, expl in zip(range(len(ops)), results, expls): + if not res: + break + if py.code._reprcompare is not None: + custom = py.code._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) + if custom is not None: + return custom + return expl + + +unary_map = { + ast.Not : "not %s", + ast.Invert : "~%s", + ast.USub : "-%s", + ast.UAdd : "+%s" +} + +binop_map = { + ast.BitOr : "|", + ast.BitXor : "^", + ast.BitAnd : "&", + ast.LShift : "<<", + ast.RShift : ">>", + ast.Add : "+", + ast.Sub : "-", + ast.Mult : "*", + ast.Div : "/", + ast.FloorDiv : "//", + ast.Mod : "%", + ast.Eq : "==", + ast.NotEq : "!=", + ast.Lt : "<", + ast.LtE : "<=", + ast.Gt : ">", + ast.GtE : ">=", + ast.Pow : "**", + ast.Is : "is", + ast.IsNot : "is not", + ast.In : "in", + ast.NotIn : "not in" +} + + +class AssertionRewriter(ast.NodeVisitor): + + def run(self, mod): + """Find all assert statements in *mod* and rewrite them.""" + if not mod.body: + # Nothing to do. + return + # Insert some special imports at top but after any docstrings. + aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), + ast.alias("py", "@pylib"), + ast.alias("_pytest.assertrewrite", "@pytest_ar")] + imports = [ast.Import([alias], lineno=0, col_offset=0) + for alias in aliases] + pos = 0 + if isinstance(mod.body[0], ast.Str): + pos = 1 + mod.body[pos:pos] = imports + # Collect asserts. + asserts = [] + nodes = collections.deque([mod]) + while nodes: + node = nodes.popleft() + for name, field in ast.iter_fields(node): + if isinstance(field, list): + for i, child in enumerate(field): + if isinstance(child, ast.Assert): + asserts.append((field, i, child)) + elif isinstance(child, ast.AST): + nodes.append(child) + elif (isinstance(field, ast.AST) and + # Don't recurse into expressions as they can't contain + # asserts. + not isinstance(field, ast.expr)): + nodes.append(field) + # Transform asserts. + for parent, pos, assert_ in asserts: + parent[pos:pos + 1] = self.visit(assert_) + + def assign(self, expr): + """Give *expr* a name.""" + # Use a character invalid in python identifiers to avoid clashing. + name = "@py_assert" + str(next(self.variable_counter)) + self.variables.add(name) + self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) + return ast.Name(name, ast.Load()) + + def display(self, expr): + """Call py.io.saferepr on the expression.""" + return self.helper("saferepr", expr) + + def helper(self, name, *args): + """Call a helper in this module.""" + py_name = ast.Name("@pytest_ar", ast.Load()) + attr = ast.Attribute(py_name, "_" + name, ast.Load()) + return ast.Call(attr, list(args), [], None, None) + + def builtin(self, name): + """Return the builtin called *name*.""" + builtin_name = ast.Name("@py_builtins", ast.Load()) + return ast.Attribute(builtin_name, name, ast.Load()) + + def explanation_param(self, expr): + specifier = "py" + str(next(self.variable_counter)) + self.explanation_specifiers[specifier] = expr + return "%(" + specifier + ")s" + + def push_format_context(self): + self.explanation_specifiers = {} + self.stack.append(self.explanation_specifiers) + + def pop_format_context(self, expl_expr): + current = self.stack.pop() + if self.stack: + self.explanation_specifiers = self.stack[-1] + keys = [ast.Str(key) for key in current.keys()] + format_dict = ast.Dict(keys, current.values()) + form = ast.BinOp(expl_expr, ast.Mod(), format_dict) + name = "@py_format" + str(next(self.variable_counter)) + self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form)) + return ast.Name(name, ast.Load()) + + def generic_visit(self, node): + """Handle expressions we don't have custom code for.""" + assert isinstance(node, ast.expr) + res = self.assign(node) + return res, self.explanation_param(self.display(res)) + + def visit_Assert(self, assert_): + if assert_.msg: + # There's already a message. Don't mess with it. + return [assert_] + self.statements = [] + self.variables = set() + self.variable_counter = itertools.count() + self.stack = [] + self.on_failure = [] + self.push_format_context() + # Rewrite assert into a bunch of statements. + top_condition, explanation = self.visit(assert_.test) + # Create failure message. + body = self.on_failure + negation = ast.UnaryOp(ast.Not(), top_condition) + self.statements.append(ast.If(negation, body, [])) + explanation = "assert " + explanation + template = ast.Str(explanation) + msg = self.pop_format_context(template) + fmt = self.helper("format_explanation", msg) + body.append(ast.Assert(top_condition, fmt)) + # Delete temporary variables. + names = [ast.Name(name, ast.Del()) for name in self.variables] + if names: + delete = ast.Delete(names) + self.statements.append(delete) + # Fix line numbers. + for stmt in self.statements: + stmt.lineno = assert_.lineno + stmt.col_offset = assert_.col_offset + ast.fix_missing_locations(stmt) + return self.statements + + def visit_Name(self, name): + # Check if the name is local or not. + locs = ast.Call(self.builtin("locals"), [], [], None, None) + globs = ast.Call(self.builtin("globals"), [], [], None, None) + ops = [ast.In(), ast.IsNot()] + test = ast.Compare(ast.Str(name.id), ops, [locs, globs]) + expr = ast.IfExp(test, self.display(name), ast.Str(name.id)) + return name, self.explanation_param(expr) + + def visit_BoolOp(self, boolop): + operands = [] + explanations = [] + self.push_format_context() + for operand in boolop.values: + res, explanation = self.visit(operand) + operands.append(res) + explanations.append(explanation) + expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load()) + is_or = ast.Num(isinstance(boolop.op, ast.Or)) + expl_template = self.helper("format_boolop", + ast.Tuple(operands, ast.Load()), expls, + is_or) + expl = self.pop_format_context(expl_template) + res = self.assign(ast.BoolOp(boolop.op, operands)) + return res, self.explanation_param(expl) + + def visit_UnaryOp(self, unary): + pattern = unary_map[unary.op.__class__] + operand_res, operand_expl = self.visit(unary.operand) + res = self.assign(ast.UnaryOp(unary.op, operand_res)) + return res, pattern % (operand_expl,) + + def visit_BinOp(self, binop): + symbol = binop_map[binop.op.__class__] + left_expr, left_expl = self.visit(binop.left) + right_expr, right_expl = self.visit(binop.right) + explanation = "(%s %s %s)" % (left_expl, symbol, right_expl) + res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) + return res, explanation + + def visit_Call(self, call): + new_func, func_expl = self.visit(call.func) + arg_expls = [] + new_args = [] + new_kwargs = [] + new_star = new_kwarg = None + for arg in call.args: + res, expl = self.visit(arg) + new_args.append(res) + arg_expls.append(expl) + for keyword in call.keywords: + res, expl = self.visit(keyword.value) + new_kwargs.append(ast.keyword(keyword.arg, res)) + arg_expls.append(keyword.arg + "=" + expl) + if call.starargs: + new_star, expl = self.visit(call.starargs) + arg_expls.append("*" + expl) + if call.kwargs: + new_kwarg, expl = self.visit(call.kwarg) + arg_expls.append("**" + expl) + expl = "%s(%s)" % (func_expl, ', '.join(arg_expls)) + new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg) + res = self.assign(new_call) + res_expl = self.explanation_param(self.display(res)) + outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl) + return res, outer_expl + + def visit_Attribute(self, attr): + if not isinstance(attr.ctx, ast.Load): + return self.generic_visit(attr) + value, value_expl = self.visit(attr.value) + res = self.assign(ast.Attribute(value, attr.attr, ast.Load())) + res_expl = self.explanation_param(self.display(res)) + pat = "%s\n{%s = %s.%s\n}" + expl = pat % (res_expl, res_expl, value_expl, attr.attr) + return res, expl + + def visit_Compare(self, comp): + self.push_format_context() + left_res, left_expl = self.visit(comp.left) + res_variables = ["@py_assert" + str(next(self.variable_counter)) + for i in range(len(comp.ops))] + load_names = [ast.Name(v, ast.Load()) for v in res_variables] + store_names = [ast.Name(v, ast.Store()) for v in res_variables] + it = zip(range(len(comp.ops)), comp.ops, comp.comparators) + expls = [] + syms = [] + results = [left_res] + for i, op, next_operand in it: + next_res, next_expl = self.visit(next_operand) + results.append(next_res) + sym = binop_map[op.__class__] + syms.append(ast.Str(sym)) + expl = "%s %s %s" % (left_expl, sym, next_expl) + expls.append(ast.Str(expl)) + res_expr = ast.Compare(left_res, [op], [next_res]) + self.statements.append(ast.Assign([store_names[i]], res_expr)) + left_res, left_expl = next_res, next_expl + # Use py.code._reprcompare if that's available. + expl_call = self.helper("call_reprcompare", ast.Tuple(syms, ast.Load()), + ast.Tuple(load_names, ast.Load()), + ast.Tuple(expls, ast.Load()), + ast.Tuple(results, ast.Load())) + args = [ast.List(load_names, ast.Load())] + res = ast.Call(self.builtin("all"), args, [], None, None) + return res, self.explanation_param(self.pop_format_context(expl_call)) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py new file mode 100644 index 000000000..052a1cd27 --- /dev/null +++ b/testing/test_assertrewrite.py @@ -0,0 +1,195 @@ +import sys +import py +import pytest + +ast = pytest.importorskip("ast") + +from _pytest.assertrewrite import rewrite_asserts + + +def setup_module(mod): + mod._old_reprcompare = py.code._reprcompare + py.code._reprcompare = None + +def teardown_module(mod): + py.code._reprcompare = mod._old_reprcompare + del mod._old_reprcompare + + +def getmsg(f, extra_ns=None, must_pass=False): + """Rewrite the assertions in f, run it, and get the failure message.""" + src = '\n'.join(py.code.Code(f).source().lines) + mod = ast.parse(src) + rewrite_asserts(mod) + code = compile(mod, "", "exec") + ns = {} + if extra_ns is not None: + ns.update(extra_ns) + exec code in ns + func = ns[f.__name__] + try: + func() + except AssertionError: + if must_pass: + pytest.fail("shouldn't have raised") + s = str(sys.exc_info()[1]) + if not s.startswith("assert"): + return "AssertionError: " + s + return s + else: + if not must_pass: + pytest.fail("function didn't raise at all") + + +class TestAssertionRewrite: + + def test_name(self): + def f(): + assert False + assert getmsg(f) == "assert False" + def f(): + f = False + assert f + assert getmsg(f) == "assert False" + def f(): + assert a_global + assert getmsg(f, {"a_global" : False}) == "assert a_global" + + def test_assert_already_has_message(self): + def f(): + assert False, "something bad!" + assert getmsg(f) == "AssertionError: something bad!" + + def test_boolop(self): + def f(): + f = g = False + assert f and g + assert getmsg(f) == "assert (False)" + def f(): + f = True + g = False + assert f and g + assert getmsg(f) == "assert (True and False)" + def f(): + f = False + g = True + assert f and g + assert getmsg(f) == "assert (False)" + def f(): + f = g = False + assert f or g + assert getmsg(f) == "assert (False or False)" + def f(): + f = True + g = False + assert f or g + getmsg(f, must_pass=True) + + def test_short_circut_evaluation(self): + pytest.xfail("complicated fix; I'm not sure if it's important") + def f(): + assert True or explode + getmsg(f, must_pass=True) + + def test_unary_op(self): + def f(): + x = True + assert not x + assert getmsg(f) == "assert not True" + def f(): + x = 0 + assert ~x + 1 + assert getmsg(f) == "assert (~0 + 1)" + def f(): + x = 3 + assert -x + x + assert getmsg(f) == "assert (-3 + 3)" + def f(): + x = 0 + assert +x + x + assert getmsg(f) == "assert (+0 + 0)" + + def test_binary_op(self): + def f(): + x = 1 + y = -1 + assert x + y + assert getmsg(f) == "assert (1 + -1)" + + def test_call(self): + def g(a=42, *args, **kwargs): + return False + ns = {"g" : g} + def f(): + assert g() + assert getmsg(f, ns) == """assert False + + where False = g()""" + def f(): + assert g(1) + assert getmsg(f, ns) == """assert False + + where False = g(1)""" + def f(): + assert g(1, 2) + assert getmsg(f, ns) == """assert False + + where False = g(1, 2)""" + def f(): + assert g(1, g=42) + assert getmsg(f, ns) == """assert False + + where False = g(1, g=42)""" + def f(): + assert g(1, 3, g=23) + assert getmsg(f, ns) == """assert False + + where False = g(1, 3, g=23)""" + + def test_attribute(self): + class X(object): + g = 3 + ns = {"X" : X, "x" : X()} + def f(): + assert not x.g + assert getmsg(f, ns) == """assert not 3 + + where 3 = x.g""" + def f(): + x.a = False + assert x.a + assert getmsg(f, ns) == """assert False + + where False = x.a""" + + def test_comparisons(self): + def f(): + a, b = range(2) + assert b < a + assert getmsg(f) == """assert 1 < 0""" + def f(): + a, b, c = range(3) + assert a > b > c + assert getmsg(f) == """assert 0 > 1""" + def f(): + a, b, c = range(3) + assert a < b > c + assert getmsg(f) == """assert 1 > 2""" + def f(): + a, b, c = range(3) + assert a < b <= c + getmsg(f, must_pass=True) + + def test_len(self): + def f(): + l = range(10) + assert len(l) == 11 + assert getmsg(f).startswith("""assert 10 == 11 + + where 10 = len([""") + + def test_custom_reprcompare(self, monkeypatch): + def my_reprcompare(op, left, right): + return "42" + monkeypatch.setattr(py.code, "_reprcompare", my_reprcompare) + def f(): + assert 42 < 3 + assert getmsg(f) == "assert 42" + def my_reprcompare(op, left, right): + return "%s %s %s" % (left, op, right) + monkeypatch.setattr(py.code, "_reprcompare", my_reprcompare) + def f(): + assert 1 < 3 < 5 <= 4 < 7 + assert getmsg(f) == "assert 5 <= 4" From 9e6dfaefd9ceba878538e3f5c785e0e39c0e17e9 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 19 May 2011 16:53:13 -0500 Subject: [PATCH 02/59] place assertion imports after __future__ statements and docstrings --- _pytest/assertrewrite.py | 14 +++++++++++--- testing/test_assertrewrite.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index 4aca62d5c..14ddc93ce 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -74,15 +74,23 @@ class AssertionRewriter(ast.NodeVisitor): if not mod.body: # Nothing to do. return - # Insert some special imports at top but after any docstrings. + # Insert some special imports at top but after any docstrings and + # __future__ imports. aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), ast.alias("py", "@pylib"), ast.alias("_pytest.assertrewrite", "@pytest_ar")] imports = [ast.Import([alias], lineno=0, col_offset=0) for alias in aliases] + expect_docstring = True pos = 0 - if isinstance(mod.body[0], ast.Str): - pos = 1 + for item in mod.body: + if (expect_docstring and isinstance(item, ast.Expr) and + isinstance(item.value, ast.Str)): + expect_docstring = False + elif (not isinstance(item, ast.ImportFrom) or item.level > 0 and + item.identifier != "__future__"): + break + pos += 1 mod.body[pos:pos] = imports # Collect asserts. asserts = [] diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 052a1cd27..64274fb07 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -16,11 +16,15 @@ def teardown_module(mod): del mod._old_reprcompare +def rewrite(src): + tree = ast.parse(src) + rewrite_asserts(tree) + return tree + def getmsg(f, extra_ns=None, must_pass=False): """Rewrite the assertions in f, run it, and get the failure message.""" src = '\n'.join(py.code.Code(f).source().lines) - mod = ast.parse(src) - rewrite_asserts(mod) + mod = rewrite(src) code = compile(mod, "", "exec") ns = {} if extra_ns is not None: @@ -43,6 +47,26 @@ def getmsg(f, extra_ns=None, must_pass=False): class TestAssertionRewrite: + def test_place_initial_imports(self): + s = """'Doc string'""" + m = rewrite(s) + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + for imp in m.body[1:]: + assert isinstance(imp, ast.Import) + s = """from __future__ import with_statement""" + m = rewrite(s) + assert isinstance(m.body[0], ast.ImportFrom) + for imp in m.body[1:]: + assert isinstance(imp, ast.Import) + s = """'doc string'\nfrom __future__ import with_statement""" + m = rewrite(s) + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + assert isinstance(m.body[1], ast.ImportFrom) + for imp in m.body[2:]: + assert isinstance(imp, ast.Import) + def test_name(self): def f(): assert False From 9ac818fb5ce42c98270a5ff19707219a3a2f4163 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 19 May 2011 18:32:48 -0500 Subject: [PATCH 03/59] small refactoring --- _pytest/assertrewrite.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index 14ddc93ce..9109eadc3 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -113,11 +113,16 @@ class AssertionRewriter(ast.NodeVisitor): for parent, pos, assert_ in asserts: parent[pos:pos + 1] = self.visit(assert_) - def assign(self, expr): - """Give *expr* a name.""" + def variable(self): + """Get a new variable.""" # Use a character invalid in python identifiers to avoid clashing. name = "@py_assert" + str(next(self.variable_counter)) self.variables.add(name) + return name + + def assign(self, expr): + """Give *expr* a name.""" + name = self.variable() self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) return ast.Name(name, ast.Load()) @@ -275,8 +280,7 @@ class AssertionRewriter(ast.NodeVisitor): def visit_Compare(self, comp): self.push_format_context() left_res, left_expl = self.visit(comp.left) - res_variables = ["@py_assert" + str(next(self.variable_counter)) - for i in range(len(comp.ops))] + res_variables = [self.variable() for i in range(len(comp.ops))] load_names = [ast.Name(v, ast.Load()) for v in res_variables] store_names = [ast.Name(v, ast.Store()) for v in res_variables] it = zip(range(len(comp.ops)), comp.ops, comp.comparators) From aae89cd02175418ec51fa06da959a9d63bdb0442 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 19 May 2011 18:56:48 -0500 Subject: [PATCH 04/59] correctly handle multiple asserts --- _pytest/assertrewrite.py | 13 ++++++++----- testing/test_assertrewrite.py | 5 +++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index 9109eadc3..bc755123f 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -99,19 +99,22 @@ class AssertionRewriter(ast.NodeVisitor): node = nodes.popleft() for name, field in ast.iter_fields(node): if isinstance(field, list): + new = [] for i, child in enumerate(field): if isinstance(child, ast.Assert): + # Transform assert. + new.extend(self.visit(child)) asserts.append((field, i, child)) - elif isinstance(child, ast.AST): - nodes.append(child) + else: + new.append(child) + if isinstance(child, ast.AST): + nodes.append(child) + setattr(node, name, new) elif (isinstance(field, ast.AST) and # Don't recurse into expressions as they can't contain # asserts. not isinstance(field, ast.expr)): nodes.append(field) - # Transform asserts. - for parent, pos, assert_ in asserts: - parent[pos:pos + 1] = self.visit(assert_) def variable(self): """Get a new variable.""" diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 64274fb07..988f09264 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -196,6 +196,11 @@ class TestAssertionRewrite: a, b, c = range(3) assert a < b <= c getmsg(f, must_pass=True) + def f(): + a, b, c = range(3) + assert a < b + assert b < c + getmsg(f, must_pass=True) def test_len(self): def f(): From 78be3db9bb0bb15588d7a26e8c2e7ac818f1ccb3 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 19 May 2011 19:15:20 -0500 Subject: [PATCH 05/59] remove unneeded list --- _pytest/assertrewrite.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index bc755123f..6b978bff8 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -93,7 +93,6 @@ class AssertionRewriter(ast.NodeVisitor): pos += 1 mod.body[pos:pos] = imports # Collect asserts. - asserts = [] nodes = collections.deque([mod]) while nodes: node = nodes.popleft() @@ -104,7 +103,6 @@ class AssertionRewriter(ast.NodeVisitor): if isinstance(child, ast.Assert): # Transform assert. new.extend(self.visit(child)) - asserts.append((field, i, child)) else: new.append(child) if isinstance(child, ast.AST): From bf039fea74a3d1657a849863fa67385bae3bb58b Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 19 May 2011 21:45:33 -0500 Subject: [PATCH 06/59] add hooks before and after a module is imported --- _pytest/hookspec.py | 6 ++++++ _pytest/python.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/_pytest/hookspec.py b/_pytest/hookspec.py index 898ffee2a..7cb60a131 100644 --- a/_pytest/hookspec.py +++ b/_pytest/hookspec.py @@ -104,6 +104,12 @@ def pytest_pycollect_makemodule(path, parent): """ pytest_pycollect_makemodule.firstresult = True +def pytest_pycollect_before_module_import(mod): + """Called before a module is imported.""" + +def pytest_pycollect_after_module_import(mod): + """Called after a module is imported.""" + def pytest_pycollect_makeitem(collector, name, obj): """ return custom item/collector for a python object in a module, or None. """ pytest_pycollect_makeitem.firstresult = True diff --git a/_pytest/python.py b/_pytest/python.py index f05aa1447..e2aa5a754 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -225,6 +225,7 @@ class Module(pytest.File, PyCollectorMixin): return self._memoizedcall('_obj', self._importtestmodule) def _importtestmodule(self): + self.ihook.pytest_pycollect_before_module_import(mod=self) # we assume we are only called once per module try: mod = self.fspath.pyimport(ensuresyspath=True) @@ -242,6 +243,8 @@ class Module(pytest.File, PyCollectorMixin): "HINT: use a unique basename for your test file modules" % e.args ) + finally: + self.ihook.pytest_pycollect_after_module_import(mod=self) #print "imported test module", mod self.config.pluginmanager.consider_module(mod) return mod From e0c128beec6d6cf88f4e0299d5fb0774434b6b2a Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 19 May 2011 21:49:37 -0500 Subject: [PATCH 07/59] unconditionally override lineno and col_offset on generated ast --- _pytest/assertrewrite.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index 6b978bff8..ed9e0ca55 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -67,6 +67,19 @@ binop_map = { } +def set_location(node, lineno, col_offset): + """Set node location information recursively.""" + def _fix(node, lineno, col_offset): + if "lineno" in node._attributes: + node.lineno = lineno + if "col_offset" in node._attributes: + node.col_offset = col_offset + for child in ast.iter_child_nodes(node): + _fix(child, lineno, col_offset) + _fix(node, lineno, col_offset) + return node + + class AssertionRewriter(ast.NodeVisitor): def run(self, mod): @@ -196,9 +209,7 @@ class AssertionRewriter(ast.NodeVisitor): self.statements.append(delete) # Fix line numbers. for stmt in self.statements: - stmt.lineno = assert_.lineno - stmt.col_offset = assert_.col_offset - ast.fix_missing_locations(stmt) + set_location(stmt, assert_.lineno, assert_.col_offset) return self.statements def visit_Name(self, name): From 4f2166c997027aca4d180de574d95f508e45cf7a Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 19 May 2011 21:52:10 -0500 Subject: [PATCH 08/59] use assertion rewriting on test files This works by writing a fake pyc with the asserts rewritten. --- _pytest/assertion.py | 56 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/_pytest/assertion.py b/_pytest/assertion.py index d40981c32..2d760381f 100644 --- a/_pytest/assertion.py +++ b/_pytest/assertion.py @@ -2,9 +2,19 @@ support for presented detailed information in failing assertions. """ import py +import imp +import marshal +import struct import sys from _pytest.monkeypatch import monkeypatch +try: + from _pytest.assertrewrite import rewrite_asserts +except ImportError: + rewrite_asserts = None +else: + import ast + def pytest_addoption(parser): group = parser.getgroup("debugconfig") group._addoption('--no-assert', action="store_true", default=False, @@ -12,6 +22,7 @@ def pytest_addoption(parser): help="disable python assert expression reinterpretation."), def pytest_configure(config): + global rewrite_asserts # The _reprcompare attribute on the py.code module is used by # py._code._assertionnew to detect this plugin was loaded and in # turn call the hooks defined here as part of the @@ -29,6 +40,51 @@ def pytest_configure(config): m.setattr(py.builtin.builtins, 'AssertionError', py.code._AssertionError) m.setattr(py.code, '_reprcompare', callbinrepr) + else: + rewrite_asserts = None + +def pytest_pycollect_before_module_import(mod): + if rewrite_asserts is None: + return + # Some deep magic: load the source, rewrite the asserts, and write a + # fake pyc, so that it'll be loaded further down this function. + source = mod.fspath.read() + try: + tree = ast.parse(source) + except SyntaxError: + # Let this pop up again in the real import. + return + rewrite_asserts(tree) + try: + co = compile(tree, str(mod.fspath), "exec") + except SyntaxError: + # It's possible that this error is from some bug in the assertion + # rewriting, but I don't know of a fast way to tell. + return + if hasattr(imp, "cache_from_source"): + # Handle PEP 3147 pycs. + pyc = py.path(imp.cache_from_source(mod.fspath)) + pyc.dirname.ensure(dir=True) + else: + pyc = mod.fspath + "c" + mod._pyc = pyc + mtime = int(mod.fspath.mtime()) + fp = pyc.open("wb") + try: + fp.write(imp.get_magic()) + fp.write(struct.pack(" Date: Thu, 19 May 2011 21:57:27 -0500 Subject: [PATCH 09/59] a less silly way to check comparison results --- _pytest/assertrewrite.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index ed9e0ca55..3d847b648 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -314,6 +314,5 @@ class AssertionRewriter(ast.NodeVisitor): ast.Tuple(load_names, ast.Load()), ast.Tuple(expls, ast.Load()), ast.Tuple(results, ast.Load())) - args = [ast.List(load_names, ast.Load())] - res = ast.Call(self.builtin("all"), args, [], None, None) + res = ast.BoolOp(ast.And(), load_names) return res, self.explanation_param(self.pop_format_context(expl_call)) From 265b7458cbb53e928bae0b5f49e94e8b5dc3e1e3 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 19 May 2011 22:11:18 -0500 Subject: [PATCH 10/59] in the common case, the and operation isn't needed --- _pytest/assertrewrite.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index 3d847b648..f30e817d9 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -314,5 +314,8 @@ class AssertionRewriter(ast.NodeVisitor): ast.Tuple(load_names, ast.Load()), ast.Tuple(expls, ast.Load()), ast.Tuple(results, ast.Load())) - res = ast.BoolOp(ast.And(), load_names) + if len(comp.ops) > 1: + res = ast.BoolOp(ast.And(), load_names) + else: + res = load_names[0] return res, self.explanation_param(self.pop_format_context(expl_call)) From 7ba8fee3dc61233d6a1d9c23192b266ab978b156 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Fri, 20 May 2011 09:44:36 -0500 Subject: [PATCH 11/59] improve this test --- testing/test_assertrewrite.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 988f09264..f6b74d97e 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -48,24 +48,27 @@ def getmsg(f, extra_ns=None, must_pass=False): class TestAssertionRewrite: def test_place_initial_imports(self): - s = """'Doc string'""" + s = """'Doc string'\nother = stuff""" m = rewrite(s) assert isinstance(m.body[0], ast.Expr) assert isinstance(m.body[0].value, ast.Str) - for imp in m.body[1:]: + for imp in m.body[1:4]: assert isinstance(imp, ast.Import) - s = """from __future__ import with_statement""" + assert isinstance(m.body[4], ast.Assign) + s = """from __future__ import with_statement\nother_stuff""" m = rewrite(s) assert isinstance(m.body[0], ast.ImportFrom) - for imp in m.body[1:]: + for imp in m.body[1:4]: assert isinstance(imp, ast.Import) - s = """'doc string'\nfrom __future__ import with_statement""" + assert isinstance(m.body[4], ast.Expr) + s = """'doc string'\nfrom __future__ import with_statement\nother""" m = rewrite(s) assert isinstance(m.body[0], ast.Expr) assert isinstance(m.body[0].value, ast.Str) assert isinstance(m.body[1], ast.ImportFrom) - for imp in m.body[2:]: + for imp in m.body[2:5]: assert isinstance(imp, ast.Import) + assert isinstance(m.body[5], ast.Expr) def test_name(self): def f(): From 9c4f6791e5a545c2fc53f68ef9e3ce031fa6842b Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Tue, 24 May 2011 17:21:58 -0500 Subject: [PATCH 12/59] give initial imports a reasonable lineno --- _pytest/assertrewrite.py | 7 +++++-- testing/test_assertrewrite.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index f30e817d9..be49b2266 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -92,18 +92,21 @@ class AssertionRewriter(ast.NodeVisitor): aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), ast.alias("py", "@pylib"), ast.alias("_pytest.assertrewrite", "@pytest_ar")] - imports = [ast.Import([alias], lineno=0, col_offset=0) - for alias in aliases] expect_docstring = True pos = 0 + lineno = 0 for item in mod.body: if (expect_docstring and isinstance(item, ast.Expr) and isinstance(item.value, ast.Str)): + lineno += len(item.value.s.splitlines()) - 1 expect_docstring = False elif (not isinstance(item, ast.ImportFrom) or item.level > 0 and item.identifier != "__future__"): + lineno = item.lineno break pos += 1 + imports = [ast.Import([alias], lineno=lineno, col_offset=0) + for alias in aliases] mod.body[pos:pos] = imports # Collect asserts. nodes = collections.deque([mod]) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index f6b74d97e..a3d831b22 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -54,12 +54,16 @@ class TestAssertionRewrite: assert isinstance(m.body[0].value, ast.Str) for imp in m.body[1:4]: assert isinstance(imp, ast.Import) + assert imp.lineno == 2 + assert imp.col_offset == 0 assert isinstance(m.body[4], ast.Assign) s = """from __future__ import with_statement\nother_stuff""" m = rewrite(s) assert isinstance(m.body[0], ast.ImportFrom) for imp in m.body[1:4]: assert isinstance(imp, ast.Import) + assert imp.lineno == 2 + assert imp.col_offset == 0 assert isinstance(m.body[4], ast.Expr) s = """'doc string'\nfrom __future__ import with_statement\nother""" m = rewrite(s) @@ -68,6 +72,8 @@ class TestAssertionRewrite: assert isinstance(m.body[1], ast.ImportFrom) for imp in m.body[2:5]: assert isinstance(imp, ast.Import) + assert imp.lineno == 3 + assert imp.col_offset == 0 assert isinstance(m.body[5], ast.Expr) def test_name(self): From 993efe927b1f1112c6de624b68c84d107af52b34 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Tue, 24 May 2011 17:28:20 -0500 Subject: [PATCH 13/59] fix sentence --- _pytest/assertrewrite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index be49b2266..f6bf31516 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -87,8 +87,8 @@ class AssertionRewriter(ast.NodeVisitor): if not mod.body: # Nothing to do. return - # Insert some special imports at top but after any docstrings and - # __future__ imports. + # Insert some special imports at the top of the module but after any + # docstrings and __future__ imports. aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), ast.alias("py", "@pylib"), ast.alias("_pytest.assertrewrite", "@pytest_ar")] From 76cede83c0a4bfbe099b8df5ab23b79976c45d6a Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Tue, 24 May 2011 17:30:35 -0500 Subject: [PATCH 14/59] add a way to disable assertion rewriting for a module --- _pytest/assertrewrite.py | 6 +++++- testing/test_assertrewrite.py | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index f6bf31516..ad4f48f1a 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -98,7 +98,11 @@ class AssertionRewriter(ast.NodeVisitor): for item in mod.body: if (expect_docstring and isinstance(item, ast.Expr) and isinstance(item.value, ast.Str)): - lineno += len(item.value.s.splitlines()) - 1 + doc = item.value.s + if "PYTEST_DONT_REWRITE" in doc: + # The module has disabled assertion rewriting. + return + lineno += len(doc) - 1 expect_docstring = False elif (not isinstance(item, ast.ImportFrom) or item.level > 0 and item.identifier != "__future__"): diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index a3d831b22..4e478d6d6 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -76,6 +76,14 @@ class TestAssertionRewrite: assert imp.col_offset == 0 assert isinstance(m.body[5], ast.Expr) + def test_dont_rewrite(self): + s = """'PYTEST_DONT_REWRITE'\nassert 14""" + m = rewrite(s) + assert len(m.body) == 2 + assert isinstance(m.body[0].value, ast.Str) + assert isinstance(m.body[1], ast.Assert) + assert m.body[1].msg is None + def test_name(self): def f(): assert False From 7fc2f8786fadcb0528595178b6631b0bed40bf0f Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Tue, 24 May 2011 17:48:56 -0500 Subject: [PATCH 15/59] refactor writing the fake pyc into its own function --- _pytest/assertion.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/_pytest/assertion.py b/_pytest/assertion.py index 2d760381f..be5ac2d3f 100644 --- a/_pytest/assertion.py +++ b/_pytest/assertion.py @@ -43,6 +43,23 @@ def pytest_configure(config): else: rewrite_asserts = None +def _write_pyc(co, source_path): + if hasattr(imp, "cache_from_source"): + # Handle PEP 3147 pycs. + pyc = py.path(imp.cache_from_source(source_math)) + pyc.dirname.ensure(dir=True) + else: + pyc = source_path + "c" + mtime = int(source_path.mtime()) + fp = pyc.open("wb") + try: + fp.write(imp.get_magic()) + fp.write(struct.pack(" Date: Tue, 24 May 2011 17:52:17 -0500 Subject: [PATCH 16/59] test that python loads our fake pycs --- testing/test_assertion.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 9fa88d2e0..567cebbf1 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -221,3 +221,10 @@ def test_warn_missing(testdir): result.stderr.fnmatch_lines([ "*WARNING*assertion*", ]) + +def test_load_fake_pyc(testdir): + path = testdir.makepyfile("x = 'hello'") + co = compile("x = 'bye'", str(path), "exec") + plugin._write_pyc(co, path) + mod = path.pyimport() + assert mod.x == "bye" From 0bb84abca7d933b712afab8a5b11d379c1e74571 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Tue, 24 May 2011 18:15:08 -0500 Subject: [PATCH 17/59] handle comparison results which raise when asked for their truth value --- _pytest/assertrewrite.py | 6 +++++- testing/test_assertrewrite.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index ad4f48f1a..04d20dd5c 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -25,7 +25,11 @@ def _format_boolop(operands, explanations, is_or): def _call_reprcompare(ops, results, expls, each_obj): for i, res, expl in zip(range(len(ops)), results, expls): - if not res: + try: + done = not res + except Exception: + done = True + if done: break if py.code._reprcompare is not None: custom = py.code._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 4e478d6d6..df6620d11 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -239,3 +239,17 @@ class TestAssertionRewrite: def f(): assert 1 < 3 < 5 <= 4 < 7 assert getmsg(f) == "assert 5 <= 4" + + def test_assert_raising_nonzero_in_comparison(self): + def f(): + class A(object): + def __nonzero__(self): + raise ValueError(42) + def __lt__(self, other): + return A() + def __repr__(self): + return "" + def myany(x): + return False + assert myany(A() < 0) + assert " < 0" in getmsg(f) From fa412675fc003f2540b8f3dc63894fcf9bf91af3 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Tue, 24 May 2011 18:28:05 -0500 Subject: [PATCH 18/59] use py.builtin.exec_ --- testing/test_assertrewrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index df6620d11..1d1f04f97 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -29,7 +29,7 @@ def getmsg(f, extra_ns=None, must_pass=False): ns = {} if extra_ns is not None: ns.update(extra_ns) - exec code in ns + py.builtin.exec_(code, ns) func = ns[f.__name__] try: func() From b061e71da9de1d302a4952710dbd569959326f17 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Tue, 24 May 2011 18:28:20 -0500 Subject: [PATCH 19/59] account for py3 dict.values --- _pytest/assertrewrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index 04d20dd5c..e0f9507c2 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -180,7 +180,7 @@ class AssertionRewriter(ast.NodeVisitor): if self.stack: self.explanation_specifiers = self.stack[-1] keys = [ast.Str(key) for key in current.keys()] - format_dict = ast.Dict(keys, current.values()) + format_dict = ast.Dict(keys, list(current.values())) form = ast.BinOp(expl_expr, ast.Mod(), format_dict) name = "@py_format" + str(next(self.variable_counter)) self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form)) From c0910abf2f37e7cc1f41e14099298086d0774090 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Tue, 24 May 2011 18:30:18 -0500 Subject: [PATCH 20/59] account py3 range objects --- testing/test_assertrewrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 1d1f04f97..826b175bf 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -221,7 +221,7 @@ class TestAssertionRewrite: def test_len(self): def f(): - l = range(10) + l = list(range(10)) assert len(l) == 11 assert getmsg(f).startswith("""assert 10 == 11 + where 10 = len([""") From e02d22aa4f0ebb768f93b939b67853ad3d9c2087 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Wed, 25 May 2011 15:55:57 -0500 Subject: [PATCH 21/59] expand try/except/finally which py2.4 does't like --- _pytest/python.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/_pytest/python.py b/_pytest/python.py index e2aa5a754..7a01b62fc 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -228,7 +228,10 @@ class Module(pytest.File, PyCollectorMixin): self.ihook.pytest_pycollect_before_module_import(mod=self) # we assume we are only called once per module try: - mod = self.fspath.pyimport(ensuresyspath=True) + try: + mod = self.fspath.pyimport(ensuresyspath=True) + finally: + self.ihook.pytest_pycollect_after_module_import(mod=self) except SyntaxError: excinfo = py.code.ExceptionInfo() raise self.CollectError(excinfo.getrepr(style="short")) @@ -243,8 +246,6 @@ class Module(pytest.File, PyCollectorMixin): "HINT: use a unique basename for your test file modules" % e.args ) - finally: - self.ihook.pytest_pycollect_after_module_import(mod=self) #print "imported test module", mod self.config.pluginmanager.consider_module(mod) return mod From 491c05cea7d339ea83d5f035bcf407632589e5f7 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Wed, 25 May 2011 16:18:45 -0500 Subject: [PATCH 22/59] create the _pytest/assertion package --- _pytest/{assertion.py => assertion/__init__.py} | 2 +- _pytest/{assertrewrite.py => assertion/rewrite.py} | 2 +- testing/test_assertrewrite.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename _pytest/{assertion.py => assertion/__init__.py} (99%) rename _pytest/{assertrewrite.py => assertion/rewrite.py} (99%) diff --git a/_pytest/assertion.py b/_pytest/assertion/__init__.py similarity index 99% rename from _pytest/assertion.py rename to _pytest/assertion/__init__.py index be5ac2d3f..26fed96b4 100644 --- a/_pytest/assertion.py +++ b/_pytest/assertion/__init__.py @@ -9,7 +9,7 @@ import sys from _pytest.monkeypatch import monkeypatch try: - from _pytest.assertrewrite import rewrite_asserts + from _pytest.assertion.rewrite import rewrite_asserts except ImportError: rewrite_asserts = None else: diff --git a/_pytest/assertrewrite.py b/_pytest/assertion/rewrite.py similarity index 99% rename from _pytest/assertrewrite.py rename to _pytest/assertion/rewrite.py index e0f9507c2..29ce43869 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertion/rewrite.py @@ -95,7 +95,7 @@ class AssertionRewriter(ast.NodeVisitor): # docstrings and __future__ imports. aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), ast.alias("py", "@pylib"), - ast.alias("_pytest.assertrewrite", "@pytest_ar")] + ast.alias("_pytest.assertion.rewrite", "@pytest_ar")] expect_docstring = True pos = 0 lineno = 0 diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 826b175bf..580eed420 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -4,7 +4,7 @@ import pytest ast = pytest.importorskip("ast") -from _pytest.assertrewrite import rewrite_asserts +from _pytest.assertion.rewrite import rewrite_asserts def setup_module(mod): From f423ce9c016aff0d84c4a68f6d972833d032181e Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Wed, 25 May 2011 17:54:02 -0500 Subject: [PATCH 23/59] import assertion code from pylib --- _pytest/assertion/__init__.py | 54 ++- _pytest/assertion/newinterpret.py | 340 ++++++++++++++++++ _pytest/assertion/oldinterpret.py | 556 ++++++++++++++++++++++++++++++ _pytest/assertion/reinterpret.py | 48 +++ _pytest/assertion/rewrite.py | 21 +- testing/test_assertinterpret.py | 327 ++++++++++++++++++ testing/test_assertion.py | 11 +- testing/test_assertrewrite.py | 9 +- 8 files changed, 1344 insertions(+), 22 deletions(-) create mode 100644 _pytest/assertion/newinterpret.py create mode 100644 _pytest/assertion/oldinterpret.py create mode 100644 _pytest/assertion/reinterpret.py create mode 100644 testing/test_assertinterpret.py diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 26fed96b4..9f89d17fc 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -31,15 +31,16 @@ def pytest_configure(config): config._cleanup.append(m.undo) warn_about_missing_assertion() if not config.getvalue("noassert") and not config.getvalue("nomagic"): + from _pytest.assertion import reinterpret def callbinrepr(op, left, right): hook_result = config.hook.pytest_assertrepr_compare( config=config, op=op, left=left, right=right) for new_expl in hook_result: if new_expl: return '\n~'.join(new_expl) - m.setattr(py.builtin.builtins, - 'AssertionError', py.code._AssertionError) - m.setattr(py.code, '_reprcompare', callbinrepr) + m.setattr(py.builtin.builtins, 'AssertionError', + reinterpret.AssertionError) + m.setattr(sys.modules[__name__], '_reprcompare', callbinrepr) else: rewrite_asserts = None @@ -98,6 +99,53 @@ def warn_about_missing_assertion(): sys.stderr.write("WARNING: failing tests may report as passing because " "assertions are turned off! (are you using python -O?)\n") +# if set, will be called by assert reinterp for comparison ops +_reprcompare = None + +def _format_explanation(explanation): + """This formats an explanation + + Normally all embedded newlines are escaped, however there are + three exceptions: \n{, \n} and \n~. The first two are intended + cover nested explanations, see function and attribute explanations + for examples (.visit_Call(), visit_Attribute()). The last one is + for when one explanation needs to span multiple lines, e.g. when + displaying diffs. + """ + raw_lines = (explanation or '').split('\n') + # escape newlines not followed by {, } and ~ + lines = [raw_lines[0]] + for l in raw_lines[1:]: + if l.startswith('{') or l.startswith('}') or l.startswith('~'): + lines.append(l) + else: + lines[-1] += '\\n' + l + + result = lines[:1] + stack = [0] + stackcnt = [0] + for line in lines[1:]: + if line.startswith('{'): + if stackcnt[-1]: + s = 'and ' + else: + s = 'where ' + stack.append(len(result)) + stackcnt[-1] += 1 + stackcnt.append(0) + result.append(' +' + ' '*(len(stack)-1) + s + line[1:]) + elif line.startswith('}'): + assert line.startswith('}') + stack.pop() + stackcnt.pop() + result[stack[-1]] += line[1:] + else: + assert line.startswith('~') + result.append(' '*len(stack) + line[1:]) + assert len(stack) == 1 + return '\n'.join(result) + + # Provide basestring in python3 try: basestring = basestring diff --git a/_pytest/assertion/newinterpret.py b/_pytest/assertion/newinterpret.py new file mode 100644 index 000000000..1d061aa46 --- /dev/null +++ b/_pytest/assertion/newinterpret.py @@ -0,0 +1,340 @@ +""" +Find intermediate evalutation results in assert statements through builtin AST. +This should replace oldinterpret.py eventually. +""" + +import sys +import ast + +import py +from _pytest import assertion +from _pytest.assertion import _format_explanation +from _pytest.assertion.reinterpret import BuiltinAssertionError + + +if sys.platform.startswith("java") and sys.version_info < (2, 5, 2): + # See http://bugs.jython.org/issue1497 + _exprs = ("BoolOp", "BinOp", "UnaryOp", "Lambda", "IfExp", "Dict", + "ListComp", "GeneratorExp", "Yield", "Compare", "Call", + "Repr", "Num", "Str", "Attribute", "Subscript", "Name", + "List", "Tuple") + _stmts = ("FunctionDef", "ClassDef", "Return", "Delete", "Assign", + "AugAssign", "Print", "For", "While", "If", "With", "Raise", + "TryExcept", "TryFinally", "Assert", "Import", "ImportFrom", + "Exec", "Global", "Expr", "Pass", "Break", "Continue") + _expr_nodes = set(getattr(ast, name) for name in _exprs) + _stmt_nodes = set(getattr(ast, name) for name in _stmts) + def _is_ast_expr(node): + return node.__class__ in _expr_nodes + def _is_ast_stmt(node): + return node.__class__ in _stmt_nodes +else: + def _is_ast_expr(node): + return isinstance(node, ast.expr) + def _is_ast_stmt(node): + return isinstance(node, ast.stmt) + + +class Failure(Exception): + """Error found while interpreting AST.""" + + def __init__(self, explanation=""): + self.cause = sys.exc_info() + self.explanation = explanation + + +def interpret(source, frame, should_fail=False): + mod = ast.parse(source) + visitor = DebugInterpreter(frame) + try: + visitor.visit(mod) + except Failure: + failure = sys.exc_info()[1] + return getfailure(failure) + if should_fail: + return ("(assertion failed, but when it was re-run for " + "printing intermediate values, it did not fail. Suggestions: " + "compute assert expression before the assert or use --no-assert)") + +def run(offending_line, frame=None): + if frame is None: + frame = py.code.Frame(sys._getframe(1)) + return interpret(offending_line, frame) + +def getfailure(failure): + explanation = _format_explanation(failure.explanation) + value = failure.cause[1] + if str(value): + lines = explanation.splitlines() + if not lines: + lines.append("") + lines[0] += " << %s" % (value,) + explanation = "\n".join(lines) + text = "%s: %s" % (failure.cause[0].__name__, explanation) + if text.startswith("AssertionError: assert "): + text = text[16:] + return text + + +operator_map = { + ast.BitOr : "|", + ast.BitXor : "^", + ast.BitAnd : "&", + ast.LShift : "<<", + ast.RShift : ">>", + ast.Add : "+", + ast.Sub : "-", + ast.Mult : "*", + ast.Div : "/", + ast.FloorDiv : "//", + ast.Mod : "%", + ast.Eq : "==", + ast.NotEq : "!=", + ast.Lt : "<", + ast.LtE : "<=", + ast.Gt : ">", + ast.GtE : ">=", + ast.Pow : "**", + ast.Is : "is", + ast.IsNot : "is not", + ast.In : "in", + ast.NotIn : "not in" +} + +unary_map = { + ast.Not : "not %s", + ast.Invert : "~%s", + ast.USub : "-%s", + ast.UAdd : "+%s" +} + + +class DebugInterpreter(ast.NodeVisitor): + """Interpret AST nodes to gleam useful debugging information. """ + + def __init__(self, frame): + self.frame = frame + + def generic_visit(self, node): + # Fallback when we don't have a special implementation. + if _is_ast_expr(node): + mod = ast.Expression(node) + co = self._compile(mod) + try: + result = self.frame.eval(co) + except Exception: + raise Failure() + explanation = self.frame.repr(result) + return explanation, result + elif _is_ast_stmt(node): + mod = ast.Module([node]) + co = self._compile(mod, "exec") + try: + self.frame.exec_(co) + except Exception: + raise Failure() + return None, None + else: + raise AssertionError("can't handle %s" %(node,)) + + def _compile(self, source, mode="eval"): + return compile(source, "", mode) + + def visit_Expr(self, expr): + return self.visit(expr.value) + + def visit_Module(self, mod): + for stmt in mod.body: + self.visit(stmt) + + def visit_Name(self, name): + explanation, result = self.generic_visit(name) + # See if the name is local. + source = "%r in locals() is not globals()" % (name.id,) + co = self._compile(source) + try: + local = self.frame.eval(co) + except Exception: + # have to assume it isn't + local = False + if not local: + return name.id, result + return explanation, result + + def visit_Compare(self, comp): + left = comp.left + left_explanation, left_result = self.visit(left) + for op, next_op in zip(comp.ops, comp.comparators): + next_explanation, next_result = self.visit(next_op) + op_symbol = operator_map[op.__class__] + explanation = "%s %s %s" % (left_explanation, op_symbol, + next_explanation) + source = "__exprinfo_left %s __exprinfo_right" % (op_symbol,) + co = self._compile(source) + try: + result = self.frame.eval(co, __exprinfo_left=left_result, + __exprinfo_right=next_result) + except Exception: + raise Failure(explanation) + try: + if not result: + break + except KeyboardInterrupt: + raise + except: + break + left_explanation, left_result = next_explanation, next_result + + if assertion._reprcompare is not None: + res = assertion._reprcompare(op_symbol, left_result, next_result) + if res: + explanation = res + return explanation, result + + def visit_BoolOp(self, boolop): + is_or = isinstance(boolop.op, ast.Or) + explanations = [] + for operand in boolop.values: + explanation, result = self.visit(operand) + explanations.append(explanation) + if result == is_or: + break + name = is_or and " or " or " and " + explanation = "(" + name.join(explanations) + ")" + return explanation, result + + def visit_UnaryOp(self, unary): + pattern = unary_map[unary.op.__class__] + operand_explanation, operand_result = self.visit(unary.operand) + explanation = pattern % (operand_explanation,) + co = self._compile(pattern % ("__exprinfo_expr",)) + try: + result = self.frame.eval(co, __exprinfo_expr=operand_result) + except Exception: + raise Failure(explanation) + return explanation, result + + def visit_BinOp(self, binop): + left_explanation, left_result = self.visit(binop.left) + right_explanation, right_result = self.visit(binop.right) + symbol = operator_map[binop.op.__class__] + explanation = "(%s %s %s)" % (left_explanation, symbol, + right_explanation) + source = "__exprinfo_left %s __exprinfo_right" % (symbol,) + co = self._compile(source) + try: + result = self.frame.eval(co, __exprinfo_left=left_result, + __exprinfo_right=right_result) + except Exception: + raise Failure(explanation) + return explanation, result + + def visit_Call(self, call): + func_explanation, func = self.visit(call.func) + arg_explanations = [] + ns = {"__exprinfo_func" : func} + arguments = [] + for arg in call.args: + arg_explanation, arg_result = self.visit(arg) + arg_name = "__exprinfo_%s" % (len(ns),) + ns[arg_name] = arg_result + arguments.append(arg_name) + arg_explanations.append(arg_explanation) + for keyword in call.keywords: + arg_explanation, arg_result = self.visit(keyword.value) + arg_name = "__exprinfo_%s" % (len(ns),) + ns[arg_name] = arg_result + keyword_source = "%s=%%s" % (keyword.arg) + arguments.append(keyword_source % (arg_name,)) + arg_explanations.append(keyword_source % (arg_explanation,)) + if call.starargs: + arg_explanation, arg_result = self.visit(call.starargs) + arg_name = "__exprinfo_star" + ns[arg_name] = arg_result + arguments.append("*%s" % (arg_name,)) + arg_explanations.append("*%s" % (arg_explanation,)) + if call.kwargs: + arg_explanation, arg_result = self.visit(call.kwargs) + arg_name = "__exprinfo_kwds" + ns[arg_name] = arg_result + arguments.append("**%s" % (arg_name,)) + arg_explanations.append("**%s" % (arg_explanation,)) + args_explained = ", ".join(arg_explanations) + explanation = "%s(%s)" % (func_explanation, args_explained) + args = ", ".join(arguments) + source = "__exprinfo_func(%s)" % (args,) + co = self._compile(source) + try: + result = self.frame.eval(co, **ns) + except Exception: + raise Failure(explanation) + pattern = "%s\n{%s = %s\n}" + rep = self.frame.repr(result) + explanation = pattern % (rep, rep, explanation) + return explanation, result + + def _is_builtin_name(self, name): + pattern = "%r not in globals() and %r not in locals()" + source = pattern % (name.id, name.id) + co = self._compile(source) + try: + return self.frame.eval(co) + except Exception: + return False + + def visit_Attribute(self, attr): + if not isinstance(attr.ctx, ast.Load): + return self.generic_visit(attr) + source_explanation, source_result = self.visit(attr.value) + explanation = "%s.%s" % (source_explanation, attr.attr) + source = "__exprinfo_expr.%s" % (attr.attr,) + co = self._compile(source) + try: + result = self.frame.eval(co, __exprinfo_expr=source_result) + except Exception: + raise Failure(explanation) + explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result), + self.frame.repr(result), + source_explanation, attr.attr) + # Check if the attr is from an instance. + source = "%r in getattr(__exprinfo_expr, '__dict__', {})" + source = source % (attr.attr,) + co = self._compile(source) + try: + from_instance = self.frame.eval(co, __exprinfo_expr=source_result) + except Exception: + from_instance = True + if from_instance: + rep = self.frame.repr(result) + pattern = "%s\n{%s = %s\n}" + explanation = pattern % (rep, rep, explanation) + return explanation, result + + def visit_Assert(self, assrt): + test_explanation, test_result = self.visit(assrt.test) + if test_explanation.startswith("False\n{False =") and \ + test_explanation.endswith("\n"): + test_explanation = test_explanation[15:-2] + explanation = "assert %s" % (test_explanation,) + if not test_result: + try: + raise BuiltinAssertionError + except Exception: + raise Failure(explanation) + return explanation, test_result + + def visit_Assign(self, assign): + value_explanation, value_result = self.visit(assign.value) + explanation = "... = %s" % (value_explanation,) + name = ast.Name("__exprinfo_expr", ast.Load(), + lineno=assign.value.lineno, + col_offset=assign.value.col_offset) + new_assign = ast.Assign(assign.targets, name, lineno=assign.lineno, + col_offset=assign.col_offset) + mod = ast.Module([new_assign]) + co = self._compile(mod, "exec") + try: + self.frame.exec_(co, __exprinfo_expr=value_result) + except Exception: + raise Failure(explanation) + return explanation, value_result diff --git a/_pytest/assertion/oldinterpret.py b/_pytest/assertion/oldinterpret.py new file mode 100644 index 000000000..3e8f1c0b3 --- /dev/null +++ b/_pytest/assertion/oldinterpret.py @@ -0,0 +1,556 @@ +import py +import sys, inspect +from compiler import parse, ast, pycodegen +from _pytest.assertion import _format_explanation +from _pytest.assertion.reinterpret import BuiltinAssertionError + +passthroughex = py.builtin._sysex + +class Failure: + def __init__(self, node): + self.exc, self.value, self.tb = sys.exc_info() + self.node = node + +class View(object): + """View base class. + + If C is a subclass of View, then C(x) creates a proxy object around + the object x. The actual class of the proxy is not C in general, + but a *subclass* of C determined by the rules below. To avoid confusion + we call view class the class of the proxy (a subclass of C, so of View) + and object class the class of x. + + Attributes and methods not found in the proxy are automatically read on x. + Other operations like setting attributes are performed on the proxy, as + determined by its view class. The object x is available from the proxy + as its __obj__ attribute. + + The view class selection is determined by the __view__ tuples and the + optional __viewkey__ method. By default, the selected view class is the + most specific subclass of C whose __view__ mentions the class of x. + If no such subclass is found, the search proceeds with the parent + object classes. For example, C(True) will first look for a subclass + of C with __view__ = (..., bool, ...) and only if it doesn't find any + look for one with __view__ = (..., int, ...), and then ..., object,... + If everything fails the class C itself is considered to be the default. + + Alternatively, the view class selection can be driven by another aspect + of the object x, instead of the class of x, by overriding __viewkey__. + See last example at the end of this module. + """ + + _viewcache = {} + __view__ = () + + def __new__(rootclass, obj, *args, **kwds): + self = object.__new__(rootclass) + self.__obj__ = obj + self.__rootclass__ = rootclass + key = self.__viewkey__() + try: + self.__class__ = self._viewcache[key] + except KeyError: + self.__class__ = self._selectsubclass(key) + return self + + def __getattr__(self, attr): + # attributes not found in the normal hierarchy rooted on View + # are looked up in the object's real class + return getattr(self.__obj__, attr) + + def __viewkey__(self): + return self.__obj__.__class__ + + def __matchkey__(self, key, subclasses): + if inspect.isclass(key): + keys = inspect.getmro(key) + else: + keys = [key] + for key in keys: + result = [C for C in subclasses if key in C.__view__] + if result: + return result + return [] + + def _selectsubclass(self, key): + subclasses = list(enumsubclasses(self.__rootclass__)) + for C in subclasses: + if not isinstance(C.__view__, tuple): + C.__view__ = (C.__view__,) + choices = self.__matchkey__(key, subclasses) + if not choices: + return self.__rootclass__ + elif len(choices) == 1: + return choices[0] + else: + # combine the multiple choices + return type('?', tuple(choices), {}) + + def __repr__(self): + return '%s(%r)' % (self.__rootclass__.__name__, self.__obj__) + + +def enumsubclasses(cls): + for subcls in cls.__subclasses__(): + for subsubclass in enumsubclasses(subcls): + yield subsubclass + yield cls + + +class Interpretable(View): + """A parse tree node with a few extra methods.""" + explanation = None + + def is_builtin(self, frame): + return False + + def eval(self, frame): + # fall-back for unknown expression nodes + try: + expr = ast.Expression(self.__obj__) + expr.filename = '' + self.__obj__.filename = '' + co = pycodegen.ExpressionCodeGenerator(expr).getCode() + result = frame.eval(co) + except passthroughex: + raise + except: + raise Failure(self) + self.result = result + self.explanation = self.explanation or frame.repr(self.result) + + def run(self, frame): + # fall-back for unknown statement nodes + try: + expr = ast.Module(None, ast.Stmt([self.__obj__])) + expr.filename = '' + co = pycodegen.ModuleCodeGenerator(expr).getCode() + frame.exec_(co) + except passthroughex: + raise + except: + raise Failure(self) + + def nice_explanation(self): + return _format_explanation(self.explanation) + + +class Name(Interpretable): + __view__ = ast.Name + + def is_local(self, frame): + source = '%r in locals() is not globals()' % self.name + try: + return frame.is_true(frame.eval(source)) + except passthroughex: + raise + except: + return False + + def is_global(self, frame): + source = '%r in globals()' % self.name + try: + return frame.is_true(frame.eval(source)) + except passthroughex: + raise + except: + return False + + def is_builtin(self, frame): + source = '%r not in locals() and %r not in globals()' % ( + self.name, self.name) + try: + return frame.is_true(frame.eval(source)) + except passthroughex: + raise + except: + return False + + def eval(self, frame): + super(Name, self).eval(frame) + if not self.is_local(frame): + self.explanation = self.name + +class Compare(Interpretable): + __view__ = ast.Compare + + def eval(self, frame): + expr = Interpretable(self.expr) + expr.eval(frame) + for operation, expr2 in self.ops: + if hasattr(self, 'result'): + # shortcutting in chained expressions + if not frame.is_true(self.result): + break + expr2 = Interpretable(expr2) + expr2.eval(frame) + self.explanation = "%s %s %s" % ( + expr.explanation, operation, expr2.explanation) + source = "__exprinfo_left %s __exprinfo_right" % operation + try: + self.result = frame.eval(source, + __exprinfo_left=expr.result, + __exprinfo_right=expr2.result) + except passthroughex: + raise + except: + raise Failure(self) + expr = expr2 + +class And(Interpretable): + __view__ = ast.And + + def eval(self, frame): + explanations = [] + for expr in self.nodes: + expr = Interpretable(expr) + expr.eval(frame) + explanations.append(expr.explanation) + self.result = expr.result + if not frame.is_true(expr.result): + break + self.explanation = '(' + ' and '.join(explanations) + ')' + +class Or(Interpretable): + __view__ = ast.Or + + def eval(self, frame): + explanations = [] + for expr in self.nodes: + expr = Interpretable(expr) + expr.eval(frame) + explanations.append(expr.explanation) + self.result = expr.result + if frame.is_true(expr.result): + break + self.explanation = '(' + ' or '.join(explanations) + ')' + + +# == Unary operations == +keepalive = [] +for astclass, astpattern in { + ast.Not : 'not __exprinfo_expr', + ast.Invert : '(~__exprinfo_expr)', + }.items(): + + class UnaryArith(Interpretable): + __view__ = astclass + + def eval(self, frame, astpattern=astpattern): + expr = Interpretable(self.expr) + expr.eval(frame) + self.explanation = astpattern.replace('__exprinfo_expr', + expr.explanation) + try: + self.result = frame.eval(astpattern, + __exprinfo_expr=expr.result) + except passthroughex: + raise + except: + raise Failure(self) + + keepalive.append(UnaryArith) + +# == Binary operations == +for astclass, astpattern in { + ast.Add : '(__exprinfo_left + __exprinfo_right)', + ast.Sub : '(__exprinfo_left - __exprinfo_right)', + ast.Mul : '(__exprinfo_left * __exprinfo_right)', + ast.Div : '(__exprinfo_left / __exprinfo_right)', + ast.Mod : '(__exprinfo_left % __exprinfo_right)', + ast.Power : '(__exprinfo_left ** __exprinfo_right)', + }.items(): + + class BinaryArith(Interpretable): + __view__ = astclass + + def eval(self, frame, astpattern=astpattern): + left = Interpretable(self.left) + left.eval(frame) + right = Interpretable(self.right) + right.eval(frame) + self.explanation = (astpattern + .replace('__exprinfo_left', left .explanation) + .replace('__exprinfo_right', right.explanation)) + try: + self.result = frame.eval(astpattern, + __exprinfo_left=left.result, + __exprinfo_right=right.result) + except passthroughex: + raise + except: + raise Failure(self) + + keepalive.append(BinaryArith) + + +class CallFunc(Interpretable): + __view__ = ast.CallFunc + + def is_bool(self, frame): + source = 'isinstance(__exprinfo_value, bool)' + try: + return frame.is_true(frame.eval(source, + __exprinfo_value=self.result)) + except passthroughex: + raise + except: + return False + + def eval(self, frame): + node = Interpretable(self.node) + node.eval(frame) + explanations = [] + vars = {'__exprinfo_fn': node.result} + source = '__exprinfo_fn(' + for a in self.args: + if isinstance(a, ast.Keyword): + keyword = a.name + a = a.expr + else: + keyword = None + a = Interpretable(a) + a.eval(frame) + argname = '__exprinfo_%d' % len(vars) + vars[argname] = a.result + if keyword is None: + source += argname + ',' + explanations.append(a.explanation) + else: + source += '%s=%s,' % (keyword, argname) + explanations.append('%s=%s' % (keyword, a.explanation)) + if self.star_args: + star_args = Interpretable(self.star_args) + star_args.eval(frame) + argname = '__exprinfo_star' + vars[argname] = star_args.result + source += '*' + argname + ',' + explanations.append('*' + star_args.explanation) + if self.dstar_args: + dstar_args = Interpretable(self.dstar_args) + dstar_args.eval(frame) + argname = '__exprinfo_kwds' + vars[argname] = dstar_args.result + source += '**' + argname + ',' + explanations.append('**' + dstar_args.explanation) + self.explanation = "%s(%s)" % ( + node.explanation, ', '.join(explanations)) + if source.endswith(','): + source = source[:-1] + source += ')' + try: + self.result = frame.eval(source, **vars) + except passthroughex: + raise + except: + raise Failure(self) + if not node.is_builtin(frame) or not self.is_bool(frame): + r = frame.repr(self.result) + self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation) + +class Getattr(Interpretable): + __view__ = ast.Getattr + + def eval(self, frame): + expr = Interpretable(self.expr) + expr.eval(frame) + source = '__exprinfo_expr.%s' % self.attrname + try: + self.result = frame.eval(source, __exprinfo_expr=expr.result) + except passthroughex: + raise + except: + raise Failure(self) + self.explanation = '%s.%s' % (expr.explanation, self.attrname) + # if the attribute comes from the instance, its value is interesting + source = ('hasattr(__exprinfo_expr, "__dict__") and ' + '%r in __exprinfo_expr.__dict__' % self.attrname) + try: + from_instance = frame.is_true( + frame.eval(source, __exprinfo_expr=expr.result)) + except passthroughex: + raise + except: + from_instance = True + if from_instance: + r = frame.repr(self.result) + self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation) + +# == Re-interpretation of full statements == + +class Assert(Interpretable): + __view__ = ast.Assert + + def run(self, frame): + test = Interpretable(self.test) + test.eval(frame) + # simplify 'assert False where False = ...' + if (test.explanation.startswith('False\n{False = ') and + test.explanation.endswith('\n}')): + test.explanation = test.explanation[15:-2] + # print the result as 'assert ' + self.result = test.result + self.explanation = 'assert ' + test.explanation + if not frame.is_true(test.result): + try: + raise BuiltinAssertionError + except passthroughex: + raise + except: + raise Failure(self) + +class Assign(Interpretable): + __view__ = ast.Assign + + def run(self, frame): + expr = Interpretable(self.expr) + expr.eval(frame) + self.result = expr.result + self.explanation = '... = ' + expr.explanation + # fall-back-run the rest of the assignment + ass = ast.Assign(self.nodes, ast.Name('__exprinfo_expr')) + mod = ast.Module(None, ast.Stmt([ass])) + mod.filename = '' + co = pycodegen.ModuleCodeGenerator(mod).getCode() + try: + frame.exec_(co, __exprinfo_expr=expr.result) + except passthroughex: + raise + except: + raise Failure(self) + +class Discard(Interpretable): + __view__ = ast.Discard + + def run(self, frame): + expr = Interpretable(self.expr) + expr.eval(frame) + self.result = expr.result + self.explanation = expr.explanation + +class Stmt(Interpretable): + __view__ = ast.Stmt + + def run(self, frame): + for stmt in self.nodes: + stmt = Interpretable(stmt) + stmt.run(frame) + + +def report_failure(e): + explanation = e.node.nice_explanation() + if explanation: + explanation = ", in: " + explanation + else: + explanation = "" + sys.stdout.write("%s: %s%s\n" % (e.exc.__name__, e.value, explanation)) + +def check(s, frame=None): + if frame is None: + frame = sys._getframe(1) + frame = py.code.Frame(frame) + expr = parse(s, 'eval') + assert isinstance(expr, ast.Expression) + node = Interpretable(expr.node) + try: + node.eval(frame) + except passthroughex: + raise + except Failure: + e = sys.exc_info()[1] + report_failure(e) + else: + if not frame.is_true(node.result): + sys.stderr.write("assertion failed: %s\n" % node.nice_explanation()) + + +########################################################### +# API / Entry points +# ######################################################### + +def interpret(source, frame, should_fail=False): + module = Interpretable(parse(source, 'exec').node) + #print "got module", module + if isinstance(frame, py.std.types.FrameType): + frame = py.code.Frame(frame) + try: + module.run(frame) + except Failure: + e = sys.exc_info()[1] + return getfailure(e) + except passthroughex: + raise + except: + import traceback + traceback.print_exc() + if should_fail: + return ("(assertion failed, but when it was re-run for " + "printing intermediate values, it did not fail. Suggestions: " + "compute assert expression before the assert or use --nomagic)") + else: + return None + +def getmsg(excinfo): + if isinstance(excinfo, tuple): + excinfo = py.code.ExceptionInfo(excinfo) + #frame, line = gettbline(tb) + #frame = py.code.Frame(frame) + #return interpret(line, frame) + + tb = excinfo.traceback[-1] + source = str(tb.statement).strip() + x = interpret(source, tb.frame, should_fail=True) + if not isinstance(x, str): + raise TypeError("interpret returned non-string %r" % (x,)) + return x + +def getfailure(e): + explanation = e.node.nice_explanation() + if str(e.value): + lines = explanation.split('\n') + lines[0] += " << %s" % (e.value,) + explanation = '\n'.join(lines) + text = "%s: %s" % (e.exc.__name__, explanation) + if text.startswith('AssertionError: assert '): + text = text[16:] + return text + +def run(s, frame=None): + if frame is None: + frame = sys._getframe(1) + frame = py.code.Frame(frame) + module = Interpretable(parse(s, 'exec').node) + try: + module.run(frame) + except Failure: + e = sys.exc_info()[1] + report_failure(e) + + +if __name__ == '__main__': + # example: + def f(): + return 5 + def g(): + return 3 + def h(x): + return 'never' + check("f() * g() == 5") + check("not f()") + check("not (f() and g() or 0)") + check("f() == g()") + i = 4 + check("i == f()") + check("len(f()) == 0") + check("isinstance(2+3+4, float)") + + run("x = i") + check("x == 5") + + run("assert not f(), 'oops'") + run("a, b, c = 1, 2") + run("a, b, c = f()") + + check("max([f(),g()]) == 4") + check("'hello'[g()] == 'h'") + run("'guk%d' % h(f())") diff --git a/_pytest/assertion/reinterpret.py b/_pytest/assertion/reinterpret.py new file mode 100644 index 000000000..6e9465d8a --- /dev/null +++ b/_pytest/assertion/reinterpret.py @@ -0,0 +1,48 @@ +import sys +import py + +BuiltinAssertionError = py.builtin.builtins.AssertionError + +class AssertionError(BuiltinAssertionError): + def __init__(self, *args): + BuiltinAssertionError.__init__(self, *args) + if args: + try: + self.msg = str(args[0]) + except py.builtin._sysex: + raise + except: + self.msg = "<[broken __repr__] %s at %0xd>" %( + args[0].__class__, id(args[0])) + else: + f = py.code.Frame(sys._getframe(1)) + try: + source = f.code.fullsource + if source is not None: + try: + source = source.getstatement(f.lineno, assertion=True) + except IndexError: + source = None + else: + source = str(source.deindent()).strip() + except py.error.ENOENT: + source = None + # this can also occur during reinterpretation, when the + # co_filename is set to "". + if source: + self.msg = reinterpret(source, f, should_fail=True) + else: + self.msg = "" + if not self.args: + self.args = (self.msg,) + +if sys.version_info > (3, 0): + AssertionError.__module__ = "builtins" + reinterpret_old = "old reinterpretation not available for py3" +else: + from _pytest.assertion.oldinterpret import interpret as reinterpret_old +if sys.version_info >= (2, 6) or (sys.platform.startswith("java")): + from _pytest.assertion.newinterpret import interpret as reinterpret +else: + reinterpret = reinterpret_old + diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 29ce43869..186d2425e 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -13,7 +13,6 @@ def rewrite_asserts(mod): _saferepr = py.io.saferepr -_format_explanation = py.code._format_explanation def _format_boolop(operands, explanations, is_or): show_explanations = [] @@ -31,8 +30,9 @@ def _call_reprcompare(ops, results, expls, each_obj): done = True if done: break - if py.code._reprcompare is not None: - custom = py.code._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) + from _pytest.assertion import _reprcompare + if _reprcompare is not None: + custom = _reprcompare(ops[i], each_obj[i], each_obj[i + 1]) if custom is not None: return custom return expl @@ -94,7 +94,7 @@ class AssertionRewriter(ast.NodeVisitor): # Insert some special imports at the top of the module but after any # docstrings and __future__ imports. aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), - ast.alias("py", "@pylib"), + ast.alias("_pytest.assertion", "@pytest_a"), ast.alias("_pytest.assertion.rewrite", "@pytest_ar")] expect_docstring = True pos = 0 @@ -153,11 +153,11 @@ class AssertionRewriter(ast.NodeVisitor): def display(self, expr): """Call py.io.saferepr on the expression.""" - return self.helper("saferepr", expr) + return self.helper("ar", "saferepr", expr) - def helper(self, name, *args): + def helper(self, mod, name, *args): """Call a helper in this module.""" - py_name = ast.Name("@pytest_ar", ast.Load()) + py_name = ast.Name("@pytest_" + mod, ast.Load()) attr = ast.Attribute(py_name, "_" + name, ast.Load()) return ast.Call(attr, list(args), [], None, None) @@ -211,7 +211,7 @@ class AssertionRewriter(ast.NodeVisitor): explanation = "assert " + explanation template = ast.Str(explanation) msg = self.pop_format_context(template) - fmt = self.helper("format_explanation", msg) + fmt = self.helper("a", "format_explanation", msg) body.append(ast.Assert(top_condition, fmt)) # Delete temporary variables. names = [ast.Name(name, ast.Del()) for name in self.variables] @@ -242,7 +242,7 @@ class AssertionRewriter(ast.NodeVisitor): explanations.append(explanation) expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load()) is_or = ast.Num(isinstance(boolop.op, ast.Or)) - expl_template = self.helper("format_boolop", + expl_template = self.helper("ar", "format_boolop", ast.Tuple(operands, ast.Load()), expls, is_or) expl = self.pop_format_context(expl_template) @@ -321,7 +321,8 @@ class AssertionRewriter(ast.NodeVisitor): self.statements.append(ast.Assign([store_names[i]], res_expr)) left_res, left_expl = next_res, next_expl # Use py.code._reprcompare if that's available. - expl_call = self.helper("call_reprcompare", ast.Tuple(syms, ast.Load()), + expl_call = self.helper("ar", "call_reprcompare", + ast.Tuple(syms, ast.Load()), ast.Tuple(load_names, ast.Load()), ast.Tuple(expls, ast.Load()), ast.Tuple(results, ast.Load())) diff --git a/testing/test_assertinterpret.py b/testing/test_assertinterpret.py new file mode 100644 index 000000000..318516eae --- /dev/null +++ b/testing/test_assertinterpret.py @@ -0,0 +1,327 @@ +"PYTEST_DONT_REWRITE" +import pytest, py + +from _pytest import assertion + +def exvalue(): + return py.std.sys.exc_info()[1] + +def f(): + return 2 + +def test_not_being_rewritten(): + assert "@py_builtins" not in globals() + +def test_assert(): + try: + assert f() == 3 + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith('assert 2 == 3\n') + +def test_assert_with_explicit_message(): + try: + assert f() == 3, "hello" + except AssertionError: + e = exvalue() + assert e.msg == 'hello' + +def test_assert_within_finally(): + class A: + def f(): + pass + excinfo = py.test.raises(TypeError, """ + try: + A().f() + finally: + i = 42 + """) + s = excinfo.exconly() + assert s.find("takes no argument") != -1 + + #def g(): + # A.f() + #excinfo = getexcinfo(TypeError, g) + #msg = getmsg(excinfo) + #assert msg.find("must be called with A") != -1 + + +def test_assert_multiline_1(): + try: + assert (f() == + 3) + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith('assert 2 == 3\n') + +def test_assert_multiline_2(): + try: + assert (f() == (4, + 3)[-1]) + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith('assert 2 ==') + +def test_in(): + try: + assert "hi" in [1, 2] + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith("assert 'hi' in") + +def test_is(): + try: + assert 1 is 2 + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith("assert 1 is 2") + + +@py.test.mark.skipif("sys.version_info < (2,6)") +def test_attrib(): + class Foo(object): + b = 1 + i = Foo() + try: + assert i.b == 2 + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith("assert 1 == 2") + +@py.test.mark.skipif("sys.version_info < (2,6)") +def test_attrib_inst(): + class Foo(object): + b = 1 + try: + assert Foo().b == 2 + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith("assert 1 == 2") + +def test_len(): + l = list(range(42)) + try: + assert len(l) == 100 + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith("assert 42 == 100") + assert "where 42 = len([" in s + +def test_assert_non_string_message(): + class A: + def __str__(self): + return "hello" + try: + assert 0 == 1, A() + except AssertionError: + e = exvalue() + assert e.msg == "hello" + +def test_assert_keyword_arg(): + def f(x=3): + return False + try: + assert f(x=5) + except AssertionError: + e = exvalue() + assert "x=5" in e.msg + +# These tests should both fail, but should fail nicely... +class WeirdRepr: + def __repr__(self): + return '' + +def bug_test_assert_repr(): + v = WeirdRepr() + try: + assert v == 1 + except AssertionError: + e = exvalue() + assert e.msg.find('WeirdRepr') != -1 + assert e.msg.find('second line') != -1 + assert 0 + +def test_assert_non_string(): + try: + assert 0, ['list'] + except AssertionError: + e = exvalue() + assert e.msg.find("list") != -1 + +def test_assert_implicit_multiline(): + try: + x = [1,2,3] + assert x != [1, + 2, 3] + except AssertionError: + e = exvalue() + assert e.msg.find('assert [1, 2, 3] !=') != -1 + + +def test_assert_with_brokenrepr_arg(): + class BrokenRepr: + def __repr__(self): 0 / 0 + e = AssertionError(BrokenRepr()) + if e.msg.find("broken __repr__") == -1: + py.test.fail("broken __repr__ not handle correctly") + +def test_multiple_statements_per_line(): + try: + a = 1; assert a == 2 + except AssertionError: + e = exvalue() + assert "assert 1 == 2" in e.msg + +def test_power(): + try: + assert 2**3 == 7 + except AssertionError: + e = exvalue() + assert "assert (2 ** 3) == 7" in e.msg + + +class TestView: + + def setup_class(cls): + cls.View = pytest.importorskip("_pytest.assertion.oldinterpret").View + + def test_class_dispatch(self): + ### Use a custom class hierarchy with existing instances + + class Picklable(self.View): + pass + + class Simple(Picklable): + __view__ = object + def pickle(self): + return repr(self.__obj__) + + class Seq(Picklable): + __view__ = list, tuple, dict + def pickle(self): + return ';'.join( + [Picklable(item).pickle() for item in self.__obj__]) + + class Dict(Seq): + __view__ = dict + def pickle(self): + return Seq.pickle(self) + '!' + Seq(self.values()).pickle() + + assert Picklable(123).pickle() == '123' + assert Picklable([1,[2,3],4]).pickle() == '1;2;3;4' + assert Picklable({1:2}).pickle() == '1!2' + + def test_viewtype_class_hierarchy(self): + # Use a custom class hierarchy based on attributes of existing instances + class Operation: + "Existing class that I don't want to change." + def __init__(self, opname, *args): + self.opname = opname + self.args = args + + existing = [Operation('+', 4, 5), + Operation('getitem', '', 'join'), + Operation('setattr', 'x', 'y', 3), + Operation('-', 12, 1)] + + class PyOp(self.View): + def __viewkey__(self): + return self.opname + def generate(self): + return '%s(%s)' % (self.opname, ', '.join(map(repr, self.args))) + + class PyBinaryOp(PyOp): + __view__ = ('+', '-', '*', '/') + def generate(self): + return '%s %s %s' % (self.args[0], self.opname, self.args[1]) + + codelines = [PyOp(op).generate() for op in existing] + assert codelines == ["4 + 5", "getitem('', 'join')", + "setattr('x', 'y', 3)", "12 - 1"] + +@py.test.mark.skipif("sys.version_info < (2,6)") +def test_assert_customizable_reprcompare(monkeypatch): + monkeypatch.setattr(assertion, '_reprcompare', lambda *args: 'hello') + try: + assert 3 == 4 + except AssertionError: + e = exvalue() + s = str(e) + assert "hello" in s + +def test_assert_long_source_1(): + try: + assert len == [ + (None, ['somet text', 'more text']), + ] + except AssertionError: + e = exvalue() + s = str(e) + assert 're-run' not in s + assert 'somet text' in s + +def test_assert_long_source_2(): + try: + assert(len == [ + (None, ['somet text', 'more text']), + ]) + except AssertionError: + e = exvalue() + s = str(e) + assert 're-run' not in s + assert 'somet text' in s + +def test_assert_raise_alias(testdir): + testdir.makepyfile(""" + "PYTEST_DONT_REWRITE" + import sys + EX = AssertionError + def test_hello(): + raise EX("hello" + "multi" + "line") + """) + result = testdir.runpytest() + result.stdout.fnmatch_lines([ + "*def test_hello*", + "*raise EX*", + "*1 failed*", + ]) + + +@pytest.mark.skipif("sys.version_info < (2,5)") +def test_assert_raise_subclass(): + class SomeEx(AssertionError): + def __init__(self, *args): + super(SomeEx, self).__init__() + try: + raise SomeEx("hello") + except AssertionError: + s = str(exvalue()) + assert 're-run' not in s + assert 'could not determine' in s + +def test_assert_raises_in_nonzero_of_object_pytest_issue10(): + class A(object): + def __nonzero__(self): + raise ValueError(42) + def __lt__(self, other): + return A() + def __repr__(self): + return "" + def myany(x): + return True + try: + assert not(myany(A() < 0)) + except AssertionError: + e = exvalue() + s = str(e) + assert " < 0" in s diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 567cebbf1..5470f6416 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -2,11 +2,12 @@ import sys import py, pytest import _pytest.assertion as plugin +from _pytest.assertion import reinterpret needsnewassert = pytest.mark.skipif("sys.version_info < (2,6)") def interpret(expr): - return py.code._reinterpret(expr, py.code.Frame(sys._getframe(1))) + return reinterpret.reinterpret(expr, py.code.Frame(sys._getframe(1))) class TestBinReprIntegration: pytestmark = needsnewassert @@ -25,7 +26,7 @@ class TestBinReprIntegration: self.right = right mockhook = MockHook() monkeypatch = request.getfuncargvalue("monkeypatch") - monkeypatch.setattr(py.code, '_reprcompare', mockhook) + monkeypatch.setattr(plugin, '_reprcompare', mockhook) return mockhook def test_pytest_assertrepr_compare_called(self, hook): @@ -40,13 +41,13 @@ class TestBinReprIntegration: assert hook.right == [0, 2] def test_configure_unconfigure(self, testdir, hook): - assert hook == py.code._reprcompare + assert hook == plugin._reprcompare config = testdir.parseconfig() plugin.pytest_configure(config) - assert hook != py.code._reprcompare + assert hook != plugin._reprcompare from _pytest.config import pytest_unconfigure pytest_unconfigure(config) - assert hook == py.code._reprcompare + assert hook == plugin._reprcompare def callequal(left, right): return plugin.pytest_assertrepr_compare('==', left, right) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 580eed420..d713b6e25 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -4,15 +4,16 @@ import pytest ast = pytest.importorskip("ast") +from _pytest import assertion from _pytest.assertion.rewrite import rewrite_asserts def setup_module(mod): - mod._old_reprcompare = py.code._reprcompare + mod._old_reprcompare = assertion._reprcompare py.code._reprcompare = None def teardown_module(mod): - py.code._reprcompare = mod._old_reprcompare + assertion._reprcompare = mod._old_reprcompare del mod._old_reprcompare @@ -229,13 +230,13 @@ class TestAssertionRewrite: def test_custom_reprcompare(self, monkeypatch): def my_reprcompare(op, left, right): return "42" - monkeypatch.setattr(py.code, "_reprcompare", my_reprcompare) + monkeypatch.setattr(assertion, "_reprcompare", my_reprcompare) def f(): assert 42 < 3 assert getmsg(f) == "assert 42" def my_reprcompare(op, left, right): return "%s %s %s" % (left, op, right) - monkeypatch.setattr(py.code, "_reprcompare", my_reprcompare) + monkeypatch.setattr(assertion, "_reprcompare", my_reprcompare) def f(): assert 1 < 3 < 5 <= 4 < 7 assert getmsg(f) == "assert 5 <= 4" From 250160b4b061e98d77db6e03c673c8c2ddac2722 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 12:01:34 -0500 Subject: [PATCH 24/59] refactor explanation formatting things into their own module --- _pytest/assertion/__init__.py | 188 +----------------------------- _pytest/assertion/newinterpret.py | 9 +- _pytest/assertion/oldinterpret.py | 4 +- _pytest/assertion/rewrite.py | 20 ++-- testing/test_assertinterpret.py | 4 +- testing/test_assertion.py | 10 +- testing/test_assertrewrite.py | 22 ++-- 7 files changed, 37 insertions(+), 220 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 9f89d17fc..c1c8f3be1 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -7,6 +7,7 @@ import marshal import struct import sys from _pytest.monkeypatch import monkeypatch +from _pytest.assertion import reinterpret, util try: from _pytest.assertion.rewrite import rewrite_asserts @@ -31,7 +32,6 @@ def pytest_configure(config): config._cleanup.append(m.undo) warn_about_missing_assertion() if not config.getvalue("noassert") and not config.getvalue("nomagic"): - from _pytest.assertion import reinterpret def callbinrepr(op, left, right): hook_result = config.hook.pytest_assertrepr_compare( config=config, op=op, left=left, right=right) @@ -40,7 +40,7 @@ def pytest_configure(config): return '\n~'.join(new_expl) m.setattr(py.builtin.builtins, 'AssertionError', reinterpret.AssertionError) - m.setattr(sys.modules[__name__], '_reprcompare', callbinrepr) + m.setattr(util, '_reprcompare', callbinrepr) else: rewrite_asserts = None @@ -99,186 +99,4 @@ def warn_about_missing_assertion(): sys.stderr.write("WARNING: failing tests may report as passing because " "assertions are turned off! (are you using python -O?)\n") -# if set, will be called by assert reinterp for comparison ops -_reprcompare = None - -def _format_explanation(explanation): - """This formats an explanation - - Normally all embedded newlines are escaped, however there are - three exceptions: \n{, \n} and \n~. The first two are intended - cover nested explanations, see function and attribute explanations - for examples (.visit_Call(), visit_Attribute()). The last one is - for when one explanation needs to span multiple lines, e.g. when - displaying diffs. - """ - raw_lines = (explanation or '').split('\n') - # escape newlines not followed by {, } and ~ - lines = [raw_lines[0]] - for l in raw_lines[1:]: - if l.startswith('{') or l.startswith('}') or l.startswith('~'): - lines.append(l) - else: - lines[-1] += '\\n' + l - - result = lines[:1] - stack = [0] - stackcnt = [0] - for line in lines[1:]: - if line.startswith('{'): - if stackcnt[-1]: - s = 'and ' - else: - s = 'where ' - stack.append(len(result)) - stackcnt[-1] += 1 - stackcnt.append(0) - result.append(' +' + ' '*(len(stack)-1) + s + line[1:]) - elif line.startswith('}'): - assert line.startswith('}') - stack.pop() - stackcnt.pop() - result[stack[-1]] += line[1:] - else: - assert line.startswith('~') - result.append(' '*len(stack) + line[1:]) - assert len(stack) == 1 - return '\n'.join(result) - - -# Provide basestring in python3 -try: - basestring = basestring -except NameError: - basestring = str - - -def pytest_assertrepr_compare(op, left, right): - """return specialised explanations for some operators/operands""" - width = 80 - 15 - len(op) - 2 # 15 chars indentation, 1 space around op - left_repr = py.io.saferepr(left, maxsize=int(width/2)) - right_repr = py.io.saferepr(right, maxsize=width-len(left_repr)) - summary = '%s %s %s' % (left_repr, op, right_repr) - - issequence = lambda x: isinstance(x, (list, tuple)) - istext = lambda x: isinstance(x, basestring) - isdict = lambda x: isinstance(x, dict) - isset = lambda x: isinstance(x, set) - - explanation = None - try: - if op == '==': - if istext(left) and istext(right): - explanation = _diff_text(left, right) - elif issequence(left) and issequence(right): - explanation = _compare_eq_sequence(left, right) - elif isset(left) and isset(right): - explanation = _compare_eq_set(left, right) - elif isdict(left) and isdict(right): - explanation = _diff_text(py.std.pprint.pformat(left), - py.std.pprint.pformat(right)) - elif op == 'not in': - if istext(left) and istext(right): - explanation = _notin_text(left, right) - except py.builtin._sysex: - raise - except: - excinfo = py.code.ExceptionInfo() - explanation = ['(pytest_assertion plugin: representation of ' - 'details failed. Probably an object has a faulty __repr__.)', - str(excinfo) - ] - - - if not explanation: - return None - - # Don't include pageloads of data, should be configurable - if len(''.join(explanation)) > 80*8: - explanation = ['Detailed information too verbose, truncated'] - - return [summary] + explanation - - -def _diff_text(left, right): - """Return the explanation for the diff between text - - This will skip leading and trailing characters which are - identical to keep the diff minimal. - """ - explanation = [] - i = 0 # just in case left or right has zero length - for i in range(min(len(left), len(right))): - if left[i] != right[i]: - break - if i > 42: - i -= 10 # Provide some context - explanation = ['Skipping %s identical ' - 'leading characters in diff' % i] - left = left[i:] - right = right[i:] - if len(left) == len(right): - for i in range(len(left)): - if left[-i] != right[-i]: - break - if i > 42: - i -= 10 # Provide some context - explanation += ['Skipping %s identical ' - 'trailing characters in diff' % i] - left = left[:-i] - right = right[:-i] - explanation += [line.strip('\n') - for line in py.std.difflib.ndiff(left.splitlines(), - right.splitlines())] - return explanation - - -def _compare_eq_sequence(left, right): - explanation = [] - for i in range(min(len(left), len(right))): - if left[i] != right[i]: - explanation += ['At index %s diff: %r != %r' % - (i, left[i], right[i])] - break - if len(left) > len(right): - explanation += ['Left contains more items, ' - 'first extra item: %s' % py.io.saferepr(left[len(right)],)] - elif len(left) < len(right): - explanation += ['Right contains more items, ' - 'first extra item: %s' % py.io.saferepr(right[len(left)],)] - return explanation # + _diff_text(py.std.pprint.pformat(left), - # py.std.pprint.pformat(right)) - - -def _compare_eq_set(left, right): - explanation = [] - diff_left = left - right - diff_right = right - left - if diff_left: - explanation.append('Extra items in the left set:') - for item in diff_left: - explanation.append(py.io.saferepr(item)) - if diff_right: - explanation.append('Extra items in the right set:') - for item in diff_right: - explanation.append(py.io.saferepr(item)) - return explanation - - -def _notin_text(term, text): - index = text.find(term) - head = text[:index] - tail = text[index+len(term):] - correct_text = head + tail - diff = _diff_text(correct_text, text) - newdiff = ['%s is contained here:' % py.io.saferepr(term, maxsize=42)] - for line in diff: - if line.startswith('Skipping'): - continue - if line.startswith('- '): - continue - if line.startswith('+ '): - newdiff.append(' ' + line[2:]) - else: - newdiff.append(line) - return newdiff +pytest_assertrepr_compare = util.assertrepr_compare diff --git a/_pytest/assertion/newinterpret.py b/_pytest/assertion/newinterpret.py index 1d061aa46..c6e2dea17 100644 --- a/_pytest/assertion/newinterpret.py +++ b/_pytest/assertion/newinterpret.py @@ -7,8 +7,7 @@ import sys import ast import py -from _pytest import assertion -from _pytest.assertion import _format_explanation +from _pytest.assertion import util from _pytest.assertion.reinterpret import BuiltinAssertionError @@ -62,7 +61,7 @@ def run(offending_line, frame=None): return interpret(offending_line, frame) def getfailure(failure): - explanation = _format_explanation(failure.explanation) + explanation = util.format_explanation(failure.explanation) value = failure.cause[1] if str(value): lines = explanation.splitlines() @@ -185,8 +184,8 @@ class DebugInterpreter(ast.NodeVisitor): break left_explanation, left_result = next_explanation, next_result - if assertion._reprcompare is not None: - res = assertion._reprcompare(op_symbol, left_result, next_result) + if util._reprcompare is not None: + res = util._reprcompare(op_symbol, left_result, next_result) if res: explanation = res return explanation, result diff --git a/_pytest/assertion/oldinterpret.py b/_pytest/assertion/oldinterpret.py index 3e8f1c0b3..0c91558a1 100644 --- a/_pytest/assertion/oldinterpret.py +++ b/_pytest/assertion/oldinterpret.py @@ -1,7 +1,7 @@ import py import sys, inspect from compiler import parse, ast, pycodegen -from _pytest.assertion import _format_explanation +from _pytest.assertion.util import format_explanation from _pytest.assertion.reinterpret import BuiltinAssertionError passthroughex = py.builtin._sysex @@ -132,7 +132,7 @@ class Interpretable(View): raise Failure(self) def nice_explanation(self): - return _format_explanation(self.explanation) + return format_explanation(self.explanation) class Name(Interpretable): diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 186d2425e..7e18f2c30 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -5,6 +5,7 @@ import collections import itertools import py +from _pytest.assertion import util def rewrite_asserts(mod): @@ -13,6 +14,7 @@ def rewrite_asserts(mod): _saferepr = py.io.saferepr +from _pytest.assertion.util import format_explanation as _format_explanation def _format_boolop(operands, explanations, is_or): show_explanations = [] @@ -30,9 +32,8 @@ def _call_reprcompare(ops, results, expls, each_obj): done = True if done: break - from _pytest.assertion import _reprcompare - if _reprcompare is not None: - custom = _reprcompare(ops[i], each_obj[i], each_obj[i + 1]) + if util._reprcompare is not None: + custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) if custom is not None: return custom return expl @@ -94,7 +95,6 @@ class AssertionRewriter(ast.NodeVisitor): # Insert some special imports at the top of the module but after any # docstrings and __future__ imports. aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), - ast.alias("_pytest.assertion", "@pytest_a"), ast.alias("_pytest.assertion.rewrite", "@pytest_ar")] expect_docstring = True pos = 0 @@ -153,11 +153,11 @@ class AssertionRewriter(ast.NodeVisitor): def display(self, expr): """Call py.io.saferepr on the expression.""" - return self.helper("ar", "saferepr", expr) + return self.helper("saferepr", expr) - def helper(self, mod, name, *args): + def helper(self, name, *args): """Call a helper in this module.""" - py_name = ast.Name("@pytest_" + mod, ast.Load()) + py_name = ast.Name("@pytest_ar", ast.Load()) attr = ast.Attribute(py_name, "_" + name, ast.Load()) return ast.Call(attr, list(args), [], None, None) @@ -211,7 +211,7 @@ class AssertionRewriter(ast.NodeVisitor): explanation = "assert " + explanation template = ast.Str(explanation) msg = self.pop_format_context(template) - fmt = self.helper("a", "format_explanation", msg) + fmt = self.helper("format_explanation", msg) body.append(ast.Assert(top_condition, fmt)) # Delete temporary variables. names = [ast.Name(name, ast.Del()) for name in self.variables] @@ -242,7 +242,7 @@ class AssertionRewriter(ast.NodeVisitor): explanations.append(explanation) expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load()) is_or = ast.Num(isinstance(boolop.op, ast.Or)) - expl_template = self.helper("ar", "format_boolop", + expl_template = self.helper("format_boolop", ast.Tuple(operands, ast.Load()), expls, is_or) expl = self.pop_format_context(expl_template) @@ -321,7 +321,7 @@ class AssertionRewriter(ast.NodeVisitor): self.statements.append(ast.Assign([store_names[i]], res_expr)) left_res, left_expl = next_res, next_expl # Use py.code._reprcompare if that's available. - expl_call = self.helper("ar", "call_reprcompare", + expl_call = self.helper("call_reprcompare", ast.Tuple(syms, ast.Load()), ast.Tuple(load_names, ast.Load()), ast.Tuple(expls, ast.Load()), diff --git a/testing/test_assertinterpret.py b/testing/test_assertinterpret.py index 318516eae..316cf49d4 100644 --- a/testing/test_assertinterpret.py +++ b/testing/test_assertinterpret.py @@ -1,7 +1,7 @@ "PYTEST_DONT_REWRITE" import pytest, py -from _pytest import assertion +from _pytest.assertion import util def exvalue(): return py.std.sys.exc_info()[1] @@ -249,7 +249,7 @@ class TestView: @py.test.mark.skipif("sys.version_info < (2,6)") def test_assert_customizable_reprcompare(monkeypatch): - monkeypatch.setattr(assertion, '_reprcompare', lambda *args: 'hello') + monkeypatch.setattr(util, '_reprcompare', lambda *args: 'hello') try: assert 3 == 4 except AssertionError: diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 5470f6416..24b665066 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -2,7 +2,7 @@ import sys import py, pytest import _pytest.assertion as plugin -from _pytest.assertion import reinterpret +from _pytest.assertion import reinterpret, util needsnewassert = pytest.mark.skipif("sys.version_info < (2,6)") @@ -26,7 +26,7 @@ class TestBinReprIntegration: self.right = right mockhook = MockHook() monkeypatch = request.getfuncargvalue("monkeypatch") - monkeypatch.setattr(plugin, '_reprcompare', mockhook) + monkeypatch.setattr(util, '_reprcompare', mockhook) return mockhook def test_pytest_assertrepr_compare_called(self, hook): @@ -41,13 +41,13 @@ class TestBinReprIntegration: assert hook.right == [0, 2] def test_configure_unconfigure(self, testdir, hook): - assert hook == plugin._reprcompare + assert hook == util._reprcompare config = testdir.parseconfig() plugin.pytest_configure(config) - assert hook != plugin._reprcompare + assert hook != util._reprcompare from _pytest.config import pytest_unconfigure pytest_unconfigure(config) - assert hook == plugin._reprcompare + assert hook == util._reprcompare def callequal(left, right): return plugin.pytest_assertrepr_compare('==', left, right) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index d713b6e25..ffb544fdd 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -4,16 +4,16 @@ import pytest ast = pytest.importorskip("ast") -from _pytest import assertion +from _pytest.assertion import util from _pytest.assertion.rewrite import rewrite_asserts def setup_module(mod): - mod._old_reprcompare = assertion._reprcompare + mod._old_reprcompare = util._reprcompare py.code._reprcompare = None def teardown_module(mod): - assertion._reprcompare = mod._old_reprcompare + util._reprcompare = mod._old_reprcompare del mod._old_reprcompare @@ -53,29 +53,29 @@ class TestAssertionRewrite: m = rewrite(s) assert isinstance(m.body[0], ast.Expr) assert isinstance(m.body[0].value, ast.Str) - for imp in m.body[1:4]: + for imp in m.body[1:3]: assert isinstance(imp, ast.Import) assert imp.lineno == 2 assert imp.col_offset == 0 - assert isinstance(m.body[4], ast.Assign) + assert isinstance(m.body[3], ast.Assign) s = """from __future__ import with_statement\nother_stuff""" m = rewrite(s) assert isinstance(m.body[0], ast.ImportFrom) - for imp in m.body[1:4]: + for imp in m.body[1:3]: assert isinstance(imp, ast.Import) assert imp.lineno == 2 assert imp.col_offset == 0 - assert isinstance(m.body[4], ast.Expr) + assert isinstance(m.body[3], ast.Expr) s = """'doc string'\nfrom __future__ import with_statement\nother""" m = rewrite(s) assert isinstance(m.body[0], ast.Expr) assert isinstance(m.body[0].value, ast.Str) assert isinstance(m.body[1], ast.ImportFrom) - for imp in m.body[2:5]: + for imp in m.body[2:4]: assert isinstance(imp, ast.Import) assert imp.lineno == 3 assert imp.col_offset == 0 - assert isinstance(m.body[5], ast.Expr) + assert isinstance(m.body[4], ast.Expr) def test_dont_rewrite(self): s = """'PYTEST_DONT_REWRITE'\nassert 14""" @@ -230,13 +230,13 @@ class TestAssertionRewrite: def test_custom_reprcompare(self, monkeypatch): def my_reprcompare(op, left, right): return "42" - monkeypatch.setattr(assertion, "_reprcompare", my_reprcompare) + monkeypatch.setattr(util, "_reprcompare", my_reprcompare) def f(): assert 42 < 3 assert getmsg(f) == "assert 42" def my_reprcompare(op, left, right): return "%s %s %s" % (left, op, right) - monkeypatch.setattr(assertion, "_reprcompare", my_reprcompare) + monkeypatch.setattr(util, "_reprcompare", my_reprcompare) def f(): assert 1 < 3 < 5 <= 4 < 7 assert getmsg(f) == "assert 5 <= 4" From 4fe13e59a7b4769c2a9133945e9e6105837d7545 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 13:15:03 -0500 Subject: [PATCH 25/59] fix comment --- _pytest/assertion/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index c1c8f3be1..5c63be76a 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -65,7 +65,7 @@ def pytest_pycollect_before_module_import(mod): if rewrite_asserts is None: return # Some deep magic: load the source, rewrite the asserts, and write a - # fake pyc, so that it'll be loaded further down this function. + # fake pyc, so that it'll be loaded when the module is imported. source = mod.fspath.read() try: tree = ast.parse(source) From ee64da4badece9f2151b8a005b769a81c90f1940 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 13:15:21 -0500 Subject: [PATCH 26/59] fix grammar --- _pytest/assertion/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 5c63be76a..4ce02aac3 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -1,5 +1,5 @@ """ -support for presented detailed information in failing assertions. +support for presenting detailed information in failing assertions. """ import py import imp From 15b9e8ed7dbb62daeace2e4a3fc2d3115a9ea896 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 13:17:26 -0500 Subject: [PATCH 27/59] forgot to util module --- _pytest/assertion/util.py | 191 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 _pytest/assertion/util.py diff --git a/_pytest/assertion/util.py b/_pytest/assertion/util.py new file mode 100644 index 000000000..4db9ab161 --- /dev/null +++ b/_pytest/assertion/util.py @@ -0,0 +1,191 @@ +"""Utilities for assertion debugging""" + +import py + + +# The _reprcompare attribute on the util module is used by the new assertion +# interpretation code and assertion rewriter to detect this plugin was +# loaded and in turn call the hooks defined here as part of the +# DebugInterpreter. +_reprcompare = None + +def format_explanation(explanation): + """This formats an explanation + + Normally all embedded newlines are escaped, however there are + three exceptions: \n{, \n} and \n~. The first two are intended + cover nested explanations, see function and attribute explanations + for examples (.visit_Call(), visit_Attribute()). The last one is + for when one explanation needs to span multiple lines, e.g. when + displaying diffs. + """ + raw_lines = (explanation or '').split('\n') + # escape newlines not followed by {, } and ~ + lines = [raw_lines[0]] + for l in raw_lines[1:]: + if l.startswith('{') or l.startswith('}') or l.startswith('~'): + lines.append(l) + else: + lines[-1] += '\\n' + l + + result = lines[:1] + stack = [0] + stackcnt = [0] + for line in lines[1:]: + if line.startswith('{'): + if stackcnt[-1]: + s = 'and ' + else: + s = 'where ' + stack.append(len(result)) + stackcnt[-1] += 1 + stackcnt.append(0) + result.append(' +' + ' '*(len(stack)-1) + s + line[1:]) + elif line.startswith('}'): + assert line.startswith('}') + stack.pop() + stackcnt.pop() + result[stack[-1]] += line[1:] + else: + assert line.startswith('~') + result.append(' '*len(stack) + line[1:]) + assert len(stack) == 1 + return '\n'.join(result) + + +# Provide basestring in python3 +try: + basestring = basestring +except NameError: + basestring = str + + +def assertrepr_compare(op, left, right): + """return specialised explanations for some operators/operands""" + width = 80 - 15 - len(op) - 2 # 15 chars indentation, 1 space around op + left_repr = py.io.saferepr(left, maxsize=int(width/2)) + right_repr = py.io.saferepr(right, maxsize=width-len(left_repr)) + summary = '%s %s %s' % (left_repr, op, right_repr) + + issequence = lambda x: isinstance(x, (list, tuple)) + istext = lambda x: isinstance(x, basestring) + isdict = lambda x: isinstance(x, dict) + isset = lambda x: isinstance(x, set) + + explanation = None + try: + if op == '==': + if istext(left) and istext(right): + explanation = _diff_text(left, right) + elif issequence(left) and issequence(right): + explanation = _compare_eq_sequence(left, right) + elif isset(left) and isset(right): + explanation = _compare_eq_set(left, right) + elif isdict(left) and isdict(right): + explanation = _diff_text(py.std.pprint.pformat(left), + py.std.pprint.pformat(right)) + elif op == 'not in': + if istext(left) and istext(right): + explanation = _notin_text(left, right) + except py.builtin._sysex: + raise + except: + excinfo = py.code.ExceptionInfo() + explanation = ['(pytest_assertion plugin: representation of ' + 'details failed. Probably an object has a faulty __repr__.)', + str(excinfo) + ] + + + if not explanation: + return None + + # Don't include pageloads of data, should be configurable + if len(''.join(explanation)) > 80*8: + explanation = ['Detailed information too verbose, truncated'] + + return [summary] + explanation + + +def _diff_text(left, right): + """Return the explanation for the diff between text + + This will skip leading and trailing characters which are + identical to keep the diff minimal. + """ + explanation = [] + i = 0 # just in case left or right has zero length + for i in range(min(len(left), len(right))): + if left[i] != right[i]: + break + if i > 42: + i -= 10 # Provide some context + explanation = ['Skipping %s identical ' + 'leading characters in diff' % i] + left = left[i:] + right = right[i:] + if len(left) == len(right): + for i in range(len(left)): + if left[-i] != right[-i]: + break + if i > 42: + i -= 10 # Provide some context + explanation += ['Skipping %s identical ' + 'trailing characters in diff' % i] + left = left[:-i] + right = right[:-i] + explanation += [line.strip('\n') + for line in py.std.difflib.ndiff(left.splitlines(), + right.splitlines())] + return explanation + + +def _compare_eq_sequence(left, right): + explanation = [] + for i in range(min(len(left), len(right))): + if left[i] != right[i]: + explanation += ['At index %s diff: %r != %r' % + (i, left[i], right[i])] + break + if len(left) > len(right): + explanation += ['Left contains more items, ' + 'first extra item: %s' % py.io.saferepr(left[len(right)],)] + elif len(left) < len(right): + explanation += ['Right contains more items, ' + 'first extra item: %s' % py.io.saferepr(right[len(left)],)] + return explanation # + _diff_text(py.std.pprint.pformat(left), + # py.std.pprint.pformat(right)) + + +def _compare_eq_set(left, right): + explanation = [] + diff_left = left - right + diff_right = right - left + if diff_left: + explanation.append('Extra items in the left set:') + for item in diff_left: + explanation.append(py.io.saferepr(item)) + if diff_right: + explanation.append('Extra items in the right set:') + for item in diff_right: + explanation.append(py.io.saferepr(item)) + return explanation + + +def _notin_text(term, text): + index = text.find(term) + head = text[:index] + tail = text[index+len(term):] + correct_text = head + tail + diff = _diff_text(correct_text, text) + newdiff = ['%s is contained here:' % py.io.saferepr(term, maxsize=42)] + for line in diff: + if line.startswith('Skipping'): + continue + if line.startswith('- '): + continue + if line.startswith('+ '): + newdiff.append(' ' + line[2:]) + else: + newdiff.append(line) + return newdiff From d3645758ea3333f1de55bc61280448a8d6eecb9f Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 13:17:39 -0500 Subject: [PATCH 28/59] this comment was moved away --- _pytest/assertion/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 4ce02aac3..41534db02 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -24,10 +24,6 @@ def pytest_addoption(parser): def pytest_configure(config): global rewrite_asserts - # The _reprcompare attribute on the py.code module is used by - # py._code._assertionnew to detect this plugin was loaded and in - # turn call the hooks defined here as part of the - # DebugInterpreter. m = monkeypatch() config._cleanup.append(m.undo) warn_about_missing_assertion() From d438a0bd83f562d78590934ab424431503a53905 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 14:34:27 -0500 Subject: [PATCH 29/59] introduce --assertmode option --- _pytest/assertion/__init__.py | 27 +++++++++++++++++++++------ _pytest/helpconfig.py | 3 --- testing/test_assertion.py | 28 +++++++++++++++++++++++++--- 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 41534db02..cfe978804 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -6,6 +6,7 @@ import imp import marshal import struct import sys +import pytest from _pytest.monkeypatch import monkeypatch from _pytest.assertion import reinterpret, util @@ -18,26 +19,40 @@ else: def pytest_addoption(parser): group = parser.getgroup("debugconfig") + group._addoption('--assertmode', action="store", dest="assertmode", + choices=("on", "old", "off", "default"), default="default", + metavar="on|old|off", + help="Control assertion debugging tools") group._addoption('--no-assert', action="store_true", default=False, - dest="noassert", - help="disable python assert expression reinterpretation."), + dest="noassert", help="DEPRECATED equivalent to --assertmode=off") + group._addoption('--nomagic', action="store_true", default=False, + dest="nomagic", + help="DEPRECATED equivalent to --assertmode=off") + def pytest_configure(config): global rewrite_asserts - m = monkeypatch() - config._cleanup.append(m.undo) warn_about_missing_assertion() - if not config.getvalue("noassert") and not config.getvalue("nomagic"): + mode = config.getvalue("assertmode") + if config.getvalue("noassert") or config.getvalue("nomagic"): + if mode not in ("off", "default"): + raise pytest.UsageError("assertion options conflict") + mode = "off" + elif mode == "default": + mode = "on" + if mode != "off": def callbinrepr(op, left, right): hook_result = config.hook.pytest_assertrepr_compare( config=config, op=op, left=left, right=right) for new_expl in hook_result: if new_expl: return '\n~'.join(new_expl) + m = monkeypatch() + config._cleanup.append(m.undo) m.setattr(py.builtin.builtins, 'AssertionError', reinterpret.AssertionError) m.setattr(util, '_reprcompare', callbinrepr) - else: + if mode != "on": rewrite_asserts = None def _write_pyc(co, source_path): diff --git a/_pytest/helpconfig.py b/_pytest/helpconfig.py index b89b33b56..fa81f87e9 100644 --- a/_pytest/helpconfig.py +++ b/_pytest/helpconfig.py @@ -16,9 +16,6 @@ def pytest_addoption(parser): group.addoption('--traceconfig', action="store_true", dest="traceconfig", default=False, help="trace considerations of conftest.py files."), - group._addoption('--nomagic', - action="store_true", dest="nomagic", default=False, - help="don't reinterpret asserts, no traceback cutting. ") group.addoption('--debug', action="store_true", dest="debug", default=False, help="generate and show internal debugging information.") diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 24b665066..beb74b9ad 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -160,7 +160,7 @@ def test_sequence_comparison_uses_repr(testdir): ]) -def test_functional(testdir): +def test_assertion_options(testdir): testdir.makepyfile(""" def test_hello(): x = 3 @@ -168,8 +168,30 @@ def test_functional(testdir): """) result = testdir.runpytest() assert "3 == 4" in result.stdout.str() - result = testdir.runpytest("--no-assert") - assert "3 == 4" not in result.stdout.str() + off_options = (("--no-assert",), + ("--nomagic",), + ("--no-assert", "--nomagic"), + ("--assertmode=off",), + ("--assertmode=off", "--no-assert"), + ("--assertmode=off", "--nomagic"), + ("--assertmode=off," "--no-assert", "--nomagic")) + for opt in off_options: + result = testdir.runpytest(*opt) + assert "3 == 4" not in result.stdout.str() + for mode in "on", "old": + for other_opt in off_options[:3]: + opt = ("--assertmode=" + mode,) + other_opt + result = testdir.runpytest(*opt) + assert result.ret == 3 + assert "assertion options conflict" in result.stderr.str() + +def test_old_assert_mode(testdir): + testdir.makepyfile(""" + def test_in_old_mode(): + assert "@py_builtins" not in globals() + """) + result = testdir.runpytest("--assertmode=old") + assert result.ret == 0 def test_triple_quoted_string_issue113(testdir): testdir.makepyfile(""" From 32a67f962263fb8a17b45ba29de010ae5a0464fd Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 16:08:25 -0500 Subject: [PATCH 30/59] add some tracing in the assert plugin --- _pytest/assertion/__init__.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index cfe978804..9b62c571a 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -29,6 +29,12 @@ def pytest_addoption(parser): dest="nomagic", help="DEPRECATED equivalent to --assertmode=off") +class AssertionState: + """State for the assertion plugin.""" + + def __init__(self, config, mode): + self.mode = mode + self.trace = config.trace.root.get("assertion") def pytest_configure(config): global rewrite_asserts @@ -54,6 +60,8 @@ def pytest_configure(config): m.setattr(util, '_reprcompare', callbinrepr) if mode != "on": rewrite_asserts = None + config._assertion = AssertionState(config, mode) + config._assertion.trace("configured with mode set to %r" % (mode,)) def _write_pyc(co, source_path): if hasattr(imp, "cache_from_source"): @@ -82,6 +90,7 @@ def pytest_pycollect_before_module_import(mod): tree = ast.parse(source) except SyntaxError: # Let this pop up again in the real import. + mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,)) return rewrite_asserts(tree) try: @@ -89,8 +98,10 @@ def pytest_pycollect_before_module_import(mod): except SyntaxError: # It's possible that this error is from some bug in the assertion # rewriting, but I don't know of a fast way to tell. + mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,)) return mod._pyc = _write_pyc(co, mod.fspath) + mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,)) def pytest_pycollect_after_module_import(mod): if rewrite_asserts is None or not hasattr(mod, "_pyc"): @@ -99,7 +110,9 @@ def pytest_pycollect_after_module_import(mod): try: mod._pyc.remove() except py.error.ENOENT: - pass + mod.config._assertstate.trace("couldn't find pyc: %r" % (mod._pyc,)) + else: + mod.config._assertstate.trace("removed pyc: %r" % (mod._pyc,)) def warn_about_missing_assertion(): try: From bf3d9f37370b320147a0983da8a40137e9a37fe1 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 16:18:18 -0500 Subject: [PATCH 31/59] correct attribute name --- _pytest/assertion/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 9b62c571a..5682f5535 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -60,8 +60,8 @@ def pytest_configure(config): m.setattr(util, '_reprcompare', callbinrepr) if mode != "on": rewrite_asserts = None - config._assertion = AssertionState(config, mode) - config._assertion.trace("configured with mode set to %r" % (mode,)) + config._assertstate = AssertionState(config, mode) + config._assertstate.trace("configured with mode set to %r" % (mode,)) def _write_pyc(co, source_path): if hasattr(imp, "cache_from_source"): From c4d761fe993699725a80adbb36c32c881d2d33ed Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 16:50:04 -0500 Subject: [PATCH 32/59] these tests should cause pytest_configure to be called --- testing/test_collection.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/testing/test_collection.py b/testing/test_collection.py index 564b5bc63..5b60841a1 100644 --- a/testing/test_collection.py +++ b/testing/test_collection.py @@ -328,7 +328,7 @@ class TestSession: def test_collect_protocol_single_function(self, testdir): p = testdir.makepyfile("def test_func(): pass") id = "::".join([p.basename, "test_func"]) - config = testdir.parseconfig(id) + config = testdir.parseconfigure(id) topdir = testdir.tmpdir rcol = Session(config) assert topdir == rcol.fspath @@ -363,7 +363,7 @@ class TestSession: p.basename + "::TestClass::()", normid, ]: - config = testdir.parseconfig(id) + config = testdir.parseconfigure(id) rcol = Session(config=config) rcol.perform_collect() items = rcol.items @@ -388,7 +388,7 @@ class TestSession: """ % p.basename) id = p.basename - config = testdir.parseconfig(id) + config = testdir.parseconfigure(id) rcol = Session(config) hookrec = testdir.getreportrecorder(config) rcol.perform_collect() @@ -413,7 +413,7 @@ class TestSession: aaa = testdir.mkpydir("aaa") test_aaa = aaa.join("test_aaa.py") p.move(test_aaa) - config = testdir.parseconfig() + config = testdir.parseconfigure() rcol = Session(config) hookrec = testdir.getreportrecorder(config) rcol.perform_collect() @@ -437,7 +437,7 @@ class TestSession: p.move(test_bbb) id = "." - config = testdir.parseconfig(id) + config = testdir.parseconfigure(id) rcol = Session(config) hookrec = testdir.getreportrecorder(config) rcol.perform_collect() @@ -455,7 +455,7 @@ class TestSession: def test_serialization_byid(self, testdir): p = testdir.makepyfile("def test_func(): pass") - config = testdir.parseconfig() + config = testdir.parseconfigure() rcol = Session(config) rcol.perform_collect() items = rcol.items @@ -476,7 +476,7 @@ class TestSession: pass """) arg = p.basename + ("::TestClass::test_method") - config = testdir.parseconfig(arg) + config = testdir.parseconfigure(arg) rcol = Session(config) rcol.perform_collect() items = rcol.items From 89d6defd68e508b31e1fcd82bbb77b1eb0a7b398 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 17:08:44 -0500 Subject: [PATCH 33/59] correctly initialize and shutdown sessions --- _pytest/pytester.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/_pytest/pytester.py b/_pytest/pytester.py index 619bba988..cc25bf1c0 100644 --- a/_pytest/pytester.py +++ b/_pytest/pytester.py @@ -6,7 +6,7 @@ import re import inspect import time from fnmatch import fnmatch -from _pytest.main import Session +from _pytest.main import Session, EXIT_OK from py.builtin import print_ from _pytest.core import HookRelay @@ -292,13 +292,19 @@ class TmpTestdir: assert '::' not in str(arg) p = py.path.local(arg) x = session.fspath.bestrelpath(p) - return session.perform_collect([x], genitems=False)[0] + config.hook.pytest_sessionstart(session=session) + res = session.perform_collect([x], genitems=False)[0] + config.hook.pytest_sessionfinish(session=session, exitstatus=EXIT_OK) + return res def getpathnode(self, path): config = self.parseconfig(path) session = Session(config) x = session.fspath.bestrelpath(path) - return session.perform_collect([x], genitems=False)[0] + config.hook.pytest_sessionstart(session=session) + res = session.perform_collect([x], genitems=False)[0] + config.hook.pytest_sessionfinish(session=session, exitstatus=EXIT_OK) + return res def genitems(self, colitems): session = colitems[0].session @@ -312,7 +318,9 @@ class TmpTestdir: config = self.parseconfigure(*args) rec = self.getreportrecorder(config) session = Session(config) + config.hook.pytest_sessionstart(session=session) session.perform_collect() + config.hook.pytest_sessionfinish(session=session, exitstatus=EXIT_OK) return session.items, rec def runitem(self, source): From dd199d255cb1b03de8ac88f72782b60ae222b023 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 17:08:56 -0500 Subject: [PATCH 34/59] move _setupstate into session --- _pytest/python.py | 4 ++-- _pytest/runner.py | 22 ++++++++++------------ testing/test_python.py | 8 ++++---- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/_pytest/python.py b/_pytest/python.py index 7a01b62fc..5d1ad4713 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -378,7 +378,7 @@ class Generator(FunctionMixin, PyCollectorMixin, pytest.Collector): # test generators are seen as collectors but they also # invoke setup/teardown on popular request # (induced by the common "test_*" naming shared with normal tests) - self.config._setupstate.prepare(self) + self.session._setupstate.prepare(self) # see FunctionMixin.setup and test_setupstate_is_preserved_134 self._preservedparent = self.parent.obj l = [] @@ -730,7 +730,7 @@ class FuncargRequest: def _addfinalizer(self, finalizer, scope): colitem = self._getscopeitem(scope) - self.config._setupstate.addfinalizer( + self._pyfuncitem.session._setupstate.addfinalizer( finalizer=finalizer, colitem=colitem) def __repr__(self): diff --git a/_pytest/runner.py b/_pytest/runner.py index 4deb8685b..c1b73e94c 100644 --- a/_pytest/runner.py +++ b/_pytest/runner.py @@ -14,17 +14,15 @@ def pytest_namespace(): # # pytest plugin hooks -# XXX move to pytest_sessionstart and fix py.test owns tests -def pytest_configure(config): - config._setupstate = SetupState() +def pytest_sessionstart(session): + session._setupstate = SetupState() def pytest_sessionfinish(session, exitstatus): - if hasattr(session.config, '_setupstate'): - hook = session.config.hook - rep = hook.pytest__teardown_final(session=session) - if rep: - hook.pytest__teardown_final_logerror(session=session, report=rep) - session.exitstatus = 1 + hook = session.config.hook + rep = hook.pytest__teardown_final(session=session) + if rep: + hook.pytest__teardown_final_logerror(session=session, report=rep) + session.exitstatus = 1 class NodeInfo: def __init__(self, location): @@ -46,16 +44,16 @@ def runtestprotocol(item, log=True): return reports def pytest_runtest_setup(item): - item.config._setupstate.prepare(item) + item.session._setupstate.prepare(item) def pytest_runtest_call(item): item.runtest() def pytest_runtest_teardown(item): - item.config._setupstate.teardown_exact(item) + item.session._setupstate.teardown_exact(item) def pytest__teardown_final(session): - call = CallInfo(session.config._setupstate.teardown_all, when="teardown") + call = CallInfo(session._setupstate.teardown_all, when="teardown") if call.excinfo: ntraceback = call.excinfo.traceback .cut(excludepath=py._pydir) call.excinfo.traceback = ntraceback.filter() diff --git a/testing/test_python.py b/testing/test_python.py index 8675bca0f..bba7bb94c 100644 --- a/testing/test_python.py +++ b/testing/test_python.py @@ -705,11 +705,11 @@ class TestRequest: def test_func(something): pass """) req = funcargs.FuncargRequest(item) - req.config._setupstate.prepare(item) # XXX + req._pyfuncitem.session._setupstate.prepare(item) # XXX req._fillfuncargs() # successively check finalization calls teardownlist = item.getparent(pytest.Module).obj.teardownlist - ss = item.config._setupstate + ss = item.session._setupstate assert not teardownlist ss.teardown_exact(item) print(ss.stack) @@ -834,11 +834,11 @@ class TestRequestCachedSetup: ret1 = req1.cached_setup(setup, teardown, scope="function") assert l == ['setup'] # artificial call of finalizer - req1.config._setupstate._callfinalizers(item1) + req1._pyfuncitem.session._setupstate._callfinalizers(item1) assert l == ["setup", "teardown"] ret2 = req1.cached_setup(setup, teardown, scope="function") assert l == ["setup", "teardown", "setup"] - req1.config._setupstate._callfinalizers(item1) + req1._pyfuncitem.session._setupstate._callfinalizers(item1) assert l == ["setup", "teardown", "setup", "teardown"] def test_request_cached_setup_two_args(self, testdir): From 657522b629f6e07fdce9b85544f73ccabcd08ea3 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 17:17:48 -0500 Subject: [PATCH 35/59] a less ugly way to detect if assert rewriting is enabled --- _pytest/assertion/__init__.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 5682f5535..3be766d9b 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -37,7 +37,6 @@ class AssertionState: self.trace = config.trace.root.get("assertion") def pytest_configure(config): - global rewrite_asserts warn_about_missing_assertion() mode = config.getvalue("assertmode") if config.getvalue("noassert") or config.getvalue("nomagic"): @@ -58,8 +57,8 @@ def pytest_configure(config): m.setattr(py.builtin.builtins, 'AssertionError', reinterpret.AssertionError) m.setattr(util, '_reprcompare', callbinrepr) - if mode != "on": - rewrite_asserts = None + if mode == "on" and rewrite_asserts is None: + mode = "old" config._assertstate = AssertionState(config, mode) config._assertstate.trace("configured with mode set to %r" % (mode,)) @@ -81,7 +80,7 @@ def _write_pyc(co, source_path): return pyc def pytest_pycollect_before_module_import(mod): - if rewrite_asserts is None: + if mod.config._assertstate.mode != "on": return # Some deep magic: load the source, rewrite the asserts, and write a # fake pyc, so that it'll be loaded when the module is imported. @@ -104,7 +103,7 @@ def pytest_pycollect_before_module_import(mod): mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,)) def pytest_pycollect_after_module_import(mod): - if rewrite_asserts is None or not hasattr(mod, "_pyc"): + if mod.config._assertstate.mode != "on" or not hasattr(mod, "_pyc"): return # Remove our tweaked pyc to avoid subtle bugs. try: From 7cf8afef477176df4b4853daa735d9db8d45cc6c Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 18:10:49 -0500 Subject: [PATCH 36/59] cause configure hooks to be called --- testing/test_collection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/test_collection.py b/testing/test_collection.py index 5b60841a1..1b6b3bded 100644 --- a/testing/test_collection.py +++ b/testing/test_collection.py @@ -313,7 +313,7 @@ class TestSession: def test_collect_topdir(self, testdir): p = testdir.makepyfile("def test_func(): pass") id = "::".join([p.basename, "test_func"]) - config = testdir.parseconfig(id) + config = testdir.parseconfigure(id) topdir = testdir.tmpdir rcol = Session(config) assert topdir == rcol.fspath From 96521ada684aed2321591699663344dd786395a6 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 18:11:12 -0500 Subject: [PATCH 37/59] call configure hooks in reparseconfig --- _pytest/pytester.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/_pytest/pytester.py b/_pytest/pytester.py index cc25bf1c0..cfd9ea358 100644 --- a/_pytest/pytester.py +++ b/_pytest/pytester.py @@ -390,6 +390,8 @@ class TmpTestdir: c.basetemp = py.path.local.make_numbered_dir(prefix="reparse", keep=0, rootdir=self.tmpdir, lock_timeout=None) c.parse(args) + c.pluginmanager.do_configure(c) + self.request.addfinalizer(lambda: c.pluginmanager.do_unconfigure(c)) return c finally: py.test.config = oldconfig From 411e9b136b5b7636dfa1d88f05931ed0ce4b0b85 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 18:37:04 -0500 Subject: [PATCH 38/59] do configure hooks here, too --- _pytest/pytester.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_pytest/pytester.py b/_pytest/pytester.py index cfd9ea358..8bfa3d37b 100644 --- a/_pytest/pytester.py +++ b/_pytest/pytester.py @@ -298,7 +298,7 @@ class TmpTestdir: return res def getpathnode(self, path): - config = self.parseconfig(path) + config = self.parseconfigure(path) session = Session(config) x = session.fspath.bestrelpath(path) config.hook.pytest_sessionstart(session=session) From 241ff0b43ad04d81a423566267d9c045bae1b036 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 18:56:45 -0500 Subject: [PATCH 39/59] add a hook called when a Module is successfully created --- _pytest/hookspec.py | 3 +++ _pytest/python.py | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/_pytest/hookspec.py b/_pytest/hookspec.py index 7cb60a131..99ae77d55 100644 --- a/_pytest/hookspec.py +++ b/_pytest/hookspec.py @@ -104,6 +104,9 @@ def pytest_pycollect_makemodule(path, parent): """ pytest_pycollect_makemodule.firstresult = True +def pytest_pycollect_onmodule(mod): + """ Called when a module is collected.""" + def pytest_pycollect_before_module_import(mod): """Called before a module is imported.""" diff --git a/_pytest/python.py b/_pytest/python.py index 5d1ad4713..615e01a0c 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -60,8 +60,11 @@ def pytest_collect_file(path, parent): break else: return - return parent.ihook.pytest_pycollect_makemodule( + mod = parent.ihook.pytest_pycollect_makemodule( path=path, parent=parent) + if mod is not None: + parent.ihook.pytest_pycollect_onmodule(mod=mod) + return mod def pytest_pycollect_makemodule(path, parent): return Module(path, parent) From 196cece338f7624c74772dd4c920fc53e400d711 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 18:57:37 -0500 Subject: [PATCH 40/59] add a hook called after the inital fs collection --- _pytest/hookspec.py | 3 +++ _pytest/main.py | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/_pytest/hookspec.py b/_pytest/hookspec.py index 99ae77d55..a2f3d6024 100644 --- a/_pytest/hookspec.py +++ b/_pytest/hookspec.py @@ -79,6 +79,9 @@ def pytest_collect_file(path, parent): def pytest_collectstart(collector): """ collector starts collecting. """ +def pytest_after_initial_collect(collector): + """ after the initial file system walk before genitems""" + def pytest_itemcollected(item): """ we just collected a test item. """ diff --git a/_pytest/main.py b/_pytest/main.py index f73ff597a..9769f4a2d 100644 --- a/_pytest/main.py +++ b/_pytest/main.py @@ -386,7 +386,10 @@ class Session(FSCollector): self._initialparts.append(parts) self._initialpaths.add(parts[0]) self.ihook.pytest_collectstart(collector=self) - rep = self.ihook.pytest_make_collect_report(collector=self) + try: + rep = self.ihook.pytest_make_collect_report(collector=self) + finally: + self.ihook.pytest_after_initial_collect(collector=self) self.ihook.pytest_collectreport(report=rep) self.trace.root.indent -= 1 if self._notfound: From f684a9ed563020537000705c7eeb15aa02498d19 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 18:58:31 -0500 Subject: [PATCH 41/59] expose Session on pytest namespace --- _pytest/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/_pytest/main.py b/_pytest/main.py index 9769f4a2d..de61807ad 100644 --- a/_pytest/main.py +++ b/_pytest/main.py @@ -46,7 +46,8 @@ def pytest_addoption(parser): def pytest_namespace(): - return dict(collect=dict(Item=Item, Collector=Collector, File=File)) + collect = dict(Item=Item, Collector=Collector, File=File, Session=Session) + return dict(collect=collect) def pytest_configure(config): py.test.config = config # compatibiltiy From 0a7237b72fd9f42d57ef82532613f7cd4300b7e9 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 19:09:42 -0500 Subject: [PATCH 42/59] refactor common config/session protocol code for main() functions --- _pytest/main.py | 17 ++++++++++++----- _pytest/python.py | 6 ++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/_pytest/main.py b/_pytest/main.py index de61807ad..d58b42b12 100644 --- a/_pytest/main.py +++ b/_pytest/main.py @@ -54,16 +54,14 @@ def pytest_configure(config): if config.option.exitfirst: config.option.maxfail = 1 -def pytest_cmdline_main(config): - """ default command line protocol for initialization, session, - running tests and reporting. """ +def wrap_session(config, doit): + """Skeleton command line program""" session = Session(config) session.exitstatus = EXIT_OK try: config.pluginmanager.do_configure(config) config.hook.pytest_sessionstart(session=session) - config.hook.pytest_collection(session=session) - config.hook.pytest_runtestloop(session=session) + doit(config, session) except pytest.UsageError: raise except KeyboardInterrupt: @@ -83,6 +81,15 @@ def pytest_cmdline_main(config): config.pluginmanager.do_unconfigure(config) return session.exitstatus +def pytest_cmdline_main(config): + return wrap_session(config, _main) + +def _main(config, session): + """ default command line protocol for initialization, session, + running tests and reporting. """ + config.hook.pytest_collection(session=session) + config.hook.pytest_runtestloop(session=session) + def pytest_collection(session): session.perform_collect() hook = session.config.hook diff --git a/_pytest/python.py b/_pytest/python.py index 615e01a0c..9402d93d0 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -754,8 +754,10 @@ class FuncargRequest: raise self.LookupError(msg) def showfuncargs(config): - from _pytest.main import Session - session = Session(config) + from _pytest.main import wrap_session + return wrap_session(config, _showfuncargs_main) + +def _showfuncargs_main(config, session): session.perform_collect() if session.items: plugins = session.items[0].getplugins() From 2f984e0c238be4450309a41e21acf62ff59b8526 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 19:43:02 -0500 Subject: [PATCH 43/59] remove after_initial_collect hook --- _pytest/hookspec.py | 3 --- _pytest/main.py | 5 +---- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/_pytest/hookspec.py b/_pytest/hookspec.py index a2f3d6024..99ae77d55 100644 --- a/_pytest/hookspec.py +++ b/_pytest/hookspec.py @@ -79,9 +79,6 @@ def pytest_collect_file(path, parent): def pytest_collectstart(collector): """ collector starts collecting. """ -def pytest_after_initial_collect(collector): - """ after the initial file system walk before genitems""" - def pytest_itemcollected(item): """ we just collected a test item. """ diff --git a/_pytest/main.py b/_pytest/main.py index d58b42b12..bd1fcc095 100644 --- a/_pytest/main.py +++ b/_pytest/main.py @@ -394,10 +394,7 @@ class Session(FSCollector): self._initialparts.append(parts) self._initialpaths.add(parts[0]) self.ihook.pytest_collectstart(collector=self) - try: - rep = self.ihook.pytest_make_collect_report(collector=self) - finally: - self.ihook.pytest_after_initial_collect(collector=self) + rep = self.ihook.pytest_make_collect_report(collector=self) self.ihook.pytest_collectreport(report=rep) self.trace.root.indent -= 1 if self._notfound: From cf6949c9a3b4e90d423527d6dbe140cccf33ba5d Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 19:53:47 -0500 Subject: [PATCH 44/59] stuff contents of pytest_collection hook into perform_collect --- _pytest/main.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/_pytest/main.py b/_pytest/main.py index bd1fcc095..a8645f4f3 100644 --- a/_pytest/main.py +++ b/_pytest/main.py @@ -91,12 +91,7 @@ def _main(config, session): config.hook.pytest_runtestloop(session=session) def pytest_collection(session): - session.perform_collect() - hook = session.config.hook - hook.pytest_collection_modifyitems(session=session, - config=session.config, items=session.items) - hook.pytest_collection_finish(session=session) - return True + return session.perform_collect() def pytest_runtestloop(session): if session.config.option.collectonly: @@ -382,6 +377,16 @@ class Session(FSCollector): return HookProxy(fspath, self.config) def perform_collect(self, args=None, genitems=True): + hook = self.config.hook + try: + items = self._perform_collect(args, genitems) + hook.pytest_collection_modifyitems(session=self, + config=self.config, items=items) + finally: + hook.pytest_collection_finish(session=self) + return items + + def _perform_collect(self, args, genitems): if args is None: args = self.config.args self.trace("perform_collect", self, args) From abb07fc732ddf0999dcb427857d9488a3bbe26d9 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 19:57:30 -0500 Subject: [PATCH 45/59] new way to rewrite tests: do it all during fs collection This should allow modules to be rewritten before some other test module loads them. --- _pytest/assertion/__init__.py | 41 +++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 3be766d9b..a23e78ae4 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -62,6 +62,10 @@ def pytest_configure(config): config._assertstate = AssertionState(config, mode) config._assertstate.trace("configured with mode set to %r" % (mode,)) +def pytest_collectstart(collector): + if isinstance(collector, pytest.Session): + collector._rewritten_pycs = [] + def _write_pyc(co, source_path): if hasattr(imp, "cache_from_source"): # Handle PEP 3147 pycs. @@ -79,9 +83,9 @@ def _write_pyc(co, source_path): fp.close() return pyc -def pytest_pycollect_before_module_import(mod): - if mod.config._assertstate.mode != "on": - return +def pytest_pycollect_onmodule(mod): + if mod is None or mod.config._assertstate.mode != "on": + return mod # Some deep magic: load the source, rewrite the asserts, and write a # fake pyc, so that it'll be loaded when the module is imported. source = mod.fspath.read() @@ -90,7 +94,7 @@ def pytest_pycollect_before_module_import(mod): except SyntaxError: # Let this pop up again in the real import. mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,)) - return + return mod rewrite_asserts(tree) try: co = compile(tree, str(mod.fspath), "exec") @@ -98,20 +102,25 @@ def pytest_pycollect_before_module_import(mod): # It's possible that this error is from some bug in the assertion # rewriting, but I don't know of a fast way to tell. mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,)) - return - mod._pyc = _write_pyc(co, mod.fspath) - mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,)) + return mod + pyc = _write_pyc(co, mod.fspath) + mod.session._rewritten_pycs.append(pyc) + mod.config._assertstate.trace("wrote pyc: %r" % (pyc,)) + return mod -def pytest_pycollect_after_module_import(mod): - if mod.config._assertstate.mode != "on" or not hasattr(mod, "_pyc"): +def pytest_collection_finish(session): + if not hasattr(session, "_rewritten_pycs"): return - # Remove our tweaked pyc to avoid subtle bugs. - try: - mod._pyc.remove() - except py.error.ENOENT: - mod.config._assertstate.trace("couldn't find pyc: %r" % (mod._pyc,)) - else: - mod.config._assertstate.trace("removed pyc: %r" % (mod._pyc,)) + state = session.config._assertstate + # Remove our tweaked pycs to avoid subtle bugs. + for pyc in session._rewritten_pycs: + try: + pyc.remove() + except py.error.ENOENT: + state.trace("couldn't find pyc: %r" % (pyc,)) + else: + state.trace("removed pyc: %r" % (pyc,)) + del session._rewritten_pycs[:] def warn_about_missing_assertion(): try: From 16b4f5454565ba9419a49e7891655119a8d7e04e Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 20:00:29 -0500 Subject: [PATCH 46/59] remove module before/after import hooks --- _pytest/hookspec.py | 6 ------ _pytest/python.py | 6 +----- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/_pytest/hookspec.py b/_pytest/hookspec.py index 99ae77d55..a3f79fd01 100644 --- a/_pytest/hookspec.py +++ b/_pytest/hookspec.py @@ -107,12 +107,6 @@ pytest_pycollect_makemodule.firstresult = True def pytest_pycollect_onmodule(mod): """ Called when a module is collected.""" -def pytest_pycollect_before_module_import(mod): - """Called before a module is imported.""" - -def pytest_pycollect_after_module_import(mod): - """Called after a module is imported.""" - def pytest_pycollect_makeitem(collector, name, obj): """ return custom item/collector for a python object in a module, or None. """ pytest_pycollect_makeitem.firstresult = True diff --git a/_pytest/python.py b/_pytest/python.py index 9402d93d0..2b47476c9 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -228,13 +228,9 @@ class Module(pytest.File, PyCollectorMixin): return self._memoizedcall('_obj', self._importtestmodule) def _importtestmodule(self): - self.ihook.pytest_pycollect_before_module_import(mod=self) # we assume we are only called once per module try: - try: - mod = self.fspath.pyimport(ensuresyspath=True) - finally: - self.ihook.pytest_pycollect_after_module_import(mod=self) + mod = self.fspath.pyimport(ensuresyspath=True) except SyntaxError: excinfo = py.code.ExceptionInfo() raise self.CollectError(excinfo.getrepr(style="short")) From 971f34147ad144f23e840e8f1fe6f6a62dce170b Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 20:06:11 -0500 Subject: [PATCH 47/59] test that tests get rewritten --- testing/test_assertion.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/testing/test_assertion.py b/testing/test_assertion.py index beb74b9ad..13ed9d62f 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -120,6 +120,10 @@ class TestAssert_reprcompare: expl = ' '.join(callequal('foo', 'bar')) assert 'raised in repr()' not in expl +@pytest.mark.skipif("config._assertstate.mode != 'on'") +def test_rewritten(): + assert "@py_builtins" in globals() + def test_reprcompare_notin(): detail = plugin.pytest_assertrepr_compare('not in', 'foo', 'aaafoobbb')[1:] assert detail == ["'foo' is contained here:", ' aaafoobbb', '? +++'] From 914f689ee8e79c387ac6f7c0c529f70da79a028e Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 20:33:12 -0500 Subject: [PATCH 48/59] beef up --assertmode help --- _pytest/assertion/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index a23e78ae4..5bac5b161 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -22,7 +22,11 @@ def pytest_addoption(parser): group._addoption('--assertmode', action="store", dest="assertmode", choices=("on", "old", "off", "default"), default="default", metavar="on|old|off", - help="Control assertion debugging tools") + help="""control assertion debugging tools. +'off' performs no assertion debugging. +'old' reinterprets the expressions in asserts to glean information. +'new' rewrites the assert statements in test modules to provide sub-expression +results.""") group._addoption('--no-assert', action="store_true", default=False, dest="noassert", help="DEPRECATED equivalent to --assertmode=off") group._addoption('--nomagic', action="store_true", default=False, From d53feaf6f0432f015741abf8d1415a73c1217cc4 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 20:59:43 -0500 Subject: [PATCH 49/59] fix help for --assertmode --- _pytest/assertion/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 5bac5b161..91b15bcd5 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -25,8 +25,8 @@ def pytest_addoption(parser): help="""control assertion debugging tools. 'off' performs no assertion debugging. 'old' reinterprets the expressions in asserts to glean information. -'new' rewrites the assert statements in test modules to provide sub-expression -results.""") +'on' (the default) rewrites the assert statements in test modules to provide +sub-expression results.""") group._addoption('--no-assert', action="store_true", default=False, dest="noassert", help="DEPRECATED equivalent to --assertmode=off") group._addoption('--nomagic', action="store_true", default=False, From e22d3e03fea5bf3d601533b30222a063deb40b35 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 21:08:55 -0500 Subject: [PATCH 50/59] doc updates for new assertion debugging --- doc/assert.txt | 57 +++++++++++++++++++++++++++++++++++--------------- doc/faq.txt | 18 +++++++++------- 2 files changed, 51 insertions(+), 24 deletions(-) diff --git a/doc/assert.txt b/doc/assert.txt index 7d1c21da3..191cd1a66 100644 --- a/doc/assert.txt +++ b/doc/assert.txt @@ -39,23 +39,6 @@ assertion fails you will see the value of ``x``:: test_assert1.py:5: AssertionError ========================= 1 failed in 0.02 seconds ========================= -Reporting details about the failing assertion is achieved by re-evaluating -the assert expression and recording the intermediate values. - -Note: If evaluating the assert expression has side effects you may get a -warning that the intermediate values could not be determined safely. A -common example of this issue is an assertion which reads from a file:: - - assert f.read() != '...' - -If this assertion fails then the re-evaluation will probably succeed! -This is because ``f.read()`` will return an empty string when it is -called the second time during the re-evaluation. However, it is -easy to rewrite the assertion and avoid any trouble:: - - content = f.read() - assert content != '...' - assertions about expected exceptions ------------------------------------------ @@ -137,6 +120,46 @@ Special comparisons are done for a number of cases: See the :ref:`reporting demo ` for many more examples. + +Assertion debugging details +--------------------------- + +Reporting details about the failing assertion is achieved either by rewriting +assert statements before they are run or re-evaluating the assert expression and +recording the intermediate values. Which technique is used depends on the +location of the assert, py.test's configuration, and Python version being used +to run py.test. + +By default, if the Python version is greater than or equal to 2.6, py.test +rewrites assert statements in test modules. Rewritten assert statements put +debugging information into the assertion failure message. Note py.test only +rewrites test modules directly discovered by its test collection process, so +asserts in supporting modules will not be rewritten. + +If an assert statement has not been rewritten or the Python version is less than +2.6, py.test falls back on assert reinterpretation. In assert reinterpretation, +py.test walks the frame of the function containing the assert statement to +discover sub-expression results of the failing assert statement. You can force +py.test to always use assertion reinterpretation by passing the +``--assertmode=old`` option. + +Assert reinterpretation has a caveat not present with assert rewriting: If +evaluating the assert expression has side effects you may get a warning that the +intermediate values could not be determined safely. A common example of this +issue is an assertion which reads from a file:: + + assert f.read() != '...' + +If this assertion fails then the re-evaluation will probably succeed! +This is because ``f.read()`` will return an empty string when it is +called the second time during the re-evaluation. However, it is +easy to rewrite the assertion and avoid any trouble:: + + content = f.read() + assert content != '...' + +All assert debugging can be turned off by passing ``--assertmode=off``. + .. Defining your own comparison ---------------------------------------------- diff --git a/doc/faq.txt b/doc/faq.txt index 0b76dab26..916d0c001 100644 --- a/doc/faq.txt +++ b/doc/faq.txt @@ -47,13 +47,17 @@ customizable testing frameworks for Python. However, ``py.test`` still uses many metaprogramming techniques and reading its source is thus likely not something for Python beginners. -A second "magic" issue arguably the assert statement re-intepreation: -When an ``assert`` statement fails, py.test re-interprets the expression -to show intermediate values if a test fails. If your expression -has side effects (better to avoid them anyway!) the intermediate values -may not be the same, obfuscating the initial error (this is also -explained at the command line if it happens). -``py.test --no-assert`` turns off assert re-interpretation. +A second "magic" issue arguably the assert statement debugging feature. When +loading test modules py.test rewrites the source code of assert statements. When +a rewritten assert statement fails, its error message has more information than +the original. py.test has a second assert debugging technique. When an +``assert`` statement that was missed by the rewriter fails, py.test +re-interprets the expression to show intermediate values if a test fails. This +second technique suffers from caveat that the rewriting does not: If your +expression has side effects (better to avoid them anyway!) the intermediate +values may not be the same, confusing the reinterpreter and obfuscating the +initial error (this is also explained at the command line if it happens). +You can turn off all assertion debugging with ``py.test --assertmode=off``. .. _`py namespaces`: index.html .. _`py/__init__.py`: http://bitbucket.org/hpk42/py-trunk/src/trunk/py/__init__.py From e56838cb6c2398197200a9501d11e3438f82e574 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 21:15:40 -0500 Subject: [PATCH 51/59] write an explicit raise if the assertion fails --- _pytest/assertion/rewrite.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 7e18f2c30..6c9067525 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -3,6 +3,7 @@ import ast import collections import itertools +import sys import py from _pytest.assertion import util @@ -212,7 +213,13 @@ class AssertionRewriter(ast.NodeVisitor): template = ast.Str(explanation) msg = self.pop_format_context(template) fmt = self.helper("format_explanation", msg) - body.append(ast.Assert(top_condition, fmt)) + err_name = ast.Name("AssertionError", ast.Load()) + exc = ast.Call(err_name, [fmt], [], None, None) + if sys.version_info[0] >= 3: + raise_ = ast.Raise(exc, None) + else: + raise_ = ast.Raise(exc, None, None) + body.append(raise_) # Delete temporary variables. names = [ast.Name(name, ast.Del()) for name in self.variables] if names: From 606ea870f08980e0fb1eada970280aa59fcba00f Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 23:13:39 -0500 Subject: [PATCH 52/59] versionadded and versionchanged for asserts --- doc/assert.txt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/doc/assert.txt b/doc/assert.txt index 191cd1a66..04a1023a3 100644 --- a/doc/assert.txt +++ b/doc/assert.txt @@ -160,6 +160,15 @@ easy to rewrite the assertion and avoid any trouble:: All assert debugging can be turned off by passing ``--assertmode=off``. +.. versionadded:: 2.1 + + Add assert rewriting as an alternate debugging technique. + +.. versionchanged:: 2.1 + + Introduce the ``--assertmode`` option. Deprecate ``--no-assert`` and + ``--nomagic``. + .. Defining your own comparison ---------------------------------------------- From 5f75c5851fc6cc75e459b76ee596569690a216b2 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 26 May 2011 23:15:33 -0500 Subject: [PATCH 53/59] can use non-underscored addoption --- _pytest/assertion/__init__.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 91b15bcd5..535f5b2ca 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -19,19 +19,18 @@ else: def pytest_addoption(parser): group = parser.getgroup("debugconfig") - group._addoption('--assertmode', action="store", dest="assertmode", - choices=("on", "old", "off", "default"), default="default", - metavar="on|old|off", - help="""control assertion debugging tools. + group.addoption('--assertmode', action="store", dest="assertmode", + choices=("on", "old", "off", "default"), default="default", + metavar="on|old|off", + help="""control assertion debugging tools. 'off' performs no assertion debugging. 'old' reinterprets the expressions in asserts to glean information. 'on' (the default) rewrites the assert statements in test modules to provide sub-expression results.""") - group._addoption('--no-assert', action="store_true", default=False, + group.addoption('--no-assert', action="store_true", default=False, dest="noassert", help="DEPRECATED equivalent to --assertmode=off") - group._addoption('--nomagic', action="store_true", default=False, - dest="nomagic", - help="DEPRECATED equivalent to --assertmode=off") + group.addoption('--nomagic', action="store_true", default=False, + dest="nomagic", help="DEPRECATED equivalent to --assertmode=off") class AssertionState: """State for the assertion plugin.""" From e98057130dea15b92903d2ab022981365f8a076f Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Fri, 27 May 2011 12:30:27 -0500 Subject: [PATCH 54/59] a few more sentences --- doc/assert.txt | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/doc/assert.txt b/doc/assert.txt index 04a1023a3..ebc5e9062 100644 --- a/doc/assert.txt +++ b/doc/assert.txt @@ -18,8 +18,8 @@ following:: def test_function(): assert f() == 4 -to assert that your object returns a certain value. If this -assertion fails you will see the value of ``x``:: +to assert that your function returns a certain value. If this assertion fails +you will see the value of ``x``:: $ py.test test_assert1.py =========================== test session starts ============================ @@ -39,6 +39,13 @@ assertion fails you will see the value of ``x``:: test_assert1.py:5: AssertionError ========================= 1 failed in 0.02 seconds ========================= +py.test has support for showing the values of the most common subexpressions +including calls, attributes, comparisons, and binary and unary operators. This +allows you to use the idiomatic python constructs without boilerplate code while +not losing debugging information. + +See :ref:`assert-details` for more information on assertion debugging. + assertions about expected exceptions ------------------------------------------ @@ -121,6 +128,8 @@ Special comparisons are done for a number of cases: See the :ref:`reporting demo ` for many more examples. +.. _assert-details: + Assertion debugging details --------------------------- From 326b63adf82b8f8f6d3cdd243f306733051e938c Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Sat, 28 May 2011 10:02:51 -0500 Subject: [PATCH 55/59] bump pylib required --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 801dc6efd..9bfed71b7 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ def main(): author='holger krekel, Guido Wesdorp, Carl Friedrich Bolz, Armin Rigo, Maciej Fijalkowski & others', author_email='holger at merlinux.eu', entry_points= make_entry_points(), - install_requires=['py>1.4.1'], + install_requires=['py>1.4.3'], classifiers=['Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'License :: OSI Approved :: MIT License', @@ -67,4 +67,4 @@ def make_entry_points(): return {'console_scripts': l} if __name__ == '__main__': - main() \ No newline at end of file + main() From f63ff5267c1853c2b61cd888dc8c5f516644f505 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Sat, 28 May 2011 16:01:02 -0500 Subject: [PATCH 56/59] s/debugging/introspection/ --- doc/assert.txt | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/assert.txt b/doc/assert.txt index ebc5e9062..9f26a8446 100644 --- a/doc/assert.txt +++ b/doc/assert.txt @@ -42,9 +42,9 @@ you will see the value of ``x``:: py.test has support for showing the values of the most common subexpressions including calls, attributes, comparisons, and binary and unary operators. This allows you to use the idiomatic python constructs without boilerplate code while -not losing debugging information. +not losing introspection information. -See :ref:`assert-details` for more information on assertion debugging. +See :ref:`assert-details` for more information on assertion introspection. assertions about expected exceptions @@ -130,8 +130,8 @@ See the :ref:`reporting demo ` for many more examples. .. _assert-details: -Assertion debugging details ---------------------------- +Assertion introspection details +------------------------------- Reporting details about the failing assertion is achieved either by rewriting assert statements before they are run or re-evaluating the assert expression and @@ -141,7 +141,7 @@ to run py.test. By default, if the Python version is greater than or equal to 2.6, py.test rewrites assert statements in test modules. Rewritten assert statements put -debugging information into the assertion failure message. Note py.test only +introspection information into the assertion failure message. Note py.test only rewrites test modules directly discovered by its test collection process, so asserts in supporting modules will not be rewritten. @@ -167,11 +167,11 @@ easy to rewrite the assertion and avoid any trouble:: content = f.read() assert content != '...' -All assert debugging can be turned off by passing ``--assertmode=off``. +All assert introspeciton can be turned off by passing ``--assertmode=off``. .. versionadded:: 2.1 - Add assert rewriting as an alternate debugging technique. + Add assert rewriting as an alternate introspection technique. .. versionchanged:: 2.1 From 6fdcecb8641d67b670dfc1452bccb346fb163d77 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Sat, 28 May 2011 16:04:36 -0500 Subject: [PATCH 57/59] typo --- doc/assert.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/assert.txt b/doc/assert.txt index 9f26a8446..ca99612ca 100644 --- a/doc/assert.txt +++ b/doc/assert.txt @@ -167,7 +167,7 @@ easy to rewrite the assertion and avoid any trouble:: content = f.read() assert content != '...' -All assert introspeciton can be turned off by passing ``--assertmode=off``. +All assert introspection can be turned off by passing ``--assertmode=off``. .. versionadded:: 2.1 From 5e31624315be819497ea0180399387505400c288 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Sat, 28 May 2011 18:47:16 -0500 Subject: [PATCH 58/59] return to the old scheme of rewriting test modules from _importtestmodule --- _pytest/assertion/__init__.py | 43 ++++++++++++++--------------------- _pytest/hookspec.py | 3 --- _pytest/python.py | 12 ++++++---- 3 files changed, 24 insertions(+), 34 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 535f5b2ca..6daad05c4 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -65,10 +65,6 @@ def pytest_configure(config): config._assertstate = AssertionState(config, mode) config._assertstate.trace("configured with mode set to %r" % (mode,)) -def pytest_collectstart(collector): - if isinstance(collector, pytest.Session): - collector._rewritten_pycs = [] - def _write_pyc(co, source_path): if hasattr(imp, "cache_from_source"): # Handle PEP 3147 pycs. @@ -86,9 +82,9 @@ def _write_pyc(co, source_path): fp.close() return pyc -def pytest_pycollect_onmodule(mod): - if mod is None or mod.config._assertstate.mode != "on": - return mod +def before_module_import(mod): + if mod.config._assertstate.mode != "on": + return # Some deep magic: load the source, rewrite the asserts, and write a # fake pyc, so that it'll be loaded when the module is imported. source = mod.fspath.read() @@ -97,7 +93,7 @@ def pytest_pycollect_onmodule(mod): except SyntaxError: # Let this pop up again in the real import. mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,)) - return mod + return rewrite_asserts(tree) try: co = compile(tree, str(mod.fspath), "exec") @@ -105,25 +101,20 @@ def pytest_pycollect_onmodule(mod): # It's possible that this error is from some bug in the assertion # rewriting, but I don't know of a fast way to tell. mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,)) - return mod - pyc = _write_pyc(co, mod.fspath) - mod.session._rewritten_pycs.append(pyc) - mod.config._assertstate.trace("wrote pyc: %r" % (pyc,)) - return mod - -def pytest_collection_finish(session): - if not hasattr(session, "_rewritten_pycs"): return - state = session.config._assertstate - # Remove our tweaked pycs to avoid subtle bugs. - for pyc in session._rewritten_pycs: - try: - pyc.remove() - except py.error.ENOENT: - state.trace("couldn't find pyc: %r" % (pyc,)) - else: - state.trace("removed pyc: %r" % (pyc,)) - del session._rewritten_pycs[:] + mod._pyc = _write_pyc(co, mod.fspath) + mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,)) + +def after_module_import(mod): + if not hasattr(mod, "_pyc"): + return + state = mod.config._assertstate + try: + mod._pyc.remove() + except py.error.ENOENT: + state.trace("couldn't find pyc: %r" % (mod._pyc,)) + else: + state.trace("removed pyc: %r" % (mod._pyc,)) def warn_about_missing_assertion(): try: diff --git a/_pytest/hookspec.py b/_pytest/hookspec.py index a3f79fd01..898ffee2a 100644 --- a/_pytest/hookspec.py +++ b/_pytest/hookspec.py @@ -104,9 +104,6 @@ def pytest_pycollect_makemodule(path, parent): """ pytest_pycollect_makemodule.firstresult = True -def pytest_pycollect_onmodule(mod): - """ Called when a module is collected.""" - def pytest_pycollect_makeitem(collector, name, obj): """ return custom item/collector for a python object in a module, or None. """ pytest_pycollect_makeitem.firstresult = True diff --git a/_pytest/python.py b/_pytest/python.py index 2b47476c9..ae964f0b8 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -4,6 +4,7 @@ import inspect import sys import pytest from py._code.code import TerminalRepr +from _pytest import assertion import _pytest cutdir = py.path.local(_pytest.__file__).dirpath() @@ -60,11 +61,8 @@ def pytest_collect_file(path, parent): break else: return - mod = parent.ihook.pytest_pycollect_makemodule( + return parent.ihook.pytest_pycollect_makemodule( path=path, parent=parent) - if mod is not None: - parent.ihook.pytest_pycollect_onmodule(mod=mod) - return mod def pytest_pycollect_makemodule(path, parent): return Module(path, parent) @@ -229,8 +227,12 @@ class Module(pytest.File, PyCollectorMixin): def _importtestmodule(self): # we assume we are only called once per module + assertion.before_module_import(self) try: - mod = self.fspath.pyimport(ensuresyspath=True) + try: + mod = self.fspath.pyimport(ensuresyspath=True) + finally: + assertion.after_module_import(self) except SyntaxError: excinfo = py.code.ExceptionInfo() raise self.CollectError(excinfo.getrepr(style="short")) From 00dee742b0dbc3bc838af687ae8ec7d75fd566c9 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Sat, 28 May 2011 19:00:47 -0500 Subject: [PATCH 59/59] describe how assert rewriting interacts with cross test imports --- doc/assert.txt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/doc/assert.txt b/doc/assert.txt index ca99612ca..e32e31489 100644 --- a/doc/assert.txt +++ b/doc/assert.txt @@ -145,6 +145,15 @@ introspection information into the assertion failure message. Note py.test only rewrites test modules directly discovered by its test collection process, so asserts in supporting modules will not be rewritten. +.. note:: + + py.test rewrites test modules as it collects tests from them. It does this by + writing a new pyc file which Python loads when the test module is + imported. If the module has already been loaded (it is in sys.modules), + though, Python will not load the rewritten module. This means if a test + module imports another test module which has not already been rewritten, then + py.test will not be able to rewrite the second module. + If an assert statement has not been rewritten or the Python version is less than 2.6, py.test falls back on assert reinterpretation. In assert reinterpretation, py.test walks the frame of the function containing the assert statement to