From 5e31624315be819497ea0180399387505400c288 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Sat, 28 May 2011 18:47:16 -0500 Subject: [PATCH] return to the old scheme of rewriting test modules from _importtestmodule --- _pytest/assertion/__init__.py | 43 ++++++++++++++--------------------- _pytest/hookspec.py | 3 --- _pytest/python.py | 12 ++++++---- 3 files changed, 24 insertions(+), 34 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 535f5b2ca..6daad05c4 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -65,10 +65,6 @@ def pytest_configure(config): config._assertstate = AssertionState(config, mode) config._assertstate.trace("configured with mode set to %r" % (mode,)) -def pytest_collectstart(collector): - if isinstance(collector, pytest.Session): - collector._rewritten_pycs = [] - def _write_pyc(co, source_path): if hasattr(imp, "cache_from_source"): # Handle PEP 3147 pycs. @@ -86,9 +82,9 @@ def _write_pyc(co, source_path): fp.close() return pyc -def pytest_pycollect_onmodule(mod): - if mod is None or mod.config._assertstate.mode != "on": - return mod +def before_module_import(mod): + if mod.config._assertstate.mode != "on": + return # Some deep magic: load the source, rewrite the asserts, and write a # fake pyc, so that it'll be loaded when the module is imported. source = mod.fspath.read() @@ -97,7 +93,7 @@ def pytest_pycollect_onmodule(mod): except SyntaxError: # Let this pop up again in the real import. mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,)) - return mod + return rewrite_asserts(tree) try: co = compile(tree, str(mod.fspath), "exec") @@ -105,25 +101,20 @@ def pytest_pycollect_onmodule(mod): # It's possible that this error is from some bug in the assertion # rewriting, but I don't know of a fast way to tell. mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,)) - return mod - pyc = _write_pyc(co, mod.fspath) - mod.session._rewritten_pycs.append(pyc) - mod.config._assertstate.trace("wrote pyc: %r" % (pyc,)) - return mod - -def pytest_collection_finish(session): - if not hasattr(session, "_rewritten_pycs"): return - state = session.config._assertstate - # Remove our tweaked pycs to avoid subtle bugs. - for pyc in session._rewritten_pycs: - try: - pyc.remove() - except py.error.ENOENT: - state.trace("couldn't find pyc: %r" % (pyc,)) - else: - state.trace("removed pyc: %r" % (pyc,)) - del session._rewritten_pycs[:] + mod._pyc = _write_pyc(co, mod.fspath) + mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,)) + +def after_module_import(mod): + if not hasattr(mod, "_pyc"): + return + state = mod.config._assertstate + try: + mod._pyc.remove() + except py.error.ENOENT: + state.trace("couldn't find pyc: %r" % (mod._pyc,)) + else: + state.trace("removed pyc: %r" % (mod._pyc,)) def warn_about_missing_assertion(): try: diff --git a/_pytest/hookspec.py b/_pytest/hookspec.py index a3f79fd01..898ffee2a 100644 --- a/_pytest/hookspec.py +++ b/_pytest/hookspec.py @@ -104,9 +104,6 @@ def pytest_pycollect_makemodule(path, parent): """ pytest_pycollect_makemodule.firstresult = True -def pytest_pycollect_onmodule(mod): - """ Called when a module is collected.""" - def pytest_pycollect_makeitem(collector, name, obj): """ return custom item/collector for a python object in a module, or None. """ pytest_pycollect_makeitem.firstresult = True diff --git a/_pytest/python.py b/_pytest/python.py index 2b47476c9..ae964f0b8 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -4,6 +4,7 @@ import inspect import sys import pytest from py._code.code import TerminalRepr +from _pytest import assertion import _pytest cutdir = py.path.local(_pytest.__file__).dirpath() @@ -60,11 +61,8 @@ def pytest_collect_file(path, parent): break else: return - mod = parent.ihook.pytest_pycollect_makemodule( + return parent.ihook.pytest_pycollect_makemodule( path=path, parent=parent) - if mod is not None: - parent.ihook.pytest_pycollect_onmodule(mod=mod) - return mod def pytest_pycollect_makemodule(path, parent): return Module(path, parent) @@ -229,8 +227,12 @@ class Module(pytest.File, PyCollectorMixin): def _importtestmodule(self): # we assume we are only called once per module + assertion.before_module_import(self) try: - mod = self.fspath.pyimport(ensuresyspath=True) + try: + mod = self.fspath.pyimport(ensuresyspath=True) + finally: + assertion.after_module_import(self) except SyntaxError: excinfo = py.code.ExceptionInfo() raise self.CollectError(excinfo.getrepr(style="short"))