diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index aedf7efe7..3cc026a27 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -28,7 +28,6 @@ class AssertionState: def __init__(self, config, mode): self.mode = mode self.trace = config.trace.root.get("assertion") - self.pycs = [] def pytest_configure(config): mode = config.getvalue("assertmode") @@ -62,8 +61,6 @@ def pytest_configure(config): config._assertstate.trace("configured with mode set to %r" % (mode,)) def pytest_unconfigure(config): - if config._assertstate.mode == "rewrite": - rewrite._drain_pycs(config._assertstate) hook = config._assertstate.hook if hook is not None: sys.meta_path.remove(hook) @@ -77,8 +74,6 @@ def pytest_collection(session): hook.set_session(session) def pytest_sessionfinish(session): - if session.config._assertstate.mode == "rewrite": - rewrite._drain_pycs(session.config._assertstate) hook = session.config._assertstate.hook if hook is not None: hook.session = None diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index c36ede108..612dc8d35 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -8,6 +8,7 @@ import marshal import os import struct import sys +import types import py from _pytest.assertion import util @@ -22,14 +23,11 @@ else: del ver class AssertionRewritingHook(object): - """Import hook which rewrites asserts. - - Note this hook doesn't load modules itself. It uses find_module to write a - fake pyc, so the normal import system will find it. - """ + """Import hook which rewrites asserts.""" def __init__(self): self.session = None + self.modules = {} def set_session(self, session): self.fnpats = session.config.getini("python_files") @@ -77,37 +75,53 @@ class AssertionRewritingHook(object): return None finally: self.session = sess - # This looks like a test file, so rewrite it. This is the most magical - # part of the process: load the source, rewrite the asserts, and write a - # fake pyc, so that it'll be loaded when the module is imported. This is - # complicated by the fact we cache rewritten pycs. - pyc = _compute_pyc_location(fn_pypath) - state.pycs.append(pyc) - cache_fn = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc" - cache = py.path.local(fn_pypath.dirname).join("__pycache__", cache_fn) - if _use_cached_pyc(fn_pypath, cache): - state.trace("found cached rewritten pyc for %r" % (fn,)) - _atomic_copy(cache, pyc) - else: + # The requested module looks like a test file, so rewrite it. This is + # the most magical part of the process: load the source, rewrite the + # asserts, and load the rewritten source. We also cache the rewritten + # module code in a special pyc. We must be aware of the possibility of + # concurrent py.test processes rewriting and loading pycs. To avoid + # tricky race conditions, we maintain the following invariant: The + # cached pyc is always a complete, valid pyc. Operations on it must be + # atomic. POSIX's atomic rename comes in handy. + cache_dir = os.path.join(fn_pypath.dirname, "__pycache__") + py.path.local(cache_dir).ensure(dir=True) + cache_name = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc" + pyc = os.path.join(cache_dir, cache_name) + co = _read_pyc(fn_pypath, pyc) + if co is None: state.trace("rewriting %r" % (fn,)) - _make_rewritten_pyc(state, fn_pypath, pyc) - # Try cache it in the __pycache__ directory. - _cache_pyc(state, pyc, cache) - return None - -def _drain_pycs(state): - for pyc in state.pycs: - try: - pyc.remove() - except py.error.ENOENT: - state.trace("couldn't find pyc: %r" % (pyc,)) + co = _make_rewritten_pyc(state, fn_pypath, pyc) + if co is None: + # Probably a SyntaxError in the module. + return None else: - state.trace("removed pyc: %r" % (pyc,)) - del state.pycs[:] + state.trace("found cached rewritten pyc for %r" % (fn,)) + self.modules[name] = co, pyc + return self + + def load_module(self, name): + co, pyc = self.modules.pop(name) + # I wish I could just call imp.load_compiled here, but __file__ has to + # be set properly. In Python 3.2+, this all would be handled correctly + # by load_compiled. + mod = sys.modules[name] = imp.new_module(name) + try: + mod.__file__ = co.co_filename + # Normally, this attribute is 3.2+. + mod.__cached__ = pyc + exec co in mod.__dict__ + except: + del sys.modules[name] + raise + return sys.modules[name] def _write_pyc(co, source_path, pyc): + # Technically, we don't have to have the same pyc format as (C)Python, since + # these "pycs" should never be seen by builtin import. However, there's + # little reason deviate, and I hope sometime to be able to use + # imp.load_compiled to load them. (See the comment in load_module above.) mtime = int(source_path.mtime()) - fp = pyc.open("wb") + fp = open(pyc, "wb") try: fp.write(imp.get_magic()) fp.write(struct.pack("