simplify rewrite-on-import

Use load_module on the import hook to load the rewritten module. This allows the
removal of the complicated code related to copying pyc files in and out of the
cache location. It also plays more nicely with parallel py.test processes like
the ones found in xdist.
This commit is contained in:
Benjamin Peterson 2011-07-06 23:24:04 -05:00
parent df85ddf0d2
commit c13fa886d9
3 changed files with 79 additions and 77 deletions

View File

@ -28,7 +28,6 @@ class AssertionState:
def __init__(self, config, mode):
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.pycs = []
def pytest_configure(config):
mode = config.getvalue("assertmode")
@ -62,8 +61,6 @@ def pytest_configure(config):
config._assertstate.trace("configured with mode set to %r" % (mode,))
def pytest_unconfigure(config):
if config._assertstate.mode == "rewrite":
rewrite._drain_pycs(config._assertstate)
hook = config._assertstate.hook
if hook is not None:
sys.meta_path.remove(hook)
@ -77,8 +74,6 @@ def pytest_collection(session):
hook.set_session(session)
def pytest_sessionfinish(session):
if session.config._assertstate.mode == "rewrite":
rewrite._drain_pycs(session.config._assertstate)
hook = session.config._assertstate.hook
if hook is not None:
hook.session = None

View File

@ -8,6 +8,7 @@ import marshal
import os
import struct
import sys
import types
import py
from _pytest.assertion import util
@ -22,14 +23,11 @@ else:
del ver
class AssertionRewritingHook(object):
"""Import hook which rewrites asserts.
Note this hook doesn't load modules itself. It uses find_module to write a
fake pyc, so the normal import system will find it.
"""
"""Import hook which rewrites asserts."""
def __init__(self):
self.session = None
self.modules = {}
def set_session(self, session):
self.fnpats = session.config.getini("python_files")
@ -77,37 +75,53 @@ class AssertionRewritingHook(object):
return None
finally:
self.session = sess
# This looks like a test file, so rewrite it. This is the most magical
# part of the process: load the source, rewrite the asserts, and write a
# fake pyc, so that it'll be loaded when the module is imported. This is
# complicated by the fact we cache rewritten pycs.
pyc = _compute_pyc_location(fn_pypath)
state.pycs.append(pyc)
cache_fn = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
cache = py.path.local(fn_pypath.dirname).join("__pycache__", cache_fn)
if _use_cached_pyc(fn_pypath, cache):
state.trace("found cached rewritten pyc for %r" % (fn,))
_atomic_copy(cache, pyc)
else:
# The requested module looks like a test file, so rewrite it. This is
# the most magical part of the process: load the source, rewrite the
# asserts, and load the rewritten source. We also cache the rewritten
# module code in a special pyc. We must be aware of the possibility of
# concurrent py.test processes rewriting and loading pycs. To avoid
# tricky race conditions, we maintain the following invariant: The
# cached pyc is always a complete, valid pyc. Operations on it must be
# atomic. POSIX's atomic rename comes in handy.
cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
py.path.local(cache_dir).ensure(dir=True)
cache_name = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
pyc = os.path.join(cache_dir, cache_name)
co = _read_pyc(fn_pypath, pyc)
if co is None:
state.trace("rewriting %r" % (fn,))
_make_rewritten_pyc(state, fn_pypath, pyc)
# Try cache it in the __pycache__ directory.
_cache_pyc(state, pyc, cache)
return None
def _drain_pycs(state):
for pyc in state.pycs:
try:
pyc.remove()
except py.error.ENOENT:
state.trace("couldn't find pyc: %r" % (pyc,))
co = _make_rewritten_pyc(state, fn_pypath, pyc)
if co is None:
# Probably a SyntaxError in the module.
return None
else:
state.trace("removed pyc: %r" % (pyc,))
del state.pycs[:]
state.trace("found cached rewritten pyc for %r" % (fn,))
self.modules[name] = co, pyc
return self
def load_module(self, name):
co, pyc = self.modules.pop(name)
# I wish I could just call imp.load_compiled here, but __file__ has to
# be set properly. In Python 3.2+, this all would be handled correctly
# by load_compiled.
mod = sys.modules[name] = imp.new_module(name)
try:
mod.__file__ = co.co_filename
# Normally, this attribute is 3.2+.
mod.__cached__ = pyc
exec co in mod.__dict__
except:
del sys.modules[name]
raise
return sys.modules[name]
def _write_pyc(co, source_path, pyc):
# Technically, we don't have to have the same pyc format as (C)Python, since
# these "pycs" should never be seen by builtin import. However, there's
# little reason deviate, and I hope sometime to be able to use
# imp.load_compiled to load them. (See the comment in load_module above.)
mtime = int(source_path.mtime())
fp = pyc.open("wb")
fp = open(pyc, "wb")
try:
fp.write(imp.get_magic())
fp.write(struct.pack("<l", mtime))
@ -116,6 +130,11 @@ def _write_pyc(co, source_path, pyc):
fp.close()
def _make_rewritten_pyc(state, fn, pyc):
"""Try to rewrite *fn* and dump the rewritten code to *pyc*.
Return the code object of the rewritten module on success. Return None if
there are problems parsing or compiling the module.
"""
try:
source = fn.read("rb")
except EnvironmentError:
@ -134,44 +153,40 @@ def _make_rewritten_pyc(state, fn, pyc):
# assertion rewriting, but I don't know of a fast way to tell.
state.trace("failed to compile: %r" % (fn,))
return None
_write_pyc(co, fn, pyc)
# Dump the code object into a file specific to this process.
proc_pyc = pyc + "." + str(os.getpid())
_write_pyc(co, fn, proc_pyc)
# Atomically replace the pyc.
os.rename(proc_pyc, pyc)
return co
def _compute_pyc_location(source_path):
if hasattr(imp, "cache_from_source"):
# Handle PEP 3147 pycs.
pyc = py.path.local(imp.cache_from_source(str(source_path)))
pyc.ensure()
else:
pyc = source_path + "c"
return pyc
def _read_pyc(source, pyc):
"""Possibly read a py.test pyc containing rewritten code.
def _use_cached_pyc(source, cache):
Return rewritten code if successful or None if not.
"""
try:
fp = open(pyc, "rb")
except IOError:
return None
try:
mtime = int(source.mtime())
fp = cache.open("rb")
try:
mtime = int(source.mtime())
data = fp.read(8)
finally:
fp.close()
except EnvironmentError:
return False
# Check for invalid or out of date pyc file.
return (len(data) == 8 and
data[:4] == imp.get_magic() and
struct.unpack("<l", data[4:])[0] == mtime)
def _cache_pyc(state, pyc, cache):
try:
cache.dirpath().ensure(dir=True)
_atomic_copy(pyc, cache)
except EnvironmentError:
state.trace("failed to cache %r as %r" % (pyc, cache))
def _atomic_copy(orig, to):
"""An atomic copy (at least on POSIX platforms)"""
temp = py.path.local(orig.strpath + str(os.getpid()))
orig.copy(temp)
temp.rename(to)
except EnvironmentError:
return None
# Check for invalid or out of date pyc file.
if (len(data) != 8 or
data[:4] != imp.get_magic() or
struct.unpack("<l", data[4:])[0] != mtime):
return None
co = marshal.load(fp)
if not isinstance(co, types.CodeType):
# That's interesting....
return None
return co
finally:
fp.close()
def rewrite_asserts(mod):

View File

@ -246,11 +246,3 @@ def test_warn_missing(testdir):
result.stderr.fnmatch_lines([
"*WARNING*assert statements are not executed*",
])
def test_load_fake_pyc(testdir):
rewrite = pytest.importorskip("_pytest.assertion.rewrite")
path = testdir.makepyfile(a_random_module="x = 'hello'")
co = compile("x = 'bye'", str(path), "exec")
rewrite._write_pyc(co, path, rewrite._compute_pyc_location(path))
mod = path.pyimport()
assert mod.x == "bye"