diff --git a/changelog/3506.bugfix.rst b/changelog/3506.bugfix.rst new file mode 100644 index 000000000..ccce61d04 --- /dev/null +++ b/changelog/3506.bugfix.rst @@ -0,0 +1 @@ +Fix possible infinite recursion when writing ``.pyc`` files. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 5cf63a063..a48a931ac 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -64,11 +64,16 @@ class AssertionRewritingHook(object): self._rewritten_names = set() self._register_with_pkg_resources() self._must_rewrite = set() + # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, + # which might result in infinite recursion (#3506) + self._writing_pyc = False def set_session(self, session): self.session = session def find_module(self, name, path=None): + if self._writing_pyc: + return None state = self.config._assertstate state.trace("find_module called for: %s" % name) names = name.rsplit(".", 1) @@ -151,7 +156,11 @@ class AssertionRewritingHook(object): # Probably a SyntaxError in the test. return None if write: - _write_pyc(state, co, source_stat, pyc) + self._writing_pyc = True + try: + _write_pyc(state, co, source_stat, pyc) + finally: + self._writing_pyc = False else: state.trace("found cached rewritten pyc for %r" % (fn,)) self.modules[name] = co, pyc diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 79e7cf0e3..c436ab0de 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1124,3 +1124,32 @@ def test_simple_failure(): result = testdir.runpytest() result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3") + + +def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch): + """Fix infinite recursion when writing pyc files: if an import happens to be triggered when writing the pyc + file, this would cause another call to the hook, which would trigger another pyc writing, which could + trigger another import, and so on. (#3506)""" + from _pytest.assertion import rewrite + + testdir.syspathinsert() + testdir.makepyfile(test_foo="def test_foo(): pass") + testdir.makepyfile(test_bar="def test_bar(): pass") + + original_write_pyc = rewrite._write_pyc + + write_pyc_called = [] + + def spy_write_pyc(*args, **kwargs): + # make a note that we have called _write_pyc + write_pyc_called.append(True) + # try to import a module at this point: we should not try to rewrite this module + assert hook.find_module("test_bar") is None + return original_write_pyc(*args, **kwargs) + + monkeypatch.setattr(rewrite, "_write_pyc", spy_write_pyc) + monkeypatch.setattr(sys, "dont_write_bytecode", False) + + hook = AssertionRewritingHook(pytestconfig) + assert hook.find_module("test_foo") is not None + assert len(write_pyc_called) == 1