diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 2afe76b82..a79a52157 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -135,7 +135,7 @@ class AssertionRewritingHook: co = _read_pyc(fn, pyc, state.trace) if co is None: state.trace("rewriting {!r}".format(fn)) - source_stat, co = _rewrite_test(fn) + source_stat, co = _rewrite_test(fn, self.config) if write: self._writing_pyc = True try: @@ -279,13 +279,13 @@ def _write_pyc(state, co, source_stat, pyc): return True -def _rewrite_test(fn): +def _rewrite_test(fn, config): """read and rewrite *fn* and return the code object.""" stat = os.stat(fn) with open(fn, "rb") as f: source = f.read() tree = ast.parse(source, filename=fn) - rewrite_asserts(tree, fn) + rewrite_asserts(tree, fn, config) co = compile(tree, fn, "exec", dont_inherit=True) return stat, co @@ -327,9 +327,9 @@ def _read_pyc(source, pyc, trace=lambda x: None): return co -def rewrite_asserts(mod, module_path=None): +def rewrite_asserts(mod, module_path=None, config=None): """Rewrite the assert statements in mod.""" - AssertionRewriter(module_path).run(mod) + AssertionRewriter(module_path, config).run(mod) def _saferepr(obj): @@ -523,7 +523,7 @@ class AssertionRewriter(ast.NodeVisitor): """ - def __init__(self, module_path): + def __init__(self, module_path, config): super().__init__() self.module_path = module_path self.config = config @@ -761,7 +761,7 @@ class AssertionRewriter(ast.NodeVisitor): ) # If any hooks implement assert_pass hook hook_impl_test = ast.If( - self.helper("_check_if_assertionpass_impl"), [hook_call_pass], [] + self.helper("_check_if_assertion_pass_impl"), [hook_call_pass], [] ) statements_pass = [] statements_pass.extend(self.expl_stmts)