Add type annotations to _pytest._code.code

This commit is contained in:
Ran Benita 2019-11-15 23:02:55 +02:00
parent 562d4811d5
commit eaa34a9df0
4 changed files with 299 additions and 214 deletions

View File

@ -7,13 +7,17 @@ from inspect import CO_VARKEYWORDS
from io import StringIO
from traceback import format_exception_only
from types import CodeType
from types import FrameType
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import List
from typing import Optional
from typing import Pattern
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TypeVar
@ -27,9 +31,16 @@ import py
import _pytest
from _pytest._io.saferepr import safeformat
from _pytest._io.saferepr import saferepr
from _pytest.compat import overload
if False: # TYPE_CHECKING
from typing import Type
from typing_extensions import Literal
from weakref import ReferenceType # noqa: F401
from _pytest._code import Source
_TracebackStyle = Literal["long", "short", "no", "native"]
class Code:
@ -38,13 +49,12 @@ class Code:
def __init__(self, rawcode) -> None:
if not hasattr(rawcode, "co_filename"):
rawcode = getrawcode(rawcode)
try:
self.filename = rawcode.co_filename
self.firstlineno = rawcode.co_firstlineno - 1
self.name = rawcode.co_name
except AttributeError:
if not isinstance(rawcode, CodeType):
raise TypeError("not a code object: {!r}".format(rawcode))
self.raw = rawcode # type: CodeType
self.filename = rawcode.co_filename
self.firstlineno = rawcode.co_firstlineno - 1
self.name = rawcode.co_name
self.raw = rawcode
def __eq__(self, other):
return self.raw == other.raw
@ -72,7 +82,7 @@ class Code:
return p
@property
def fullsource(self):
def fullsource(self) -> Optional["Source"]:
""" return a _pytest._code.Source object for the full source file of the code
"""
from _pytest._code import source
@ -80,7 +90,7 @@ class Code:
full, _ = source.findsource(self.raw)
return full
def source(self):
def source(self) -> "Source":
""" return a _pytest._code.Source object for the code object's source only
"""
# return source only for that part of code
@ -88,7 +98,7 @@ class Code:
return _pytest._code.Source(self.raw)
def getargs(self, var=False):
def getargs(self, var: bool = False) -> Tuple[str, ...]:
""" return a tuple with the argument names for the code object
if 'var' is set True also return the names of the variable and
@ -107,7 +117,7 @@ class Frame:
"""Wrapper around a Python frame holding f_locals and f_globals
in which expressions can be evaluated."""
def __init__(self, frame):
def __init__(self, frame: FrameType) -> None:
self.lineno = frame.f_lineno - 1
self.f_globals = frame.f_globals
self.f_locals = frame.f_locals
@ -115,7 +125,7 @@ class Frame:
self.code = Code(frame.f_code)
@property
def statement(self):
def statement(self) -> "Source":
""" statement this frame is at """
import _pytest._code
@ -134,7 +144,7 @@ class Frame:
f_locals.update(vars)
return eval(code, self.f_globals, f_locals)
def exec_(self, code, **vars):
def exec_(self, code, **vars) -> None:
""" exec 'code' in the frame
'vars' are optional; additional local variables
@ -143,7 +153,7 @@ class Frame:
f_locals.update(vars)
exec(code, self.f_globals, f_locals)
def repr(self, object):
def repr(self, object: object) -> str:
""" return a 'safe' (non-recursive, one-line) string repr for 'object'
"""
return saferepr(object)
@ -151,7 +161,7 @@ class Frame:
def is_true(self, object):
return object
def getargs(self, var=False):
def getargs(self, var: bool = False):
""" return a list of tuples (name, value) for all arguments
if 'var' is set True also include the variable and keyword
@ -169,35 +179,34 @@ class Frame:
class TracebackEntry:
""" a single entry in a traceback """
_repr_style = None
_repr_style = None # type: Optional[Literal["short", "long"]]
exprinfo = None
def __init__(self, rawentry, excinfo=None):
def __init__(self, rawentry: TracebackType, excinfo=None) -> None:
self._excinfo = excinfo
self._rawentry = rawentry
self.lineno = rawentry.tb_lineno - 1
def set_repr_style(self, mode):
def set_repr_style(self, mode: "Literal['short', 'long']") -> None:
assert mode in ("short", "long")
self._repr_style = mode
@property
def frame(self):
import _pytest._code
return _pytest._code.Frame(self._rawentry.tb_frame)
def frame(self) -> Frame:
return Frame(self._rawentry.tb_frame)
@property
def relline(self):
def relline(self) -> int:
return self.lineno - self.frame.code.firstlineno
def __repr__(self):
def __repr__(self) -> str:
return "<TracebackEntry %s:%d>" % (self.frame.code.path, self.lineno + 1)
@property
def statement(self):
def statement(self) -> "Source":
""" _pytest._code.Source object for the current statement """
source = self.frame.code.fullsource
assert source is not None
return source.getstatement(self.lineno)
@property
@ -206,14 +215,14 @@ class TracebackEntry:
return self.frame.code.path
@property
def locals(self):
def locals(self) -> Dict[str, Any]:
""" locals of underlying frame """
return self.frame.f_locals
def getfirstlinesource(self):
def getfirstlinesource(self) -> int:
return self.frame.code.firstlineno
def getsource(self, astcache=None):
def getsource(self, astcache=None) -> Optional["Source"]:
""" return failing source code. """
# we use the passed in astcache to not reparse asttrees
# within exception info printing
@ -258,7 +267,7 @@ class TracebackEntry:
return tbh(None if self._excinfo is None else self._excinfo())
return tbh
def __str__(self):
def __str__(self) -> str:
try:
fn = str(self.path)
except py.error.Error:
@ -273,31 +282,42 @@ class TracebackEntry:
return " File %r:%d in %s\n %s\n" % (fn, self.lineno + 1, name, line)
@property
def name(self):
def name(self) -> str:
""" co_name of underlying code """
return self.frame.code.raw.co_name
class Traceback(list):
class Traceback(List[TracebackEntry]):
""" Traceback objects encapsulate and offer higher level
access to Traceback entries.
"""
def __init__(self, tb, excinfo=None):
def __init__(
self,
tb: Union[TracebackType, Iterable[TracebackEntry]],
excinfo: Optional["ReferenceType[ExceptionInfo]"] = None,
) -> None:
""" initialize from given python traceback object and ExceptionInfo """
self._excinfo = excinfo
if hasattr(tb, "tb_next"):
if isinstance(tb, TracebackType):
def f(cur):
while cur is not None:
yield TracebackEntry(cur, excinfo=excinfo)
cur = cur.tb_next
def f(cur: TracebackType) -> Iterable[TracebackEntry]:
cur_ = cur # type: Optional[TracebackType]
while cur_ is not None:
yield TracebackEntry(cur_, excinfo=excinfo)
cur_ = cur_.tb_next
list.__init__(self, f(tb))
super().__init__(f(tb))
else:
list.__init__(self, tb)
super().__init__(tb)
def cut(self, path=None, lineno=None, firstlineno=None, excludepath=None):
def cut(
self,
path=None,
lineno: Optional[int] = None,
firstlineno: Optional[int] = None,
excludepath=None,
) -> "Traceback":
""" return a Traceback instance wrapping part of this Traceback
by providing any combination of path, lineno and firstlineno, the
@ -323,13 +343,25 @@ class Traceback(list):
return Traceback(x._rawentry, self._excinfo)
return self
def __getitem__(self, key):
val = super().__getitem__(key)
if isinstance(key, type(slice(0))):
val = self.__class__(val)
return val
@overload
def __getitem__(self, key: int) -> TracebackEntry:
raise NotImplementedError()
def filter(self, fn=lambda x: not x.ishidden()):
@overload # noqa: F811
def __getitem__(self, key: slice) -> "Traceback": # noqa: F811
raise NotImplementedError()
def __getitem__( # noqa: F811
self, key: Union[int, slice]
) -> Union[TracebackEntry, "Traceback"]:
if isinstance(key, slice):
return self.__class__(super().__getitem__(key))
else:
return super().__getitem__(key)
def filter(
self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden()
) -> "Traceback":
""" return a Traceback instance with certain items removed
fn is a function that gets a single argument, a TracebackEntry
@ -341,7 +373,7 @@ class Traceback(list):
"""
return Traceback(filter(fn, self), self._excinfo)
def getcrashentry(self):
def getcrashentry(self) -> TracebackEntry:
""" return last non-hidden traceback entry that lead
to the exception of a traceback.
"""
@ -351,7 +383,7 @@ class Traceback(list):
return entry
return self[-1]
def recursionindex(self):
def recursionindex(self) -> Optional[int]:
""" return the index of the frame/TracebackEntry where recursion
originates if appropriate, None if no recursion occurred
"""
@ -541,7 +573,7 @@ class ExceptionInfo(Generic[_E]):
def getrepr(
self,
showlocals: bool = False,
style: str = "long",
style: "_TracebackStyle" = "long",
abspath: bool = False,
tbfilter: bool = True,
funcargs: bool = False,
@ -619,16 +651,16 @@ class FormattedExcinfo:
flow_marker = ">"
fail_marker = "E"
showlocals = attr.ib(default=False)
style = attr.ib(default="long")
abspath = attr.ib(default=True)
tbfilter = attr.ib(default=True)
funcargs = attr.ib(default=False)
truncate_locals = attr.ib(default=True)
chain = attr.ib(default=True)
showlocals = attr.ib(type=bool, default=False)
style = attr.ib(type="_TracebackStyle", default="long")
abspath = attr.ib(type=bool, default=True)
tbfilter = attr.ib(type=bool, default=True)
funcargs = attr.ib(type=bool, default=False)
truncate_locals = attr.ib(type=bool, default=True)
chain = attr.ib(type=bool, default=True)
astcache = attr.ib(default=attr.Factory(dict), init=False, repr=False)
def _getindent(self, source):
def _getindent(self, source: "Source") -> int:
# figure out indent for given source
try:
s = str(source.getstatement(len(source) - 1))
@ -643,20 +675,27 @@ class FormattedExcinfo:
return 0
return 4 + (len(s) - len(s.lstrip()))
def _getentrysource(self, entry):
def _getentrysource(self, entry: TracebackEntry) -> Optional["Source"]:
source = entry.getsource(self.astcache)
if source is not None:
source = source.deindent()
return source
def repr_args(self, entry):
def repr_args(self, entry: TracebackEntry) -> Optional["ReprFuncArgs"]:
if self.funcargs:
args = []
for argname, argvalue in entry.frame.getargs(var=True):
args.append((argname, saferepr(argvalue)))
return ReprFuncArgs(args)
return None
def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]:
def get_source(
self,
source: "Source",
line_index: int = -1,
excinfo: Optional[ExceptionInfo] = None,
short: bool = False,
) -> List[str]:
""" return formatted and marked up source lines. """
import _pytest._code
@ -680,19 +719,21 @@ class FormattedExcinfo:
lines.extend(self.get_exconly(excinfo, indent=indent, markall=True))
return lines
def get_exconly(self, excinfo, indent=4, markall=False):
def get_exconly(
self, excinfo: ExceptionInfo, indent: int = 4, markall: bool = False
) -> List[str]:
lines = []
indent = " " * indent
indentstr = " " * indent
# get the real exception information out
exlines = excinfo.exconly(tryshort=True).split("\n")
failindent = self.fail_marker + indent[1:]
failindent = self.fail_marker + indentstr[1:]
for line in exlines:
lines.append(failindent + line)
if not markall:
failindent = indent
failindent = indentstr
return lines
def repr_locals(self, locals):
def repr_locals(self, locals: Dict[str, object]) -> Optional["ReprLocals"]:
if self.showlocals:
lines = []
keys = [loc for loc in locals if loc[0] != "@"]
@ -717,8 +758,11 @@ class FormattedExcinfo:
# # XXX
# pprint.pprint(value, stream=self.excinfowriter)
return ReprLocals(lines)
return None
def repr_traceback_entry(self, entry, excinfo=None):
def repr_traceback_entry(
self, entry: TracebackEntry, excinfo: Optional[ExceptionInfo] = None
) -> "ReprEntry":
import _pytest._code
source = self._getentrysource(entry)
@ -729,9 +773,7 @@ class FormattedExcinfo:
line_index = entry.lineno - entry.getfirstlinesource()
lines = [] # type: List[str]
style = entry._repr_style
if style is None:
style = self.style
style = entry._repr_style if entry._repr_style is not None else self.style
if style in ("short", "long"):
short = style == "short"
reprargs = self.repr_args(entry) if not short else None
@ -761,7 +803,7 @@ class FormattedExcinfo:
path = np
return path
def repr_traceback(self, excinfo):
def repr_traceback(self, excinfo: ExceptionInfo) -> "ReprTraceback":
traceback = excinfo.traceback
if self.tbfilter:
traceback = traceback.filter()
@ -779,7 +821,9 @@ class FormattedExcinfo:
entries.append(reprentry)
return ReprTraceback(entries, extraline, style=self.style)
def _truncate_recursive_traceback(self, traceback):
def _truncate_recursive_traceback(
self, traceback: Traceback
) -> Tuple[Traceback, Optional[str]]:
"""
Truncate the given recursive traceback trying to find the starting point
of the recursion.
@ -806,7 +850,9 @@ class FormattedExcinfo:
max_frames=max_frames,
total=len(traceback),
) # type: Optional[str]
traceback = traceback[:max_frames] + traceback[-max_frames:]
# Type ignored because adding two instaces of a List subtype
# currently incorrectly has type List instead of the subtype.
traceback = traceback[:max_frames] + traceback[-max_frames:] # type: ignore
else:
if recursionindex is not None:
extraline = "!!! Recursion detected (same locals & position)"
@ -863,7 +909,7 @@ class FormattedExcinfo:
class TerminalRepr:
def __str__(self):
def __str__(self) -> str:
# FYI this is called from pytest-xdist's serialization of exception
# information.
io = StringIO()
@ -871,7 +917,7 @@ class TerminalRepr:
self.toterminal(tw)
return io.getvalue().strip()
def __repr__(self):
def __repr__(self) -> str:
return "<{} instance at {:0x}>".format(self.__class__, id(self))
def toterminal(self, tw) -> None:
@ -882,7 +928,7 @@ class ExceptionRepr(TerminalRepr):
def __init__(self) -> None:
self.sections = [] # type: List[Tuple[str, str, str]]
def addsection(self, name, content, sep="-"):
def addsection(self, name: str, content: str, sep: str = "-") -> None:
self.sections.append((name, content, sep))
def toterminal(self, tw) -> None:
@ -892,7 +938,12 @@ class ExceptionRepr(TerminalRepr):
class ExceptionChainRepr(ExceptionRepr):
def __init__(self, chain):
def __init__(
self,
chain: Sequence[
Tuple["ReprTraceback", Optional["ReprFileLocation"], Optional[str]]
],
) -> None:
super().__init__()
self.chain = chain
# reprcrash and reprtraceback of the outermost (the newest) exception
@ -910,7 +961,9 @@ class ExceptionChainRepr(ExceptionRepr):
class ReprExceptionInfo(ExceptionRepr):
def __init__(self, reprtraceback, reprcrash):
def __init__(
self, reprtraceback: "ReprTraceback", reprcrash: "ReprFileLocation"
) -> None:
super().__init__()
self.reprtraceback = reprtraceback
self.reprcrash = reprcrash
@ -923,7 +976,12 @@ class ReprExceptionInfo(ExceptionRepr):
class ReprTraceback(TerminalRepr):
entrysep = "_ "
def __init__(self, reprentries, extraline, style):
def __init__(
self,
reprentries: Sequence[Union["ReprEntry", "ReprEntryNative"]],
extraline: Optional[str],
style: "_TracebackStyle",
) -> None:
self.reprentries = reprentries
self.extraline = extraline
self.style = style
@ -948,16 +1006,16 @@ class ReprTraceback(TerminalRepr):
class ReprTracebackNative(ReprTraceback):
def __init__(self, tblines):
def __init__(self, tblines: Sequence[str]) -> None:
self.style = "native"
self.reprentries = [ReprEntryNative(tblines)]
self.extraline = None
class ReprEntryNative(TerminalRepr):
style = "native"
style = "native" # type: _TracebackStyle
def __init__(self, tblines):
def __init__(self, tblines: Sequence[str]) -> None:
self.lines = tblines
def toterminal(self, tw) -> None:
@ -965,7 +1023,14 @@ class ReprEntryNative(TerminalRepr):
class ReprEntry(TerminalRepr):
def __init__(self, lines, reprfuncargs, reprlocals, filelocrepr, style):
def __init__(
self,
lines: Sequence[str],
reprfuncargs: Optional["ReprFuncArgs"],
reprlocals: Optional["ReprLocals"],
filelocrepr: Optional["ReprFileLocation"],
style: "_TracebackStyle",
) -> None:
self.lines = lines
self.reprfuncargs = reprfuncargs
self.reprlocals = reprlocals
@ -974,6 +1039,7 @@ class ReprEntry(TerminalRepr):
def toterminal(self, tw) -> None:
if self.style == "short":
assert self.reprfileloc is not None
self.reprfileloc.toterminal(tw)
for line in self.lines:
red = line.startswith("E ")
@ -992,14 +1058,14 @@ class ReprEntry(TerminalRepr):
tw.line("")
self.reprfileloc.toterminal(tw)
def __str__(self):
def __str__(self) -> str:
return "{}\n{}\n{}".format(
"\n".join(self.lines), self.reprlocals, self.reprfileloc
)
class ReprFileLocation(TerminalRepr):
def __init__(self, path, lineno, message):
def __init__(self, path, lineno: int, message: str) -> None:
self.path = str(path)
self.lineno = lineno
self.message = message
@ -1016,7 +1082,7 @@ class ReprFileLocation(TerminalRepr):
class ReprLocals(TerminalRepr):
def __init__(self, lines):
def __init__(self, lines: Sequence[str]) -> None:
self.lines = lines
def toterminal(self, tw) -> None:
@ -1025,7 +1091,7 @@ class ReprLocals(TerminalRepr):
class ReprFuncArgs(TerminalRepr):
def __init__(self, args):
def __init__(self, args: Sequence[Tuple[str, object]]) -> None:
self.args = args
def toterminal(self, tw) -> None:
@ -1047,7 +1113,7 @@ class ReprFuncArgs(TerminalRepr):
tw.line("")
def getrawcode(obj, trycall=True):
def getrawcode(obj, trycall: bool = True):
""" return code object for given function. """
try:
return obj.__code__
@ -1075,7 +1141,7 @@ _PYTEST_DIR = py.path.local(_pytest.__file__).dirpath()
_PY_DIR = py.path.local(py.__file__).dirpath()
def filter_traceback(entry):
def filter_traceback(entry: TracebackEntry) -> bool:
"""Return True if a TracebackEntry instance should be removed from tracebacks:
* dynamically generated code (no code to show up for it);
* internal traceback from pytest or its internal libraries, py and pluggy.

View File

@ -1,18 +1,19 @@
import sys
from types import FrameType
from unittest import mock
import _pytest._code
import pytest
def test_ne():
def test_ne() -> None:
code1 = _pytest._code.Code(compile('foo = "bar"', "", "exec"))
assert code1 == code1
code2 = _pytest._code.Code(compile('foo = "baz"', "", "exec"))
assert code2 != code1
def test_code_gives_back_name_for_not_existing_file():
def test_code_gives_back_name_for_not_existing_file() -> None:
name = "abc-123"
co_code = compile("pass\n", name, "exec")
assert co_code.co_filename == name
@ -21,68 +22,67 @@ def test_code_gives_back_name_for_not_existing_file():
assert code.fullsource is None
def test_code_with_class():
def test_code_with_class() -> None:
class A:
pass
pytest.raises(TypeError, _pytest._code.Code, A)
def x():
def x() -> None:
raise NotImplementedError()
def test_code_fullsource():
def test_code_fullsource() -> None:
code = _pytest._code.Code(x)
full = code.fullsource
assert "test_code_fullsource()" in str(full)
def test_code_source():
def test_code_source() -> None:
code = _pytest._code.Code(x)
src = code.source()
expected = """def x():
expected = """def x() -> None:
raise NotImplementedError()"""
assert str(src) == expected
def test_frame_getsourcelineno_myself():
def func():
def test_frame_getsourcelineno_myself() -> None:
def func() -> FrameType:
return sys._getframe(0)
f = func()
f = _pytest._code.Frame(f)
f = _pytest._code.Frame(func())
source, lineno = f.code.fullsource, f.lineno
assert source is not None
assert source[lineno].startswith(" return sys._getframe(0)")
def test_getstatement_empty_fullsource():
def func():
def test_getstatement_empty_fullsource() -> None:
def func() -> FrameType:
return sys._getframe(0)
f = func()
f = _pytest._code.Frame(f)
f = _pytest._code.Frame(func())
with mock.patch.object(f.code.__class__, "fullsource", None):
assert f.statement == ""
def test_code_from_func():
def test_code_from_func() -> None:
co = _pytest._code.Code(test_frame_getsourcelineno_myself)
assert co.firstlineno
assert co.path
def test_unicode_handling():
def test_unicode_handling() -> None:
value = "ąć".encode()
def f():
def f() -> None:
raise Exception(value)
excinfo = pytest.raises(Exception, f)
str(excinfo)
def test_code_getargs():
def test_code_getargs() -> None:
def f1(x):
raise NotImplementedError()
@ -108,26 +108,26 @@ def test_code_getargs():
assert c4.getargs(var=True) == ("x", "y", "z")
def test_frame_getargs():
def f1(x):
def test_frame_getargs() -> None:
def f1(x) -> FrameType:
return sys._getframe(0)
fr1 = _pytest._code.Frame(f1("a"))
assert fr1.getargs(var=True) == [("x", "a")]
def f2(x, *y):
def f2(x, *y) -> FrameType:
return sys._getframe(0)
fr2 = _pytest._code.Frame(f2("a", "b", "c"))
assert fr2.getargs(var=True) == [("x", "a"), ("y", ("b", "c"))]
def f3(x, **z):
def f3(x, **z) -> FrameType:
return sys._getframe(0)
fr3 = _pytest._code.Frame(f3("a", b="c"))
assert fr3.getargs(var=True) == [("x", "a"), ("z", {"b": "c"})]
def f4(x, *y, **z):
def f4(x, *y, **z) -> FrameType:
return sys._getframe(0)
fr4 = _pytest._code.Frame(f4("a", "b", c="d"))
@ -135,7 +135,7 @@ def test_frame_getargs():
class TestExceptionInfo:
def test_bad_getsource(self):
def test_bad_getsource(self) -> None:
try:
if False:
pass
@ -145,13 +145,13 @@ class TestExceptionInfo:
exci = _pytest._code.ExceptionInfo.from_current()
assert exci.getrepr()
def test_from_current_with_missing(self):
def test_from_current_with_missing(self) -> None:
with pytest.raises(AssertionError, match="no current exception"):
_pytest._code.ExceptionInfo.from_current()
class TestTracebackEntry:
def test_getsource(self):
def test_getsource(self) -> None:
try:
if False:
pass
@ -161,12 +161,13 @@ class TestTracebackEntry:
exci = _pytest._code.ExceptionInfo.from_current()
entry = exci.traceback[0]
source = entry.getsource()
assert source is not None
assert len(source) == 6
assert "assert False" in source[5]
class TestReprFuncArgs:
def test_not_raise_exception_with_mixed_encoding(self, tw_mock):
def test_not_raise_exception_with_mixed_encoding(self, tw_mock) -> None:
from _pytest._code.code import ReprFuncArgs
args = [("unicode_string", "São Paulo"), ("utf8_string", b"S\xc3\xa3o Paulo")]

View File

@ -3,6 +3,7 @@ import os
import queue
import sys
import textwrap
from typing import Union
import py
@ -224,23 +225,25 @@ class TestTraceback_f_g_h:
repr = excinfo.getrepr()
assert "RuntimeError: hello" in str(repr.reprcrash)
def test_traceback_no_recursion_index(self):
def do_stuff():
def test_traceback_no_recursion_index(self) -> None:
def do_stuff() -> None:
raise RuntimeError
def reraise_me():
def reraise_me() -> None:
import sys
exc, val, tb = sys.exc_info()
assert val is not None
raise val.with_traceback(tb)
def f(n):
def f(n: int) -> None:
try:
do_stuff()
except: # noqa
reraise_me()
excinfo = pytest.raises(RuntimeError, f, 8)
assert excinfo is not None
traceback = excinfo.traceback
recindex = traceback.recursionindex()
assert recindex is None
@ -596,7 +599,6 @@ raise ValueError()
assert lines[3] == "E world"
assert not lines[4:]
loc = repr_entry.reprlocals is not None
loc = repr_entry.reprfileloc
assert loc.path == mod.__file__
assert loc.lineno == 3
@ -1286,9 +1288,10 @@ raise ValueError()
@pytest.mark.parametrize("style", ["short", "long"])
@pytest.mark.parametrize("encoding", [None, "utf8", "utf16"])
def test_repr_traceback_with_unicode(style, encoding):
msg = ""
if encoding is not None:
msg = msg.encode(encoding)
if encoding is None:
msg = "" # type: Union[str, bytes]
else:
msg = "".encode(encoding)
try:
raise RuntimeError(msg)
except RuntimeError:

View File

@ -4,13 +4,16 @@
import ast
import inspect
import sys
from typing import Any
from typing import Dict
from typing import Optional
import _pytest._code
import pytest
from _pytest._code import Source
def test_source_str_function():
def test_source_str_function() -> None:
x = Source("3")
assert str(x) == "3"
@ -25,7 +28,7 @@ def test_source_str_function():
assert str(x) == "\n3"
def test_unicode():
def test_unicode() -> None:
x = Source("4")
assert str(x) == "4"
co = _pytest._code.compile('"å"', mode="eval")
@ -33,12 +36,12 @@ def test_unicode():
assert isinstance(val, str)
def test_source_from_function():
def test_source_from_function() -> None:
source = _pytest._code.Source(test_source_str_function)
assert str(source).startswith("def test_source_str_function():")
assert str(source).startswith("def test_source_str_function() -> None:")
def test_source_from_method():
def test_source_from_method() -> None:
class TestClass:
def test_method(self):
pass
@ -47,13 +50,13 @@ def test_source_from_method():
assert source.lines == ["def test_method(self):", " pass"]
def test_source_from_lines():
def test_source_from_lines() -> None:
lines = ["a \n", "b\n", "c"]
source = _pytest._code.Source(lines)
assert source.lines == ["a ", "b", "c"]
def test_source_from_inner_function():
def test_source_from_inner_function() -> None:
def f():
pass
@ -63,7 +66,7 @@ def test_source_from_inner_function():
assert str(source).startswith("def f():")
def test_source_putaround_simple():
def test_source_putaround_simple() -> None:
source = Source("raise ValueError")
source = source.putaround(
"try:",
@ -85,7 +88,7 @@ else:
)
def test_source_putaround():
def test_source_putaround() -> None:
source = Source()
source = source.putaround(
"""
@ -96,28 +99,29 @@ def test_source_putaround():
assert str(source).strip() == "if 1:\n x=1"
def test_source_strips():
def test_source_strips() -> None:
source = Source("")
assert source == Source()
assert str(source) == ""
assert source.strip() == source
def test_source_strip_multiline():
def test_source_strip_multiline() -> None:
source = Source()
source.lines = ["", " hello", " "]
source2 = source.strip()
assert source2.lines == [" hello"]
def test_syntaxerror_rerepresentation():
def test_syntaxerror_rerepresentation() -> None:
ex = pytest.raises(SyntaxError, _pytest._code.compile, "xyz xyz")
assert ex is not None
assert ex.value.lineno == 1
assert ex.value.offset in {5, 7} # cpython: 7, pypy3.6 7.1.1: 5
assert ex.value.text.strip(), "x x"
assert ex.value.text == "xyz xyz\n"
def test_isparseable():
def test_isparseable() -> None:
assert Source("hello").isparseable()
assert Source("if 1:\n pass").isparseable()
assert Source(" \nif 1:\n pass").isparseable()
@ -127,7 +131,7 @@ def test_isparseable():
class TestAccesses:
def setup_class(self):
def setup_class(self) -> None:
self.source = Source(
"""\
def f(x):
@ -137,26 +141,26 @@ class TestAccesses:
"""
)
def test_getrange(self):
def test_getrange(self) -> None:
x = self.source[0:2]
assert x.isparseable()
assert len(x.lines) == 2
assert str(x) == "def f(x):\n pass"
def test_getline(self):
def test_getline(self) -> None:
x = self.source[0]
assert x == "def f(x):"
def test_len(self):
def test_len(self) -> None:
assert len(self.source) == 4
def test_iter(self):
def test_iter(self) -> None:
values = [x for x in self.source]
assert len(values) == 4
class TestSourceParsingAndCompiling:
def setup_class(self):
def setup_class(self) -> None:
self.source = Source(
"""\
def f(x):
@ -166,19 +170,19 @@ class TestSourceParsingAndCompiling:
"""
).strip()
def test_compile(self):
def test_compile(self) -> None:
co = _pytest._code.compile("x=3")
d = {}
d = {} # type: Dict[str, Any]
exec(co, d)
assert d["x"] == 3
def test_compile_and_getsource_simple(self):
def test_compile_and_getsource_simple(self) -> None:
co = _pytest._code.compile("x=3")
exec(co)
source = _pytest._code.Source(co)
assert str(source) == "x=3"
def test_compile_and_getsource_through_same_function(self):
def test_compile_and_getsource_through_same_function(self) -> None:
def gensource(source):
return _pytest._code.compile(source)
@ -199,7 +203,7 @@ class TestSourceParsingAndCompiling:
source2 = inspect.getsource(co2)
assert "ValueError" in source2
def test_getstatement(self):
def test_getstatement(self) -> None:
# print str(self.source)
ass = str(self.source[1:])
for i in range(1, 4):
@ -208,7 +212,7 @@ class TestSourceParsingAndCompiling:
# x = s.deindent()
assert str(s) == ass
def test_getstatementrange_triple_quoted(self):
def test_getstatementrange_triple_quoted(self) -> None:
# print str(self.source)
source = Source(
"""hello('''
@ -219,7 +223,7 @@ class TestSourceParsingAndCompiling:
s = source.getstatement(1)
assert s == str(source)
def test_getstatementrange_within_constructs(self):
def test_getstatementrange_within_constructs(self) -> None:
source = Source(
"""\
try:
@ -241,7 +245,7 @@ class TestSourceParsingAndCompiling:
# assert source.getstatementrange(5) == (0, 7)
assert source.getstatementrange(6) == (6, 7)
def test_getstatementrange_bug(self):
def test_getstatementrange_bug(self) -> None:
source = Source(
"""\
try:
@ -255,7 +259,7 @@ class TestSourceParsingAndCompiling:
assert len(source) == 6
assert source.getstatementrange(2) == (1, 4)
def test_getstatementrange_bug2(self):
def test_getstatementrange_bug2(self) -> None:
source = Source(
"""\
assert (
@ -272,7 +276,7 @@ class TestSourceParsingAndCompiling:
assert len(source) == 9
assert source.getstatementrange(5) == (0, 9)
def test_getstatementrange_ast_issue58(self):
def test_getstatementrange_ast_issue58(self) -> None:
source = Source(
"""\
@ -286,38 +290,44 @@ class TestSourceParsingAndCompiling:
assert getstatement(2, source).lines == source.lines[2:3]
assert getstatement(3, source).lines == source.lines[3:4]
def test_getstatementrange_out_of_bounds_py3(self):
def test_getstatementrange_out_of_bounds_py3(self) -> None:
source = Source("if xxx:\n from .collections import something")
r = source.getstatementrange(1)
assert r == (1, 2)
def test_getstatementrange_with_syntaxerror_issue7(self):
def test_getstatementrange_with_syntaxerror_issue7(self) -> None:
source = Source(":")
pytest.raises(SyntaxError, lambda: source.getstatementrange(0))
def test_compile_to_ast(self):
def test_compile_to_ast(self) -> None:
source = Source("x = 4")
mod = source.compile(flag=ast.PyCF_ONLY_AST)
assert isinstance(mod, ast.Module)
compile(mod, "<filename>", "exec")
def test_compile_and_getsource(self):
def test_compile_and_getsource(self) -> None:
co = self.source.compile()
exec(co, globals())
f(7)
excinfo = pytest.raises(AssertionError, f, 6)
f(7) # type: ignore
excinfo = pytest.raises(AssertionError, f, 6) # type: ignore
assert excinfo is not None
frame = excinfo.traceback[-1].frame
assert isinstance(frame.code.fullsource, Source)
stmt = frame.code.fullsource.getstatement(frame.lineno)
assert str(stmt).strip().startswith("assert")
@pytest.mark.parametrize("name", ["", None, "my"])
def test_compilefuncs_and_path_sanity(self, name):
def test_compilefuncs_and_path_sanity(self, name: Optional[str]) -> None:
def check(comp, name):
co = comp(self.source, name)
if not name:
expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2)
expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2) # type: ignore
else:
expected = "codegen %r %s:%d>" % (name, mypath, mylineno + 2 + 2)
expected = "codegen %r %s:%d>" % (
name,
mypath, # type: ignore
mylineno + 2 + 2, # type: ignore
) # type: ignore
fn = co.co_filename
assert fn.endswith(expected)
@ -332,9 +342,9 @@ class TestSourceParsingAndCompiling:
pytest.raises(SyntaxError, _pytest._code.compile, "lambda a,a: 0", mode="eval")
def test_getstartingblock_singleline():
def test_getstartingblock_singleline() -> None:
class A:
def __init__(self, *args):
def __init__(self, *args) -> None:
frame = sys._getframe(1)
self.source = _pytest._code.Frame(frame).statement
@ -344,22 +354,22 @@ def test_getstartingblock_singleline():
assert len(values) == 1
def test_getline_finally():
def c():
def test_getline_finally() -> None:
def c() -> None:
pass
with pytest.raises(TypeError) as excinfo:
teardown = None
try:
c(1)
c(1) # type: ignore
finally:
if teardown:
teardown()
source = excinfo.traceback[-1].statement
assert str(source).strip() == "c(1)"
assert str(source).strip() == "c(1) # type: ignore"
def test_getfuncsource_dynamic():
def test_getfuncsource_dynamic() -> None:
source = """
def f():
raise ValueError
@ -368,11 +378,13 @@ def test_getfuncsource_dynamic():
"""
co = _pytest._code.compile(source)
exec(co, globals())
assert str(_pytest._code.Source(f)).strip() == "def f():\n raise ValueError"
assert str(_pytest._code.Source(g)).strip() == "def g(): pass"
f_source = _pytest._code.Source(f) # type: ignore
g_source = _pytest._code.Source(g) # type: ignore
assert str(f_source).strip() == "def f():\n raise ValueError"
assert str(g_source).strip() == "def g(): pass"
def test_getfuncsource_with_multine_string():
def test_getfuncsource_with_multine_string() -> None:
def f():
c = """while True:
pass
@ -387,7 +399,7 @@ def test_getfuncsource_with_multine_string():
assert str(_pytest._code.Source(f)) == expected.rstrip()
def test_deindent():
def test_deindent() -> None:
from _pytest._code.source import deindent as deindent
assert deindent(["\tfoo", "\tbar"]) == ["foo", "bar"]
@ -401,7 +413,7 @@ def test_deindent():
assert lines == ["def f():", " def g():", " pass"]
def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot):
def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot) -> None:
# this test fails because the implicit inspect.getsource(A) below
# does not return the "x = 1" last line.
source = _pytest._code.Source(
@ -423,7 +435,7 @@ if True:
pass
def test_getsource_fallback():
def test_getsource_fallback() -> None:
from _pytest._code.source import getsource
expected = """def x():
@ -432,7 +444,7 @@ def test_getsource_fallback():
assert src == expected
def test_idem_compile_and_getsource():
def test_idem_compile_and_getsource() -> None:
from _pytest._code.source import getsource
expected = "def x(): pass"
@ -441,15 +453,16 @@ def test_idem_compile_and_getsource():
assert src == expected
def test_findsource_fallback():
def test_findsource_fallback() -> None:
from _pytest._code.source import findsource
src, lineno = findsource(x)
assert src is not None
assert "test_findsource_simple" in str(src)
assert src[lineno] == " def x():"
def test_findsource():
def test_findsource() -> None:
from _pytest._code.source import findsource
co = _pytest._code.compile(
@ -460,19 +473,21 @@ def test_findsource():
)
src, lineno = findsource(co)
assert src is not None
assert "if 1:" in str(src)
d = {}
d = {} # type: Dict[str, Any]
eval(co, d)
src, lineno = findsource(d["x"])
assert src is not None
assert "if 1:" in str(src)
assert src[lineno] == " def x():"
def test_getfslineno():
def test_getfslineno() -> None:
from _pytest._code import getfslineno
def f(x):
def f(x) -> None:
pass
fspath, lineno = getfslineno(f)
@ -498,40 +513,40 @@ def test_getfslineno():
assert getfslineno(B)[1] == -1
def test_code_of_object_instance_with_call():
def test_code_of_object_instance_with_call() -> None:
class A:
pass
pytest.raises(TypeError, lambda: _pytest._code.Source(A()))
class WithCall:
def __call__(self):
def __call__(self) -> None:
pass
code = _pytest._code.Code(WithCall())
assert "pass" in str(code.source())
class Hello:
def __call__(self):
def __call__(self) -> None:
pass
pytest.raises(TypeError, lambda: _pytest._code.Code(Hello))
def getstatement(lineno, source):
def getstatement(lineno: int, source) -> Source:
from _pytest._code.source import getstatementrange_ast
source = _pytest._code.Source(source, deindent=False)
ast, start, end = getstatementrange_ast(lineno, source)
return source[start:end]
src = _pytest._code.Source(source, deindent=False)
ast, start, end = getstatementrange_ast(lineno, src)
return src[start:end]
def test_oneline():
def test_oneline() -> None:
source = getstatement(0, "raise ValueError")
assert str(source) == "raise ValueError"
def test_comment_and_no_newline_at_end():
def test_comment_and_no_newline_at_end() -> None:
from _pytest._code.source import getstatementrange_ast
source = Source(
@ -545,12 +560,12 @@ def test_comment_and_no_newline_at_end():
assert end == 2
def test_oneline_and_comment():
def test_oneline_and_comment() -> None:
source = getstatement(0, "raise ValueError\n#hello")
assert str(source) == "raise ValueError"
def test_comments():
def test_comments() -> None:
source = '''def test():
"comment 1"
x = 1
@ -576,7 +591,7 @@ comment 4
assert str(getstatement(line, source)) == '"""\ncomment 4\n"""'
def test_comment_in_statement():
def test_comment_in_statement() -> None:
source = """test(foo=1,
# comment 1
bar=2)
@ -588,17 +603,17 @@ def test_comment_in_statement():
)
def test_single_line_else():
def test_single_line_else() -> None:
source = getstatement(1, "if False: 2\nelse: 3")
assert str(source) == "else: 3"
def test_single_line_finally():
def test_single_line_finally() -> None:
source = getstatement(1, "try: 1\nfinally: 3")
assert str(source) == "finally: 3"
def test_issue55():
def test_issue55() -> None:
source = (
"def round_trip(dinp):\n assert 1 == dinp\n"
'def test_rt():\n round_trip("""\n""")\n'
@ -607,7 +622,7 @@ def test_issue55():
assert str(s) == ' round_trip("""\n""")'
def test_multiline():
def test_multiline() -> None:
source = getstatement(
0,
"""\
@ -621,7 +636,7 @@ x = 3
class TestTry:
def setup_class(self):
def setup_class(self) -> None:
self.source = """\
try:
raise ValueError
@ -631,25 +646,25 @@ else:
raise KeyError()
"""
def test_body(self):
def test_body(self) -> None:
source = getstatement(1, self.source)
assert str(source) == " raise ValueError"
def test_except_line(self):
def test_except_line(self) -> None:
source = getstatement(2, self.source)
assert str(source) == "except Something:"
def test_except_body(self):
def test_except_body(self) -> None:
source = getstatement(3, self.source)
assert str(source) == " raise IndexError(1)"
def test_else(self):
def test_else(self) -> None:
source = getstatement(5, self.source)
assert str(source) == " raise KeyError()"
class TestTryFinally:
def setup_class(self):
def setup_class(self) -> None:
self.source = """\
try:
raise ValueError
@ -657,17 +672,17 @@ finally:
raise IndexError(1)
"""
def test_body(self):
def test_body(self) -> None:
source = getstatement(1, self.source)
assert str(source) == " raise ValueError"
def test_finally(self):
def test_finally(self) -> None:
source = getstatement(3, self.source)
assert str(source) == " raise IndexError(1)"
class TestIf:
def setup_class(self):
def setup_class(self) -> None:
self.source = """\
if 1:
y = 3
@ -677,24 +692,24 @@ else:
y = 7
"""
def test_body(self):
def test_body(self) -> None:
source = getstatement(1, self.source)
assert str(source) == " y = 3"
def test_elif_clause(self):
def test_elif_clause(self) -> None:
source = getstatement(2, self.source)
assert str(source) == "elif False:"
def test_elif(self):
def test_elif(self) -> None:
source = getstatement(3, self.source)
assert str(source) == " y = 5"
def test_else(self):
def test_else(self) -> None:
source = getstatement(5, self.source)
assert str(source) == " y = 7"
def test_semicolon():
def test_semicolon() -> None:
s = """\
hello ; pytest.skip()
"""
@ -702,7 +717,7 @@ hello ; pytest.skip()
assert str(source) == s.strip()
def test_def_online():
def test_def_online() -> None:
s = """\
def func(): raise ValueError(42)
@ -713,7 +728,7 @@ def something():
assert str(source) == "def func(): raise ValueError(42)"
def XXX_test_expression_multiline():
def XXX_test_expression_multiline() -> None:
source = """\
something
'''
@ -722,7 +737,7 @@ something
assert str(result) == "'''\n'''"
def test_getstartingblock_multiline():
def test_getstartingblock_multiline() -> None:
class A:
def __init__(self, *args):
frame = sys._getframe(1)