simplify rewrite-on-import

Use load_module on the import hook to load the rewritten module. This allows the
removal of the complicated code related to copying pyc files in and out of the
cache location. It also plays more nicely with parallel py.test processes like
the ones found in xdist.
This commit is contained in:
Benjamin Peterson 2011-07-06 23:24:04 -05:00
parent df85ddf0d2
commit c13fa886d9
3 changed files with 79 additions and 77 deletions

View File

@ -28,7 +28,6 @@ class AssertionState:
def __init__(self, config, mode): def __init__(self, config, mode):
self.mode = mode self.mode = mode
self.trace = config.trace.root.get("assertion") self.trace = config.trace.root.get("assertion")
self.pycs = []
def pytest_configure(config): def pytest_configure(config):
mode = config.getvalue("assertmode") mode = config.getvalue("assertmode")
@ -62,8 +61,6 @@ def pytest_configure(config):
config._assertstate.trace("configured with mode set to %r" % (mode,)) config._assertstate.trace("configured with mode set to %r" % (mode,))
def pytest_unconfigure(config): def pytest_unconfigure(config):
if config._assertstate.mode == "rewrite":
rewrite._drain_pycs(config._assertstate)
hook = config._assertstate.hook hook = config._assertstate.hook
if hook is not None: if hook is not None:
sys.meta_path.remove(hook) sys.meta_path.remove(hook)
@ -77,8 +74,6 @@ def pytest_collection(session):
hook.set_session(session) hook.set_session(session)
def pytest_sessionfinish(session): def pytest_sessionfinish(session):
if session.config._assertstate.mode == "rewrite":
rewrite._drain_pycs(session.config._assertstate)
hook = session.config._assertstate.hook hook = session.config._assertstate.hook
if hook is not None: if hook is not None:
hook.session = None hook.session = None

View File

