new way to rewrite tests: do it all during fs collection

This should allow modules to be rewritten before some other test module loads
them.
This commit is contained in:
Benjamin Peterson 2011-05-26 19:57:30 -05:00
parent cf6949c9a3
commit abb07fc732
1 changed files with 25 additions and 16 deletions

View File

@ -62,6 +62,10 @@ def pytest_configure(config):
config._assertstate = AssertionState(config, mode) config._assertstate = AssertionState(config, mode)
config._assertstate.trace("configured with mode set to %r" % (mode,)) config._assertstate.trace("configured with mode set to %r" % (mode,))
def pytest_collectstart(collector):
if isinstance(collector, pytest.Session):
collector._rewritten_pycs = []
def _write_pyc(co, source_path): def _write_pyc(co, source_path):
if hasattr(imp, "cache_from_source"): if hasattr(imp, "cache_from_source"):
# Handle PEP 3147 pycs. # Handle PEP 3147 pycs.
@ -79,9 +83,9 @@ def _write_pyc(co, source_path):
fp.close() fp.close()
return pyc return pyc
def pytest_pycollect_before_module_import(mod): def pytest_pycollect_onmodule(mod):
if mod.config._assertstate.mode != "on": if mod is None or mod.config._assertstate.mode != "on":
return return mod
# Some deep magic: load the source, rewrite the asserts, and write a # Some deep magic: load the source, rewrite the asserts, and write a
# fake pyc, so that it'll be loaded when the module is imported. # fake pyc, so that it'll be loaded when the module is imported.
source = mod.fspath.read() source = mod.fspath.read()
@ -90,7 +94,7 @@ def pytest_pycollect_before_module_import(mod):
except SyntaxError: except SyntaxError:
# Let this pop up again in the real import. # Let this pop up again in the real import.
mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,)) mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,))
return return mod
rewrite_asserts(tree) rewrite_asserts(tree)
try: try:
co = compile(tree, str(mod.fspath), "exec") co = compile(tree, str(mod.fspath), "exec")
@ -98,20 +102,25 @@ def pytest_pycollect_before_module_import(mod):
# It's possible that this error is from some bug in the assertion # It's possible that this error is from some bug in the assertion
# rewriting, but I don't know of a fast way to tell. # rewriting, but I don't know of a fast way to tell.
mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,)) mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,))
return return mod
mod._pyc = _write_pyc(co, mod.fspath) pyc = _write_pyc(co, mod.fspath)
mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,)) mod.session._rewritten_pycs.append(pyc)
mod.config._assertstate.trace("wrote pyc: %r" % (pyc,))
return mod
def pytest_pycollect_after_module_import(mod): def pytest_collection_finish(session):
if mod.config._assertstate.mode != "on" or not hasattr(mod, "_pyc"): if not hasattr(session, "_rewritten_pycs"):
return return
# Remove our tweaked pyc to avoid subtle bugs. state = session.config._assertstate
try: # Remove our tweaked pycs to avoid subtle bugs.
mod._pyc.remove() for pyc in session._rewritten_pycs:
except py.error.ENOENT: try:
mod.config._assertstate.trace("couldn't find pyc: %r" % (mod._pyc,)) pyc.remove()
else: except py.error.ENOENT:
mod.config._assertstate.trace("removed pyc: %r" % (mod._pyc,)) state.trace("couldn't find pyc: %r" % (pyc,))
else:
state.trace("removed pyc: %r" % (pyc,))
del session._rewritten_pycs[:]
def warn_about_missing_assertion(): def warn_about_missing_assertion():
try: try: