rewrite test modules on import

This commit is contained in:
Benjamin Peterson 2011-06-28 21:13:12 -05:00
parent d52ff3e2b9
commit 48b76c7544
4 changed files with 210 additions and 72 deletions

View File

@ -2,20 +2,12 @@
support for presenting detailed information in failing assertions. support for presenting detailed information in failing assertions.
""" """
import py import py
import imp
import marshal
import struct
import sys import sys
import pytest import pytest
from _pytest.monkeypatch import monkeypatch from _pytest.monkeypatch import monkeypatch
from _pytest.assertion import reinterpret, util from _pytest.assertion import util
try: REWRITING_AVAILABLE = "_ast" in sys.builtin_module_names
from _pytest.assertion.rewrite import rewrite_asserts
except ImportError:
rewrite_asserts = None
else:
import ast
def pytest_addoption(parser): def pytest_addoption(parser):
group = parser.getgroup("debugconfig") group = parser.getgroup("debugconfig")
@ -38,9 +30,9 @@ class AssertionState:
def __init__(self, config, mode): def __init__(self, config, mode):
self.mode = mode self.mode = mode
self.trace = config.trace.root.get("assertion") self.trace = config.trace.root.get("assertion")
self.pycs = []
def pytest_configure(config): def pytest_configure(config):
warn_about_missing_assertion()
mode = config.getvalue("assertmode") mode = config.getvalue("assertmode")
if config.getvalue("noassert") or config.getvalue("nomagic"): if config.getvalue("noassert") or config.getvalue("nomagic"):
if mode not in ("off", "default"): if mode not in ("off", "default"):
@ -48,7 +40,10 @@ def pytest_configure(config):
mode = "off" mode = "off"
elif mode == "default": elif mode == "default":
mode = "on" mode = "on"
if mode == "on" and not REWRITING_AVAILABLE:
mode = "old"
if mode != "off": if mode != "off":
_load_modules(mode)
def callbinrepr(op, left, right): def callbinrepr(op, left, right):
hook_result = config.hook.pytest_assertrepr_compare( hook_result = config.hook.pytest_assertrepr_compare(
config=config, op=op, left=left, right=right) config=config, op=op, left=left, right=right)
@ -60,69 +55,55 @@ def pytest_configure(config):
m.setattr(py.builtin.builtins, 'AssertionError', m.setattr(py.builtin.builtins, 'AssertionError',
reinterpret.AssertionError) reinterpret.AssertionError)
m.setattr(util, '_reprcompare', callbinrepr) m.setattr(util, '_reprcompare', callbinrepr)
if mode == "on" and rewrite_asserts is None: hook = None
mode = "old" if mode == "on":
hook = rewrite.AssertionRewritingHook()
sys.meta_path.append(hook)
warn_about_missing_assertion(mode)
config._assertstate = AssertionState(config, mode) config._assertstate = AssertionState(config, mode)
config._assertstate.hook = hook
config._assertstate.trace("configured with mode set to %r" % (mode,)) config._assertstate.trace("configured with mode set to %r" % (mode,))
def _write_pyc(co, source_path): def pytest_unconfigure(config):
if hasattr(imp, "cache_from_source"): if config._assertstate.mode == "on":
# Handle PEP 3147 pycs. rewrite._drain_pycs(config._assertstate)
pyc = py.path.local(imp.cache_from_source(str(source_path))) hook = config._assertstate.hook
pyc.ensure() if hook is not None:
else: sys.meta_path.remove(hook)
pyc = source_path + "c"
mtime = int(source_path.mtime())
fp = pyc.open("wb")
try:
fp.write(imp.get_magic())
fp.write(struct.pack("<l", mtime))
marshal.dump(co, fp)
finally:
fp.close()
return pyc
def before_module_import(mod): def pytest_sessionstart(session):
if mod.config._assertstate.mode != "on": hook = session.config._assertstate.hook
return if hook is not None:
# Some deep magic: load the source, rewrite the asserts, and write a hook.set_session(session)
# fake pyc, so that it'll be loaded when the module is imported.
source = mod.fspath.read()
try:
tree = ast.parse(source)
except SyntaxError:
# Let this pop up again in the real import.
mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,))
return
rewrite_asserts(tree)
try:
co = compile(tree, str(mod.fspath), "exec")
except SyntaxError:
# It's possible that this error is from some bug in the assertion
# rewriting, but I don't know of a fast way to tell.
mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,))
return
mod._pyc = _write_pyc(co, mod.fspath)
mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,))
def after_module_import(mod): def pytest_sessionfinish(session):
if not hasattr(mod, "_pyc"): if session.config._assertstate.mode == "on":
return rewrite._drain_pycs(session.config._assertstate)
state = mod.config._assertstate hook = session.config._assertstate.hook
try: if hook is not None:
mod._pyc.remove() hook.session = None
except py.error.ENOENT:
state.trace("couldn't find pyc: %r" % (mod._pyc,))
else:
state.trace("removed pyc: %r" % (mod._pyc,))
def warn_about_missing_assertion(): def _load_modules(mode):
"""Lazily import assertion related code."""
global rewrite, reinterpret
from _pytest.assertion import reinterpret
if mode == "on":
from _pytest.assertion import rewrite
def warn_about_missing_assertion(mode):
try: try:
assert False assert False
except AssertionError: except AssertionError:
pass pass
else: else:
sys.stderr.write("WARNING: failing tests may report as passing because " if mode == "on":
"assertions are turned off! (are you using python -O?)\n") specifically = ("assertions which are not in test modules "
"will be ignored")
else:
specifically = "failing tests may report as passing"
sys.stderr.write("WARNING: " + specifically +
" because assertions are turned off "
"(are you using python -O?)\n")
pytest_assertrepr_compare = util.assertrepr_compare pytest_assertrepr_compare = util.assertrepr_compare

