Expand list comprehensions as well

This commit is contained in:
Tomer Keren 2019-05-09 18:51:03 +03:00
parent e37ff3042e
commit 437d6452c1
2 changed files with 4 additions and 4 deletions

View File

@ -991,13 +991,13 @@ warn_explicit(
def visit_all(self, call): def visit_all(self, call):
"""Special rewrite for the builtin all function, see #5602""" """Special rewrite for the builtin all function, see #5602"""
if not isinstance(call.args[0], ast.GeneratorExp): if not isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp)):
return return
gen_exp = call.args[0] gen_exp = call.args[0]
assertion_module = ast.Module( assertion_module = ast.Module(
body=[ast.Assert(test=gen_exp.elt, lineno=1, msg="", col_offset=1)] body=[ast.Assert(test=gen_exp.elt, lineno=1, msg="", col_offset=1)]
) )
AssertionRewriter(None, None).run(assertion_module) AssertionRewriter(module_path=None, config=None).run(assertion_module)
for_loop = ast.For( for_loop = ast.For(
iter=gen_exp.generators[0].iter, iter=gen_exp.generators[0].iter,
target=gen_exp.generators[0].target, target=gen_exp.generators[0].target,

View File

@ -677,7 +677,7 @@ class TestAssertionRewrite(object):
assert "UnicodeDecodeError" not in msg assert "UnicodeDecodeError" not in msg
assert "UnicodeEncodeError" not in msg assert "UnicodeEncodeError" not in msg
def test_generator(self, testdir): def test_unroll_generator(self, testdir):
testdir.makepyfile( testdir.makepyfile(
""" """
def check_even(num): def check_even(num):
@ -692,7 +692,7 @@ class TestAssertionRewrite(object):
result = testdir.runpytest() result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"]) result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])
def test_list_comprehension(self, testdir): def test_unroll_list_comprehension(self, testdir):
testdir.makepyfile( testdir.makepyfile(
""" """
def check_even(num): def check_even(num):