This commit is contained in:
holger krekel 2013-10-11 09:30:08 +02:00
commit 124e58e42d
2 changed files with 51 additions and 0 deletions

View File

@ -41,6 +41,7 @@ class AssertionRewritingHook(object):
def __init__(self):
self.session = None
self.modules = {}
self._register_with_pkg_resources()
def set_session(self, session):
self.fnpats = session.config.getini("python_files")
@ -169,6 +170,24 @@ class AssertionRewritingHook(object):
tp = desc[2]
return tp == imp.PKG_DIRECTORY
@classmethod
def _register_with_pkg_resources(cls):
"""
Ensure package resources can be loaded from this loader. May be called
multiple times, as the operation is idempotent.
"""
try:
import pkg_resources
# access an attribute in case a deferred importer is present
pkg_resources.__name__
except ImportError:
return
# Since pytest tests are always located in the file system, the
# DefaultProvider is appropriate.
pkg_resources.register_loader_type(cls, pkg_resources.DefaultProvider)
def _write_pyc(state, co, source_path, pyc):
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin

View File

@ -493,3 +493,35 @@ def test_rewritten():
raise e
monkeypatch.setattr(b, "open", open)
assert not _write_pyc(state, [1], source_path, pycpath)
def test_resources_provider_for_loader(self, testdir):
"""
Attempts to load resources from a package should succeed normally,
even when the AssertionRewriteHook is used to load the modules.
See #366 for details.
"""
pytest.importorskip("pkg_resources")
testdir.mkpydir('testpkg')
contents = {
'testpkg/test_pkg': """
import pkg_resources
import pytest
from _pytest.assertion.rewrite import AssertionRewritingHook
def test_load_resource():
assert isinstance(__loader__, AssertionRewritingHook)
res = pkg_resources.resource_string(__name__, 'resource.txt')
res = res.decode('ascii')
assert res == 'Load me please.'
""",
}
testdir.makepyfile(**contents)
testdir.maketxtfile(**{'testpkg/resource': "Load me please."})
result = testdir.runpytest()
result.stdout.fnmatch_lines([
'* 1 passed*',
])