Revert "Fix all() unroll for non-generators/non-list comprehensions (#5360)"

This reverts commit 733f43b02e, reversing
changes made to e4fe41ebb7.
This commit is contained in:
Anthony Sottile 2019-06-03 08:34:25 -07:00
parent 5976f36240
commit 2125d04501
3 changed files with 5 additions and 39 deletions

View File

@ -1 +0,0 @@
Fix assertion rewriting of ``all()`` calls to deal with non-generators.

View File

@ -903,21 +903,11 @@ warn_explicit(
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
return res, explanation return res, explanation
@staticmethod
def _is_any_call_with_generator_or_list_comprehension(call):
"""Return True if the Call node is an 'any' call with a generator or list comprehension"""
return (
isinstance(call.func, ast.Name)
and call.func.id == "all"
and len(call.args) == 1
and isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp))
)
def visit_Call(self, call): def visit_Call(self, call):
""" """
visit `ast.Call` nodes visit `ast.Call` nodes
""" """
if self._is_any_call_with_generator_or_list_comprehension(call): if isinstance(call.func, ast.Name) and call.func.id == "all":
return self._visit_all(call) return self._visit_all(call)
new_func, func_expl = self.visit(call.func) new_func, func_expl = self.visit(call.func)
arg_expls = [] arg_expls = []
@ -944,6 +934,8 @@ warn_explicit(
def _visit_all(self, call): def _visit_all(self, call):
"""Special rewrite for the builtin all function, see #5062""" """Special rewrite for the builtin all function, see #5062"""
if not isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp)):
return
gen_exp = call.args[0] gen_exp = call.args[0]
assertion_module = ast.Module( assertion_module = ast.Module(
body=[ast.Assert(test=gen_exp.elt, lineno=1, msg="", col_offset=1)] body=[ast.Assert(test=gen_exp.elt, lineno=1, msg="", col_offset=1)]

View File

@ -656,7 +656,7 @@ class TestAssertionRewrite:
assert "UnicodeDecodeError" not in msg assert "UnicodeDecodeError" not in msg
assert "UnicodeEncodeError" not in msg assert "UnicodeEncodeError" not in msg
def test_unroll_all_generator(self, testdir): def test_unroll_generator(self, testdir):
testdir.makepyfile( testdir.makepyfile(
""" """
def check_even(num): def check_even(num):
@ -671,7 +671,7 @@ class TestAssertionRewrite:
result = testdir.runpytest() result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"]) result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])
def test_unroll_all_list_comprehension(self, testdir): def test_unroll_list_comprehension(self, testdir):
testdir.makepyfile( testdir.makepyfile(
""" """
def check_even(num): def check_even(num):
@ -686,31 +686,6 @@ class TestAssertionRewrite:
result = testdir.runpytest() result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"]) result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])
def test_unroll_all_object(self, testdir):
"""all() for non generators/non list-comprehensions (#5358)"""
testdir.makepyfile(
"""
def test():
assert all((1, 0))
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert False*", "*where False = all((1, 0))*"])
def test_unroll_all_starred(self, testdir):
"""all() for non generators/non list-comprehensions (#5358)"""
testdir.makepyfile(
"""
def test():
x = ((1, 0),)
assert all(*x)
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(
["*assert False*", "*where False = all(*((1, 0),))*"]
)
def test_for_loop(self, testdir): def test_for_loop(self, testdir):
testdir.makepyfile( testdir.makepyfile(
""" """