rewrite test modules on import

This commit is contained in:
Benjamin Peterson 2011-06-28 21:13:12 -05:00
parent d52ff3e2b9
commit 48b76c7544
4 changed files with 210 additions and 72 deletions

View File

@ -2,20 +2,12 @@
support for presenting detailed information in failing assertions.
"""
import py
import imp
import marshal
import struct
import sys
import pytest
from _pytest.monkeypatch import monkeypatch
from _pytest.assertion import reinterpret, util
from _pytest.assertion import util
try:
from _pytest.assertion.rewrite import rewrite_asserts
except ImportError:
rewrite_asserts = None
else:
import ast
REWRITING_AVAILABLE = "_ast" in sys.builtin_module_names
def pytest_addoption(parser):
group = parser.getgroup("debugconfig")
@ -38,9 +30,9 @@ class AssertionState:
def __init__(self, config, mode):
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.pycs = []
def pytest_configure(config):
warn_about_missing_assertion()
mode = config.getvalue("assertmode")
if config.getvalue("noassert") or config.getvalue("nomagic"):
if mode not in ("off", "default"):
@ -48,7 +40,10 @@ def pytest_configure(config):
mode = "off"
elif mode == "default":
mode = "on"
if mode == "on" and not REWRITING_AVAILABLE:
mode = "old"
if mode != "off":
_load_modules(mode)
def callbinrepr(op, left, right):
hook_result = config.hook.pytest_assertrepr_compare(
config=config, op=op, left=left, right=right)
@ -60,69 +55,55 @@ def pytest_configure(config):
m.setattr(py.builtin.builtins, 'AssertionError',
reinterpret.AssertionError)
m.setattr(util, '_reprcompare', callbinrepr)
if mode == "on" and rewrite_asserts is None:
mode = "old"
hook = None
if mode == "on":
hook = rewrite.AssertionRewritingHook()
sys.meta_path.append(hook)
warn_about_missing_assertion(mode)
config._assertstate = AssertionState(config, mode)
config._assertstate.hook = hook
config._assertstate.trace("configured with mode set to %r" % (mode,))
def _write_pyc(co, source_path):
if hasattr(imp, "cache_from_source"):
# Handle PEP 3147 pycs.
pyc = py.path.local(imp.cache_from_source(str(source_path)))
pyc.ensure()
else:
pyc = source_path + "c"
mtime = int(source_path.mtime())
fp = pyc.open("wb")
try:
fp.write(imp.get_magic())
fp.write(struct.pack("<l", mtime))
marshal.dump(co, fp)
finally:
fp.close()
return pyc
def pytest_unconfigure(config):
if config._assertstate.mode == "on":
rewrite._drain_pycs(config._assertstate)
hook = config._assertstate.hook
if hook is not None:
sys.meta_path.remove(hook)
def before_module_import(mod):
if mod.config._assertstate.mode != "on":
return
# Some deep magic: load the source, rewrite the asserts, and write a
# fake pyc, so that it'll be loaded when the module is imported.
source = mod.fspath.read()
try:
tree = ast.parse(source)
except SyntaxError:
# Let this pop up again in the real import.
mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,))
return
rewrite_asserts(tree)
try:
co = compile(tree, str(mod.fspath), "exec")
except SyntaxError:
# It's possible that this error is from some bug in the assertion
# rewriting, but I don't know of a fast way to tell.
mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,))
return
mod._pyc = _write_pyc(co, mod.fspath)
mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,))
def pytest_sessionstart(session):
hook = session.config._assertstate.hook
if hook is not None:
hook.set_session(session)
def after_module_import(mod):
if not hasattr(mod, "_pyc"):
return
state = mod.config._assertstate
try:
mod._pyc.remove()
except py.error.ENOENT:
state.trace("couldn't find pyc: %r" % (mod._pyc,))
else:
state.trace("removed pyc: %r" % (mod._pyc,))
def pytest_sessionfinish(session):
if session.config._assertstate.mode == "on":
rewrite._drain_pycs(session.config._assertstate)
hook = session.config._assertstate.hook
if hook is not None:
hook.session = None
def warn_about_missing_assertion():
def _load_modules(mode):
"""Lazily import assertion related code."""
global rewrite, reinterpret
from _pytest.assertion import reinterpret
if mode == "on":
from _pytest.assertion import rewrite
def warn_about_missing_assertion(mode):
try:
assert False
except AssertionError:
pass
else:
sys.stderr.write("WARNING: failing tests may report as passing because "
"assertions are turned off! (are you using python -O?)\n")
if mode == "on":
specifically = ("assertions which are not in test modules "
"will be ignored")
else:
specifically = "failing tests may report as passing"
sys.stderr.write("WARNING: " + specifically +
" because assertions are turned off "
"(are you using python -O?)\n")
pytest_assertrepr_compare = util.assertrepr_compare

View File

@ -3,12 +3,173 @@
import ast
import collections
import itertools
import imp
import marshal
import os
import struct
import sys
import py
from _pytest.assertion import util
# py.test caches rewritten pycs in __pycache__.
if hasattr(imp, "get_tag"):
PYTEST_TAG = imp.get_tag() + "-PYTEST"
else:
ver = sys.version_info
PYTEST_TAG = "cpython-" + str(ver[0]) + str(ver[1]) + "-PYTEST"
del ver
class AssertionRewritingHook(object):
"""Import hook which rewrites asserts.
Note this hook doesn't load modules itself. It uses find_module to write a
fake pyc, so the normal import system will find it.
"""
def __init__(self):
self.session = None
def set_session(self, session):
self.fnpats = session.config.getini("python_files")
self.session = session
def find_module(self, name, path=None):
if self.session is None:
return None
sess = self.session
state = sess.config._assertstate
names = name.rsplit(".", 1)
lastname = names[-1]
pth = None
if path is not None and len(path) == 1:
pth = path[0]
if pth is None:
try:
fd, fn, desc = imp.find_module(lastname, path)
except ImportError:
return None
if fd is not None:
fd.close()
tp = desc[2]
if tp == imp.PY_COMPILED:
if hasattr(imp, "source_from_cache"):
fn = imp.source_from_cache(fn)
else:
fn = fn[:-1]
elif tp != imp.PY_SOURCE:
# Don't know what this is.
return None
else:
fn = os.path.join(pth, name + ".py")
fn_pypath = py.path.local(fn)
# Is this a test file?
if not sess.isinitpath(fn):
# We have to be very careful here because imports in this code can
# trigger a cycle.
self.session = None
try:
for pat in self.fnpats:
if fn_pypath.fnmatch(pat):
break
else:
return None
finally:
self.session = sess
# This looks like a test file, so rewrite it. This is the most magical
# part of the process: load the source, rewrite the asserts, and write a
# fake pyc, so that it'll be loaded when the module is imported. This is
# complicated by the fact we cache rewritten pycs.
pyc = _compute_pyc_location(fn_pypath)
state.pycs.append(pyc)
cache_fn = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
cache = py.path.local(fn_pypath.dirname).join("__pycache__", cache_fn)
if _use_cached_pyc(fn_pypath, cache):
state.trace("found cached rewritten pyc for %r" % (fn,))
cache.copy(pyc)
else:
state.trace("rewriting %r" % (fn,))
_make_rewritten_pyc(state, fn_pypath, pyc)
# Try cache it in the __pycache__ directory.
_cache_pyc(state, pyc, cache)
return None
def _drain_pycs(state):
for pyc in state.pycs:
try:
pyc.remove()
except py.error.ENOENT:
state.trace("couldn't find pyc: %r" % (pyc,))
else:
state.trace("removed pyc: %r" % (pyc,))
def _write_pyc(co, source_path, pyc):
mtime = int(source_path.mtime())
fp = pyc.open("wb")
try:
fp.write(imp.get_magic())
fp.write(struct.pack("<l", mtime))
marshal.dump(co, fp)
finally:
fp.close()
def _make_rewritten_pyc(state, fn, pyc):
try:
source = fn.read("rb")
except EnvironmentError:
return None
try:
tree = ast.parse(source)
except SyntaxError:
# Let this pop up again in the real import.
state.trace("failed to parse: %r" % (fn,))
return None
rewrite_asserts(tree)
try:
co = compile(tree, fn.strpath, "exec")
except SyntaxError:
# It's possible that this error is from some bug in the
# assertion rewriting, but I don't know of a fast way to tell.
state.trace("failed to compile: %r" % (fn,))
return None
_write_pyc(co, fn, pyc)
def _compute_pyc_location(source_path):
if hasattr(imp, "cache_from_source"):
# Handle PEP 3147 pycs.
pyc = py.path.local(imp.cache_from_source(str(source_path)))
pyc.ensure()
else:
pyc = source_path + "c"
return pyc
def _use_cached_pyc(source, cache):
try:
mtime = source.mtime()
fp = cache.open("rb")
try:
data = fp.read(8)
finally:
fp.close()
except EnvironmentError:
return False
if (len(data) != 8 or
data[:4] != imp.get_magic() or
struct.unpack("<l", data[4:])[0] != mtime):
# Invalid or out of date.
return False
# The cached pyc exists and is up to date.
return True
def _cache_pyc(state, pyc, cache):
try:
cache.dirpath().ensure(dir=True)
pyc.copy(cache)
except EnvironmentError:
state.trace("failed to cache %r as %r" % (pyc, cache))
def rewrite_asserts(mod):
"""Rewrite the assert statements in mod."""
AssertionRewriter().run(mod)

View File

@ -226,13 +226,8 @@ class Module(pytest.File, PyCollectorMixin):
def _importtestmodule(self):
# we assume we are only called once per module
from _pytest import assertion
assertion.before_module_import(self)
try:
try:
mod = self.fspath.pyimport(ensuresyspath=True)
finally:
assertion.after_module_import(self)
mod = self.fspath.pyimport(ensuresyspath=True)
except SyntaxError:
excinfo = py.code.ExceptionInfo()
raise self.CollectError(excinfo.getrepr(style="short"))

View File

@ -250,8 +250,9 @@ def test_warn_missing(testdir):
])
def test_load_fake_pyc(testdir):
path = testdir.makepyfile("x = 'hello'")
rewrite = pytest.importorskip("_pytest.assertion.rewrite")
path = testdir.makepyfile(a_random_module="x = 'hello'")
co = compile("x = 'bye'", str(path), "exec")
plugin._write_pyc(co, path)
rewrite._write_pyc(co, path, rewrite._compute_pyc_location(path))
mod = path.pyimport()
assert mod.x == "bye"