View File

@ -3,12 +3,173 @@
import ast import ast
import collections import collections
import itertools import itertools
import imp
import marshal
import os
import struct
import sys import sys
import py import py
from _pytest.assertion import util from _pytest.assertion import util
# py.test caches rewritten pycs in __pycache__.
if hasattr(imp, "get_tag"):
PYTEST_TAG = imp.get_tag() + "-PYTEST"
else:
ver = sys.version_info
PYTEST_TAG = "cpython-" + str(ver[0]) + str(ver[1]) + "-PYTEST"
del ver
class AssertionRewritingHook(object):
"""Import hook which rewrites asserts.
Note this hook doesn't load modules itself. It uses find_module to write a
fake pyc, so the normal import system will find it.
"""
def __init__(self):
self.session = None
def set_session(self, session):
self.fnpats = session.config.getini("python_files")
self.session = session
def find_module(self, name, path=None):
if self.session is None:
return None
sess = self.session
state = sess.config._assertstate
names = name.rsplit(".", 1)
lastname = names[-1]
pth = None
if path is not None and len(path) == 1:
pth = path[0]
if pth is None:
try:
fd, fn, desc = imp.find_module(lastname, path)
except ImportError:
return None
if fd is not None:
fd.close()
tp = desc[2]
if tp == imp.PY_COMPILED:
if hasattr(imp, "source_from_cache"):
fn = imp.source_from_cache(fn)
else:
fn = fn[:-1]
elif tp != imp.PY_SOURCE:
# Don't know what this is.
return None
else:
fn = os.path.join(pth, name + ".py")
fn_pypath = py.path.local(fn)
# Is this a test file?
if not sess.isinitpath(fn):
# We have to be very careful here because imports in this code can
# trigger a cycle.
self.session = None
try:
for pat in self.fnpats:
if fn_pypath.fnmatch(pat):
break
else:
return None
finally:
self.session = sess
# This looks like a test file, so rewrite it. This is the most magical
# part of the process: load the source, rewrite the asserts, and write a
# fake pyc, so that it'll be loaded when the module is imported. This is
# complicated by the fact we cache rewritten pycs.
pyc = _compute_pyc_location(fn_pypath)
state.pycs.append(pyc)
cache_fn = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
cache = py.path.local(fn_pypath.dirname).join("__pycache__", cache_fn)
if _use_cached_pyc(fn_pypath, cache):
state.trace("found cached rewritten pyc for %r" % (fn,))
cache.copy(pyc)
else:
state.trace("rewriting %r" % (fn,))
_make_rewritten_pyc(state, fn_pypath, pyc)
# Try cache it in the __pycache__ directory.
_cache_pyc(state, pyc, cache)
return None
def _drain_pycs(state):
for pyc in state.pycs:
try:
pyc.remove()
except py.error.ENOENT:
state.trace("couldn't find pyc: %r" % (pyc,))
else:
state.trace("removed pyc: %r" % (pyc,))
def _write_pyc(co, source_path, pyc):
mtime = int(source_path.mtime())
fp = pyc.open("wb")
try:
fp.write(imp.get_magic())
fp.write(struct.pack("<l", mtime))
marshal.dump(co, fp)
finally:
fp.close()
def _make_rewritten_pyc(state, fn, pyc):
try:
source = fn.read("rb")
except EnvironmentError:
return None
try:
tree = ast.parse(source)
except SyntaxError:
# Let this pop up again in the real import.
state.trace("failed to parse: %r" % (fn,))
return None
rewrite_asserts(tree)
try:
co = compile(tree, fn.strpath, "exec")
except SyntaxError:
# It's possible that this error is from some bug in the
# assertion rewriting, but I don't know of a fast way to tell.
state.trace("failed to compile: %r" % (fn,))
return None
_write_pyc(co, fn, pyc)
def _compute_pyc_location(source_path):
if hasattr(imp, "cache_from_source"):
# Handle PEP 3147 pycs.
pyc = py.path.local(imp.cache_from_source(str(source_path)))
pyc.ensure()
else:
pyc = source_path + "c"
return pyc
def _use_cached_pyc(source, cache):
try:
mtime = source.mtime()
fp = cache.open("rb")
try:
data = fp.read(8)
finally:
fp.close()
except EnvironmentError:
return False
if (len(data) != 8 or
data[:4] != imp.get_magic() or
struct.unpack("<l", data[4:])[0] != mtime):
# Invalid or out of date.
return False
# The cached pyc exists and is up to date.
return True
def _cache_pyc(state, pyc, cache):
try:
cache.dirpath().ensure(dir=True)
pyc.copy(cache)
except EnvironmentError:
state.trace("failed to cache %r as %r" % (pyc, cache))
def rewrite_asserts(mod): def rewrite_asserts(mod):
"""Rewrite the assert statements in mod.""" """Rewrite the assert statements in mod."""
AssertionRewriter().run(mod) AssertionRewriter().run(mod)

