Merge pull request #7438 from bluetech/source-cleanups

code/source: some cleanups
This commit is contained in:
Ran Benita 2020-07-04 12:57:32 +03:00 committed by GitHub
commit 36b958c99e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 106 additions and 416 deletions

View File

@ -0,0 +1,11 @@
Some changes were made to the internal ``_pytest._code.source``, listed here
for the benefit of plugin authors who may be using it:
- The ``deindent`` argument to ``Source()`` has been removed, now it is always true.
- Support for zero or multiple arguments to ``Source()`` has been removed.
- Support for comparing ``Source`` with an ``str`` has been removed.
- The methods ``Source.isparseable()`` and ``Source.putaround()`` have been removed.
- The method ``Source.compile()`` and function ``_pytest._code.compile()`` have
been removed; use plain ``compile()`` instead.
- The function ``_pytest._code.source.getsource()`` has been removed; use
``Source()`` directly instead.

View File

@ -7,7 +7,6 @@ from .code import getfslineno
from .code import getrawcode
from .code import Traceback
from .code import TracebackEntry
from .source import compile_ as compile
from .source import Source
__all__ = [
@ -19,6 +18,5 @@ __all__ = [
"getrawcode",
"Traceback",
"TracebackEntry",
"compile",
"Source",
]

View File

@ -1,61 +1,43 @@
import ast
import inspect
import linecache
import sys
import textwrap
import tokenize
import warnings
from bisect import bisect_right
from types import CodeType
from types import FrameType
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import py
from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING
if TYPE_CHECKING:
from typing_extensions import Literal
class Source:
""" an immutable object holding a source code fragment,
possibly deindenting it.
"""An immutable object holding a source code fragment.
When using Source(...), the source lines are deindented.
"""
_compilecounter = 0
def __init__(self, obj: object = None) -> None:
if not obj:
self.lines = [] # type: List[str]
elif isinstance(obj, Source):
self.lines = obj.lines
elif isinstance(obj, (tuple, list)):
self.lines = deindent(x.rstrip("\n") for x in obj)
elif isinstance(obj, str):
self.lines = deindent(obj.split("\n"))
else:
rawcode = getrawcode(obj)
src = inspect.getsource(rawcode)
self.lines = deindent(src.split("\n"))
def __init__(self, *parts, **kwargs) -> None:
self.lines = lines = [] # type: List[str]
de = kwargs.get("deindent", True)
for part in parts:
if not part:
partlines = [] # type: List[str]
elif isinstance(part, Source):
partlines = part.lines
elif isinstance(part, (tuple, list)):
partlines = [x.rstrip("\n") for x in part]
elif isinstance(part, str):
partlines = part.split("\n")
else:
partlines = getsource(part, deindent=de).lines
if de:
partlines = deindent(partlines)
lines.extend(partlines)
def __eq__(self, other):
try:
return self.lines == other.lines
except AttributeError:
if isinstance(other, str):
return str(self) == other
return False
def __eq__(self, other: object) -> bool:
if not isinstance(other, Source):
return NotImplemented
return self.lines == other.lines
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
@ -97,19 +79,6 @@ class Source:
source.lines[:] = self.lines[start:end]
return source
def putaround(
self, before: str = "", after: str = "", indent: str = " " * 4
) -> "Source":
""" return a copy of the source object with
'before' and 'after' wrapped around it.
"""
beforesource = Source(before)
aftersource = Source(after)
newsource = Source()
lines = [(indent + line) for line in self.lines]
newsource.lines = beforesource.lines + lines + aftersource.lines
return newsource
def indent(self, indent: str = " " * 4) -> "Source":
""" return a copy of the source object with
all lines indented by the given indent-string.
@ -140,142 +109,9 @@ class Source:
newsource.lines[:] = deindent(self.lines)
return newsource
def isparseable(self, deindent: bool = True) -> bool:
""" return True if source is parseable, heuristically
deindenting it by default.
"""
if deindent:
source = str(self.deindent())
else:
source = str(self)
try:
ast.parse(source)
except (SyntaxError, ValueError, TypeError):
return False
else:
return True
def __str__(self) -> str:
return "\n".join(self.lines)
@overload
def compile(
self,
filename: Optional[str] = ...,
mode: str = ...,
flag: "Literal[0]" = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> CodeType:
raise NotImplementedError()
@overload # noqa: F811
def compile( # noqa: F811
self,
filename: Optional[str] = ...,
mode: str = ...,
flag: int = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()
def compile( # noqa: F811
self,
filename: Optional[str] = None,
mode: str = "exec",
flag: int = 0,
dont_inherit: int = 0,
_genframe: Optional[FrameType] = None,
) -> Union[CodeType, ast.AST]:
""" return compiled code object. if filename is None
invent an artificial filename which displays
the source/line position of the caller frame.
"""
if not filename or py.path.local(filename).check(file=0):
if _genframe is None:
_genframe = sys._getframe(1) # the caller
fn, lineno = _genframe.f_code.co_filename, _genframe.f_lineno
base = "<%d-codegen " % self._compilecounter
self.__class__._compilecounter += 1
if not filename:
filename = base + "%s:%d>" % (fn, lineno)
else:
filename = base + "%r %s:%d>" % (filename, fn, lineno)
source = "\n".join(self.lines) + "\n"
try:
co = compile(source, filename, mode, flag)
except SyntaxError as ex:
# re-represent syntax errors from parsing python strings
msglines = self.lines[: ex.lineno]
if ex.offset:
msglines.append(" " * ex.offset + "^")
msglines.append("(code was compiled probably from here: %s)" % filename)
newex = SyntaxError("\n".join(msglines))
newex.offset = ex.offset
newex.lineno = ex.lineno
newex.text = ex.text
raise newex from ex
else:
if flag & ast.PyCF_ONLY_AST:
assert isinstance(co, ast.AST)
return co
assert isinstance(co, CodeType)
lines = [(x + "\n") for x in self.lines]
# Type ignored because linecache.cache is private.
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
return co
#
# public API shortcut functions
#
@overload
def compile_(
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: "Literal[0]" = ...,
dont_inherit: int = ...,
) -> CodeType:
raise NotImplementedError()
@overload # noqa: F811
def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: int = ...,
dont_inherit: int = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()
def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = None,
mode: str = "exec",
flags: int = 0,
dont_inherit: int = 0,
) -> Union[CodeType, ast.AST]:
""" compile the given source to a raw code object,
and maintain an internal cache which allows later
retrieval of the source code for the code object
and any recursively created code objects.
"""
if isinstance(source, ast.AST):
# XXX should Source support having AST?
assert filename is not None
co = compile(source, filename, mode, flags, dont_inherit)
assert isinstance(co, (CodeType, ast.AST))
return co
_genframe = sys._getframe(1) # the caller
s = Source(source)
return s.compile(filename, mode, flags, _genframe=_genframe)
#
# helper functions
@ -307,17 +143,7 @@ def getrawcode(obj, trycall: bool = True):
return obj
def getsource(obj, **kwargs) -> Source:
obj = getrawcode(obj)
try:
strsrc = inspect.getsource(obj)
except IndentationError:
strsrc = '"Buggy python version consider upgrading, cannot get source"'
assert isinstance(strsrc, str)
return Source(strsrc, **kwargs)
def deindent(lines: Sequence[str]) -> List[str]:
def deindent(lines: Iterable[str]) -> List[str]:
return textwrap.dedent("\n".join(lines)).splitlines()

View File

@ -9,7 +9,6 @@ from typing import Tuple
import attr
import _pytest._code
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import hookimpl
@ -105,7 +104,8 @@ def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool,
if hasattr(item, "obj"):
globals_.update(item.obj.__globals__) # type: ignore[attr-defined]
try:
condition_code = _pytest._code.compile(condition, mode="eval")
filename = "<{} condition>".format(mark.name)
condition_code = compile(condition, filename, "eval")
result = eval(condition_code, globals_)
except SyntaxError as exc:
msglines = [

View File

@ -6,6 +6,7 @@ import pytest
from _pytest._code import Code
from _pytest._code import ExceptionInfo
from _pytest._code import Frame
from _pytest._code import Source
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ReprFuncArgs
@ -67,7 +68,7 @@ def test_getstatement_empty_fullsource() -> None:
f = Frame(func())
with mock.patch.object(f.code.__class__, "fullsource", None):
assert f.statement == ""
assert f.statement == Source("")
def test_code_from_func() -> None:

View File

@ -127,24 +127,28 @@ class TestTraceback_f_g_h:
assert s.endswith("raise ValueError")
def test_traceback_entry_getsource_in_construct(self):
source = _pytest._code.Source(
"""\
def xyz():
try:
raise ValueError
except somenoname:
pass
xyz()
"""
)
def xyz():
try:
raise ValueError
except somenoname: # type: ignore[name-defined] # noqa: F821
pass # pragma: no cover
try:
exec(source.compile())
xyz()
except NameError:
tb = _pytest._code.ExceptionInfo.from_current().traceback
print(tb[-1].getsource())
s = str(tb[-1].getsource())
assert s.startswith("def xyz():\n try:")
assert s.strip().endswith("except somenoname:")
excinfo = _pytest._code.ExceptionInfo.from_current()
else:
assert False, "did not raise NameError"
tb = excinfo.traceback
source = tb[-1].getsource()
assert source is not None
assert source.deindent().lines == [
"def xyz():",
" try:",
" raise ValueError",
" except somenoname: # type: ignore[name-defined] # noqa: F821",
]
def test_traceback_cut(self):
co = _pytest._code.Code(f)
@ -445,16 +449,6 @@ class TestFormattedExcinfo:
return importasmod
def excinfo_from_exec(self, source):
source = _pytest._code.Source(source).strip()
try:
exec(source.compile())
except KeyboardInterrupt:
raise
except BaseException:
return _pytest._code.ExceptionInfo.from_current()
assert 0, "did not raise"
def test_repr_source(self):
pr = FormattedExcinfo()
source = _pytest._code.Source(
@ -471,19 +465,29 @@ class TestFormattedExcinfo:
def test_repr_source_excinfo(self) -> None:
""" check if indentation is right """
pr = FormattedExcinfo()
excinfo = self.excinfo_from_exec(
"""
def f():
assert 0
f()
"""
)
try:
def f():
1 / 0
f()
except BaseException:
excinfo = _pytest._code.ExceptionInfo.from_current()
else:
assert False, "did not raise"
pr = FormattedExcinfo()
source = pr._getentrysource(excinfo.traceback[-1])
assert source is not None
lines = pr.get_source(source, 1, excinfo)
assert lines == [" def f():", "> assert 0", "E AssertionError"]
for line in lines:
print(line)
assert lines == [
" def f():",
"> 1 / 0",
"E ZeroDivisionError: division by zero",
]
def test_repr_source_not_existing(self):
pr = FormattedExcinfo()

View File

@ -3,7 +3,9 @@
# or redundant on purpose and can't be disable on a line-by-line basis
import ast
import inspect
import linecache
import sys
import textwrap
from types import CodeType
from typing import Any
from typing import Dict
@ -32,14 +34,6 @@ def test_source_str_function() -> None:
assert str(x) == "\n3"
def test_unicode() -> None:
x = Source("4")
assert str(x) == "4"
co = _pytest._code.compile('"å"', mode="eval")
val = eval(co)
assert isinstance(val, str)
def test_source_from_function() -> None:
source = _pytest._code.Source(test_source_str_function)
assert str(source).startswith("def test_source_str_function() -> None:")
@ -62,47 +56,12 @@ def test_source_from_lines() -> None:
def test_source_from_inner_function() -> None:
def f():
pass
raise NotImplementedError()
source = _pytest._code.Source(f, deindent=False)
assert str(source).startswith(" def f():")
source = _pytest._code.Source(f)
assert str(source).startswith("def f():")
def test_source_putaround_simple() -> None:
source = Source("raise ValueError")
source = source.putaround(
"try:",
"""\
except ValueError:
x = 42
else:
x = 23""",
)
assert (
str(source)
== """\
try:
raise ValueError
except ValueError:
x = 42
else:
x = 23"""
)
def test_source_putaround() -> None:
source = Source()
source = source.putaround(
"""
if 1:
x=1
"""
)
assert str(source).strip() == "if 1:\n x=1"
def test_source_strips() -> None:
source = Source("")
assert source == Source()
@ -117,24 +76,6 @@ def test_source_strip_multiline() -> None:
assert source2.lines == [" hello"]
def test_syntaxerror_rerepresentation() -> None:
ex = pytest.raises(SyntaxError, _pytest._code.compile, "xyz xyz")
assert ex is not None
assert ex.value.lineno == 1
assert ex.value.offset in {5, 7} # cpython: 7, pypy3.6 7.1.1: 5
assert ex.value.text
assert ex.value.text.rstrip("\n") == "xyz xyz"
def test_isparseable() -> None:
assert Source("hello").isparseable()
assert Source("if 1:\n pass").isparseable()
assert Source(" \nif 1:\n pass").isparseable()
assert not Source("if 1:\n").isparseable()
assert not Source(" \nif 1:\npass").isparseable()
assert not Source(chr(0)).isparseable()
class TestAccesses:
def setup_class(self) -> None:
self.source = Source(
@ -148,7 +89,6 @@ class TestAccesses:
def test_getrange(self) -> None:
x = self.source[0:2]
assert x.isparseable()
assert len(x.lines) == 2
assert str(x) == "def f(x):\n pass"
@ -168,7 +108,7 @@ class TestAccesses:
assert len(values) == 4
class TestSourceParsingAndCompiling:
class TestSourceParsing:
def setup_class(self) -> None:
self.source = Source(
"""\
@ -179,39 +119,6 @@ class TestSourceParsingAndCompiling:
"""
).strip()
def test_compile(self) -> None:
co = _pytest._code.compile("x=3")
d = {} # type: Dict[str, Any]
exec(co, d)
assert d["x"] == 3
def test_compile_and_getsource_simple(self) -> None:
co = _pytest._code.compile("x=3")
exec(co)
source = _pytest._code.Source(co)
assert str(source) == "x=3"
def test_compile_and_getsource_through_same_function(self) -> None:
def gensource(source):
return _pytest._code.compile(source)
co1 = gensource(
"""
def f():
raise KeyError()
"""
)
co2 = gensource(
"""
def f():
raise ValueError()
"""
)
source1 = inspect.getsource(co1)
assert "KeyError" in source1
source2 = inspect.getsource(co2)
assert "ValueError" in source2
def test_getstatement(self) -> None:
# print str(self.source)
ass = str(self.source[1:])
@ -228,9 +135,9 @@ class TestSourceParsingAndCompiling:
''')"""
)
s = source.getstatement(0)
assert s == str(source)
assert s == source
s = source.getstatement(1)
assert s == str(source)
assert s == source
def test_getstatementrange_within_constructs(self) -> None:
source = Source(
@ -308,44 +215,6 @@ class TestSourceParsingAndCompiling:
source = Source(":")
pytest.raises(SyntaxError, lambda: source.getstatementrange(0))
def test_compile_to_ast(self) -> None:
source = Source("x = 4")
mod = source.compile(flag=ast.PyCF_ONLY_AST)
assert isinstance(mod, ast.Module)
compile(mod, "<filename>", "exec")
def test_compile_and_getsource(self) -> None:
co = self.source.compile()
exec(co, globals())
f(7) # type: ignore
excinfo = pytest.raises(AssertionError, f, 6) # type: ignore
assert excinfo is not None
frame = excinfo.traceback[-1].frame
assert isinstance(frame.code.fullsource, Source)
stmt = frame.code.fullsource.getstatement(frame.lineno)
assert str(stmt).strip().startswith("assert")
@pytest.mark.parametrize("name", ["", None, "my"])
def test_compilefuncs_and_path_sanity(self, name: Optional[str]) -> None:
def check(comp, name) -> None:
co = comp(self.source, name)
if not name:
expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2) # type: ignore
else:
expected = "codegen %r %s:%d>" % (name, mypath, mylineno + 2 + 2) # type: ignore
fn = co.co_filename
assert fn.endswith(expected)
mycode = _pytest._code.Code(self.test_compilefuncs_and_path_sanity)
mylineno = mycode.firstlineno
mypath = mycode.path
for comp in _pytest._code.compile, _pytest._code.Source.compile:
check(comp, name)
def test_offsetless_synerr(self):
pytest.raises(SyntaxError, _pytest._code.compile, "lambda a,a: 0", mode="eval")
def test_getstartingblock_singleline() -> None:
class A:
@ -375,18 +244,16 @@ def test_getline_finally() -> None:
def test_getfuncsource_dynamic() -> None:
source = """
def f():
raise ValueError
def f():
raise NotImplementedError()
def g(): pass
"""
co = _pytest._code.compile(source)
exec(co, globals())
f_source = _pytest._code.Source(f) # type: ignore
g_source = _pytest._code.Source(g) # type: ignore
assert str(f_source).strip() == "def f():\n raise ValueError"
assert str(g_source).strip() == "def g(): pass"
def g():
pass # pragma: no cover
f_source = _pytest._code.Source(f)
g_source = _pytest._code.Source(g)
assert str(f_source).strip() == "def f():\n raise NotImplementedError()"
assert str(g_source).strip() == "def g():\n pass # pragma: no cover"
def test_getfuncsource_with_multine_string() -> None:
@ -440,30 +307,11 @@ if True:
pass
def test_getsource_fallback() -> None:
from _pytest._code.source import getsource
def test_source_fallback() -> None:
src = Source(x)
expected = """def x():
pass"""
src = getsource(x)
assert src == expected
def test_idem_compile_and_getsource() -> None:
from _pytest._code.source import getsource
expected = "def x(): pass"
co = _pytest._code.compile(expected)
src = getsource(co)
assert src == expected
def test_compile_ast() -> None:
# We don't necessarily want to support this.
# This test was added just for coverage.
stmt = ast.parse("def x(): pass")
co = _pytest._code.compile(stmt, filename="foo.py")
assert isinstance(co, CodeType)
assert str(src) == expected
def test_findsource_fallback() -> None:
@ -475,15 +323,15 @@ def test_findsource_fallback() -> None:
assert src[lineno] == " def x():"
def test_findsource() -> None:
def test_findsource(monkeypatch) -> None:
from _pytest._code.source import findsource
co = _pytest._code.compile(
"""if 1:
def x():
pass
"""
)
filename = "<pytest-test_findsource>"
lines = ["if 1:\n", " def x():\n", " pass\n"]
co = compile("".join(lines), filename, "exec")
# Type ignored because linecache.cache is private.
monkeypatch.setitem(linecache.cache, filename, (1, None, lines, filename)) # type: ignore[attr-defined]
src, lineno = findsource(co)
assert src is not None
@ -557,7 +405,7 @@ def test_code_of_object_instance_with_call() -> None:
def getstatement(lineno: int, source) -> Source:
from _pytest._code.source import getstatementrange_ast
src = _pytest._code.Source(source, deindent=False)
src = _pytest._code.Source(source)
ast, start, end = getstatementrange_ast(lineno, src)
return src[start:end]
@ -633,7 +481,7 @@ def test_source_with_decorator() -> None:
assert False
src = inspect.getsource(deco_mark)
assert str(Source(deco_mark, deindent=False)) == src
assert textwrap.indent(str(Source(deco_mark)), " ") + "\n" == src
assert src.startswith(" @pytest.mark.foo")
@pytest.fixture
@ -646,7 +494,9 @@ def test_source_with_decorator() -> None:
# existing behavior here for explicitness, but perhaps we should revisit/change this
# in the future
assert str(Source(deco_fixture)).startswith("@functools.wraps(function)")
assert str(Source(get_real_func(deco_fixture), deindent=False)) == src
assert (
textwrap.indent(str(Source(get_real_func(deco_fixture))), " ") + "\n" == src
)
def test_single_line_else() -> None: