simplify rewrite-on-import
Use load_module on the import hook to load the rewritten module. This allows the removal of the complicated code related to copying pyc files in and out of the cache location. It also plays more nicely with parallel py.test processes like the ones found in xdist.
This commit is contained in:
parent
df85ddf0d2
commit
c13fa886d9
|
@ -28,7 +28,6 @@ class AssertionState:
|
||||||
def __init__(self, config, mode):
|
def __init__(self, config, mode):
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.trace = config.trace.root.get("assertion")
|
self.trace = config.trace.root.get("assertion")
|
||||||
self.pycs = []
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
mode = config.getvalue("assertmode")
|
mode = config.getvalue("assertmode")
|
||||||
|
@ -62,8 +61,6 @@ def pytest_configure(config):
|
||||||
config._assertstate.trace("configured with mode set to %r" % (mode,))
|
config._assertstate.trace("configured with mode set to %r" % (mode,))
|
||||||
|
|
||||||
def pytest_unconfigure(config):
|
def pytest_unconfigure(config):
|
||||||
if config._assertstate.mode == "rewrite":
|
|
||||||
rewrite._drain_pycs(config._assertstate)
|
|
||||||
hook = config._assertstate.hook
|
hook = config._assertstate.hook
|
||||||
if hook is not None:
|
if hook is not None:
|
||||||
sys.meta_path.remove(hook)
|
sys.meta_path.remove(hook)
|
||||||
|
@ -77,8 +74,6 @@ def pytest_collection(session):
|
||||||
hook.set_session(session)
|
hook.set_session(session)
|
||||||
|
|
||||||
def pytest_sessionfinish(session):
|
def pytest_sessionfinish(session):
|
||||||
if session.config._assertstate.mode == "rewrite":
|
|
||||||
rewrite._drain_pycs(session.config._assertstate)
|
|
||||||
hook = session.config._assertstate.hook
|
hook = session.config._assertstate.hook
|
||||||
if hook is not None:
|
if hook is not None:
|
||||||
hook.session = None
|
hook.session = None
|
||||||
|
|
|
@ -8,6 +8,7 @@ import marshal
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
import py
|
import py
|
||||||
from _pytest.assertion import util
|
from _pytest.assertion import util
|
||||||
|
@ -22,14 +23,11 @@ else:
|
||||||
del ver
|
del ver
|
||||||
|
|
||||||
class AssertionRewritingHook(object):
|
class AssertionRewritingHook(object):
|
||||||
"""Import hook which rewrites asserts.
|
"""Import hook which rewrites asserts."""
|
||||||
|
|
||||||
Note this hook doesn't load modules itself. It uses find_module to write a
|
|
||||||
fake pyc, so the normal import system will find it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.session = None
|
self.session = None
|
||||||
|
self.modules = {}
|
||||||
|
|
||||||
def set_session(self, session):
|
def set_session(self, session):
|
||||||
self.fnpats = session.config.getini("python_files")
|
self.fnpats = session.config.getini("python_files")
|
||||||
|
@ -77,37 +75,53 @@ class AssertionRewritingHook(object):
|
||||||
return None
|
return None
|
||||||
finally:
|
finally:
|
||||||
self.session = sess
|
self.session = sess
|
||||||
# This looks like a test file, so rewrite it. This is the most magical
|
# The requested module looks like a test file, so rewrite it. This is
|
||||||
# part of the process: load the source, rewrite the asserts, and write a
|
# the most magical part of the process: load the source, rewrite the
|
||||||
# fake pyc, so that it'll be loaded when the module is imported. This is
|
# asserts, and load the rewritten source. We also cache the rewritten
|
||||||
# complicated by the fact we cache rewritten pycs.
|
# module code in a special pyc. We must be aware of the possibility of
|
||||||
pyc = _compute_pyc_location(fn_pypath)
|
# concurrent py.test processes rewriting and loading pycs. To avoid
|
||||||
state.pycs.append(pyc)
|
# tricky race conditions, we maintain the following invariant: The
|
||||||
cache_fn = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
|
# cached pyc is always a complete, valid pyc. Operations on it must be
|
||||||
cache = py.path.local(fn_pypath.dirname).join("__pycache__", cache_fn)
|
# atomic. POSIX's atomic rename comes in handy.
|
||||||
if _use_cached_pyc(fn_pypath, cache):
|
cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
|
||||||
state.trace("found cached rewritten pyc for %r" % (fn,))
|
py.path.local(cache_dir).ensure(dir=True)
|
||||||
_atomic_copy(cache, pyc)
|
cache_name = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
|
||||||
else:
|
pyc = os.path.join(cache_dir, cache_name)
|
||||||
|
co = _read_pyc(fn_pypath, pyc)
|
||||||
|
if co is None:
|
||||||
state.trace("rewriting %r" % (fn,))
|
state.trace("rewriting %r" % (fn,))
|
||||||
_make_rewritten_pyc(state, fn_pypath, pyc)
|
co = _make_rewritten_pyc(state, fn_pypath, pyc)
|
||||||
# Try cache it in the __pycache__ directory.
|
if co is None:
|
||||||
_cache_pyc(state, pyc, cache)
|
# Probably a SyntaxError in the module.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _drain_pycs(state):
|
|
||||||
for pyc in state.pycs:
|
|
||||||
try:
|
|
||||||
pyc.remove()
|
|
||||||
except py.error.ENOENT:
|
|
||||||
state.trace("couldn't find pyc: %r" % (pyc,))
|
|
||||||
else:
|
else:
|
||||||
state.trace("removed pyc: %r" % (pyc,))
|
state.trace("found cached rewritten pyc for %r" % (fn,))
|
||||||
del state.pycs[:]
|
self.modules[name] = co, pyc
|
||||||
|
return self
|
||||||
|
|
||||||
|
def load_module(self, name):
|
||||||
|
co, pyc = self.modules.pop(name)
|
||||||
|
# I wish I could just call imp.load_compiled here, but __file__ has to
|
||||||
|
# be set properly. In Python 3.2+, this all would be handled correctly
|
||||||
|
# by load_compiled.
|
||||||
|
mod = sys.modules[name] = imp.new_module(name)
|
||||||
|
try:
|
||||||
|
mod.__file__ = co.co_filename
|
||||||
|
# Normally, this attribute is 3.2+.
|
||||||
|
mod.__cached__ = pyc
|
||||||
|
exec co in mod.__dict__
|
||||||
|
except:
|
||||||
|
del sys.modules[name]
|
||||||
|
raise
|
||||||
|
return sys.modules[name]
|
||||||
|
|
||||||
def _write_pyc(co, source_path, pyc):
|
def _write_pyc(co, source_path, pyc):
|
||||||
|
# Technically, we don't have to have the same pyc format as (C)Python, since
|
||||||
|
# these "pycs" should never be seen by builtin import. However, there's
|
||||||
|
# little reason deviate, and I hope sometime to be able to use
|
||||||
|
# imp.load_compiled to load them. (See the comment in load_module above.)
|
||||||
mtime = int(source_path.mtime())
|
mtime = int(source_path.mtime())
|
||||||
fp = pyc.open("wb")
|
fp = open(pyc, "wb")
|
||||||
try:
|
try:
|
||||||
fp.write(imp.get_magic())
|
fp.write(imp.get_magic())
|
||||||
fp.write(struct.pack("<l", mtime))
|
fp.write(struct.pack("<l", mtime))
|
||||||
|
@ -116,6 +130,11 @@ def _write_pyc(co, source_path, pyc):
|
||||||
fp.close()
|
fp.close()
|
||||||
|
|
||||||
def _make_rewritten_pyc(state, fn, pyc):
|
def _make_rewritten_pyc(state, fn, pyc):
|
||||||
|
"""Try to rewrite *fn* and dump the rewritten code to *pyc*.
|
||||||
|
|
||||||
|
Return the code object of the rewritten module on success. Return None if
|
||||||
|
there are problems parsing or compiling the module.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
source = fn.read("rb")
|
source = fn.read("rb")
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
|
@ -134,44 +153,40 @@ def _make_rewritten_pyc(state, fn, pyc):
|
||||||
# assertion rewriting, but I don't know of a fast way to tell.
|
# assertion rewriting, but I don't know of a fast way to tell.
|
||||||
state.trace("failed to compile: %r" % (fn,))
|
state.trace("failed to compile: %r" % (fn,))
|
||||||
return None
|
return None
|
||||||
_write_pyc(co, fn, pyc)
|
# Dump the code object into a file specific to this process.
|
||||||
|
proc_pyc = pyc + "." + str(os.getpid())
|
||||||
|
_write_pyc(co, fn, proc_pyc)
|
||||||
|
# Atomically replace the pyc.
|
||||||
|
os.rename(proc_pyc, pyc)
|
||||||
|
return co
|
||||||
|
|
||||||
def _compute_pyc_location(source_path):
|
def _read_pyc(source, pyc):
|
||||||
if hasattr(imp, "cache_from_source"):
|
"""Possibly read a py.test pyc containing rewritten code.
|
||||||
# Handle PEP 3147 pycs.
|
|
||||||
pyc = py.path.local(imp.cache_from_source(str(source_path)))
|
|
||||||
pyc.ensure()
|
|
||||||
else:
|
|
||||||
pyc = source_path + "c"
|
|
||||||
return pyc
|
|
||||||
|
|
||||||
def _use_cached_pyc(source, cache):
|
Return rewritten code if successful or None if not.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
fp = open(pyc, "rb")
|
||||||
|
except IOError:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
try:
|
try:
|
||||||
mtime = int(source.mtime())
|
mtime = int(source.mtime())
|
||||||
fp = cache.open("rb")
|
|
||||||
try:
|
|
||||||
data = fp.read(8)
|
data = fp.read(8)
|
||||||
|
except EnvironmentError:
|
||||||
|
return None
|
||||||
|
# Check for invalid or out of date pyc file.
|
||||||
|
if (len(data) != 8 or
|
||||||
|
data[:4] != imp.get_magic() or
|
||||||
|
struct.unpack("<l", data[4:])[0] != mtime):
|
||||||
|
return None
|
||||||
|
co = marshal.load(fp)
|
||||||
|
if not isinstance(co, types.CodeType):
|
||||||
|
# That's interesting....
|
||||||
|
return None
|
||||||
|
return co
|
||||||
finally:
|
finally:
|
||||||
fp.close()
|
fp.close()
|
||||||
except EnvironmentError:
|
|
||||||
return False
|
|
||||||
# Check for invalid or out of date pyc file.
|
|
||||||
return (len(data) == 8 and
|
|
||||||
data[:4] == imp.get_magic() and
|
|
||||||
struct.unpack("<l", data[4:])[0] == mtime)
|
|
||||||
|
|
||||||
def _cache_pyc(state, pyc, cache):
|
|
||||||
try:
|
|
||||||
cache.dirpath().ensure(dir=True)
|
|
||||||
_atomic_copy(pyc, cache)
|
|
||||||
except EnvironmentError:
|
|
||||||
state.trace("failed to cache %r as %r" % (pyc, cache))
|
|
||||||
|
|
||||||
def _atomic_copy(orig, to):
|
|
||||||
"""An atomic copy (at least on POSIX platforms)"""
|
|
||||||
temp = py.path.local(orig.strpath + str(os.getpid()))
|
|
||||||
orig.copy(temp)
|
|
||||||
temp.rename(to)
|
|
||||||
|
|
||||||
|
|
||||||
def rewrite_asserts(mod):
|
def rewrite_asserts(mod):
|
||||||
|
|
|
@ -246,11 +246,3 @@ def test_warn_missing(testdir):
|
||||||
result.stderr.fnmatch_lines([
|
result.stderr.fnmatch_lines([
|
||||||
"*WARNING*assert statements are not executed*",
|
"*WARNING*assert statements are not executed*",
|
||||||
])
|
])
|
||||||
|
|
||||||
def test_load_fake_pyc(testdir):
|
|
||||||
rewrite = pytest.importorskip("_pytest.assertion.rewrite")
|
|
||||||
path = testdir.makepyfile(a_random_module="x = 'hello'")
|
|
||||||
co = compile("x = 'bye'", str(path), "exec")
|
|
||||||
rewrite._write_pyc(co, path, rewrite._compute_pyc_location(path))
|
|
||||||
mod = path.pyimport()
|
|
||||||
assert mod.x == "bye"
|
|
||||||
|
|
Loading…
Reference in New Issue