Merge pull request #3895 from nicoddemus/issue-3506
Avoid possible infinite recursion when writing pyc files in assert rewrite
This commit is contained in:
commit
10c1c7c41a
|
@ -0,0 +1 @@
|
|||
Fix possible infinite recursion when writing ``.pyc`` files.
|
|
@ -64,11 +64,16 @@ class AssertionRewritingHook(object):
|
|||
self._rewritten_names = set()
|
||||
self._register_with_pkg_resources()
|
||||
self._must_rewrite = set()
|
||||
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
|
||||
# which might result in infinite recursion (#3506)
|
||||
self._writing_pyc = False
|
||||
|
||||
def set_session(self, session):
|
||||
self.session = session
|
||||
|
||||
def find_module(self, name, path=None):
|
||||
if self._writing_pyc:
|
||||
return None
|
||||
state = self.config._assertstate
|
||||
state.trace("find_module called for: %s" % name)
|
||||
names = name.rsplit(".", 1)
|
||||
|
@ -151,7 +156,11 @@ class AssertionRewritingHook(object):
|
|||
# Probably a SyntaxError in the test.
|
||||
return None
|
||||
if write:
|
||||
self._writing_pyc = True
|
||||
try:
|
||||
_write_pyc(state, co, source_stat, pyc)
|
||||
finally:
|
||||
self._writing_pyc = False
|
||||
else:
|
||||
state.trace("found cached rewritten pyc for %r" % (fn,))
|
||||
self.modules[name] = co, pyc
|
||||
|
|
|
@ -1124,3 +1124,32 @@ def test_simple_failure():
|
|||
|
||||
result = testdir.runpytest()
|
||||
result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3")
|
||||
|
||||
|
||||
def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
|
||||
"""Fix infinite recursion when writing pyc files: if an import happens to be triggered when writing the pyc
|
||||
file, this would cause another call to the hook, which would trigger another pyc writing, which could
|
||||
trigger another import, and so on. (#3506)"""
|
||||
from _pytest.assertion import rewrite
|
||||
|
||||
testdir.syspathinsert()
|
||||
testdir.makepyfile(test_foo="def test_foo(): pass")
|
||||
testdir.makepyfile(test_bar="def test_bar(): pass")
|
||||
|
||||
original_write_pyc = rewrite._write_pyc
|
||||
|
||||
write_pyc_called = []
|
||||
|
||||
def spy_write_pyc(*args, **kwargs):
|
||||
# make a note that we have called _write_pyc
|
||||
write_pyc_called.append(True)
|
||||
# try to import a module at this point: we should not try to rewrite this module
|
||||
assert hook.find_module("test_bar") is None
|
||||
return original_write_pyc(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(rewrite, "_write_pyc", spy_write_pyc)
|
||||
monkeypatch.setattr(sys, "dont_write_bytecode", False)
|
||||
|
||||
hook = AssertionRewritingHook(pytestconfig)
|
||||
assert hook.find_module("test_foo") is not None
|
||||
assert len(write_pyc_called) == 1
|
||||
|
|
Loading…
Reference in New Issue