Merge pull request #3895 from nicoddemus/issue-3506

Avoid possible infinite recursion when writing pyc files in assert rewrite
This commit is contained in:
Bruno Oliveira 2018-08-28 18:16:10 -03:00 committed by GitHub
commit 10c1c7c41a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 1 deletions

View File

@ -0,0 +1 @@
Fix possible infinite recursion when writing ``.pyc`` files.

View File

@ -64,11 +64,16 @@ class AssertionRewritingHook(object):
self._rewritten_names = set() self._rewritten_names = set()
self._register_with_pkg_resources() self._register_with_pkg_resources()
self._must_rewrite = set() self._must_rewrite = set()
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506)
self._writing_pyc = False
def set_session(self, session): def set_session(self, session):
self.session = session self.session = session
def find_module(self, name, path=None): def find_module(self, name, path=None):
if self._writing_pyc:
return None
state = self.config._assertstate state = self.config._assertstate
state.trace("find_module called for: %s" % name) state.trace("find_module called for: %s" % name)
names = name.rsplit(".", 1) names = name.rsplit(".", 1)
@ -151,7 +156,11 @@ class AssertionRewritingHook(object):
# Probably a SyntaxError in the test. # Probably a SyntaxError in the test.
return None return None
if write: if write:
_write_pyc(state, co, source_stat, pyc) self._writing_pyc = True
try:
_write_pyc(state, co, source_stat, pyc)
finally:
self._writing_pyc = False
else: else:
state.trace("found cached rewritten pyc for %r" % (fn,)) state.trace("found cached rewritten pyc for %r" % (fn,))
self.modules[name] = co, pyc self.modules[name] = co, pyc

View File

@ -1124,3 +1124,32 @@ def test_simple_failure():
result = testdir.runpytest() result = testdir.runpytest()
result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3") result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3")
def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
"""Fix infinite recursion when writing pyc files: if an import happens to be triggered when writing the pyc
file, this would cause another call to the hook, which would trigger another pyc writing, which could
trigger another import, and so on. (#3506)"""
from _pytest.assertion import rewrite
testdir.syspathinsert()
testdir.makepyfile(test_foo="def test_foo(): pass")
testdir.makepyfile(test_bar="def test_bar(): pass")
original_write_pyc = rewrite._write_pyc
write_pyc_called = []
def spy_write_pyc(*args, **kwargs):
# make a note that we have called _write_pyc
write_pyc_called.append(True)
# try to import a module at this point: we should not try to rewrite this module
assert hook.find_module("test_bar") is None
return original_write_pyc(*args, **kwargs)
monkeypatch.setattr(rewrite, "_write_pyc", spy_write_pyc)
monkeypatch.setattr(sys, "dont_write_bytecode", False)
hook = AssertionRewritingHook(pytestconfig)
assert hook.find_module("test_foo") is not None
assert len(write_pyc_called) == 1