rewrite with proper short-circuting on boolean operators (fixes #57)
This commit is contained in:
parent
c6e3606c6b
commit
f286a02582
|
@ -17,13 +17,8 @@ def rewrite_asserts(mod):
|
||||||
_saferepr = py.io.saferepr
|
_saferepr = py.io.saferepr
|
||||||
from _pytest.assertion.util import format_explanation as _format_explanation
|
from _pytest.assertion.util import format_explanation as _format_explanation
|
||||||
|
|
||||||
def _format_boolop(operands, explanations, is_or):
|
def _format_boolop(explanations, is_or):
|
||||||
show_explanations = []
|
return "(" + (is_or and " or " or " and ").join(explanations) + ")"
|
||||||
for operand, expl in zip(operands, explanations):
|
|
||||||
show_explanations.append(expl)
|
|
||||||
if operand == is_or:
|
|
||||||
break
|
|
||||||
return "(" + (is_or and " or " or " and ").join(show_explanations) + ")"
|
|
||||||
|
|
||||||
def _call_reprcompare(ops, results, expls, each_obj):
|
def _call_reprcompare(ops, results, expls, each_obj):
|
||||||
for i, res, expl in zip(range(len(ops)), results, expls):
|
for i, res, expl in zip(range(len(ops)), results, expls):
|
||||||
|
@ -143,7 +138,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
"""Get a new variable."""
|
"""Get a new variable."""
|
||||||
# Use a character invalid in python identifiers to avoid clashing.
|
# Use a character invalid in python identifiers to avoid clashing.
|
||||||
name = "@py_assert" + str(next(self.variable_counter))
|
name = "@py_assert" + str(next(self.variable_counter))
|
||||||
self.variables.add(name)
|
self.variables[self.cond_chain].add(name)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def assign(self, expr):
|
def assign(self, expr):
|
||||||
|
@ -172,6 +167,13 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
self.explanation_specifiers[specifier] = expr
|
self.explanation_specifiers[specifier] = expr
|
||||||
return "%(" + specifier + ")s"
|
return "%(" + specifier + ")s"
|
||||||
|
|
||||||
|
def enter_cond(self, cond, body):
|
||||||
|
self.statements.append(ast.If(cond, body, []))
|
||||||
|
self.cond_chain += cond,
|
||||||
|
|
||||||
|
def leave_cond(self, n=1):
|
||||||
|
self.cond_chain = self.cond_chain[:-n]
|
||||||
|
|
||||||
def push_format_context(self):
|
def push_format_context(self):
|
||||||
self.explanation_specifiers = {}
|
self.explanation_specifiers = {}
|
||||||
self.stack.append(self.explanation_specifiers)
|
self.stack.append(self.explanation_specifiers)
|
||||||
|
@ -198,7 +200,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
# There's already a message. Don't mess with it.
|
# There's already a message. Don't mess with it.
|
||||||
return [assert_]
|
return [assert_]
|
||||||
self.statements = []
|
self.statements = []
|
||||||
self.variables = set()
|
self.cond_chain = ()
|
||||||
|
self.variables = collections.defaultdict(set)
|
||||||
self.variable_counter = itertools.count()
|
self.variable_counter = itertools.count()
|
||||||
self.stack = []
|
self.stack = []
|
||||||
self.on_failure = []
|
self.on_failure = []
|
||||||
|
@ -220,11 +223,22 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
else:
|
else:
|
||||||
raise_ = ast.Raise(exc, None, None)
|
raise_ = ast.Raise(exc, None, None)
|
||||||
body.append(raise_)
|
body.append(raise_)
|
||||||
# Delete temporary variables.
|
# Delete temporary variables. This requires a bit cleverness about the
|
||||||
names = [ast.Name(name, ast.Del()) for name in self.variables]
|
# order, so we don't delete variables that are themselves conditions for
|
||||||
if names:
|
# later variables.
|
||||||
delete = ast.Delete(names)
|
for chain in sorted(self.variables, key=len, reverse=True):
|
||||||
self.statements.append(delete)
|
if chain:
|
||||||
|
where = []
|
||||||
|
if len(chain) > 1:
|
||||||
|
cond = ast.Boolop(ast.And(), chain)
|
||||||
|
else:
|
||||||
|
cond = chain[0]
|
||||||
|
self.statements.append(ast.If(cond, where, []))
|
||||||
|
else:
|
||||||
|
where = self.statements
|
||||||
|
v = self.variables[chain]
|
||||||
|
names = [ast.Name(name, ast.Del()) for name in v]
|
||||||
|
where.append(ast.Delete(names))
|
||||||
# Fix line numbers.
|
# Fix line numbers.
|
||||||
for stmt in self.statements:
|
for stmt in self.statements:
|
||||||
set_location(stmt, assert_.lineno, assert_.col_offset)
|
set_location(stmt, assert_.lineno, assert_.col_offset)
|
||||||
|
@ -240,21 +254,32 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
return name, self.explanation_param(expr)
|
return name, self.explanation_param(expr)
|
||||||
|
|
||||||
def visit_BoolOp(self, boolop):
|
def visit_BoolOp(self, boolop):
|
||||||
operands = []
|
res_var = self.variable()
|
||||||
explanations = []
|
expl_list = self.assign(ast.List([], ast.Load()))
|
||||||
|
app = ast.Attribute(expl_list, "append", ast.Load())
|
||||||
|
is_or = isinstance(boolop.op, ast.Or)
|
||||||
|
body = save = self.statements
|
||||||
|
levels = len(boolop.values) - 1
|
||||||
self.push_format_context()
|
self.push_format_context()
|
||||||
for operand in boolop.values:
|
# Process each operand, short-circuting if needed.
|
||||||
res, explanation = self.visit(operand)
|
for i, v in enumerate(boolop.values):
|
||||||
operands.append(res)
|
res, expl = self.visit(v)
|
||||||
explanations.append(explanation)
|
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
||||||
expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load())
|
call = ast.Call(app, [ast.Str(expl)], [], None, None)
|
||||||
is_or = ast.Num(isinstance(boolop.op, ast.Or))
|
body.append(ast.Expr(call))
|
||||||
expl_template = self.helper("format_boolop",
|
if i < levels:
|
||||||
ast.Tuple(operands, ast.Load()), expls,
|
inner = []
|
||||||
is_or)
|
cond = res
|
||||||
|
if is_or:
|
||||||
|
cond = ast.UnaryOp(ast.Not(), cond)
|
||||||
|
self.enter_cond(cond, inner)
|
||||||
|
self.statements = body = inner
|
||||||
|
# Leave all conditions.
|
||||||
|
self.leave_cond(levels)
|
||||||
|
self.statements = save
|
||||||
|
expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or))
|
||||||
expl = self.pop_format_context(expl_template)
|
expl = self.pop_format_context(expl_template)
|
||||||
res = self.assign(ast.BoolOp(boolop.op, operands))
|
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
|
||||||
return res, self.explanation_param(expl)
|
|
||||||
|
|
||||||
def visit_UnaryOp(self, unary):
|
def visit_UnaryOp(self, unary):
|
||||||
pattern = unary_map[unary.op.__class__]
|
pattern = unary_map[unary.op.__class__]
|
||||||
|
|
|
@ -128,6 +128,10 @@ class TestAssertionRewrite:
|
||||||
f = g = False
|
f = g = False
|
||||||
assert f or g
|
assert f or g
|
||||||
assert getmsg(f) == "assert (False or False)"
|
assert getmsg(f) == "assert (False or False)"
|
||||||
|
def f():
|
||||||
|
f = g = False
|
||||||
|
assert not f and not g
|
||||||
|
getmsg(f, must_pass=True)
|
||||||
def f():
|
def f():
|
||||||
f = True
|
f = True
|
||||||
g = False
|
g = False
|
||||||
|
@ -135,10 +139,13 @@ class TestAssertionRewrite:
|
||||||
getmsg(f, must_pass=True)
|
getmsg(f, must_pass=True)
|
||||||
|
|
||||||
def test_short_circut_evaluation(self):
|
def test_short_circut_evaluation(self):
|
||||||
pytest.xfail("complicated fix; I'm not sure if it's important")
|
|
||||||
def f():
|
def f():
|
||||||
assert True or explode
|
assert True or explode
|
||||||
getmsg(f, must_pass=True)
|
getmsg(f, must_pass=True)
|
||||||
|
def f():
|
||||||
|
x = 1
|
||||||
|
assert x == 1 or x == 2
|
||||||
|
getmsg(f, must_pass=True)
|
||||||
|
|
||||||
def test_unary_op(self):
|
def test_unary_op(self):
|
||||||
def f():
|
def f():
|
||||||
|
|
Loading…
Reference in New Issue