Merge pull request #7438 from bluetech/source-cleanups

code/source: some cleanups
This commit is contained in:
Ran Benita 2020-07-04 12:57:32 +03:00 committed by GitHub
commit 36b958c99e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 106 additions and 416 deletions

View File

@ -0,0 +1,11 @@
Some changes were made to the internal ``_pytest._code.source``, listed here
for the benefit of plugin authors who may be using it:
- The ``deindent`` argument to ``Source()`` has been removed, now it is always true.
- Support for zero or multiple arguments to ``Source()`` has been removed.
- Support for comparing ``Source`` with an ``str`` has been removed.
- The methods ``Source.isparseable()`` and ``Source.putaround()`` have been removed.
- The method ``Source.compile()`` and function ``_pytest._code.compile()`` have
been removed; use plain ``compile()`` instead.
- The function ``_pytest._code.source.getsource()`` has been removed; use
``Source()`` directly instead.

View File

@ -7,7 +7,6 @@ from .code import getfslineno
from .code import getrawcode from .code import getrawcode
from .code import Traceback from .code import Traceback
from .code import TracebackEntry from .code import TracebackEntry
from .source import compile_ as compile
from .source import Source from .source import Source
__all__ = [ __all__ = [
@ -19,6 +18,5 @@ __all__ = [
"getrawcode", "getrawcode",
"Traceback", "Traceback",
"TracebackEntry", "TracebackEntry",
"compile",
"Source", "Source",
] ]

View File

@ -1,61 +1,43 @@
import ast import ast
import inspect import inspect
import linecache
import sys
import textwrap import textwrap
import tokenize import tokenize
import warnings import warnings
from bisect import bisect_right from bisect import bisect_right
from types import CodeType from typing import Iterable
from types import FrameType
from typing import Iterator from typing import Iterator
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Sequence
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
import py
from _pytest.compat import overload from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING
if TYPE_CHECKING:
from typing_extensions import Literal
class Source: class Source:
""" an immutable object holding a source code fragment, """An immutable object holding a source code fragment.
possibly deindenting it.
When using Source(...), the source lines are deindented.
""" """
_compilecounter = 0 def __init__(self, obj: object = None) -> None:
if not obj:
def __init__(self, *parts, **kwargs) -> None: self.lines = [] # type: List[str]
self.lines = lines = [] # type: List[str] elif isinstance(obj, Source):
de = kwargs.get("deindent", True) self.lines = obj.lines
for part in parts: elif isinstance(obj, (tuple, list)):
if not part: self.lines = deindent(x.rstrip("\n") for x in obj)
partlines = [] # type: List[str] elif isinstance(obj, str):
elif isinstance(part, Source): self.lines = deindent(obj.split("\n"))
partlines = part.lines
elif isinstance(part, (tuple, list)):
partlines = [x.rstrip("\n") for x in part]
elif isinstance(part, str):
partlines = part.split("\n")
else: else:
partlines = getsource(part, deindent=de).lines rawcode = getrawcode(obj)
if de: src = inspect.getsource(rawcode)
partlines = deindent(partlines) self.lines = deindent(src.split("\n"))
lines.extend(partlines)
def __eq__(self, other): def __eq__(self, other: object) -> bool:
try: if not isinstance(other, Source):
return NotImplemented
return self.lines == other.lines return self.lines == other.lines
except AttributeError:
if isinstance(other, str):
return str(self) == other
return False
# Ignore type because of https://github.com/python/mypy/issues/4266. # Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore __hash__ = None # type: ignore
@ -97,19 +79,6 @@ class Source:
source.lines[:] = self.lines[start:end] source.lines[:] = self.lines[start:end]
return source return source
def putaround(
self, before: str = "", after: str = "", indent: str = " " * 4
) -> "Source":
""" return a copy of the source object with
'before' and 'after' wrapped around it.
"""
beforesource = Source(before)
aftersource = Source(after)
newsource = Source()
lines = [(indent + line) for line in self.lines]
newsource.lines = beforesource.lines + lines + aftersource.lines
return newsource
def indent(self, indent: str = " " * 4) -> "Source": def indent(self, indent: str = " " * 4) -> "Source":
""" return a copy of the source object with """ return a copy of the source object with
all lines indented by the given indent-string. all lines indented by the given indent-string.
@ -140,142 +109,9 @@ class Source:
newsource.lines[:] = deindent(self.lines) newsource.lines[:] = deindent(self.lines)
return newsource return newsource
def isparseable(self, deindent: bool = True) -> bool:
""" return True if source is parseable, heuristically
deindenting it by default.
"""
if deindent:
source = str(self.deindent())
else:
source = str(self)
try:
ast.parse(source)
except (SyntaxError, ValueError, TypeError):
return False
else:
return True
def __str__(self) -> str: def __str__(self) -> str:
return "\n".join(self.lines) return "\n".join(self.lines)
@overload
def compile(
self,
filename: Optional[str] = ...,
mode: str = ...,
flag: "Literal[0]" = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> CodeType:
raise NotImplementedError()
@overload # noqa: F811
def compile( # noqa: F811
self,
filename: Optional[str] = ...,
mode: str = ...,
flag: int = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()
def compile( # noqa: F811
self,
filename: Optional[str] = None,
mode: str = "exec",
flag: int = 0,
dont_inherit: int = 0,
_genframe: Optional[FrameType] = None,
) -> Union[CodeType, ast.AST]:
""" return compiled code object. if filename is None
invent an artificial filename which displays
the source/line position of the caller frame.
"""
if not filename or py.path.local(filename).check(file=0):
if _genframe is None:
_genframe = sys._getframe(1) # the caller
fn, lineno = _genframe.f_code.co_filename, _genframe.f_lineno
base = "<%d-codegen " % self._compilecounter
self.__class__._compilecounter += 1
if not filename:
filename = base + "%s:%d>" % (fn, lineno)
else:
filename = base + "%r %s:%d>" % (filename, fn, lineno)
source = "\n".join(self.lines) + "\n"
try:
co = compile(source, filename, mode, flag)
except SyntaxError as ex:
# re-represent syntax errors from parsing python strings
msglines = self.lines[: ex.lineno]
if ex.offset:
msglines.append(" " * ex.offset + "^")
msglines.append("(code was compiled probably from here: %s)" % filename)
newex = SyntaxError("\n".join(msglines))
newex.offset = ex.offset
newex.lineno = ex.lineno
newex.text = ex.text
raise newex from ex
else:
if flag & ast.PyCF_ONLY_AST:
assert isinstance(co, ast.AST)
return co
assert isinstance(co, CodeType)
lines = [(x + "\n") for x in self.lines]
# Type ignored because linecache.cache is private.
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
return co
#
# public API shortcut functions
#
@overload
def compile_(
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: "Literal[0]" = ...,
dont_inherit: int = ...,
) -> CodeType:
raise NotImplementedError()
@overload # noqa: F811
def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: int = ...,
dont_inherit: int = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()
def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = None,
mode: str = "exec",
flags: int = 0,
dont_inherit: int = 0,
) -> Union[CodeType, ast.AST]:
""" compile the given source to a raw code object,
and maintain an internal cache which allows later
retrieval of the source code for the code object
and any recursively created code objects.
"""
if isinstance(source, ast.AST):
# XXX should Source support having AST?
assert filename is not None
co = compile(source, filename, mode, flags, dont_inherit)
assert isinstance(co, (CodeType, ast.AST))
return co
_genframe = sys._getframe(1) # the caller
s = Source(source)
return s.compile(filename, mode, flags, _genframe=_genframe)
# #
# helper functions # helper functions
@ -307,17 +143,7 @@ def getrawcode(obj, trycall: bool = True):
return obj return obj
def getsource(obj, **kwargs) -> Source: def deindent(lines: Iterable[str]) -> List[str]:
obj = getrawcode(obj)
try:
strsrc = inspect.getsource(obj)
except IndentationError:
strsrc = '"Buggy python version consider upgrading, cannot get source"'
assert isinstance(strsrc, str)
return Source(strsrc, **kwargs)
def deindent(lines: Sequence[str]) -> List[str]:
return textwrap.dedent("\n".join(lines)).splitlines() return textwrap.dedent("\n".join(lines)).splitlines()

View File

@ -9,7 +9,6 @@ from typing import Tuple
import attr import attr
import _pytest._code
from _pytest.compat import TYPE_CHECKING from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import hookimpl from _pytest.config import hookimpl
@ -105,7 +104,8 @@ def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool,
if hasattr(item, "obj"): if hasattr(item, "obj"):
globals_.update(item.obj.__globals__) # type: ignore[attr-defined] globals_.update(item.obj.__globals__) # type: ignore[attr-defined]
try: try:
condition_code = _pytest._code.compile(condition, mode="eval") filename = "<{} condition>".format(mark.name)
condition_code = compile(condition, filename, "eval")
result = eval(condition_code, globals_) result = eval(condition_code, globals_)
except SyntaxError as exc: except SyntaxError as exc:
msglines = [ msglines = [

View File

@ -6,6 +6,7 @@ import pytest
from _pytest._code import Code from _pytest._code import Code
from _pytest._code import ExceptionInfo from _pytest._code import ExceptionInfo
from _pytest._code import Frame from _pytest._code import Frame
from _pytest._code import Source
from _pytest._code.code import ExceptionChainRepr from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ReprFuncArgs from _pytest._code.code import ReprFuncArgs
@ -67,7 +68,7 @@ def test_getstatement_empty_fullsource() -> None:
f = Frame(func()) f = Frame(func())
with mock.patch.object(f.code.__class__, "fullsource", None): with mock.patch.object(f.code.__class__, "fullsource", None):
assert f.statement == "" assert f.statement == Source("")
def test_code_from_func() -> None: def test_code_from_func() -> None:

View File

@ -127,24 +127,28 @@ class TestTraceback_f_g_h:
assert s.endswith("raise ValueError") assert s.endswith("raise ValueError")
def test_traceback_entry_getsource_in_construct(self): def test_traceback_entry_getsource_in_construct(self):
source = _pytest._code.Source(
"""\
def xyz(): def xyz():
try: try:
raise ValueError raise ValueError
except somenoname: except somenoname: # type: ignore[name-defined] # noqa: F821
pass pass # pragma: no cover
xyz()
"""
)
try: try:
exec(source.compile()) xyz()
except NameError: except NameError:
tb = _pytest._code.ExceptionInfo.from_current().traceback excinfo = _pytest._code.ExceptionInfo.from_current()
print(tb[-1].getsource()) else:
s = str(tb[-1].getsource()) assert False, "did not raise NameError"
assert s.startswith("def xyz():\n try:")
assert s.strip().endswith("except somenoname:") tb = excinfo.traceback
source = tb[-1].getsource()
assert source is not None
assert source.deindent().lines == [
"def xyz():",
" try:",
" raise ValueError",
" except somenoname: # type: ignore[name-defined] # noqa: F821",
]
def test_traceback_cut(self): def test_traceback_cut(self):
co = _pytest._code.Code(f) co = _pytest._code.Code(f)
@ -445,16 +449,6 @@ class TestFormattedExcinfo:
return importasmod return importasmod
def excinfo_from_exec(self, source):
source = _pytest._code.Source(source).strip()
try:
exec(source.compile())
except KeyboardInterrupt:
raise
except BaseException:
return _pytest._code.ExceptionInfo.from_current()
assert 0, "did not raise"
def test_repr_source(self): def test_repr_source(self):
pr = FormattedExcinfo() pr = FormattedExcinfo()
source = _pytest._code.Source( source = _pytest._code.Source(
@ -471,19 +465,29 @@ class TestFormattedExcinfo:
def test_repr_source_excinfo(self) -> None: def test_repr_source_excinfo(self) -> None:
""" check if indentation is right """ """ check if indentation is right """
pr = FormattedExcinfo() try:
excinfo = self.excinfo_from_exec(
"""
def f(): def f():
assert 0 1 / 0
f() f()
"""
) except BaseException:
excinfo = _pytest._code.ExceptionInfo.from_current()
else:
assert False, "did not raise"
pr = FormattedExcinfo() pr = FormattedExcinfo()
source = pr._getentrysource(excinfo.traceback[-1]) source = pr._getentrysource(excinfo.traceback[-1])
assert source is not None assert source is not None
lines = pr.get_source(source, 1, excinfo) lines = pr.get_source(source, 1, excinfo)
assert lines == [" def f():", "> assert 0", "E AssertionError"] for line in lines:
print(line)
assert lines == [
" def f():",
"> 1 / 0",
"E ZeroDivisionError: division by zero",
]
def test_repr_source_not_existing(self): def test_repr_source_not_existing(self):
pr = FormattedExcinfo() pr = FormattedExcinfo()

View File

@ -3,7 +3,9 @@
# or redundant on purpose and can't be disable on a line-by-line basis # or redundant on purpose and can't be disable on a line-by-line basis
import ast import ast
import inspect import inspect
import linecache
import sys import sys
import textwrap
from types import CodeType from types import CodeType
from typing import Any from typing import Any
from typing import Dict from typing import Dict
@ -32,14 +34,6 @@ def test_source_str_function() -> None:
assert str(x) == "\n3" assert str(x) == "\n3"
def test_unicode() -> None:
x = Source("4")
assert str(x) == "4"
co = _pytest._code.compile('"å"', mode="eval")
val = eval(co)
assert isinstance(val, str)
def test_source_from_function() -> None: def test_source_from_function() -> None:
source = _pytest._code.Source(test_source_str_function) source = _pytest._code.Source(test_source_str_function)
assert str(source).startswith("def test_source_str_function() -> None:") assert str(source).startswith("def test_source_str_function() -> None:")
@ -62,47 +56,12 @@ def test_source_from_lines() -> None:
def test_source_from_inner_function() -> None: def test_source_from_inner_function() -> None:
def f(): def f():
pass raise NotImplementedError()
source = _pytest._code.Source(f, deindent=False)
assert str(source).startswith(" def f():")
source = _pytest._code.Source(f) source = _pytest._code.Source(f)
assert str(source).startswith("def f():") assert str(source).startswith("def f():")
def test_source_putaround_simple() -> None:
source = Source("raise ValueError")
source = source.putaround(
"try:",
"""\
except ValueError:
x = 42
else:
x = 23""",
)
assert (
str(source)
== """\
try:
raise ValueError
except ValueError:
x = 42
else:
x = 23"""
)
def test_source_putaround() -> None:
source = Source()
source = source.putaround(
"""
if 1:
x=1
"""
)
assert str(source).strip() == "if 1:\n x=1"
def test_source_strips() -> None: def test_source_strips() -> None:
source = Source("") source = Source("")
assert source == Source() assert source == Source()
@ -117,24 +76,6 @@ def test_source_strip_multiline() -> None:
assert source2.lines == [" hello"] assert source2.lines == [" hello"]
def test_syntaxerror_rerepresentation() -> None:
ex = pytest.raises(SyntaxError, _pytest._code.compile, "xyz xyz")
assert ex is not None
assert ex.value.lineno == 1
assert ex.value.offset in {5, 7} # cpython: 7, pypy3.6 7.1.1: 5
assert ex.value.text
assert ex.value.text.rstrip("\n") == "xyz xyz"
def test_isparseable() -> None:
assert Source("hello").isparseable()
assert Source("if 1:\n pass").isparseable()
assert Source(" \nif 1:\n pass").isparseable()
assert not Source("if 1:\n").isparseable()
assert not Source(" \nif 1:\npass").isparseable()
assert not Source(chr(0)).isparseable()
class TestAccesses: class TestAccesses:
def setup_class(self) -> None: def setup_class(self) -> None:
self.source = Source( self.source = Source(
@ -148,7 +89,6 @@ class TestAccesses:
def test_getrange(self) -> None: def test_getrange(self) -> None:
x = self.source[0:2] x = self.source[0:2]
assert x.isparseable()
assert len(x.lines) == 2 assert len(x.lines) == 2
assert str(x) == "def f(x):\n pass" assert str(x) == "def f(x):\n pass"
@ -168,7 +108,7 @@ class TestAccesses:
assert len(values) == 4 assert len(values) == 4
class TestSourceParsingAndCompiling: class TestSourceParsing:
def setup_class(self) -> None: def setup_class(self) -> None:
self.source = Source( self.source = Source(
"""\ """\
@ -179,39 +119,6 @@ class TestSourceParsingAndCompiling:
""" """
).strip() ).strip()
def test_compile(self) -> None:
co = _pytest._code.compile("x=3")
d = {} # type: Dict[str, Any]
exec(co, d)
assert d["x"] == 3
def test_compile_and_getsource_simple(self) -> None:
co = _pytest._code.compile("x=3")
exec(co)
source = _pytest._code.Source(co)
assert str(source) == "x=3"
def test_compile_and_getsource_through_same_function(self) -> None:
def gensource(source):
return _pytest._code.compile(source)
co1 = gensource(
"""
def f():
raise KeyError()
"""
)
co2 = gensource(
"""
def f():
raise ValueError()
"""
)
source1 = inspect.getsource(co1)
assert "KeyError" in source1
source2 = inspect.getsource(co2)
assert "ValueError" in source2
def test_getstatement(self) -> None: def test_getstatement(self) -> None:
# print str(self.source) # print str(self.source)
ass = str(self.source[1:]) ass = str(self.source[1:])
@ -228,9 +135,9 @@ class TestSourceParsingAndCompiling:
''')""" ''')"""
) )
s = source.getstatement(0) s = source.getstatement(0)
assert s == str(source) assert s == source
s = source.getstatement(1) s = source.getstatement(1)
assert s == str(source) assert s == source
def test_getstatementrange_within_constructs(self) -> None: def test_getstatementrange_within_constructs(self) -> None:
source = Source( source = Source(
@ -308,44 +215,6 @@ class TestSourceParsingAndCompiling:
source = Source(":") source = Source(":")
pytest.raises(SyntaxError, lambda: source.getstatementrange(0)) pytest.raises(SyntaxError, lambda: source.getstatementrange(0))
def test_compile_to_ast(self) -> None:
source = Source("x = 4")
mod = source.compile(flag=ast.PyCF_ONLY_AST)
assert isinstance(mod, ast.Module)
compile(mod, "<filename>", "exec")
def test_compile_and_getsource(self) -> None:
co = self.source.compile()
exec(co, globals())
f(7) # type: ignore
excinfo = pytest.raises(AssertionError, f, 6) # type: ignore
assert excinfo is not None
frame = excinfo.traceback[-1].frame
assert isinstance(frame.code.fullsource, Source)
stmt = frame.code.fullsource.getstatement(frame.lineno)
assert str(stmt).strip().startswith("assert")
@pytest.mark.parametrize("name", ["", None, "my"])
def test_compilefuncs_and_path_sanity(self, name: Optional[str]) -> None:
def check(comp, name) -> None:
co = comp(self.source, name)
if not name:
expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2) # type: ignore
else:
expected = "codegen %r %s:%d>" % (name, mypath, mylineno + 2 + 2) # type: ignore
fn = co.co_filename
assert fn.endswith(expected)
mycode = _pytest._code.Code(self.test_compilefuncs_and_path_sanity)
mylineno = mycode.firstlineno
mypath = mycode.path
for comp in _pytest._code.compile, _pytest._code.Source.compile:
check(comp, name)
def test_offsetless_synerr(self):
pytest.raises(SyntaxError, _pytest._code.compile, "lambda a,a: 0", mode="eval")
def test_getstartingblock_singleline() -> None: def test_getstartingblock_singleline() -> None:
class A: class A:
@ -375,18 +244,16 @@ def test_getline_finally() -> None:
def test_getfuncsource_dynamic() -> None: def test_getfuncsource_dynamic() -> None:
source = """
def f(): def f():
raise ValueError raise NotImplementedError()
def g(): pass def g():
""" pass # pragma: no cover
co = _pytest._code.compile(source)
exec(co, globals()) f_source = _pytest._code.Source(f)
f_source = _pytest._code.Source(f) # type: ignore g_source = _pytest._code.Source(g)
g_source = _pytest._code.Source(g) # type: ignore assert str(f_source).strip() == "def f():\n raise NotImplementedError()"
assert str(f_source).strip() == "def f():\n raise ValueError" assert str(g_source).strip() == "def g():\n pass # pragma: no cover"
assert str(g_source).strip() == "def g(): pass"
def test_getfuncsource_with_multine_string() -> None: def test_getfuncsource_with_multine_string() -> None:
@ -440,30 +307,11 @@ if True:
pass pass
def test_getsource_fallback() -> None: def test_source_fallback() -> None:
from _pytest._code.source import getsource src = Source(x)
expected = """def x(): expected = """def x():
pass""" pass"""
src = getsource(x) assert str(src) == expected
assert src == expected
def test_idem_compile_and_getsource() -> None:
from _pytest._code.source import getsource
expected = "def x(): pass"
co = _pytest._code.compile(expected)
src = getsource(co)
assert src == expected
def test_compile_ast() -> None:
# We don't necessarily want to support this.
# This test was added just for coverage.
stmt = ast.parse("def x(): pass")
co = _pytest._code.compile(stmt, filename="foo.py")
assert isinstance(co, CodeType)
def test_findsource_fallback() -> None: def test_findsource_fallback() -> None:
@ -475,15 +323,15 @@ def test_findsource_fallback() -> None:
assert src[lineno] == " def x():" assert src[lineno] == " def x():"
def test_findsource() -> None: def test_findsource(monkeypatch) -> None:
from _pytest._code.source import findsource from _pytest._code.source import findsource
co = _pytest._code.compile( filename = "<pytest-test_findsource>"
"""if 1: lines = ["if 1:\n", " def x():\n", " pass\n"]
def x(): co = compile("".join(lines), filename, "exec")
pass
""" # Type ignored because linecache.cache is private.
) monkeypatch.setitem(linecache.cache, filename, (1, None, lines, filename)) # type: ignore[attr-defined]
src, lineno = findsource(co) src, lineno = findsource(co)
assert src is not None assert src is not None
@ -557,7 +405,7 @@ def test_code_of_object_instance_with_call() -> None:
def getstatement(lineno: int, source) -> Source: def getstatement(lineno: int, source) -> Source:
from _pytest._code.source import getstatementrange_ast from _pytest._code.source import getstatementrange_ast
src = _pytest._code.Source(source, deindent=False) src = _pytest._code.Source(source)
ast, start, end = getstatementrange_ast(lineno, src) ast, start, end = getstatementrange_ast(lineno, src)
return src[start:end] return src[start:end]
@ -633,7 +481,7 @@ def test_source_with_decorator() -> None:
assert False assert False
src = inspect.getsource(deco_mark) src = inspect.getsource(deco_mark)
assert str(Source(deco_mark, deindent=False)) == src assert textwrap.indent(str(Source(deco_mark)), " ") + "\n" == src
assert src.startswith(" @pytest.mark.foo") assert src.startswith(" @pytest.mark.foo")
@pytest.fixture @pytest.fixture
@ -646,7 +494,9 @@ def test_source_with_decorator() -> None:
# existing behavior here for explicitness, but perhaps we should revisit/change this # existing behavior here for explicitness, but perhaps we should revisit/change this
# in the future # in the future
assert str(Source(deco_fixture)).startswith("@functools.wraps(function)") assert str(Source(deco_fixture)).startswith("@functools.wraps(function)")
assert str(Source(get_real_func(deco_fixture), deindent=False)) == src assert (
textwrap.indent(str(Source(get_real_func(deco_fixture))), " ") + "\n" == src
)
def test_single_line_else() -> None: def test_single_line_else() -> None: