return to the old scheme of rewriting test modules from _importtestmodule
This commit is contained in:
parent
6fdcecb864
commit
5e31624315
|
@ -65,10 +65,6 @@ def pytest_configure(config):
|
||||||
config._assertstate = AssertionState(config, mode)
|
config._assertstate = AssertionState(config, mode)
|
||||||
config._assertstate.trace("configured with mode set to %r" % (mode,))
|
config._assertstate.trace("configured with mode set to %r" % (mode,))
|
||||||
|
|
||||||
def pytest_collectstart(collector):
|
|
||||||
if isinstance(collector, pytest.Session):
|
|
||||||
collector._rewritten_pycs = []
|
|
||||||
|
|
||||||
def _write_pyc(co, source_path):
|
def _write_pyc(co, source_path):
|
||||||
if hasattr(imp, "cache_from_source"):
|
if hasattr(imp, "cache_from_source"):
|
||||||
# Handle PEP 3147 pycs.
|
# Handle PEP 3147 pycs.
|
||||||
|
@ -86,9 +82,9 @@ def _write_pyc(co, source_path):
|
||||||
fp.close()
|
fp.close()
|
||||||
return pyc
|
return pyc
|
||||||
|
|
||||||
def pytest_pycollect_onmodule(mod):
|
def before_module_import(mod):
|
||||||
if mod is None or mod.config._assertstate.mode != "on":
|
if mod.config._assertstate.mode != "on":
|
||||||
return mod
|
return
|
||||||
# Some deep magic: load the source, rewrite the asserts, and write a
|
# Some deep magic: load the source, rewrite the asserts, and write a
|
||||||
# fake pyc, so that it'll be loaded when the module is imported.
|
# fake pyc, so that it'll be loaded when the module is imported.
|
||||||
source = mod.fspath.read()
|
source = mod.fspath.read()
|
||||||
|
@ -97,7 +93,7 @@ def pytest_pycollect_onmodule(mod):
|
||||||
except SyntaxError:
|
except SyntaxError:
|
||||||
# Let this pop up again in the real import.
|
# Let this pop up again in the real import.
|
||||||
mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,))
|
mod.config._assertstate.trace("failed to parse: %r" % (mod.fspath,))
|
||||||
return mod
|
return
|
||||||
rewrite_asserts(tree)
|
rewrite_asserts(tree)
|
||||||
try:
|
try:
|
||||||
co = compile(tree, str(mod.fspath), "exec")
|
co = compile(tree, str(mod.fspath), "exec")
|
||||||
|
@ -105,25 +101,20 @@ def pytest_pycollect_onmodule(mod):
|
||||||
# It's possible that this error is from some bug in the assertion
|
# It's possible that this error is from some bug in the assertion
|
||||||
# rewriting, but I don't know of a fast way to tell.
|
# rewriting, but I don't know of a fast way to tell.
|
||||||
mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,))
|
mod.config._assertstate.trace("failed to compile: %r" % (mod.fspath,))
|
||||||
return mod
|
|
||||||
pyc = _write_pyc(co, mod.fspath)
|
|
||||||
mod.session._rewritten_pycs.append(pyc)
|
|
||||||
mod.config._assertstate.trace("wrote pyc: %r" % (pyc,))
|
|
||||||
return mod
|
|
||||||
|
|
||||||
def pytest_collection_finish(session):
|
|
||||||
if not hasattr(session, "_rewritten_pycs"):
|
|
||||||
return
|
return
|
||||||
state = session.config._assertstate
|
mod._pyc = _write_pyc(co, mod.fspath)
|
||||||
# Remove our tweaked pycs to avoid subtle bugs.
|
mod.config._assertstate.trace("wrote pyc: %r" % (mod._pyc,))
|
||||||
for pyc in session._rewritten_pycs:
|
|
||||||
try:
|
def after_module_import(mod):
|
||||||
pyc.remove()
|
if not hasattr(mod, "_pyc"):
|
||||||
except py.error.ENOENT:
|
return
|
||||||
state.trace("couldn't find pyc: %r" % (pyc,))
|
state = mod.config._assertstate
|
||||||
else:
|
try:
|
||||||
state.trace("removed pyc: %r" % (pyc,))
|
mod._pyc.remove()
|
||||||
del session._rewritten_pycs[:]
|
except py.error.ENOENT:
|
||||||
|
state.trace("couldn't find pyc: %r" % (mod._pyc,))
|
||||||
|
else:
|
||||||
|
state.trace("removed pyc: %r" % (mod._pyc,))
|
||||||
|
|
||||||
def warn_about_missing_assertion():
|
def warn_about_missing_assertion():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -104,9 +104,6 @@ def pytest_pycollect_makemodule(path, parent):
|
||||||
"""
|
"""
|
||||||
pytest_pycollect_makemodule.firstresult = True
|
pytest_pycollect_makemodule.firstresult = True
|
||||||
|
|
||||||
def pytest_pycollect_onmodule(mod):
|
|
||||||
""" Called when a module is collected."""
|
|
||||||
|
|
||||||
def pytest_pycollect_makeitem(collector, name, obj):
|
def pytest_pycollect_makeitem(collector, name, obj):
|
||||||
""" return custom item/collector for a python object in a module, or None. """
|
""" return custom item/collector for a python object in a module, or None. """
|
||||||
pytest_pycollect_makeitem.firstresult = True
|
pytest_pycollect_makeitem.firstresult = True
|
||||||
|
|
|
@ -4,6 +4,7 @@ import inspect
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
from py._code.code import TerminalRepr
|
from py._code.code import TerminalRepr
|
||||||
|
from _pytest import assertion
|
||||||
|
|
||||||
import _pytest
|
import _pytest
|
||||||
cutdir = py.path.local(_pytest.__file__).dirpath()
|
cutdir = py.path.local(_pytest.__file__).dirpath()
|
||||||
|
@ -60,11 +61,8 @@ def pytest_collect_file(path, parent):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
mod = parent.ihook.pytest_pycollect_makemodule(
|
return parent.ihook.pytest_pycollect_makemodule(
|
||||||
path=path, parent=parent)
|
path=path, parent=parent)
|
||||||
if mod is not None:
|
|
||||||
parent.ihook.pytest_pycollect_onmodule(mod=mod)
|
|
||||||
return mod
|
|
||||||
|
|
||||||
def pytest_pycollect_makemodule(path, parent):
|
def pytest_pycollect_makemodule(path, parent):
|
||||||
return Module(path, parent)
|
return Module(path, parent)
|
||||||
|
@ -229,8 +227,12 @@ class Module(pytest.File, PyCollectorMixin):
|
||||||
|
|
||||||
def _importtestmodule(self):
|
def _importtestmodule(self):
|
||||||
# we assume we are only called once per module
|
# we assume we are only called once per module
|
||||||
|
assertion.before_module_import(self)
|
||||||
try:
|
try:
|
||||||
mod = self.fspath.pyimport(ensuresyspath=True)
|
try:
|
||||||
|
mod = self.fspath.pyimport(ensuresyspath=True)
|
||||||
|
finally:
|
||||||
|
assertion.after_module_import(self)
|
||||||
except SyntaxError:
|
except SyntaxError:
|
||||||
excinfo = py.code.ExceptionInfo()
|
excinfo = py.code.ExceptionInfo()
|
||||||
raise self.CollectError(excinfo.getrepr(style="short"))
|
raise self.CollectError(excinfo.getrepr(style="short"))
|
||||||
|
|
Loading…
Reference in New Issue