rewrite with proper short-circuting on boolean operators (fixes #57)

This commit is contained in:
Benjamin Peterson 2011-06-28 20:21:22 -05:00
parent c6e3606c6b
commit f286a02582
2 changed files with 60 additions and 28 deletions

View File

@ -17,13 +17,8 @@ def rewrite_asserts(mod):
_saferepr = py.io.saferepr _saferepr = py.io.saferepr
from _pytest.assertion.util import format_explanation as _format_explanation from _pytest.assertion.util import format_explanation as _format_explanation
def _format_boolop(operands, explanations, is_or): def _format_boolop(explanations, is_or):
show_explanations = [] return "(" + (is_or and " or " or " and ").join(explanations) + ")"
for operand, expl in zip(operands, explanations):
show_explanations.append(expl)
if operand == is_or:
break
return "(" + (is_or and " or " or " and ").join(show_explanations) + ")"
def _call_reprcompare(ops, results, expls, each_obj): def _call_reprcompare(ops, results, expls, each_obj):
for i, res, expl in zip(range(len(ops)), results, expls): for i, res, expl in zip(range(len(ops)), results, expls):
@ -143,7 +138,7 @@ class AssertionRewriter(ast.NodeVisitor):
"""Get a new variable.""" """Get a new variable."""
# Use a character invalid in python identifiers to avoid clashing. # Use a character invalid in python identifiers to avoid clashing.
name = "@py_assert" + str(next(self.variable_counter)) name = "@py_assert" + str(next(self.variable_counter))
self.variables.add(name) self.variables[self.cond_chain].add(name)
return name return name
def assign(self, expr): def assign(self, expr):
@ -172,6 +167,13 @@ class AssertionRewriter(ast.NodeVisitor):
self.explanation_specifiers[specifier] = expr self.explanation_specifiers[specifier] = expr
return "%(" + specifier + ")s" return "%(" + specifier + ")s"
def enter_cond(self, cond, body):
self.statements.append(ast.If(cond, body, []))
self.cond_chain += cond,
def leave_cond(self, n=1):
self.cond_chain = self.cond_chain[:-n]
def push_format_context(self): def push_format_context(self):
self.explanation_specifiers = {} self.explanation_specifiers = {}
self.stack.append(self.explanation_specifiers) self.stack.append(self.explanation_specifiers)
@ -198,7 +200,8 @@ class AssertionRewriter(ast.NodeVisitor):
# There's already a message. Don't mess with it. # There's already a message. Don't mess with it.
return [assert_] return [assert_]
self.statements = [] self.statements = []
self.variables = set() self.cond_chain = ()
self.variables = collections.defaultdict(set)
self.variable_counter = itertools.count() self.variable_counter = itertools.count()
self.stack = [] self.stack = []
self.on_failure = [] self.on_failure = []
@ -220,11 +223,22 @@ class AssertionRewriter(ast.NodeVisitor):
else: else:
raise_ = ast.Raise(exc, None, None) raise_ = ast.Raise(exc, None, None)
body.append(raise_) body.append(raise_)
# Delete temporary variables. # Delete temporary variables. This requires a bit cleverness about the
names = [ast.Name(name, ast.Del()) for name in self.variables] # order, so we don't delete variables that are themselves conditions for
if names: # later variables.
delete = ast.Delete(names) for chain in sorted(self.variables, key=len, reverse=True):
self.statements.append(delete) if chain:
where = []
if len(chain) > 1:
cond = ast.Boolop(ast.And(), chain)
else:
cond = chain[0]
self.statements.append(ast.If(cond, where, []))
else:
where = self.statements
v = self.variables[chain]
names = [ast.Name(name, ast.Del()) for name in v]
where.append(ast.Delete(names))
# Fix line numbers. # Fix line numbers.
for stmt in self.statements: for stmt in self.statements:
set_location(stmt, assert_.lineno, assert_.col_offset) set_location(stmt, assert_.lineno, assert_.col_offset)
@ -240,21 +254,32 @@ class AssertionRewriter(ast.NodeVisitor):
return name, self.explanation_param(expr) return name, self.explanation_param(expr)
def visit_BoolOp(self, boolop): def visit_BoolOp(self, boolop):
operands = [] res_var = self.variable()
explanations = [] expl_list = self.assign(ast.List([], ast.Load()))
app = ast.Attribute(expl_list, "append", ast.Load())
is_or = isinstance(boolop.op, ast.Or)
body = save = self.statements
levels = len(boolop.values) - 1
self.push_format_context() self.push_format_context()
for operand in boolop.values: # Process each operand, short-circuting if needed.
res, explanation = self.visit(operand) for i, v in enumerate(boolop.values):
operands.append(res) res, expl = self.visit(v)
explanations.append(explanation) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load()) call = ast.Call(app, [ast.Str(expl)], [], None, None)
is_or = ast.Num(isinstance(boolop.op, ast.Or)) body.append(ast.Expr(call))
expl_template = self.helper("format_boolop", if i < levels:
ast.Tuple(operands, ast.Load()), expls, inner = []
is_or) cond = res
if is_or:
cond = ast.UnaryOp(ast.Not(), cond)
self.enter_cond(cond, inner)
self.statements = body = inner
# Leave all conditions.
self.leave_cond(levels)
self.statements = save
expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or))
expl = self.pop_format_context(expl_template) expl = self.pop_format_context(expl_template)
res = self.assign(ast.BoolOp(boolop.op, operands)) return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
return res, self.explanation_param(expl)
def visit_UnaryOp(self, unary): def visit_UnaryOp(self, unary):
pattern = unary_map[unary.op.__class__] pattern = unary_map[unary.op.__class__]

View File

@ -128,6 +128,10 @@ class TestAssertionRewrite:
f = g = False f = g = False
assert f or g assert f or g
assert getmsg(f) == "assert (False or False)" assert getmsg(f) == "assert (False or False)"
def f():
f = g = False
assert not f and not g
getmsg(f, must_pass=True)
def f(): def f():
f = True f = True
g = False g = False
@ -135,10 +139,13 @@ class TestAssertionRewrite:
getmsg(f, must_pass=True) getmsg(f, must_pass=True)
def test_short_circut_evaluation(self): def test_short_circut_evaluation(self):
pytest.xfail("complicated fix; I'm not sure if it's important")
def f(): def f():
assert True or explode assert True or explode
getmsg(f, must_pass=True) getmsg(f, must_pass=True)
def f():
x = 1
assert x == 1 or x == 2
getmsg(f, must_pass=True)
def test_unary_op(self): def test_unary_op(self):
def f(): def f():