small refactoring

This commit is contained in:
Benjamin Peterson 2011-05-19 18:32:48 -05:00
parent 9e6dfaefd9
commit 9ac818fb5c
1 changed files with 8 additions and 4 deletions

View File

@ -113,11 +113,16 @@ class AssertionRewriter(ast.NodeVisitor):
for parent, pos, assert_ in asserts: for parent, pos, assert_ in asserts:
parent[pos:pos + 1] = self.visit(assert_) parent[pos:pos + 1] = self.visit(assert_)
def assign(self, expr): def variable(self):
"""Give *expr* a name.""" """Get a new variable."""
# Use a character invalid in python identifiers to avoid clashing. # Use a character invalid in python identifiers to avoid clashing.
name = "@py_assert" + str(next(self.variable_counter)) name = "@py_assert" + str(next(self.variable_counter))
self.variables.add(name) self.variables.add(name)
return name
def assign(self, expr):
"""Give *expr* a name."""
name = self.variable()
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
return ast.Name(name, ast.Load()) return ast.Name(name, ast.Load())
@ -275,8 +280,7 @@ class AssertionRewriter(ast.NodeVisitor):
def visit_Compare(self, comp): def visit_Compare(self, comp):
self.push_format_context() self.push_format_context()
left_res, left_expl = self.visit(comp.left) left_res, left_expl = self.visit(comp.left)
res_variables = ["@py_assert" + str(next(self.variable_counter)) res_variables = [self.variable() for i in range(len(comp.ops))]
for i in range(len(comp.ops))]
load_names = [ast.Name(v, ast.Load()) for v in res_variables] load_names = [ast.Name(v, ast.Load()) for v in res_variables]
store_names = [ast.Name(v, ast.Store()) for v in res_variables] store_names = [ast.Name(v, ast.Store()) for v in res_variables]
it = zip(range(len(comp.ops)), comp.ops, comp.comparators) it = zip(range(len(comp.ops)), comp.ops, comp.comparators)