put explanation simplification in format_explanation so everyone can benefit

This commit is contained in:
Benjamin Peterson 2011-06-12 22:41:58 -05:00
parent d853d9a9af
commit 8e81ed693a
4 changed files with 28 additions and 19 deletions

View File

@ -308,9 +308,6 @@ class DebugInterpreter(ast.NodeVisitor):
def visit_Assert(self, assrt):
test_explanation, test_result = self.visit(assrt.test)
if test_explanation.startswith("False\n{False =") and \
test_explanation.endswith("\n}"):
test_explanation = test_explanation[15:-2]
explanation = "assert %s" % (test_explanation,)
if not self.frame.is_true(test_result):
try:

View File

@ -384,10 +384,6 @@ class Assert(Interpretable):
def run(self, frame):
test = Interpretable(self.test)
test.eval(frame)
# simplify 'assert False where False = ...'
if (test.explanation.startswith('False\n{False = ') and
test.explanation.endswith('\n}')):
test.explanation = test.explanation[15:-2]
# print the result as 'assert <explanation>'
self.result = test.result
self.explanation = 'assert ' + test.explanation

View File

@ -19,6 +19,28 @@ def format_explanation(explanation):
for when one explanation needs to span multiple lines, e.g. when
displaying diffs.
"""
# simplify 'assert False where False = ...'
where = 0
while True:
start = where = explanation.find("False\n{False = ", where)
if where == -1:
break
level = 0
for i, c in enumerate(explanation[start:]):
if c == "{":
level += 1
elif c == "}":
level -= 1
if not level:
break
else:
raise AssertionError("unbalanced braces: %r" % (explanation,))
end = start + i
where = end
if explanation[end - 1] == '\n':
explanation = (explanation[:start] + explanation[start+15:end-1] +
explanation[end+1:])
where -= 17
raw_lines = (explanation or '').split('\n')
# escape newlines not followed by {, } and ~
lines = [raw_lines[0]]

View File

@ -164,24 +164,19 @@ class TestAssertionRewrite:
ns = {"g" : g}
def f():
assert g()
assert getmsg(f, ns) == """assert False
+ where False = g()"""
assert getmsg(f, ns) == """assert g()"""
def f():
assert g(1)
assert getmsg(f, ns) == """assert False
+ where False = g(1)"""
assert getmsg(f, ns) == """assert g(1)"""
def f():
assert g(1, 2)
assert getmsg(f, ns) == """assert False
+ where False = g(1, 2)"""
assert getmsg(f, ns) == """assert g(1, 2)"""
def f():
assert g(1, g=42)
assert getmsg(f, ns) == """assert False
+ where False = g(1, g=42)"""
assert getmsg(f, ns) == """assert g(1, g=42)"""
def f():
assert g(1, 3, g=23)
assert getmsg(f, ns) == """assert False
+ where False = g(1, 3, g=23)"""
assert getmsg(f, ns) == """assert g(1, 3, g=23)"""
def test_attribute(self):
class X(object):
@ -194,8 +189,7 @@ class TestAssertionRewrite:
def f():
x.a = False
assert x.a
assert getmsg(f, ns) == """assert False
+ where False = x.a"""
assert getmsg(f, ns) == """assert x.a"""
def test_comparisons(self):
def f():