simplify rewrite-on-import
Use load_module on the import hook to load the rewritten module. This allows the removal of the complicated code related to copying pyc files in and out of the cache location. It also plays more nicely with parallel py.test processes like the ones found in xdist.
This commit is contained in:
parent
df85ddf0d2
commit
c13fa886d9
|
@ -28,7 +28,6 @@ class AssertionState:
|
|||
def __init__(self, config, mode):
|
||||
self.mode = mode
|
||||
self.trace = config.trace.root.get("assertion")
|
||||
self.pycs = []
|
||||
|
||||
def pytest_configure(config):
|
||||
mode = config.getvalue("assertmode")
|
||||
|
@ -62,8 +61,6 @@ def pytest_configure(config):
|
|||
config._assertstate.trace("configured with mode set to %r" % (mode,))
|
||||
|
||||
def pytest_unconfigure(config):
|
||||
if config._assertstate.mode == "rewrite":
|
||||
rewrite._drain_pycs(config._assertstate)
|
||||
hook = config._assertstate.hook
|
||||
if hook is not None:
|
||||
sys.meta_path.remove(hook)
|
||||
|
@ -77,8 +74,6 @@ def pytest_collection(session):
|
|||
hook.set_session(session)
|
||||
|
||||
def pytest_sessionfinish(session):
|
||||
if session.config._assertstate.mode == "rewrite":
|
||||
rewrite._drain_pycs(session.config._assertstate)
|
||||
hook = session.config._assertstate.hook
|
||||
if hook is not None:
|
||||
hook.session = None
|
||||
|
|
|
@ -8,6 +8,7 @@ import marshal
|
|||
import os
|
||||
import struct
|
||||
import sys
|
||||
import types
|
||||
|
||||
import py
|
||||
from _pytest.assertion import util
|
||||
|
@ -22,14 +23,11 @@ else:
|
|||
del ver
|
||||
|
||||
class AssertionRewritingHook(object):
|
||||
"""Import hook which rewrites asserts.
|
||||
|
||||
Note this hook doesn't load modules itself. It uses find_module to write a
|
||||
fake pyc, so the normal import system will find it.
|
||||
"""
|
||||
"""Import hook which rewrites asserts."""
|
||||
|
||||
def __init__(self):
|
||||
self.session = None
|
||||
self.modules = {}
|
||||
|
||||
def set_session(self, session):
|
||||
self.fnpats = session.config.getini("python_files")
|
||||
|
@ -77,37 +75,53 @@ class AssertionRewritingHook(object):
|
|||
return None
|
||||
finally:
|
||||
self.session = sess
|
||||
# This looks like a test file, so rewrite it. This is the most magical
|
||||
# part of the process: load the source, rewrite the asserts, and write a
|
||||
# fake pyc, so that it'll be loaded when the module is imported. This is
|
||||
# complicated by the fact we cache rewritten pycs.
|
||||
pyc = _compute_pyc_location(fn_pypath)
|
||||
state.pycs.append(pyc)
|
||||
cache_fn = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
|
||||
cache = py.path.local(fn_pypath.dirname).join("__pycache__", cache_fn)
|
||||
if _use_cached_pyc(fn_pypath, cache):
|
||||
state.trace("found cached rewritten pyc for %r" % (fn,))
|
||||
_atomic_copy(cache, pyc)
|
||||
else:
|
||||
# The requested module looks like a test file, so rewrite it. This is
|
||||
# the most magical part of the process: load the source, rewrite the
|
||||
# asserts, and load the rewritten source. We also cache the rewritten
|
||||
# module code in a special pyc. We must be aware of the possibility of
|
||||
# concurrent py.test processes rewriting and loading pycs. To avoid
|
||||
# tricky race conditions, we maintain the following invariant: The
|
||||
# cached pyc is always a complete, valid pyc. Operations on it must be
|
||||
# atomic. POSIX's atomic rename comes in handy.
|
||||
cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
|
||||
py.path.local(cache_dir).ensure(dir=True)
|
||||
cache_name = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
|
||||
pyc = os.path.join(cache_dir, cache_name)
|
||||
co = _read_pyc(fn_pypath, pyc)
|
||||
if co is None:
|
||||
state.trace("rewriting %r" % (fn,))
|
||||
_make_rewritten_pyc(state, fn_pypath, pyc)
|
||||
# Try cache it in the __pycache__ directory.
|
||||
_cache_pyc(state, pyc, cache)
|
||||
return None
|
||||
|
||||
def _drain_pycs(state):
|
||||
for pyc in state.pycs:
|
||||
try:
|
||||
pyc.remove()
|
||||
except py.error.ENOENT:
|
||||
state.trace("couldn't find pyc: %r" % (pyc,))
|
||||
co = _make_rewritten_pyc(state, fn_pypath, pyc)
|
||||
if co is None:
|
||||
# Probably a SyntaxError in the module.
|
||||
return None
|
||||
else:
|
||||
state.trace("removed pyc: %r" % (pyc,))
|
||||
del state.pycs[:]
|
||||
state.trace("found cached rewritten pyc for %r" % (fn,))
|
||||
self.modules[name] = co, pyc
|
||||
return self
|
||||
|
||||
def load_module(self, name):
|
||||
co, pyc = self.modules.pop(name)
|
||||
# I wish I could just call imp.load_compiled here, but __file__ has to
|
||||
# be set properly. In Python 3.2+, this all would be handled correctly
|
||||
# by load_compiled.
|
||||
mod = sys.modules[name] = imp.new_module(name)
|
||||
try:
|
||||
mod.__file__ = co.co_filename
|
||||
# Normally, this attribute is 3.2+.
|
||||
mod.__cached__ = pyc
|
||||
exec co in mod.__dict__
|
||||
except:
|
||||
del sys.modules[name]
|
||||
raise
|
||||
return sys.modules[name]
|
||||
|
||||
def _write_pyc(co, source_path, pyc):
|
||||
# Technically, we don't have to have the same pyc format as (C)Python, since
|
||||
# these "pycs" should never be seen by builtin import. However, there's
|
||||
# little reason deviate, and I hope sometime to be able to use
|
||||
# imp.load_compiled to load them. (See the comment in load_module above.)
|
||||
mtime = int(source_path.mtime())
|
||||
fp = pyc.open("wb")
|
||||
fp = open(pyc, "wb")
|
||||
try:
|
||||
fp.write(imp.get_magic())
|
||||
fp.write(struct.pack("<l", mtime))
|
||||
|
@ -116,6 +130,11 @@ def _write_pyc(co, source_path, pyc):
|
|||
fp.close()
|
||||
|
||||
def _make_rewritten_pyc(state, fn, pyc):
|
||||
"""Try to rewrite *fn* and dump the rewritten code to *pyc*.
|
||||
|
||||
Return the code object of the rewritten module on success. Return None if
|
||||
there are problems parsing or compiling the module.
|
||||
"""
|
||||
try:
|
||||
source = fn.read("rb")
|
||||
except EnvironmentError:
|
||||
|
@ -134,44 +153,40 @@ def _make_rewritten_pyc(state, fn, pyc):
|
|||
# assertion rewriting, but I don't know of a fast way to tell.
|
||||
state.trace("failed to compile: %r" % (fn,))
|
||||
return None
|
||||
_write_pyc(co, fn, pyc)
|
||||
# Dump the code object into a file specific to this process.
|
||||
proc_pyc = pyc + "." + str(os.getpid())
|
||||
_write_pyc(co, fn, proc_pyc)
|
||||
# Atomically replace the pyc.
|
||||
os.rename(proc_pyc, pyc)
|
||||
return co
|
||||
|
||||
def _compute_pyc_location(source_path):
|
||||
if hasattr(imp, "cache_from_source"):
|
||||
# Handle PEP 3147 pycs.
|
||||
pyc = py.path.local(imp.cache_from_source(str(source_path)))
|
||||
pyc.ensure()
|
||||
else:
|
||||
pyc = source_path + "c"
|
||||
return pyc
|
||||
def _read_pyc(source, pyc):
|
||||
"""Possibly read a py.test pyc containing rewritten code.
|
||||
|
||||
def _use_cached_pyc(source, cache):
|
||||
Return rewritten code if successful or None if not.
|
||||
"""
|
||||
try:
|
||||
fp = open(pyc, "rb")
|
||||
except IOError:
|
||||
return None
|
||||
try:
|
||||
mtime = int(source.mtime())
|
||||
fp = cache.open("rb")
|
||||
try:
|
||||
mtime = int(source.mtime())
|
||||
data = fp.read(8)
|
||||
finally:
|
||||
fp.close()
|
||||
except EnvironmentError:
|
||||
return False
|
||||
# Check for invalid or out of date pyc file.
|
||||
return (len(data) == 8 and
|
||||
data[:4] == imp.get_magic() and
|
||||
struct.unpack("<l", data[4:])[0] == mtime)
|
||||
|
||||
def _cache_pyc(state, pyc, cache):
|
||||
try:
|
||||
cache.dirpath().ensure(dir=True)
|
||||
_atomic_copy(pyc, cache)
|
||||
except EnvironmentError:
|
||||
state.trace("failed to cache %r as %r" % (pyc, cache))
|
||||
|
||||
def _atomic_copy(orig, to):
|
||||
"""An atomic copy (at least on POSIX platforms)"""
|
||||
temp = py.path.local(orig.strpath + str(os.getpid()))
|
||||
orig.copy(temp)
|
||||
temp.rename(to)
|
||||
except EnvironmentError:
|
||||
return None
|
||||
# Check for invalid or out of date pyc file.
|
||||
if (len(data) != 8 or
|
||||
data[:4] != imp.get_magic() or
|
||||
struct.unpack("<l", data[4:])[0] != mtime):
|
||||
return None
|
||||
co = marshal.load(fp)
|
||||
if not isinstance(co, types.CodeType):
|
||||
# That's interesting....
|
||||
return None
|
||||
return co
|
||||
finally:
|
||||
fp.close()
|
||||
|
||||
|
||||
def rewrite_asserts(mod):
|
||||
|
|
|
@ -246,11 +246,3 @@ def test_warn_missing(testdir):
|
|||
result.stderr.fnmatch_lines([
|
||||
"*WARNING*assert statements are not executed*",
|
||||
])
|
||||
|
||||
def test_load_fake_pyc(testdir):
|
||||
rewrite = pytest.importorskip("_pytest.assertion.rewrite")
|
||||
path = testdir.makepyfile(a_random_module="x = 'hello'")
|
||||
co = compile("x = 'bye'", str(path), "exec")
|
||||
rewrite._write_pyc(co, path, rewrite._compute_pyc_location(path))
|
||||
mod = path.pyimport()
|
||||
assert mod.x == "bye"
|
||||
|
|
Loading…
Reference in New Issue