fix issue317: assertion rewriter support for the is_package method

This commit is contained in:
Ronny Pfannschmidt 2013-08-01 22:11:18 +02:00
parent 8f24e10571
commit 743711cd1f
3 changed files with 45 additions and 0 deletions

View File

@ -11,6 +11,8 @@ Changes between 2.3.5 and 2.4.DEV
for standard datatypes and recognise collections.abc. Thanks to
Brianna Laugher and Mathieu Agopian.
- fix issue317: assertion rewriter support for the is_package method
- fix issue335: document py.code.ExceptionInfo() object returned
from pytest.raises(), thanks Mathieu Agopian.

View File

@ -150,12 +150,25 @@ class AssertionRewritingHook(object):
mod.__file__ = co.co_filename
# Normally, this attribute is 3.2+.
mod.__cached__ = pyc
mod.__loader__ = self
py.builtin.exec_(co, mod.__dict__)
except:
del sys.modules[name]
raise
return sys.modules[name]
def is_package(self, name):
try:
fd, fn, desc = imp.find_module(name)
except ImportError:
return False
if fd is not None:
fd.close()
tp = desc[2]
return tp == imp.PKG_DIRECTORY
def _write_pyc(state, co, source_path, pyc):
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin

View File

@ -412,6 +412,36 @@ def test_rewritten():
testdir.tmpdir.join("test_newlines.py").write(b, "wb")
assert testdir.runpytest().ret == 0
class TestAssertionRewriteHookDetails(object):
def test_loader_is_package_false_for_module(self, testdir):
testdir.makepyfile(test_fun="""
def test_loader():
assert not __loader__.is_package(__name__)
""")
result = testdir.runpytest()
result.stdout.fnmatch_lines([
"* 1 passed*",
])
def test_loader_is_package_true_for_package(self, testdir):
testdir.makepyfile(test_fun="""
def test_loader():
assert not __loader__.is_package(__name__)
def test_fun():
assert __loader__.is_package('fun')
def test_missing():
assert not __loader__.is_package('pytest_not_there')
""")
pkg = testdir.mkpydir('fun')
result = testdir.runpytest()
result.stdout.fnmatch_lines([
'* 3 passed*',
])
@pytest.mark.skipif("sys.version_info[0] >= 3")
def test_assume_ascii(self, testdir):
content = "u'\xe2\x99\xa5'"