small refactoring
This commit is contained in:
parent
9e6dfaefd9
commit
9ac818fb5c
|
@ -113,11 +113,16 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
for parent, pos, assert_ in asserts:
|
||||
parent[pos:pos + 1] = self.visit(assert_)
|
||||
|
||||
def assign(self, expr):
|
||||
"""Give *expr* a name."""
|
||||
def variable(self):
|
||||
"""Get a new variable."""
|
||||
# Use a character invalid in python identifiers to avoid clashing.
|
||||
name = "@py_assert" + str(next(self.variable_counter))
|
||||
self.variables.add(name)
|
||||
return name
|
||||
|
||||
def assign(self, expr):
|
||||
"""Give *expr* a name."""
|
||||
name = self.variable()
|
||||
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
|
||||
return ast.Name(name, ast.Load())
|
||||
|
||||
|
@ -275,8 +280,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
def visit_Compare(self, comp):
|
||||
self.push_format_context()
|
||||
left_res, left_expl = self.visit(comp.left)
|
||||
res_variables = ["@py_assert" + str(next(self.variable_counter))
|
||||
for i in range(len(comp.ops))]
|
||||
res_variables = [self.variable() for i in range(len(comp.ops))]
|
||||
load_names = [ast.Name(v, ast.Load()) for v in res_variables]
|
||||
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
|
||||
it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
|
||||
|
|
Loading…
Reference in New Issue