Added config back to AssertionWriter and fixed typo in check_if_assertion_pass_impl function call.

This commit is contained in:
Victor Maryama 2019-06-26 19:00:31 +02:00
parent 6f851e6cbb
commit 53234bf613
1 changed files with 7 additions and 7 deletions

View File

@ -135,7 +135,7 @@ class AssertionRewritingHook:
co = _read_pyc(fn, pyc, state.trace) co = _read_pyc(fn, pyc, state.trace)
if co is None: if co is None:
state.trace("rewriting {!r}".format(fn)) state.trace("rewriting {!r}".format(fn))
source_stat, co = _rewrite_test(fn) source_stat, co = _rewrite_test(fn, self.config)
if write: if write:
self._writing_pyc = True self._writing_pyc = True
try: try:
@ -279,13 +279,13 @@ def _write_pyc(state, co, source_stat, pyc):
return True return True
def _rewrite_test(fn): def _rewrite_test(fn, config):
"""read and rewrite *fn* and return the code object.""" """read and rewrite *fn* and return the code object."""
stat = os.stat(fn) stat = os.stat(fn)
with open(fn, "rb") as f: with open(fn, "rb") as f:
source = f.read() source = f.read()
tree = ast.parse(source, filename=fn) tree = ast.parse(source, filename=fn)
rewrite_asserts(tree, fn) rewrite_asserts(tree, fn, config)
co = compile(tree, fn, "exec", dont_inherit=True) co = compile(tree, fn, "exec", dont_inherit=True)
return stat, co return stat, co
@ -327,9 +327,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co return co
def rewrite_asserts(mod, module_path=None): def rewrite_asserts(mod, module_path=None, config=None):
"""Rewrite the assert statements in mod.""" """Rewrite the assert statements in mod."""
AssertionRewriter(module_path).run(mod) AssertionRewriter(module_path, config).run(mod)
def _saferepr(obj): def _saferepr(obj):
@ -523,7 +523,7 @@ class AssertionRewriter(ast.NodeVisitor):
""" """
def __init__(self, module_path): def __init__(self, module_path, config):
super().__init__() super().__init__()
self.module_path = module_path self.module_path = module_path
self.config = config self.config = config
@ -761,7 +761,7 @@ class AssertionRewriter(ast.NodeVisitor):
) )
# If any hooks implement assert_pass hook # If any hooks implement assert_pass hook
hook_impl_test = ast.If( hook_impl_test = ast.If(
self.helper("_check_if_assertionpass_impl"), [hook_call_pass], [] self.helper("_check_if_assertion_pass_impl"), [hook_call_pass], []
) )
statements_pass = [] statements_pass = []
statements_pass.extend(self.expl_stmts) statements_pass.extend(self.expl_stmts)