diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 3be766d9b..a23e78ae4 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -62,6 +62,10 @@ def pytest_configure(config): config._assertstate = AssertionState(config, mode) config._assertstate.trace("configured with mode set to %r" % (mode,)) +def pytest_collectstart(collector): + if isinstance(collector, pytest.Session): + collector._rewritten_pycs = [] + def _write_pyc(co, source_path): if hasattr(imp, "cache_from_source"): # Handle PEP 3147 pycs. @@ -79,9 +83,9 @@ def _write_pyc(co, source_path): fp.close() return pyc -def pytest_pycollect_before_module_import(mod): - if mod.config._assertstate.mode != "on": - return +def pytest_pycollect_onmodule(mod): + if mod is None or mod.config._assertstate.mode != "on": + return mod # Some deep magic: load the source, rewrite the asserts, and write a # fake pyc, so that it'll be loaded when the module is imported. source = mod.fspath.read() @@ -90,7 +94,7 @@ def pytest_pycollect_before_module_import(mod): except SyntaxError: # Let this pop up again in the real import. mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,)) - return + return mod rewrite_asserts(tree) try: co = compile(tree, str(mod.fspath), "exec") @@ -98,20 +102,25 @@ def pytest_pycollect_before_module_import(mod): # It's possible that this error is from some bug in the assertion # rewriting, but I don't know of a fast way to tell. mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,)) - return - mod._pyc = _write_pyc(co, mod.fspath) - mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,)) + return mod + pyc = _write_pyc(co, mod.fspath) + mod.session._rewritten_pycs.append(pyc) + mod.config._assertstate.trace("wrote pyc: %r" % (pyc,)) + return mod -def pytest_pycollect_after_module_import(mod): - if mod.config._assertstate.mode != "on" or not hasattr(mod, "_pyc"): +def pytest_collection_finish(session): + if not hasattr(session, "_rewritten_pycs"): return - # Remove our tweaked pyc to avoid subtle bugs. - try: - mod._pyc.remove() - except py.error.ENOENT: - mod.config._assertstate.trace("couldn't find pyc: %r" % (mod._pyc,)) - else: - mod.config._assertstate.trace("removed pyc: %r" % (mod._pyc,)) + state = session.config._assertstate + # Remove our tweaked pycs to avoid subtle bugs. + for pyc in session._rewritten_pycs: + try: + pyc.remove() + except py.error.ENOENT: + state.trace("couldn't find pyc: %r" % (pyc,)) + else: + state.trace("removed pyc: %r" % (pyc,)) + del session._rewritten_pycs[:] def warn_about_missing_assertion(): try: