From 531416cc5a85e7e90c03ad75962fa5caf92fcf36 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Tue, 27 Oct 2020 16:07:03 +0200 Subject: [PATCH] code: simplify Code construction --- src/_pytest/_code/code.py | 14 +++++++------- src/_pytest/_code/source.py | 26 ++++++++++++++------------ src/_pytest/python.py | 2 +- testing/code/test_code.py | 19 ++++++++++--------- testing/code/test_excinfo.py | 6 +++--- testing/code/test_source.py | 16 ++++------------ testing/test_assertrewrite.py | 2 +- 7 files changed, 40 insertions(+), 45 deletions(-) diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index 430e45242..423069330 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -56,12 +56,12 @@ class Code: __slots__ = ("raw",) - def __init__(self, rawcode) -> None: - if not hasattr(rawcode, "co_filename"): - rawcode = getrawcode(rawcode) - if not isinstance(rawcode, CodeType): - raise TypeError(f"not a code object: {rawcode!r}") - self.raw = rawcode + def __init__(self, obj: CodeType) -> None: + self.raw = obj + + @classmethod + def from_function(cls, obj: object) -> "Code": + return cls(getrawcode(obj)) def __eq__(self, other): return self.raw == other.raw @@ -1196,7 +1196,7 @@ def getfslineno(obj: object) -> Tuple[Union[str, py.path.local], int]: obj = obj.place_as # type: ignore[attr-defined] try: - code = Code(obj) + code = Code.from_function(obj) except TypeError: try: fn = inspect.getsourcefile(obj) or inspect.getfile(obj) # type: ignore[arg-type] diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index c63a42360..6f54057c0 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -2,6 +2,7 @@ import ast import inspect import textwrap import tokenize +import types import warnings from bisect import bisect_right from typing import Iterable @@ -29,8 +30,11 @@ class Source: elif isinstance(obj, str): self.lines = deindent(obj.split("\n")) else: - rawcode = getrawcode(obj) - src = inspect.getsource(rawcode) + try: + rawcode = getrawcode(obj) + src = inspect.getsource(rawcode) + except TypeError: + src = inspect.getsource(obj) # type: ignore[arg-type] self.lines = deindent(src.split("\n")) def __eq__(self, other: object) -> bool: @@ -122,19 +126,17 @@ def findsource(obj) -> Tuple[Optional[Source], int]: return source, lineno -def getrawcode(obj, trycall: bool = True): +def getrawcode(obj: object, trycall: bool = True) -> types.CodeType: """Return code object for given function.""" try: - return obj.__code__ + return obj.__code__ # type: ignore[attr-defined,no-any-return] except AttributeError: - obj = getattr(obj, "f_code", obj) - obj = getattr(obj, "__code__", obj) - if trycall and not hasattr(obj, "co_firstlineno"): - if hasattr(obj, "__call__") and not inspect.isclass(obj): - x = getrawcode(obj.__call__, trycall=False) - if hasattr(x, "co_firstlineno"): - return x - return obj + pass + if trycall: + call = getattr(obj, "__call__", None) + if call and not isinstance(obj, type): + return getrawcode(call, trycall=False) + raise TypeError(f"could not get code object for {obj!r}") def deindent(lines: Iterable[str]) -> List[str]: diff --git a/src/_pytest/python.py b/src/_pytest/python.py index 35797cc07..e477b8b45 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -1647,7 +1647,7 @@ class Function(PyobjMixin, nodes.Item): def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None: if hasattr(self, "_obj") and not self.config.getoption("fulltrace", False): - code = _pytest._code.Code(get_real_func(self.obj)) + code = _pytest._code.Code.from_function(get_real_func(self.obj)) path, firstlineno = code.path, code.firstlineno traceback = excinfo.traceback ntraceback = traceback.cut(path=path, firstlineno=firstlineno) diff --git a/testing/code/test_code.py b/testing/code/test_code.py index bae86be34..33809528a 100644 --- a/testing/code/test_code.py +++ b/testing/code/test_code.py @@ -28,11 +28,12 @@ def test_code_gives_back_name_for_not_existing_file() -> None: assert code.fullsource is None -def test_code_with_class() -> None: +def test_code_from_function_with_class() -> None: class A: pass - pytest.raises(TypeError, Code, A) + with pytest.raises(TypeError): + Code.from_function(A) def x() -> None: @@ -40,13 +41,13 @@ def x() -> None: def test_code_fullsource() -> None: - code = Code(x) + code = Code.from_function(x) full = code.fullsource assert "test_code_fullsource()" in str(full) def test_code_source() -> None: - code = Code(x) + code = Code.from_function(x) src = code.source() expected = """def x() -> None: raise NotImplementedError()""" @@ -73,7 +74,7 @@ def test_getstatement_empty_fullsource() -> None: def test_code_from_func() -> None: - co = Code(test_frame_getsourcelineno_myself) + co = Code.from_function(test_frame_getsourcelineno_myself) assert co.firstlineno assert co.path @@ -92,25 +93,25 @@ def test_code_getargs() -> None: def f1(x): raise NotImplementedError() - c1 = Code(f1) + c1 = Code.from_function(f1) assert c1.getargs(var=True) == ("x",) def f2(x, *y): raise NotImplementedError() - c2 = Code(f2) + c2 = Code.from_function(f2) assert c2.getargs(var=True) == ("x", "y") def f3(x, **z): raise NotImplementedError() - c3 = Code(f3) + c3 = Code.from_function(f3) assert c3.getargs(var=True) == ("x", "z") def f4(x, *y, **z): raise NotImplementedError() - c4 = Code(f4) + c4 = Code.from_function(f4) assert c4.getargs(var=True) == ("x", "y", "z") diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index a43704ff0..5b9e3eda5 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -147,7 +147,7 @@ class TestTraceback_f_g_h: ] def test_traceback_cut(self): - co = _pytest._code.Code(f) + co = _pytest._code.Code.from_function(f) path, firstlineno = co.path, co.firstlineno traceback = self.excinfo.traceback newtraceback = traceback.cut(path=path, firstlineno=firstlineno) @@ -290,7 +290,7 @@ class TestTraceback_f_g_h: excinfo = pytest.raises(ValueError, f) tb = excinfo.traceback entry = tb.getcrashentry() - co = _pytest._code.Code(h) + co = _pytest._code.Code.from_function(h) assert entry.frame.code.path == co.path assert entry.lineno == co.firstlineno + 1 assert entry.frame.code.name == "h" @@ -307,7 +307,7 @@ class TestTraceback_f_g_h: excinfo = pytest.raises(ValueError, f) tb = excinfo.traceback entry = tb.getcrashentry() - co = _pytest._code.Code(g) + co = _pytest._code.Code.from_function(g) assert entry.frame.code.path == co.path assert entry.lineno == co.firstlineno + 2 assert entry.frame.code.name == "g" diff --git a/testing/code/test_source.py b/testing/code/test_source.py index fa2136ef1..04d0ea932 100644 --- a/testing/code/test_source.py +++ b/testing/code/test_source.py @@ -16,8 +16,8 @@ import py.path import pytest from _pytest._code import Code from _pytest._code import Frame -from _pytest._code import Source from _pytest._code import getfslineno +from _pytest._code import Source def test_source_str_function() -> None: @@ -291,7 +291,7 @@ def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot) -> None: # does not return the "x = 1" last line. source = Source( """ - class A(object): + class A: def method(self): x = 1 """ @@ -374,14 +374,6 @@ def test_getfslineno() -> None: B.__name__ = B.__qualname__ = "B2" assert getfslineno(B)[1] == -1 - co = compile("...", "", "eval") - assert co.co_filename == "" - - if hasattr(sys, "pypy_version_info"): - assert getfslineno(co) == ("", -1) - else: - assert getfslineno(co) == ("", 0) - def test_code_of_object_instance_with_call() -> None: class A: @@ -393,14 +385,14 @@ def test_code_of_object_instance_with_call() -> None: def __call__(self) -> None: pass - code = Code(WithCall()) + code = Code.from_function(WithCall()) assert "pass" in str(code.source()) class Hello: def __call__(self) -> None: pass - pytest.raises(TypeError, lambda: Code(Hello)) + pytest.raises(TypeError, lambda: Code.from_function(Hello)) def getstatement(lineno: int, source) -> Source: diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 58a31ab8d..09383cafe 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -42,7 +42,7 @@ def getmsg( f, extra_ns: Optional[Mapping[str, object]] = None, *, must_pass: bool = False ) -> Optional[str]: """Rewrite the assertions in f, run it, and get the failure message.""" - src = "\n".join(_pytest._code.Code(f).source().lines) + src = "\n".join(_pytest._code.Code.from_function(f).source().lines) mod = rewrite(src) code = compile(mod, "", "exec") ns: Dict[str, object] = {}