clear instead of deleting temporary assertion variables
This commit is contained in:
parent
661a8a4a92
commit
8b211983ff
|
@ -374,7 +374,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
"""Get a new variable."""
|
||||
# Use a character invalid in python identifiers to avoid clashing.
|
||||
name = "@py_assert" + str(next(self.variable_counter))
|
||||
self.variables[self.cond_chain].add(name)
|
||||
self.variables.append(name)
|
||||
return name
|
||||
|
||||
def assign(self, expr):
|
||||
|
@ -437,7 +437,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
return [assert_]
|
||||
self.statements = []
|
||||
self.cond_chain = ()
|
||||
self.variables = collections.defaultdict(set)
|
||||
self.variables = []
|
||||
self.variable_counter = itertools.count()
|
||||
self.stack = []
|
||||
self.on_failure = []
|
||||
|
@ -459,22 +459,11 @@ class AssertionRewriter(ast.NodeVisitor):
|
|||
else:
|
||||
raise_ = ast.Raise(exc, None, None)
|
||||
body.append(raise_)
|
||||
# Delete temporary variables. This requires a bit cleverness about the
|
||||
# order, so we don't delete variables that are themselves conditions for
|
||||
# later variables.
|
||||
for chain in sorted(self.variables, key=len, reverse=True):
|
||||
if chain:
|
||||
where = []
|
||||
if len(chain) > 1:
|
||||
cond = ast.BoolOp(ast.And(), list(chain))
|
||||
else:
|
||||
cond = chain[0]
|
||||
self.statements.append(ast.If(cond, where, []))
|
||||
else:
|
||||
where = self.statements
|
||||
v = self.variables[chain]
|
||||
names = [ast.Name(name, ast.Del()) for name in v]
|
||||
where.append(ast.Delete(names))
|
||||
# Clear temporary variables by setting them to None.
|
||||
if self.variables:
|
||||
variables = [ast.Name(name, ast.Store()) for name in self.variables]
|
||||
clear = ast.Assign(variables, ast.Name("None", ast.Load()))
|
||||
self.statements.append(clear)
|
||||
# Fix line numbers.
|
||||
for stmt in self.statements:
|
||||
set_location(stmt, assert_.lineno, assert_.col_offset)
|
||||
|
|
Loading…
Reference in New Issue