View File

@ -226,13 +226,8 @@ class Module(pytest.File, PyCollectorMixin):
def _importtestmodule(self): def _importtestmodule(self):
# we assume we are only called once per module # we assume we are only called once per module
from _pytest import assertion
assertion.before_module_import(self)
try:
try: try:
mod = self.fspath.pyimport(ensuresyspath=True) mod = self.fspath.pyimport(ensuresyspath=True)
finally:
assertion.after_module_import(self)
except SyntaxError: except SyntaxError:
excinfo = py.code.ExceptionInfo() excinfo = py.code.ExceptionInfo()
raise self.CollectError(excinfo.getrepr(style="short")) raise self.CollectError(excinfo.getrepr(style="short"))

View File

@ -250,8 +250,9 @@ def test_warn_missing(testdir):
]) ])
def test_load_fake_pyc(testdir): def test_load_fake_pyc(testdir):
path = testdir.makepyfile("x = 'hello'") rewrite = pytest.importorskip("_pytest.assertion.rewrite")
path = testdir.makepyfile(a_random_module="x = 'hello'")
co = compile("x = 'bye'", str(path), "exec") co = compile("x = 'bye'", str(path), "exec")
plugin._write_pyc(co, path) rewrite._write_pyc(co, path, rewrite._compute_pyc_location(path))
mod = path.pyimport() mod = path.pyimport()
assert mod.x == "bye" assert mod.x == "bye"