diff --git a/py/_plugin/standalonetemplate.py b/py/_plugin/standalonetemplate.py index cf1b26d1b..2d238b578 100755 --- a/py/_plugin/standalonetemplate.py +++ b/py/_plugin/standalonetemplate.py @@ -8,18 +8,10 @@ import base64 import zlib import imp -if sys.version_info >= (3,0): - exec("def do_exec(co, loc): exec(co, loc)\n") - import pickle - sources = sources.encode("ascii") # ensure bytes - sources = pickle.loads(zlib.decompress(base64.decodebytes(sources))) -else: - import cPickle as pickle - exec("def do_exec(co, loc): exec co in loc\n") - sources = pickle.loads(zlib.decompress(base64.decodestring(sources))) - class DictImporter(object): - sources = sources + def __init__(self, sources): + self.sources = sources + def find_module(self, fullname, path=None): if fullname in self.sources: return self @@ -53,12 +45,19 @@ class DictImporter(object): res = self.sources.get(name+'.__init__') return res - - -importer = DictImporter() - -sys.meta_path.append(importer) - if __name__ == "__main__": + if sys.version_info >= (3,0): + exec("def do_exec(co, loc): exec(co, loc)\n") + import pickle + sources = sources.encode("ascii") # ensure bytes + sources = pickle.loads(zlib.decompress(base64.decodebytes(sources))) + else: + import cPickle as pickle + exec("def do_exec(co, loc): exec co in loc\n") + sources = pickle.loads(zlib.decompress(base64.decodestring(sources))) + + importer = DictImporter(sources) + sys.meta_path.append(importer) + import py py.cmdline.pytest() diff --git a/testing/root/test_py_imports.py b/testing/root/test_py_imports.py index 1f2d56d45..81b1a1975 100644 --- a/testing/root/test_py_imports.py +++ b/testing/root/test_py_imports.py @@ -1,6 +1,7 @@ import py import types import sys +from py._test.outcome import Skipped def checksubpackage(name): obj = getattr(py, name) @@ -29,7 +30,6 @@ def test_importall(): nodirs = [ base.join('_path', 'gateway',), base.join('_code', 'oldmagic.py'), - base.join('_compat', 'testing'), ] if sys.version_info >= (3,0): nodirs.append(base.join('_code', '_assertionold.py')) @@ -50,7 +50,10 @@ def test_importall(): else: relpath = relpath.replace(base.sep, '.') modpath = 'py.%s' % relpath - check_import(modpath) + try: + check_import(modpath) + except Skipped: + pass def check_import(modpath): py.builtin.print_("checking import", modpath)