Merge pull request #3895 from nicoddemus/issue-3506
Avoid possible infinite recursion when writing pyc files in assert rewrite
This commit is contained in:
commit
10c1c7c41a
|
@ -0,0 +1 @@
|
||||||
|
Fix possible infinite recursion when writing ``.pyc`` files.
|
|
@ -64,11 +64,16 @@ class AssertionRewritingHook(object):
|
||||||
self._rewritten_names = set()
|
self._rewritten_names = set()
|
||||||
self._register_with_pkg_resources()
|
self._register_with_pkg_resources()
|
||||||
self._must_rewrite = set()
|
self._must_rewrite = set()
|
||||||
|
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
|
||||||
|
# which might result in infinite recursion (#3506)
|
||||||
|
self._writing_pyc = False
|
||||||
|
|
||||||
def set_session(self, session):
|
def set_session(self, session):
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
||||||
def find_module(self, name, path=None):
|
def find_module(self, name, path=None):
|
||||||
|
if self._writing_pyc:
|
||||||
|
return None
|
||||||
state = self.config._assertstate
|
state = self.config._assertstate
|
||||||
state.trace("find_module called for: %s" % name)
|
state.trace("find_module called for: %s" % name)
|
||||||
names = name.rsplit(".", 1)
|
names = name.rsplit(".", 1)
|
||||||
|
@ -151,7 +156,11 @@ class AssertionRewritingHook(object):
|
||||||
# Probably a SyntaxError in the test.
|
# Probably a SyntaxError in the test.
|
||||||
return None
|
return None
|
||||||
if write:
|
if write:
|
||||||
_write_pyc(state, co, source_stat, pyc)
|
self._writing_pyc = True
|
||||||
|
try:
|
||||||
|
_write_pyc(state, co, source_stat, pyc)
|
||||||
|
finally:
|
||||||
|
self._writing_pyc = False
|
||||||
else:
|
else:
|
||||||
state.trace("found cached rewritten pyc for %r" % (fn,))
|
state.trace("found cached rewritten pyc for %r" % (fn,))
|
||||||
self.modules[name] = co, pyc
|
self.modules[name] = co, pyc
|
||||||
|
|
|
@ -1124,3 +1124,32 @@ def test_simple_failure():
|
||||||
|
|
||||||
result = testdir.runpytest()
|
result = testdir.runpytest()
|
||||||
result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3")
|
result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3")
|
||||||
|
|
||||||
|
|
||||||
|
def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
|
||||||
|
"""Fix infinite recursion when writing pyc files: if an import happens to be triggered when writing the pyc
|
||||||
|
file, this would cause another call to the hook, which would trigger another pyc writing, which could
|
||||||
|
trigger another import, and so on. (#3506)"""
|
||||||
|
from _pytest.assertion import rewrite
|
||||||
|
|
||||||
|
testdir.syspathinsert()
|
||||||
|
testdir.makepyfile(test_foo="def test_foo(): pass")
|
||||||
|
testdir.makepyfile(test_bar="def test_bar(): pass")
|
||||||
|
|
||||||
|
original_write_pyc = rewrite._write_pyc
|
||||||
|
|
||||||
|
write_pyc_called = []
|
||||||
|
|
||||||
|
def spy_write_pyc(*args, **kwargs):
|
||||||
|
# make a note that we have called _write_pyc
|
||||||
|
write_pyc_called.append(True)
|
||||||
|
# try to import a module at this point: we should not try to rewrite this module
|
||||||
|
assert hook.find_module("test_bar") is None
|
||||||
|
return original_write_pyc(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(rewrite, "_write_pyc", spy_write_pyc)
|
||||||
|
monkeypatch.setattr(sys, "dont_write_bytecode", False)
|
||||||
|
|
||||||
|
hook = AssertionRewritingHook(pytestconfig)
|
||||||
|
assert hook.find_module("test_foo") is not None
|
||||||
|
assert len(write_pyc_called) == 1
|
||||||
|
|
Loading…
Reference in New Issue