From c13fa886d946e3fb87fdd2884883048e8e35389a Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Wed, 6 Jul 2011 23:24:04 -0500 Subject: [PATCH] simplify rewrite-on-import Use load_module on the import hook to load the rewritten module. This allows the removal of the complicated code related to copying pyc files in and out of the cache location. It also plays more nicely with parallel py.test processes like the ones found in xdist. --- _pytest/assertion/__init__.py | 5 -- _pytest/assertion/rewrite.py | 143 +++++++++++++++++++--------------- testing/test_assertion.py | 8 -- 3 files changed, 79 insertions(+), 77 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index aedf7efe7..3cc026a27 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -28,7 +28,6 @@ class AssertionState: def __init__(self, config, mode): self.mode = mode self.trace = config.trace.root.get("assertion") - self.pycs = [] def pytest_configure(config): mode = config.getvalue("assertmode") @@ -62,8 +61,6 @@ def pytest_configure(config): config._assertstate.trace("configured with mode set to %r" % (mode,)) def pytest_unconfigure(config): - if config._assertstate.mode == "rewrite": - rewrite._drain_pycs(config._assertstate) hook = config._assertstate.hook if hook is not None: sys.meta_path.remove(hook) @@ -77,8 +74,6 @@ def pytest_collection(session): hook.set_session(session) def pytest_sessionfinish(session): - if session.config._assertstate.mode == "rewrite": - rewrite._drain_pycs(session.config._assertstate) hook = session.config._assertstate.hook if hook is not None: hook.session = None diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index c36ede108..612dc8d35 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -8,6 +8,7 @@ import marshal import os import struct import sys +import types import py from _pytest.assertion import util @@ -22,14 +23,11 @@ else: del ver class AssertionRewritingHook(object): - """Import hook which rewrites asserts. - - Note this hook doesn't load modules itself. It uses find_module to write a - fake pyc, so the normal import system will find it. - """ + """Import hook which rewrites asserts.""" def __init__(self): self.session = None + self.modules = {} def set_session(self, session): self.fnpats = session.config.getini("python_files") @@ -77,37 +75,53 @@ class AssertionRewritingHook(object): return None finally: self.session = sess - # This looks like a test file, so rewrite it. This is the most magical - # part of the process: load the source, rewrite the asserts, and write a - # fake pyc, so that it'll be loaded when the module is imported. This is - # complicated by the fact we cache rewritten pycs. - pyc = _compute_pyc_location(fn_pypath) - state.pycs.append(pyc) - cache_fn = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc" - cache = py.path.local(fn_pypath.dirname).join("__pycache__", cache_fn) - if _use_cached_pyc(fn_pypath, cache): - state.trace("found cached rewritten pyc for %r" % (fn,)) - _atomic_copy(cache, pyc) - else: + # The requested module looks like a test file, so rewrite it. This is + # the most magical part of the process: load the source, rewrite the + # asserts, and load the rewritten source. We also cache the rewritten + # module code in a special pyc. We must be aware of the possibility of + # concurrent py.test processes rewriting and loading pycs. To avoid + # tricky race conditions, we maintain the following invariant: The + # cached pyc is always a complete, valid pyc. Operations on it must be + # atomic. POSIX's atomic rename comes in handy. + cache_dir = os.path.join(fn_pypath.dirname, "__pycache__") + py.path.local(cache_dir).ensure(dir=True) + cache_name = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc" + pyc = os.path.join(cache_dir, cache_name) + co = _read_pyc(fn_pypath, pyc) + if co is None: state.trace("rewriting %r" % (fn,)) - _make_rewritten_pyc(state, fn_pypath, pyc) - # Try cache it in the __pycache__ directory. - _cache_pyc(state, pyc, cache) - return None - -def _drain_pycs(state): - for pyc in state.pycs: - try: - pyc.remove() - except py.error.ENOENT: - state.trace("couldn't find pyc: %r" % (pyc,)) + co = _make_rewritten_pyc(state, fn_pypath, pyc) + if co is None: + # Probably a SyntaxError in the module. + return None else: - state.trace("removed pyc: %r" % (pyc,)) - del state.pycs[:] + state.trace("found cached rewritten pyc for %r" % (fn,)) + self.modules[name] = co, pyc + return self + + def load_module(self, name): + co, pyc = self.modules.pop(name) + # I wish I could just call imp.load_compiled here, but __file__ has to + # be set properly. In Python 3.2+, this all would be handled correctly + # by load_compiled. + mod = sys.modules[name] = imp.new_module(name) + try: + mod.__file__ = co.co_filename + # Normally, this attribute is 3.2+. + mod.__cached__ = pyc + exec co in mod.__dict__ + except: + del sys.modules[name] + raise + return sys.modules[name] def _write_pyc(co, source_path, pyc): + # Technically, we don't have to have the same pyc format as (C)Python, since + # these "pycs" should never be seen by builtin import. However, there's + # little reason deviate, and I hope sometime to be able to use + # imp.load_compiled to load them. (See the comment in load_module above.) mtime = int(source_path.mtime()) - fp = pyc.open("wb") + fp = open(pyc, "wb") try: fp.write(imp.get_magic()) fp.write(struct.pack("