From 9ac818fb5ce42c98270a5ff19707219a3a2f4163 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 19 May 2011 18:32:48 -0500 Subject: [PATCH] small refactoring --- _pytest/assertrewrite.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index 14ddc93ce..9109eadc3 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -113,11 +113,16 @@ class AssertionRewriter(ast.NodeVisitor): for parent, pos, assert_ in asserts: parent[pos:pos + 1] = self.visit(assert_) - def assign(self, expr): - """Give *expr* a name.""" + def variable(self): + """Get a new variable.""" # Use a character invalid in python identifiers to avoid clashing. name = "@py_assert" + str(next(self.variable_counter)) self.variables.add(name) + return name + + def assign(self, expr): + """Give *expr* a name.""" + name = self.variable() self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) return ast.Name(name, ast.Load()) @@ -275,8 +280,7 @@ class AssertionRewriter(ast.NodeVisitor): def visit_Compare(self, comp): self.push_format_context() left_res, left_expl = self.visit(comp.left) - res_variables = ["@py_assert" + str(next(self.variable_counter)) - for i in range(len(comp.ops))] + res_variables = [self.variable() for i in range(len(comp.ops))] load_names = [ast.Name(v, ast.Load()) for v in res_variables] store_names = [ast.Name(v, ast.Store()) for v in res_variables] it = zip(range(len(comp.ops)), comp.ops, comp.comparators)