diff --git a/py/code/code.py b/py/code/code.py index 147bb3433..56f8e4e23 100644 --- a/py/code/code.py +++ b/py/code/code.py @@ -1,4 +1,5 @@ import py +from py.__.code import source class Code(object): """ wrapper around Python code objects """ @@ -75,14 +76,8 @@ class Code(object): def fullsource(self): """ return a py.code.Source object for the full source file of the code """ - fn = self.raw.co_filename - try: - return fn.__source__ - except AttributeError: - path = self.path - if not isinstance(path, py.path.local): - return None - return py.code.Source(self.path.read(mode="rU")) + full, _ = source.findsource(self.raw) + return full fullsource = property(fullsource, None, None, "full source containing this code object") diff --git a/py/code/source.py b/py/code/source.py index 76ed8d9d1..6923012fe 100644 --- a/py/code/source.py +++ b/py/code/source.py @@ -223,6 +223,28 @@ def compile_(source, filename=None, mode='exec', flags= class MyStr(str): """ custom string which allows to add attributes. """ +def findsource(obj): + if hasattr(obj, 'func_code'): + obj = obj.func_code + elif hasattr(obj, 'f_code'): + obj = obj.f_code + try: + fullsource = obj.co_filename.__source__ + except AttributeError: + try: + sourcelines, lineno = py.std.inspect.findsource(obj) + except (KeyboardInterrupt, SystemExit): + raise + except: + return None, None + source = Source() + source.lines = map(str.rstrip, sourcelines) + return source, lineno + else: + lineno = obj.co_firstlineno - 1 + return fullsource, lineno + + def getsource(obj, **kwargs): if hasattr(obj, 'func_code'): obj = obj.func_code @@ -240,7 +262,7 @@ def getsource(obj, **kwargs): else: lineno = obj.co_firstlineno - 1 end = fullsource.getblockend(lineno) - return fullsource[lineno:end+1] + return Source(fullsource[lineno:end+1], deident=True) def deindent(lines, offset=None): diff --git a/py/code/testing/test_excinfo.py b/py/code/testing/test_excinfo.py index f87cb87bb..28a314432 100644 --- a/py/code/testing/test_excinfo.py +++ b/py/code/testing/test_excinfo.py @@ -329,7 +329,7 @@ raise ValueError() firstlineno = 5 def fullsource(self): - raise fail + return None fullsource = property(fullsource) class FakeFrame(object): diff --git a/py/code/testing/test_source.py b/py/code/testing/test_source.py index b19f46d23..eb67fcda7 100644 --- a/py/code/testing/test_source.py +++ b/py/code/testing/test_source.py @@ -314,3 +314,43 @@ def test_source_of_class_at_eof_without_newline(): path.write(source) s2 = py.code.Source(tmpdir.join("a.py").pyimport().A) assert str(source).strip() == str(s2).strip() + +if True: + def x(): + pass + +def test_getsource_fallback(): + from py.__.code.source import getsource + expected = """def x(): + pass""" + src = getsource(x) + assert src == expected + +def test_getsource___source__(): + from py.__.code.source import getsource + x = py.code.compile("""if 1: + def x(): + pass +""") + + expected = """def x(): + pass""" + src = getsource(x) + assert src == expected + +def test_findsource_fallback(): + from py.__.code.source import findsource + src, lineno = findsource(x) + assert 'test_findsource_simple' in str(src) + assert src[lineno] == ' def x():' + +def test_findsource___source__(): + from py.__.code.source import findsource + x = py.code.compile("""if 1: + def x(): + pass +""") + + src, lineno = findsource(x) + assert 'if 1:' in str(src) + assert src[lineno] == ' def x():' diff --git a/py/code/traceback2.py b/py/code/traceback2.py index ac6d26f39..d5d40e8be 100644 --- a/py/code/traceback2.py +++ b/py/code/traceback2.py @@ -51,19 +51,9 @@ class TracebackEntry(object): def getsource(self): """ return failing source code. """ - try: - source = self.frame.code.fullsource - except (IOError, py.error.ENOENT): - return None + source = self.frame.code.fullsource if source is None: - try: - sourcelines, lineno = py.std.inspect.findsource(self.frame.code.raw) - except (KeyboardInterrupt, SystemExit): - raise - except: - return None - source = py.code.Source() - source.lines = map(str.rstrip, sourcelines) + return None start = self.getfirstlinesource() end = self.lineno try: diff --git a/py/test/pycollect.py b/py/test/pycollect.py index db5bb4702..9cf6182bc 100644 --- a/py/test/pycollect.py +++ b/py/test/pycollect.py @@ -18,6 +18,7 @@ a tree of collectors and test items that this modules provides:: """ import py from py.__.test.collect import configproperty, warnoldcollect +from py.__.code.source import findsource class PyobjMixin(object): def obj(): @@ -68,7 +69,7 @@ class PyobjMixin(object): fspath = fn and py.path.local(fn) or None if fspath: try: - lines, lineno = py.std.inspect.findsource(self.obj) + _, lineno = findsource(self.obj) except IOError: lineno = None else: