rewrite test modules on import
This commit is contained in:
parent
d52ff3e2b9
commit
48b76c7544
|
@ -2,20 +2,12 @@
|
|||
support for presenting detailed information in failing assertions.
|
||||
"""
|
||||
import py
|
||||
import imp
|
||||
import marshal
|
||||
import struct
|
||||
import sys
|
||||
import pytest
|
||||
from _pytest.monkeypatch import monkeypatch
|
||||
from _pytest.assertion import reinterpret, util
|
||||
from _pytest.assertion import util
|
||||
|
||||
try:
|
||||
from _pytest.assertion.rewrite import rewrite_asserts
|
||||
except ImportError:
|
||||
rewrite_asserts = None
|
||||
else:
|
||||
import ast
|
||||
REWRITING_AVAILABLE = "_ast" in sys.builtin_module_names
|
||||
|
||||
def pytest_addoption(parser):
|
||||
group = parser.getgroup("debugconfig")
|
||||
|
@ -38,9 +30,9 @@ class AssertionState:
|
|||
def __init__(self, config, mode):
|
||||
self.mode = mode
|
||||
self.trace = config.trace.root.get("assertion")
|
||||
self.pycs = []
|
||||
|
||||
def pytest_configure(config):
|
||||
warn_about_missing_assertion()
|
||||
mode = config.getvalue("assertmode")
|
||||
if config.getvalue("noassert") or config.getvalue("nomagic"):
|
||||
if mode not in ("off", "default"):
|
||||
|
@ -48,7 +40,10 @@ def pytest_configure(config):
|
|||
mode = "off"
|
||||
elif mode == "default":
|
||||
mode = "on"
|
||||
if mode == "on" and not REWRITING_AVAILABLE:
|
||||
mode = "old"
|
||||
if mode != "off":
|
||||
_load_modules(mode)
|
||||
def callbinrepr(op, left, right):
|
||||
hook_result = config.hook.pytest_assertrepr_compare(
|
||||
config=config, op=op, left=left, right=right)
|
||||
|
@ -60,69 +55,55 @@ def pytest_configure(config):
|
|||
m.setattr(py.builtin.builtins, 'AssertionError',
|
||||
reinterpret.AssertionError)
|
||||
m.setattr(util, '_reprcompare', callbinrepr)
|
||||
if mode == "on" and rewrite_asserts is None:
|
||||
mode = "old"
|
||||
hook = None
|
||||
if mode == "on":
|
||||
hook = rewrite.AssertionRewritingHook()
|
||||
sys.meta_path.append(hook)
|
||||
warn_about_missing_assertion(mode)
|
||||
config._assertstate = AssertionState(config, mode)
|
||||
config._assertstate.hook = hook
|
||||
config._assertstate.trace("configured with mode set to %r" % (mode,))
|
||||
|
||||
def _write_pyc(co, source_path):
|
||||
if hasattr(imp, "cache_from_source"):
|
||||
# Handle PEP 3147 pycs.
|
||||
pyc = py.path.local(imp.cache_from_source(str(source_path)))
|
||||
pyc.ensure()
|
||||
else:
|
||||
pyc = source_path + "c"
|
||||
mtime = int(source_path.mtime())
|
||||
fp = pyc.open("wb")
|
||||
try:
|
||||
fp.write(imp.get_magic())
|
||||
fp.write(struct.pack("<l", mtime))
|
||||
marshal.dump(co, fp)
|
||||
finally:
|
||||
fp.close()
|
||||
return pyc
|
||||
def pytest_unconfigure(config):
|
||||
if config._assertstate.mode == "on":
|
||||
rewrite._drain_pycs(config._assertstate)
|
||||
hook = config._assertstate.hook
|
||||
if hook is not None:
|
||||
sys.meta_path.remove(hook)
|
||||
|
||||
def before_module_import(mod):
|
||||
if mod.config._assertstate.mode != "on":
|
||||
return
|
||||
# Some deep magic: load the source, rewrite the asserts, and write a
|
||||
# fake pyc, so that it'll be loaded when the module is imported.
|
||||
source = mod.fspath.read()
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
except SyntaxError:
|
||||
# Let this pop up again in the real import.
|
||||
mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,))
|
||||
return
|
||||
rewrite_asserts(tree)
|
||||
try:
|
||||
co = compile(tree, str(mod.fspath), "exec")
|
||||
except SyntaxError:
|
||||
# It's possible that this error is from some bug in the assertion
|
||||
# rewriting, but I don't know of a fast way to tell.
|
||||
mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,))
|
||||
return
|
||||
mod._pyc = _write_pyc(co, mod.fspath)
|
||||
mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,))
|
||||
def pytest_sessionstart(session):
|
||||
hook = session.config._assertstate.hook
|
||||
if hook is not None:
|
||||
hook.set_session(session)
|
||||
|
||||
def after_module_import(mod):
|
||||
if not hasattr(mod, "_pyc"):
|
||||
return
|
||||
state = mod.config._assertstate
|
||||
try:
|
||||
mod._pyc.remove()
|
||||
except py.error.ENOENT:
|
||||
state.trace("couldn't find pyc: %r" % (mod._pyc,))
|
||||
else:
|
||||
state.trace("removed pyc: %r" % (mod._pyc,))
|
||||
def pytest_sessionfinish(session):
|
||||
if session.config._assertstate.mode == "on":
|
||||
rewrite._drain_pycs(session.config._assertstate)
|
||||
hook = session.config._assertstate.hook
|
||||
if hook is not None:
|
||||
hook.session = None
|
||||
|
||||
def warn_about_missing_assertion():
|
||||
def _load_modules(mode):
|
||||
"""Lazily import assertion related code."""
|
||||
global rewrite, reinterpret
|
||||
from _pytest.assertion import reinterpret
|
||||
if mode == "on":
|
||||
from _pytest.assertion import rewrite
|
||||
|
||||
def warn_about_missing_assertion(mode):
|
||||
try:
|
||||
assert False
|
||||
except AssertionError:
|
||||
pass
|
||||
else:
|
||||
sys.stderr.write("WARNING: failing tests may report as passing because "
|
||||
"assertions are turned off! (are you using python -O?)\n")
|
||||
if mode == "on":
|
||||
specifically = ("assertions which are not in test modules "
|
||||
"will be ignored")
|
||||
else:
|
||||
specifically = "failing tests may report as passing"
|
||||
|
||||
sys.stderr.write("WARNING: " + specifically +
|
||||
" because assertions are turned off "
|
||||
"(are you using python -O?)\n")
|
||||
|
||||
pytest_assertrepr_compare = util.assertrepr_compare
|
||||
|
|
|
@ -3,12 +3,173 @@
|
|||
import ast
|
||||
import collections
|
||||
import itertools
|
||||
import imp
|
||||
import marshal
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
|
||||
import py
|
||||
from _pytest.assertion import util
|
||||
|
||||
|
||||
# py.test caches rewritten pycs in __pycache__.
|
||||
if hasattr(imp, "get_tag"):
|
||||
PYTEST_TAG = imp.get_tag() + "-PYTEST"
|
||||
else:
|
||||
ver = sys.version_info
|
||||
PYTEST_TAG = "cpython-" + str(ver[0]) + str(ver[1]) + "-PYTEST"
|
||||
del ver
|
||||
|
||||
class AssertionRewritingHook(object):
|
||||
"""Import hook which rewrites asserts.
|
||||
|
||||
Note this hook doesn't load modules itself. It uses find_module to write a
|
||||
fake pyc, so the normal import system will find it.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.session = None
|
||||
|
||||
def set_session(self, session):
|
||||
self.fnpats = session.config.getini("python_files")
|
||||
self.session = session
|
||||
|
||||
def find_module(self, name, path=None):
|
||||
if self.session is None:
|
||||
return None
|
||||
sess = self.session
|
||||
state = sess.config._assertstate
|
||||
names = name.rsplit(".", 1)
|
||||
lastname = names[-1]
|
||||
pth = None
|
||||
if path is not None and len(path) == 1:
|
||||
pth = path[0]
|
||||
if pth is None:
|
||||
try:
|
||||
fd, fn, desc = imp.find_module(lastname, path)
|
||||
except ImportError:
|
||||
return None
|
||||
if fd is not None:
|
||||
fd.close()
|
||||
tp = desc[2]
|
||||
if tp == imp.PY_COMPILED:
|
||||
if hasattr(imp, "source_from_cache"):
|
||||
fn = imp.source_from_cache(fn)
|
||||
else:
|
||||
fn = fn[:-1]
|
||||
elif tp != imp.PY_SOURCE:
|
||||
# Don't know what this is.
|
||||
return None
|
||||
else:
|
||||
fn = os.path.join(pth, name + ".py")
|
||||
fn_pypath = py.path.local(fn)
|
||||
# Is this a test file?
|
||||
if not sess.isinitpath(fn):
|
||||
# We have to be very careful here because imports in this code can
|
||||
# trigger a cycle.
|
||||
self.session = None
|
||||
try:
|
||||
for pat in self.fnpats:
|
||||
if fn_pypath.fnmatch(pat):
|
||||
break
|
||||
else:
|
||||
return None
|
||||
finally:
|
||||
self.session = sess
|
||||
# This looks like a test file, so rewrite it. This is the most magical
|
||||
# part of the process: load the source, rewrite the asserts, and write a
|
||||
# fake pyc, so that it'll be loaded when the module is imported. This is
|
||||
# complicated by the fact we cache rewritten pycs.
|
||||
pyc = _compute_pyc_location(fn_pypath)
|
||||
state.pycs.append(pyc)
|
||||
cache_fn = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
|
||||
cache = py.path.local(fn_pypath.dirname).join("__pycache__", cache_fn)
|
||||
if _use_cached_pyc(fn_pypath, cache):
|
||||
state.trace("found cached rewritten pyc for %r" % (fn,))
|
||||
cache.copy(pyc)
|
||||
else:
|
||||
state.trace("rewriting %r" % (fn,))
|
||||
_make_rewritten_pyc(state, fn_pypath, pyc)
|
||||
# Try cache it in the __pycache__ directory.
|
||||
_cache_pyc(state, pyc, cache)
|
||||
return None
|
||||
|
||||
def _drain_pycs(state):
|
||||
for pyc in state.pycs:
|
||||
try:
|
||||
pyc.remove()
|
||||
except py.error.ENOENT:
|
||||
state.trace("couldn't find pyc: %r" % (pyc,))
|
||||
else:
|
||||
state.trace("removed pyc: %r" % (pyc,))
|
||||
|
||||
def _write_pyc(co, source_path, pyc):
|
||||
mtime = int(source_path.mtime())
|
||||
fp = pyc.open("wb")
|
||||
try:
|
||||
fp.write(imp.get_magic())
|
||||
fp.write(struct.pack("<l", mtime))
|
||||
marshal.dump(co, fp)
|
||||
finally:
|
||||
fp.close()
|
||||
|
||||
def _make_rewritten_pyc(state, fn, pyc):
|
||||
try:
|
||||
source = fn.read("rb")
|
||||
except EnvironmentError:
|
||||
return None
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
except SyntaxError:
|
||||
# Let this pop up again in the real import.
|
||||
state.trace("failed to parse: %r" % (fn,))
|
||||
return None
|
||||
rewrite_asserts(tree)
|
||||
try:
|
||||
co = compile(tree, fn.strpath, "exec")
|
||||
except SyntaxError:
|
||||
# It's possible that this error is from some bug in the
|
||||
# assertion rewriting, but I don't know of a fast way to tell.
|
||||
state.trace("failed to compile: %r" % (fn,))
|
||||
return None
|
||||
_write_pyc(co, fn, pyc)
|
||||
|
||||
def _compute_pyc_location(source_path):
|
||||
if hasattr(imp, "cache_from_source"):
|
||||
# Handle PEP 3147 pycs.
|
||||
pyc = py.path.local(imp.cache_from_source(str(source_path)))
|
||||
pyc.ensure()
|
||||
else:
|
||||
pyc = source_path + "c"
|
||||
return pyc
|
||||
|
||||
def _use_cached_pyc(source, cache):
|
||||
try:
|
||||
mtime = source.mtime()
|
||||
fp = cache.open("rb")
|
||||
try:
|
||||
data = fp.read(8)
|
||||
finally:
|
||||
fp.close()
|
||||
except EnvironmentError:
|
||||
return False
|
||||
if (len(data) != 8 or
|
||||
data[:4] != imp.get_magic() or
|
||||
struct.unpack("<l", data[4:])[0] != mtime):
|
||||
# Invalid or out of date.
|
||||
return False
|
||||
# The cached pyc exists and is up to date.
|
||||
return True
|
||||
|
||||
def _cache_pyc(state, pyc, cache):
|
||||
try:
|
||||
cache.dirpath().ensure(dir=True)
|
||||
pyc.copy(cache)
|
||||
except EnvironmentError:
|
||||
state.trace("failed to cache %r as %r" % (pyc, cache))
|
||||
|
||||
|
||||
def rewrite_asserts(mod):
|
||||
"""Rewrite the assert statements in mod."""
|
||||
AssertionRewriter().run(mod)
|
||||
|
|
|
@ -226,13 +226,8 @@ class Module(pytest.File, PyCollectorMixin):
|
|||
|
||||
def _importtestmodule(self):
|
||||
# we assume we are only called once per module
|
||||
from _pytest import assertion
|
||||
assertion.before_module_import(self)
|
||||
try:
|
||||
try:
|
||||
mod = self.fspath.pyimport(ensuresyspath=True)
|
||||
finally:
|
||||
assertion.after_module_import(self)
|
||||
except SyntaxError:
|
||||
excinfo = py.code.ExceptionInfo()
|
||||
raise self.CollectError(excinfo.getrepr(style="short"))
|
||||
|
|
|
@ -250,8 +250,9 @@ def test_warn_missing(testdir):
|
|||
])
|
||||
|
||||
def test_load_fake_pyc(testdir):
|
||||
path = testdir.makepyfile("x = 'hello'")
|
||||
rewrite = pytest.importorskip("_pytest.assertion.rewrite")
|
||||
path = testdir.makepyfile(a_random_module="x = 'hello'")
|
||||
co = compile("x = 'bye'", str(path), "exec")
|
||||
plugin._write_pyc(co, path)
|
||||
rewrite._write_pyc(co, path, rewrite._compute_pyc_location(path))
|
||||
mod = path.pyimport()
|
||||
assert mod.x == "bye"
|
||||
|
|
Loading…
Reference in New Issue