@ -8,6 +8,7 @@ import marshal
import os import os
import struct import struct
import sys import sys
import types
import py import py
from _pytest.assertion import util from _pytest.assertion import util
@ -22,14 +23,11 @@ else:
del ver del ver
class AssertionRewritingHook(object): class AssertionRewritingHook(object):
"""Import hook which rewrites asserts. """Import hook which rewrites asserts."""
Note this hook doesn't load modules itself. It uses find_module to write a
fake pyc, so the normal import system will find it.
"""
def __init__(self): def __init__(self):
self.session = None self.session = None
self.modules = {}
def set_session(self, session): def set_session(self, session):
self.fnpats = session.config.getini("python_files") self.fnpats = session.config.getini("python_files")
@ -77,37 +75,53 @@ class AssertionRewritingHook(object):
return None return None
finally: finally:
self.session = sess self.session = sess
# This looks like a test file, so rewrite it. This is the most magical # The requested module looks like a test file, so rewrite it. This is
# part of the process: load the source, rewrite the asserts, and write a # the most magical part of the process: load the source, rewrite the
# fake pyc, so that it'll be loaded when the module is imported. This is # asserts, and load the rewritten source. We also cache the rewritten
# complicated by the fact we cache rewritten pycs. # module code in a special pyc. We must be aware of the possibility of
pyc = _compute_pyc_location(fn_pypath) # concurrent py.test processes rewriting and loading pycs. To avoid
state.pycs.append(pyc) # tricky race conditions, we maintain the following invariant: The
cache_fn = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc" # cached pyc is always a complete, valid pyc. Operations on it must be
cache = py.path.local(fn_pypath.dirname).join("__pycache__", cache_fn) # atomic. POSIX's atomic rename comes in handy.
if _use_cached_pyc(fn_pypath, cache): cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
state.trace("found cached rewritten pyc for %r" % (fn,)) py.path.local(cache_dir).ensure(dir=True)
_atomic_copy(cache, pyc) cache_name = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
else: pyc = os.path.join(cache_dir, cache_name)
co = _read_pyc(fn_pypath, pyc)
if co is None:
state.trace("rewriting %r" % (fn,)) state.trace("rewriting %r" % (fn,))
_make_rewritten_pyc(state, fn_pypath, pyc) co = _make_rewritten_pyc(state, fn_pypath, pyc)
# Try cache it in the __pycache__ directory. if co is None:
_cache_pyc(state, pyc, cache) # Probably a SyntaxError in the module.
return None return None
def _drain_pycs(state):
for pyc in state.pycs:
try:
pyc.remove()
except py.error.ENOENT:
state.trace("couldn't find pyc: %r" % (pyc,))
else: else:
state.trace("removed pyc: %r" % (pyc,)) state.trace("found cached rewritten pyc for %r" % (fn,))
del state.pycs[:] self.modules[name] = co, pyc
return self
def load_module(self, name):
co, pyc = self.modules.pop(name)
# I wish I could just call imp.load_compiled here, but __file__ has to
# be set properly. In Python 3.2+, this all would be handled correctly
# by load_compiled.
mod = sys.modules[name] = imp.new_module(name)
try:
mod.__file__ = co.co_filename
# Normally, this attribute is 3.2+.
mod.__cached__ = pyc
exec co in mod.__dict__
except:
del sys.modules[name]
raise
return sys.modules[name]
def _write_pyc(co, source_path, pyc): def _write_pyc(co, source_path, pyc):
# Technically, we don't have to have the same pyc format as (C)Python, since
# these "pycs" should never be seen by builtin import. However, there's
# little reason deviate, and I hope sometime to be able to use
# imp.load_compiled to load them. (See the comment in load_module above.)
mtime = int(source_path.mtime()) mtime = int(source_path.mtime())
fp = pyc.open("wb") fp = open(pyc, "wb")
try: try:
fp.write(imp.get_magic()) fp.write(imp.get_magic())
fp.write(struct.pack("<l", mtime)) fp.write(struct.pack("<l", mtime))
@ -116,6 +130,11 @@ def _write_pyc(co, source_path, pyc):
fp.close() fp.close()
def _make_rewritten_pyc(state, fn, pyc): def _make_rewritten_pyc(state, fn, pyc):
"""Try to rewrite *fn* and dump the rewritten code to *pyc*.
Return the code object of the rewritten module on success. Return None if
there are problems parsing or compiling the module.
"""
try: try:
source = fn.read("rb") source = fn.read("rb")
except EnvironmentError: except EnvironmentError:
@ -134,44 +153,40 @@ def _make_rewritten_pyc(state, fn, pyc):
# assertion rewriting, but I don't know of a fast way to tell. # assertion rewriting, but I don't know of a fast way to tell.
state.trace("failed to compile: %r" % (fn,)) state.trace("failed to compile: %r" % (fn,))
return None return None
_write_pyc(co, fn, pyc) # Dump the code object into a file specific to this process.
proc_pyc = pyc + "." + str(os.getpid())
_write_pyc(co, fn, proc_pyc)
# Atomically replace the pyc.
os.rename(proc_pyc, pyc)
return co
def _compute_pyc_location(source_path): def _read_pyc(source, pyc):
if hasattr(imp, "cache_from_source"): """Possibly read a py.test pyc containing rewritten code.
# Handle PEP 3147 pycs.
pyc = py.path.local(imp.cache_from_source(str(source_path)))
pyc.ensure()
else:
pyc = source_path + "c"
return pyc
def _use_cached_pyc(source, cache): Return rewritten code if successful or None if not.
"""
try:
fp = open(pyc, "rb")
except IOError:
return None
try: try:
mtime = int(source.mtime())
fp = cache.open("rb")
try: try:
mtime = int(source.mtime())
data = fp.read(8) data = fp.read(8)
finally: except EnvironmentError:
fp.close() return None
except EnvironmentError: # Check for invalid or out of date pyc file.
return False if (len(data) != 8 or
# Check for invalid or out of date pyc file. data[:4] != imp.get_magic() or
return (len(data) == 8 and struct.unpack("<l", data[4:])[0] != mtime):
data[:4] == imp.get_magic() and return None
struct.unpack("<l", data[4:])[0] == mtime) co = marshal.load(fp)
if not isinstance(co, types.CodeType):
def _cache_pyc(state, pyc, cache): # That's interesting....
try: return None
cache.dirpath().ensure(dir=True) return co
_atomic_copy(pyc, cache) finally:
except EnvironmentError: fp.close()
state.trace("failed to cache %r as %r" % (pyc, cache))
def _atomic_copy(orig, to):
"""An atomic copy (at least on POSIX platforms)"""
temp = py.path.local(orig.strpath + str(os.getpid()))
orig.copy(temp)
temp.rename(to)
def rewrite_asserts(mod): def rewrite_asserts(mod):

View File

@ -246,11 +246,3 @@ def test_warn_missing(testdir):
result.stderr.fnmatch_lines([ result.stderr.fnmatch_lines([
"*WARNING*assert statements are not executed*", "*WARNING*assert statements are not executed*",
]) ])
def test_load_fake_pyc(testdir):
rewrite = pytest.importorskip("_pytest.assertion.rewrite")
path = testdir.makepyfile(a_random_module="x = 'hello'")
co = compile("x = 'bye'", str(path), "exec")
rewrite._write_pyc(co, path, rewrite._compute_pyc_location(path))
mod = path.pyimport()
assert mod.x == "bye"