Add type annotations to _pytest._code.code

This commit is contained in:
Ran Benita 2019-11-15 23:02:55 +02:00
parent 562d4811d5
commit eaa34a9df0
4 changed files with 299 additions and 214 deletions

View File

@ -7,13 +7,17 @@ from inspect import CO_VARKEYWORDS
from io import StringIO from io import StringIO
from traceback import format_exception_only from traceback import format_exception_only
from types import CodeType from types import CodeType
from types import FrameType
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
from typing import Callable
from typing import Dict from typing import Dict
from typing import Generic from typing import Generic
from typing import Iterable
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Pattern from typing import Pattern
from typing import Sequence
from typing import Set from typing import Set
from typing import Tuple from typing import Tuple
from typing import TypeVar from typing import TypeVar
@ -27,9 +31,16 @@ import py
import _pytest import _pytest
from _pytest._io.saferepr import safeformat from _pytest._io.saferepr import safeformat
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest.compat import overload
if False: # TYPE_CHECKING if False: # TYPE_CHECKING
from typing import Type from typing import Type
from typing_extensions import Literal
from weakref import ReferenceType # noqa: F401
from _pytest._code import Source
_TracebackStyle = Literal["long", "short", "no", "native"]
class Code: class Code:
@ -38,13 +49,12 @@ class Code:
def __init__(self, rawcode) -> None: def __init__(self, rawcode) -> None:
if not hasattr(rawcode, "co_filename"): if not hasattr(rawcode, "co_filename"):
rawcode = getrawcode(rawcode) rawcode = getrawcode(rawcode)
try: if not isinstance(rawcode, CodeType):
self.filename = rawcode.co_filename
self.firstlineno = rawcode.co_firstlineno - 1
self.name = rawcode.co_name
except AttributeError:
raise TypeError("not a code object: {!r}".format(rawcode)) raise TypeError("not a code object: {!r}".format(rawcode))
self.raw = rawcode # type: CodeType self.filename = rawcode.co_filename
self.firstlineno = rawcode.co_firstlineno - 1
self.name = rawcode.co_name
self.raw = rawcode
def __eq__(self, other): def __eq__(self, other):
return self.raw == other.raw return self.raw == other.raw
@ -72,7 +82,7 @@ class Code:
return p return p
@property @property
def fullsource(self): def fullsource(self) -> Optional["Source"]:
""" return a _pytest._code.Source object for the full source file of the code """ return a _pytest._code.Source object for the full source file of the code
""" """
from _pytest._code import source from _pytest._code import source
@ -80,7 +90,7 @@ class Code:
full, _ = source.findsource(self.raw) full, _ = source.findsource(self.raw)
return full return full
def source(self): def source(self) -> "Source":
""" return a _pytest._code.Source object for the code object's source only """ return a _pytest._code.Source object for the code object's source only
""" """
# return source only for that part of code # return source only for that part of code
@ -88,7 +98,7 @@ class Code:
return _pytest._code.Source(self.raw) return _pytest._code.Source(self.raw)
def getargs(self, var=False): def getargs(self, var: bool = False) -> Tuple[str, ...]:
""" return a tuple with the argument names for the code object """ return a tuple with the argument names for the code object
if 'var' is set True also return the names of the variable and if 'var' is set True also return the names of the variable and
@ -107,7 +117,7 @@ class Frame:
"""Wrapper around a Python frame holding f_locals and f_globals """Wrapper around a Python frame holding f_locals and f_globals
in which expressions can be evaluated.""" in which expressions can be evaluated."""
def __init__(self, frame): def __init__(self, frame: FrameType) -> None:
self.lineno = frame.f_lineno - 1 self.lineno = frame.f_lineno - 1
self.f_globals = frame.f_globals self.f_globals = frame.f_globals
self.f_locals = frame.f_locals self.f_locals = frame.f_locals
@ -115,7 +125,7 @@ class Frame:
self.code = Code(frame.f_code) self.code = Code(frame.f_code)
@property @property
def statement(self): def statement(self) -> "Source":
""" statement this frame is at """ """ statement this frame is at """
import _pytest._code import _pytest._code
@ -134,7 +144,7 @@ class Frame:
f_locals.update(vars) f_locals.update(vars)
return eval(code, self.f_globals, f_locals) return eval(code, self.f_globals, f_locals)
def exec_(self, code, **vars): def exec_(self, code, **vars) -> None:
""" exec 'code' in the frame """ exec 'code' in the frame
'vars' are optional; additional local variables 'vars' are optional; additional local variables
@ -143,7 +153,7 @@ class Frame:
f_locals.update(vars) f_locals.update(vars)
exec(code, self.f_globals, f_locals) exec(code, self.f_globals, f_locals)
def repr(self, object): def repr(self, object: object) -> str:
""" return a 'safe' (non-recursive, one-line) string repr for 'object' """ return a 'safe' (non-recursive, one-line) string repr for 'object'
""" """
return saferepr(object) return saferepr(object)
@ -151,7 +161,7 @@ class Frame:
def is_true(self, object): def is_true(self, object):
return object return object
def getargs(self, var=False): def getargs(self, var: bool = False):
""" return a list of tuples (name, value) for all arguments """ return a list of tuples (name, value) for all arguments
if 'var' is set True also include the variable and keyword if 'var' is set True also include the variable and keyword
@ -169,35 +179,34 @@ class Frame:
class TracebackEntry: class TracebackEntry:
""" a single entry in a traceback """ """ a single entry in a traceback """
_repr_style = None _repr_style = None # type: Optional[Literal["short", "long"]]
exprinfo = None exprinfo = None
def __init__(self, rawentry, excinfo=None): def __init__(self, rawentry: TracebackType, excinfo=None) -> None:
self._excinfo = excinfo self._excinfo = excinfo
self._rawentry = rawentry self._rawentry = rawentry
self.lineno = rawentry.tb_lineno - 1 self.lineno = rawentry.tb_lineno - 1
def set_repr_style(self, mode): def set_repr_style(self, mode: "Literal['short', 'long']") -> None:
assert mode in ("short", "long") assert mode in ("short", "long")
self._repr_style = mode self._repr_style = mode
@property @property
def frame(self): def frame(self) -> Frame:
import _pytest._code return Frame(self._rawentry.tb_frame)
return _pytest._code.Frame(self._rawentry.tb_frame)
@property @property
def relline(self): def relline(self) -> int:
return self.lineno - self.frame.code.firstlineno return self.lineno - self.frame.code.firstlineno
def __repr__(self): def __repr__(self) -> str:
return "<TracebackEntry %s:%d>" % (self.frame.code.path, self.lineno + 1) return "<TracebackEntry %s:%d>" % (self.frame.code.path, self.lineno + 1)
@property @property
def statement(self): def statement(self) -> "Source":
""" _pytest._code.Source object for the current statement """ """ _pytest._code.Source object for the current statement """
source = self.frame.code.fullsource source = self.frame.code.fullsource
assert source is not None
return source.getstatement(self.lineno) return source.getstatement(self.lineno)
@property @property
@ -206,14 +215,14 @@ class TracebackEntry:
return self.frame.code.path return self.frame.code.path
@property @property
def locals(self): def locals(self) -> Dict[str, Any]:
""" locals of underlying frame """ """ locals of underlying frame """
return self.frame.f_locals return self.frame.f_locals
def getfirstlinesource(self): def getfirstlinesource(self) -> int:
return self.frame.code.firstlineno return self.frame.code.firstlineno
def getsource(self, astcache=None): def getsource(self, astcache=None) -> Optional["Source"]:
""" return failing source code. """ """ return failing source code. """
# we use the passed in astcache to not reparse asttrees # we use the passed in astcache to not reparse asttrees
# within exception info printing # within exception info printing
@ -258,7 +267,7 @@ class TracebackEntry:
return tbh(None if self._excinfo is None else self._excinfo()) return tbh(None if self._excinfo is None else self._excinfo())
return tbh return tbh
def __str__(self): def __str__(self) -> str:
try: try:
fn = str(self.path) fn = str(self.path)
except py.error.Error: except py.error.Error:
@ -273,31 +282,42 @@ class TracebackEntry:
return " File %r:%d in %s\n %s\n" % (fn, self.lineno + 1, name, line) return " File %r:%d in %s\n %s\n" % (fn, self.lineno + 1, name, line)
@property @property
def name(self): def name(self) -> str:
""" co_name of underlying code """ """ co_name of underlying code """
return self.frame.code.raw.co_name return self.frame.code.raw.co_name
class Traceback(list): class Traceback(List[TracebackEntry]):
""" Traceback objects encapsulate and offer higher level """ Traceback objects encapsulate and offer higher level
access to Traceback entries. access to Traceback entries.
""" """
def __init__(self, tb, excinfo=None): def __init__(
self,
tb: Union[TracebackType, Iterable[TracebackEntry]],
excinfo: Optional["ReferenceType[ExceptionInfo]"] = None,
) -> None:
""" initialize from given python traceback object and ExceptionInfo """ """ initialize from given python traceback object and ExceptionInfo """
self._excinfo = excinfo self._excinfo = excinfo
if hasattr(tb, "tb_next"): if isinstance(tb, TracebackType):
def f(cur): def f(cur: TracebackType) -> Iterable[TracebackEntry]:
while cur is not None: cur_ = cur # type: Optional[TracebackType]
yield TracebackEntry(cur, excinfo=excinfo) while cur_ is not None:
cur = cur.tb_next yield TracebackEntry(cur_, excinfo=excinfo)
cur_ = cur_.tb_next
list.__init__(self, f(tb)) super().__init__(f(tb))
else: else:
list.__init__(self, tb) super().__init__(tb)
def cut(self, path=None, lineno=None, firstlineno=None, excludepath=None): def cut(
self,
path=None,
lineno: Optional[int] = None,
firstlineno: Optional[int] = None,
excludepath=None,
) -> "Traceback":
""" return a Traceback instance wrapping part of this Traceback """ return a Traceback instance wrapping part of this Traceback
by providing any combination of path, lineno and firstlineno, the by providing any combination of path, lineno and firstlineno, the
@ -323,13 +343,25 @@ class Traceback(list):
return Traceback(x._rawentry, self._excinfo) return Traceback(x._rawentry, self._excinfo)
return self return self
def __getitem__(self, key): @overload
val = super().__getitem__(key) def __getitem__(self, key: int) -> TracebackEntry:
if isinstance(key, type(slice(0))): raise NotImplementedError()
val = self.__class__(val)
return val
def filter(self, fn=lambda x: not x.ishidden()): @overload # noqa: F811
def __getitem__(self, key: slice) -> "Traceback": # noqa: F811
raise NotImplementedError()
def __getitem__( # noqa: F811
self, key: Union[int, slice]
) -> Union[TracebackEntry, "Traceback"]:
if isinstance(key, slice):
return self.__class__(super().__getitem__(key))
else:
return super().__getitem__(key)
def filter(
self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden()
) -> "Traceback":
""" return a Traceback instance with certain items removed """ return a Traceback instance with certain items removed
fn is a function that gets a single argument, a TracebackEntry fn is a function that gets a single argument, a TracebackEntry
@ -341,7 +373,7 @@ class Traceback(list):
""" """
return Traceback(filter(fn, self), self._excinfo) return Traceback(filter(fn, self), self._excinfo)
def getcrashentry(self): def getcrashentry(self) -> TracebackEntry:
""" return last non-hidden traceback entry that lead """ return last non-hidden traceback entry that lead
to the exception of a traceback. to the exception of a traceback.
""" """
@ -351,7 +383,7 @@ class Traceback(list):
return entry return entry
return self[-1] return self[-1]
def recursionindex(self): def recursionindex(self) -> Optional[int]:
""" return the index of the frame/TracebackEntry where recursion """ return the index of the frame/TracebackEntry where recursion
originates if appropriate, None if no recursion occurred originates if appropriate, None if no recursion occurred
""" """
@ -541,7 +573,7 @@ class ExceptionInfo(Generic[_E]):
def getrepr( def getrepr(
self, self,
showlocals: bool = False, showlocals: bool = False,
style: str = "long", style: "_TracebackStyle" = "long",
abspath: bool = False, abspath: bool = False,
tbfilter: bool = True, tbfilter: bool = True,
funcargs: bool = False, funcargs: bool = False,
@ -619,16 +651,16 @@ class FormattedExcinfo:
flow_marker = ">" flow_marker = ">"
fail_marker = "E" fail_marker = "E"
showlocals = attr.ib(default=False) showlocals = attr.ib(type=bool, default=False)
style = attr.ib(default="long") style = attr.ib(type="_TracebackStyle", default="long")
abspath = attr.ib(default=True) abspath = attr.ib(type=bool, default=True)
tbfilter = attr.ib(default=True) tbfilter = attr.ib(type=bool, default=True)
funcargs = attr.ib(default=False) funcargs = attr.ib(type=bool, default=False)
truncate_locals = attr.ib(default=True) truncate_locals = attr.ib(type=bool, default=True)
chain = attr.ib(default=True) chain = attr.ib(type=bool, default=True)
astcache = attr.ib(default=attr.Factory(dict), init=False, repr=False) astcache = attr.ib(default=attr.Factory(dict), init=False, repr=False)
def _getindent(self, source): def _getindent(self, source: "Source") -> int:
# figure out indent for given source # figure out indent for given source
try: try:
s = str(source.getstatement(len(source) - 1)) s = str(source.getstatement(len(source) - 1))
@ -643,20 +675,27 @@ class FormattedExcinfo:
return 0 return 0
return 4 + (len(s) - len(s.lstrip())) return 4 + (len(s) - len(s.lstrip()))
def _getentrysource(self, entry): def _getentrysource(self, entry: TracebackEntry) -> Optional["Source"]:
source = entry.getsource(self.astcache) source = entry.getsource(self.astcache)
if source is not None: if source is not None:
source = source.deindent() source = source.deindent()
return source return source
def repr_args(self, entry): def repr_args(self, entry: TracebackEntry) -> Optional["ReprFuncArgs"]:
if self.funcargs: if self.funcargs:
args = [] args = []
for argname, argvalue in entry.frame.getargs(var=True): for argname, argvalue in entry.frame.getargs(var=True):
args.append((argname, saferepr(argvalue))) args.append((argname, saferepr(argvalue)))
return ReprFuncArgs(args) return ReprFuncArgs(args)
return None
def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]: def get_source(
self,
source: "Source",
line_index: int = -1,
excinfo: Optional[ExceptionInfo] = None,
short: bool = False,
) -> List[str]:
""" return formatted and marked up source lines. """ """ return formatted and marked up source lines. """
import _pytest._code import _pytest._code
@ -680,19 +719,21 @@ class FormattedExcinfo:
lines.extend(self.get_exconly(excinfo, indent=indent, markall=True)) lines.extend(self.get_exconly(excinfo, indent=indent, markall=True))
return lines return lines
def get_exconly(self, excinfo, indent=4, markall=False): def get_exconly(
self, excinfo: ExceptionInfo, indent: int = 4, markall: bool = False
) -> List[str]:
lines = [] lines = []
indent = " " * indent indentstr = " " * indent
# get the real exception information out # get the real exception information out
exlines = excinfo.exconly(tryshort=True).split("\n") exlines = excinfo.exconly(tryshort=True).split("\n")
failindent = self.fail_marker + indent[1:] failindent = self.fail_marker + indentstr[1:]
for line in exlines: for line in exlines:
lines.append(failindent + line) lines.append(failindent + line)
if not markall: if not markall:
failindent = indent failindent = indentstr
return lines return lines
def repr_locals(self, locals): def repr_locals(self, locals: Dict[str, object]) -> Optional["ReprLocals"]:
if self.showlocals: if self.showlocals:
lines = [] lines = []
keys = [loc for loc in locals if loc[0] != "@"] keys = [loc for loc in locals if loc[0] != "@"]
@ -717,8 +758,11 @@ class FormattedExcinfo:
# # XXX # # XXX
# pprint.pprint(value, stream=self.excinfowriter) # pprint.pprint(value, stream=self.excinfowriter)
return ReprLocals(lines) return ReprLocals(lines)
return None
def repr_traceback_entry(self, entry, excinfo=None): def repr_traceback_entry(
self, entry: TracebackEntry, excinfo: Optional[ExceptionInfo] = None
) -> "ReprEntry":
import _pytest._code import _pytest._code
source = self._getentrysource(entry) source = self._getentrysource(entry)
@ -729,9 +773,7 @@ class FormattedExcinfo:
line_index = entry.lineno - entry.getfirstlinesource() line_index = entry.lineno - entry.getfirstlinesource()
lines = [] # type: List[str] lines = [] # type: List[str]
style = entry._repr_style style = entry._repr_style if entry._repr_style is not None else self.style
if style is None:
style = self.style
if style in ("short", "long"): if style in ("short", "long"):
short = style == "short" short = style == "short"
reprargs = self.repr_args(entry) if not short else None reprargs = self.repr_args(entry) if not short else None
@ -761,7 +803,7 @@ class FormattedExcinfo:
path = np path = np
return path return path
def repr_traceback(self, excinfo): def repr_traceback(self, excinfo: ExceptionInfo) -> "ReprTraceback":
traceback = excinfo.traceback traceback = excinfo.traceback
if self.tbfilter: if self.tbfilter:
traceback = traceback.filter() traceback = traceback.filter()
@ -779,7 +821,9 @@ class FormattedExcinfo:
entries.append(reprentry) entries.append(reprentry)
return ReprTraceback(entries, extraline, style=self.style) return ReprTraceback(entries, extraline, style=self.style)
def _truncate_recursive_traceback(self, traceback): def _truncate_recursive_traceback(
self, traceback: Traceback
) -> Tuple[Traceback, Optional[str]]:
""" """
Truncate the given recursive traceback trying to find the starting point Truncate the given recursive traceback trying to find the starting point
of the recursion. of the recursion.
@ -806,7 +850,9 @@ class FormattedExcinfo:
max_frames=max_frames, max_frames=max_frames,
total=len(traceback), total=len(traceback),
) # type: Optional[str] ) # type: Optional[str]
traceback = traceback[:max_frames] + traceback[-max_frames:] # Type ignored because adding two instaces of a List subtype
# currently incorrectly has type List instead of the subtype.
traceback = traceback[:max_frames] + traceback[-max_frames:] # type: ignore
else: else:
if recursionindex is not None: if recursionindex is not None:
extraline = "!!! Recursion detected (same locals & position)" extraline = "!!! Recursion detected (same locals & position)"
@ -863,7 +909,7 @@ class FormattedExcinfo:
class TerminalRepr: class TerminalRepr:
def __str__(self): def __str__(self) -> str:
# FYI this is called from pytest-xdist's serialization of exception # FYI this is called from pytest-xdist's serialization of exception
# information. # information.
io = StringIO() io = StringIO()
@ -871,7 +917,7 @@ class TerminalRepr:
self.toterminal(tw) self.toterminal(tw)
return io.getvalue().strip() return io.getvalue().strip()
def __repr__(self): def __repr__(self) -> str:
return "<{} instance at {:0x}>".format(self.__class__, id(self)) return "<{} instance at {:0x}>".format(self.__class__, id(self))
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
@ -882,7 +928,7 @@ class ExceptionRepr(TerminalRepr):
def __init__(self) -> None: def __init__(self) -> None:
self.sections = [] # type: List[Tuple[str, str, str]] self.sections = [] # type: List[Tuple[str, str, str]]
def addsection(self, name, content, sep="-"): def addsection(self, name: str, content: str, sep: str = "-") -> None:
self.sections.append((name, content, sep)) self.sections.append((name, content, sep))
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
@ -892,7 +938,12 @@ class ExceptionRepr(TerminalRepr):
class ExceptionChainRepr(ExceptionRepr): class ExceptionChainRepr(ExceptionRepr):
def __init__(self, chain): def __init__(
self,
chain: Sequence[
Tuple["ReprTraceback", Optional["ReprFileLocation"], Optional[str]]
],
) -> None:
super().__init__() super().__init__()
self.chain = chain self.chain = chain
# reprcrash and reprtraceback of the outermost (the newest) exception # reprcrash and reprtraceback of the outermost (the newest) exception
@ -910,7 +961,9 @@ class ExceptionChainRepr(ExceptionRepr):
class ReprExceptionInfo(ExceptionRepr): class ReprExceptionInfo(ExceptionRepr):
def __init__(self, reprtraceback, reprcrash): def __init__(
self, reprtraceback: "ReprTraceback", reprcrash: "ReprFileLocation"
) -> None:
super().__init__() super().__init__()
self.reprtraceback = reprtraceback self.reprtraceback = reprtraceback
self.reprcrash = reprcrash self.reprcrash = reprcrash
@ -923,7 +976,12 @@ class ReprExceptionInfo(ExceptionRepr):
class ReprTraceback(TerminalRepr): class ReprTraceback(TerminalRepr):
entrysep = "_ " entrysep = "_ "
def __init__(self, reprentries, extraline, style): def __init__(
self,
reprentries: Sequence[Union["ReprEntry", "ReprEntryNative"]],
extraline: Optional[str],
style: "_TracebackStyle",
) -> None:
self.reprentries = reprentries self.reprentries = reprentries
self.extraline = extraline self.extraline = extraline
self.style = style self.style = style
@ -948,16 +1006,16 @@ class ReprTraceback(TerminalRepr):
class ReprTracebackNative(ReprTraceback): class ReprTracebackNative(ReprTraceback):
def __init__(self, tblines): def __init__(self, tblines: Sequence[str]) -> None:
self.style = "native" self.style = "native"
self.reprentries = [ReprEntryNative(tblines)] self.reprentries = [ReprEntryNative(tblines)]
self.extraline = None self.extraline = None
class ReprEntryNative(TerminalRepr): class ReprEntryNative(TerminalRepr):
style = "native" style = "native" # type: _TracebackStyle
def __init__(self, tblines): def __init__(self, tblines: Sequence[str]) -> None:
self.lines = tblines self.lines = tblines
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
@ -965,7 +1023,14 @@ class ReprEntryNative(TerminalRepr):
class ReprEntry(TerminalRepr): class ReprEntry(TerminalRepr):
def __init__(self, lines, reprfuncargs, reprlocals, filelocrepr, style): def __init__(
self,
lines: Sequence[str],
reprfuncargs: Optional["ReprFuncArgs"],
reprlocals: Optional["ReprLocals"],
filelocrepr: Optional["ReprFileLocation"],
style: "_TracebackStyle",
) -> None:
self.lines = lines self.lines = lines
self.reprfuncargs = reprfuncargs self.reprfuncargs = reprfuncargs
self.reprlocals = reprlocals self.reprlocals = reprlocals
@ -974,6 +1039,7 @@ class ReprEntry(TerminalRepr):
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
if self.style == "short": if self.style == "short":
assert self.reprfileloc is not None
self.reprfileloc.toterminal(tw) self.reprfileloc.toterminal(tw)
for line in self.lines: for line in self.lines:
red = line.startswith("E ") red = line.startswith("E ")
@ -992,14 +1058,14 @@ class ReprEntry(TerminalRepr):
tw.line("") tw.line("")
self.reprfileloc.toterminal(tw) self.reprfileloc.toterminal(tw)
def __str__(self): def __str__(self) -> str:
return "{}\n{}\n{}".format( return "{}\n{}\n{}".format(
"\n".join(self.lines), self.reprlocals, self.reprfileloc "\n".join(self.lines), self.reprlocals, self.reprfileloc
) )
class ReprFileLocation(TerminalRepr): class ReprFileLocation(TerminalRepr):
def __init__(self, path, lineno, message): def __init__(self, path, lineno: int, message: str) -> None:
self.path = str(path) self.path = str(path)
self.lineno = lineno self.lineno = lineno
self.message = message self.message = message
@ -1016,7 +1082,7 @@ class ReprFileLocation(TerminalRepr):
class ReprLocals(TerminalRepr): class ReprLocals(TerminalRepr):
def __init__(self, lines): def __init__(self, lines: Sequence[str]) -> None:
self.lines = lines self.lines = lines
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
@ -1025,7 +1091,7 @@ class ReprLocals(TerminalRepr):
class ReprFuncArgs(TerminalRepr): class ReprFuncArgs(TerminalRepr):
def __init__(self, args): def __init__(self, args: Sequence[Tuple[str, object]]) -> None:
self.args = args self.args = args
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
@ -1047,7 +1113,7 @@ class ReprFuncArgs(TerminalRepr):
tw.line("") tw.line("")
def getrawcode(obj, trycall=True): def getrawcode(obj, trycall: bool = True):
""" return code object for given function. """ """ return code object for given function. """
try: try:
return obj.__code__ return obj.__code__
@ -1075,7 +1141,7 @@ _PYTEST_DIR = py.path.local(_pytest.__file__).dirpath()
_PY_DIR = py.path.local(py.__file__).dirpath() _PY_DIR = py.path.local(py.__file__).dirpath()
def filter_traceback(entry): def filter_traceback(entry: TracebackEntry) -> bool:
"""Return True if a TracebackEntry instance should be removed from tracebacks: """Return True if a TracebackEntry instance should be removed from tracebacks:
* dynamically generated code (no code to show up for it); * dynamically generated code (no code to show up for it);
* internal traceback from pytest or its internal libraries, py and pluggy. * internal traceback from pytest or its internal libraries, py and pluggy.

View File

@ -1,18 +1,19 @@
import sys import sys
from types import FrameType
from unittest import mock from unittest import mock
import _pytest._code import _pytest._code
import pytest import pytest
def test_ne(): def test_ne() -> None:
code1 = _pytest._code.Code(compile('foo = "bar"', "", "exec")) code1 = _pytest._code.Code(compile('foo = "bar"', "", "exec"))
assert code1 == code1 assert code1 == code1
code2 = _pytest._code.Code(compile('foo = "baz"', "", "exec")) code2 = _pytest._code.Code(compile('foo = "baz"', "", "exec"))
assert code2 != code1 assert code2 != code1
def test_code_gives_back_name_for_not_existing_file(): def test_code_gives_back_name_for_not_existing_file() -> None:
name = "abc-123" name = "abc-123"
co_code = compile("pass\n", name, "exec") co_code = compile("pass\n", name, "exec")
assert co_code.co_filename == name assert co_code.co_filename == name
@ -21,68 +22,67 @@ def test_code_gives_back_name_for_not_existing_file():
assert code.fullsource is None assert code.fullsource is None
def test_code_with_class(): def test_code_with_class() -> None:
class A: class A:
pass pass
pytest.raises(TypeError, _pytest._code.Code, A) pytest.raises(TypeError, _pytest._code.Code, A)
def x(): def x() -> None:
raise NotImplementedError() raise NotImplementedError()
def test_code_fullsource(): def test_code_fullsource() -> None:
code = _pytest._code.Code(x) code = _pytest._code.Code(x)
full = code.fullsource full = code.fullsource
assert "test_code_fullsource()" in str(full) assert "test_code_fullsource()" in str(full)
def test_code_source(): def test_code_source() -> None:
code = _pytest._code.Code(x) code = _pytest._code.Code(x)
src = code.source() src = code.source()
expected = """def x(): expected = """def x() -> None:
raise NotImplementedError()""" raise NotImplementedError()"""
assert str(src) == expected assert str(src) == expected
def test_frame_getsourcelineno_myself(): def test_frame_getsourcelineno_myself() -> None:
def func(): def func() -> FrameType:
return sys._getframe(0) return sys._getframe(0)
f = func() f = _pytest._code.Frame(func())
f = _pytest._code.Frame(f)
source, lineno = f.code.fullsource, f.lineno source, lineno = f.code.fullsource, f.lineno
assert source is not None
assert source[lineno].startswith(" return sys._getframe(0)") assert source[lineno].startswith(" return sys._getframe(0)")
def test_getstatement_empty_fullsource(): def test_getstatement_empty_fullsource() -> None:
def func(): def func() -> FrameType:
return sys._getframe(0) return sys._getframe(0)
f = func() f = _pytest._code.Frame(func())
f = _pytest._code.Frame(f)
with mock.patch.object(f.code.__class__, "fullsource", None): with mock.patch.object(f.code.__class__, "fullsource", None):
assert f.statement == "" assert f.statement == ""
def test_code_from_func(): def test_code_from_func() -> None:
co = _pytest._code.Code(test_frame_getsourcelineno_myself) co = _pytest._code.Code(test_frame_getsourcelineno_myself)
assert co.firstlineno assert co.firstlineno
assert co.path assert co.path
def test_unicode_handling(): def test_unicode_handling() -> None:
value = "ąć".encode() value = "ąć".encode()
def f(): def f() -> None:
raise Exception(value) raise Exception(value)
excinfo = pytest.raises(Exception, f) excinfo = pytest.raises(Exception, f)
str(excinfo) str(excinfo)
def test_code_getargs(): def test_code_getargs() -> None:
def f1(x): def f1(x):
raise NotImplementedError() raise NotImplementedError()
@ -108,26 +108,26 @@ def test_code_getargs():
assert c4.getargs(var=True) == ("x", "y", "z") assert c4.getargs(var=True) == ("x", "y", "z")
def test_frame_getargs(): def test_frame_getargs() -> None:
def f1(x): def f1(x) -> FrameType:
return sys._getframe(0) return sys._getframe(0)
fr1 = _pytest._code.Frame(f1("a")) fr1 = _pytest._code.Frame(f1("a"))
assert fr1.getargs(var=True) == [("x", "a")] assert fr1.getargs(var=True) == [("x", "a")]
def f2(x, *y): def f2(x, *y) -> FrameType:
return sys._getframe(0) return sys._getframe(0)
fr2 = _pytest._code.Frame(f2("a", "b", "c")) fr2 = _pytest._code.Frame(f2("a", "b", "c"))
assert fr2.getargs(var=True) == [("x", "a"), ("y", ("b", "c"))] assert fr2.getargs(var=True) == [("x", "a"), ("y", ("b", "c"))]
def f3(x, **z): def f3(x, **z) -> FrameType:
return sys._getframe(0) return sys._getframe(0)
fr3 = _pytest._code.Frame(f3("a", b="c")) fr3 = _pytest._code.Frame(f3("a", b="c"))
assert fr3.getargs(var=True) == [("x", "a"), ("z", {"b": "c"})] assert fr3.getargs(var=True) == [("x", "a"), ("z", {"b": "c"})]
def f4(x, *y, **z): def f4(x, *y, **z) -> FrameType:
return sys._getframe(0) return sys._getframe(0)
fr4 = _pytest._code.Frame(f4("a", "b", c="d")) fr4 = _pytest._code.Frame(f4("a", "b", c="d"))
@ -135,7 +135,7 @@ def test_frame_getargs():
class TestExceptionInfo: class TestExceptionInfo:
def test_bad_getsource(self): def test_bad_getsource(self) -> None:
try: try:
if False: if False:
pass pass
@ -145,13 +145,13 @@ class TestExceptionInfo:
exci = _pytest._code.ExceptionInfo.from_current() exci = _pytest._code.ExceptionInfo.from_current()
assert exci.getrepr() assert exci.getrepr()
def test_from_current_with_missing(self): def test_from_current_with_missing(self) -> None:
with pytest.raises(AssertionError, match="no current exception"): with pytest.raises(AssertionError, match="no current exception"):
_pytest._code.ExceptionInfo.from_current() _pytest._code.ExceptionInfo.from_current()
class TestTracebackEntry: class TestTracebackEntry:
def test_getsource(self): def test_getsource(self) -> None:
try: try:
if False: if False:
pass pass
@ -161,12 +161,13 @@ class TestTracebackEntry:
exci = _pytest._code.ExceptionInfo.from_current() exci = _pytest._code.ExceptionInfo.from_current()
entry = exci.traceback[0] entry = exci.traceback[0]
source = entry.getsource() source = entry.getsource()
assert source is not None
assert len(source) == 6 assert len(source) == 6
assert "assert False" in source[5] assert "assert False" in source[5]
class TestReprFuncArgs: class TestReprFuncArgs:
def test_not_raise_exception_with_mixed_encoding(self, tw_mock): def test_not_raise_exception_with_mixed_encoding(self, tw_mock) -> None:
from _pytest._code.code import ReprFuncArgs from _pytest._code.code import ReprFuncArgs
args = [("unicode_string", "São Paulo"), ("utf8_string", b"S\xc3\xa3o Paulo")] args = [("unicode_string", "São Paulo"), ("utf8_string", b"S\xc3\xa3o Paulo")]

View File

@ -3,6 +3,7 @@ import os
import queue import queue
import sys import sys
import textwrap import textwrap
from typing import Union
import py import py
@ -224,23 +225,25 @@ class TestTraceback_f_g_h:
repr = excinfo.getrepr() repr = excinfo.getrepr()
assert "RuntimeError: hello" in str(repr.reprcrash) assert "RuntimeError: hello" in str(repr.reprcrash)
def test_traceback_no_recursion_index(self): def test_traceback_no_recursion_index(self) -> None:
def do_stuff(): def do_stuff() -> None:
raise RuntimeError raise RuntimeError
def reraise_me(): def reraise_me() -> None:
import sys import sys
exc, val, tb = sys.exc_info() exc, val, tb = sys.exc_info()
assert val is not None
raise val.with_traceback(tb) raise val.with_traceback(tb)
def f(n): def f(n: int) -> None:
try: try:
do_stuff() do_stuff()
except: # noqa except: # noqa
reraise_me() reraise_me()
excinfo = pytest.raises(RuntimeError, f, 8) excinfo = pytest.raises(RuntimeError, f, 8)
assert excinfo is not None
traceback = excinfo.traceback traceback = excinfo.traceback
recindex = traceback.recursionindex() recindex = traceback.recursionindex()
assert recindex is None assert recindex is None
@ -596,7 +599,6 @@ raise ValueError()
assert lines[3] == "E world" assert lines[3] == "E world"
assert not lines[4:] assert not lines[4:]
loc = repr_entry.reprlocals is not None
loc = repr_entry.reprfileloc loc = repr_entry.reprfileloc
assert loc.path == mod.__file__ assert loc.path == mod.__file__
assert loc.lineno == 3 assert loc.lineno == 3
@ -1286,9 +1288,10 @@ raise ValueError()
@pytest.mark.parametrize("style", ["short", "long"]) @pytest.mark.parametrize("style", ["short", "long"])
@pytest.mark.parametrize("encoding", [None, "utf8", "utf16"]) @pytest.mark.parametrize("encoding", [None, "utf8", "utf16"])
def test_repr_traceback_with_unicode(style, encoding): def test_repr_traceback_with_unicode(style, encoding):
msg = "" if encoding is None:
if encoding is not None: msg = "" # type: Union[str, bytes]
msg = msg.encode(encoding) else:
msg = "".encode(encoding)
try: try:
raise RuntimeError(msg) raise RuntimeError(msg)
except RuntimeError: except RuntimeError:

View File

@ -4,13 +4,16 @@
import ast import ast
import inspect import inspect
import sys import sys
from typing import Any
from typing import Dict
from typing import Optional
import _pytest._code import _pytest._code
import pytest import pytest
from _pytest._code import Source from _pytest._code import Source
def test_source_str_function(): def test_source_str_function() -> None:
x = Source("3") x = Source("3")
assert str(x) == "3" assert str(x) == "3"
@ -25,7 +28,7 @@ def test_source_str_function():
assert str(x) == "\n3" assert str(x) == "\n3"
def test_unicode(): def test_unicode() -> None:
x = Source("4") x = Source("4")
assert str(x) == "4" assert str(x) == "4"
co = _pytest._code.compile('"å"', mode="eval") co = _pytest._code.compile('"å"', mode="eval")
@ -33,12 +36,12 @@ def test_unicode():
assert isinstance(val, str) assert isinstance(val, str)
def test_source_from_function(): def test_source_from_function() -> None:
source = _pytest._code.Source(test_source_str_function) source = _pytest._code.Source(test_source_str_function)
assert str(source).startswith("def test_source_str_function():") assert str(source).startswith("def test_source_str_function() -> None:")
def test_source_from_method(): def test_source_from_method() -> None:
class TestClass: class TestClass:
def test_method(self): def test_method(self):
pass pass
@ -47,13 +50,13 @@ def test_source_from_method():
assert source.lines == ["def test_method(self):", " pass"] assert source.lines == ["def test_method(self):", " pass"]
def test_source_from_lines(): def test_source_from_lines() -> None:
lines = ["a \n", "b\n", "c"] lines = ["a \n", "b\n", "c"]
source = _pytest._code.Source(lines) source = _pytest._code.Source(lines)
assert source.lines == ["a ", "b", "c"] assert source.lines == ["a ", "b", "c"]
def test_source_from_inner_function(): def test_source_from_inner_function() -> None:
def f(): def f():
pass pass
@ -63,7 +66,7 @@ def test_source_from_inner_function():
assert str(source).startswith("def f():") assert str(source).startswith("def f():")
def test_source_putaround_simple(): def test_source_putaround_simple() -> None:
source = Source("raise ValueError") source = Source("raise ValueError")
source = source.putaround( source = source.putaround(
"try:", "try:",
@ -85,7 +88,7 @@ else:
) )
def test_source_putaround(): def test_source_putaround() -> None:
source = Source() source = Source()
source = source.putaround( source = source.putaround(
""" """
@ -96,28 +99,29 @@ def test_source_putaround():
assert str(source).strip() == "if 1:\n x=1" assert str(source).strip() == "if 1:\n x=1"
def test_source_strips(): def test_source_strips() -> None:
source = Source("") source = Source("")
assert source == Source() assert source == Source()
assert str(source) == "" assert str(source) == ""
assert source.strip() == source assert source.strip() == source
def test_source_strip_multiline(): def test_source_strip_multiline() -> None:
source = Source() source = Source()
source.lines = ["", " hello", " "] source.lines = ["", " hello", " "]
source2 = source.strip() source2 = source.strip()
assert source2.lines == [" hello"] assert source2.lines == [" hello"]
def test_syntaxerror_rerepresentation(): def test_syntaxerror_rerepresentation() -> None:
ex = pytest.raises(SyntaxError, _pytest._code.compile, "xyz xyz") ex = pytest.raises(SyntaxError, _pytest._code.compile, "xyz xyz")
assert ex is not None
assert ex.value.lineno == 1 assert ex.value.lineno == 1
assert ex.value.offset in {5, 7} # cpython: 7, pypy3.6 7.1.1: 5 assert ex.value.offset in {5, 7} # cpython: 7, pypy3.6 7.1.1: 5
assert ex.value.text.strip(), "x x" assert ex.value.text == "xyz xyz\n"
def test_isparseable(): def test_isparseable() -> None:
assert Source("hello").isparseable() assert Source("hello").isparseable()
assert Source("if 1:\n pass").isparseable() assert Source("if 1:\n pass").isparseable()
assert Source(" \nif 1:\n pass").isparseable() assert Source(" \nif 1:\n pass").isparseable()
@ -127,7 +131,7 @@ def test_isparseable():
class TestAccesses: class TestAccesses:
def setup_class(self): def setup_class(self) -> None:
self.source = Source( self.source = Source(
"""\ """\
def f(x): def f(x):
@ -137,26 +141,26 @@ class TestAccesses:
""" """
) )
def test_getrange(self): def test_getrange(self) -> None:
x = self.source[0:2] x = self.source[0:2]
assert x.isparseable() assert x.isparseable()
assert len(x.lines) == 2 assert len(x.lines) == 2
assert str(x) == "def f(x):\n pass" assert str(x) == "def f(x):\n pass"
def test_getline(self): def test_getline(self) -> None:
x = self.source[0] x = self.source[0]
assert x == "def f(x):" assert x == "def f(x):"
def test_len(self): def test_len(self) -> None:
assert len(self.source) == 4 assert len(self.source) == 4
def test_iter(self): def test_iter(self) -> None:
values = [x for x in self.source] values = [x for x in self.source]
assert len(values) == 4 assert len(values) == 4
class TestSourceParsingAndCompiling: class TestSourceParsingAndCompiling:
def setup_class(self): def setup_class(self) -> None:
self.source = Source( self.source = Source(
"""\ """\
def f(x): def f(x):
@ -166,19 +170,19 @@ class TestSourceParsingAndCompiling:
""" """
).strip() ).strip()
def test_compile(self): def test_compile(self) -> None:
co = _pytest._code.compile("x=3") co = _pytest._code.compile("x=3")
d = {} d = {} # type: Dict[str, Any]
exec(co, d) exec(co, d)
assert d["x"] == 3 assert d["x"] == 3
def test_compile_and_getsource_simple(self): def test_compile_and_getsource_simple(self) -> None:
co = _pytest._code.compile("x=3") co = _pytest._code.compile("x=3")
exec(co) exec(co)
source = _pytest._code.Source(co) source = _pytest._code.Source(co)
assert str(source) == "x=3" assert str(source) == "x=3"
def test_compile_and_getsource_through_same_function(self): def test_compile_and_getsource_through_same_function(self) -> None:
def gensource(source): def gensource(source):
return _pytest._code.compile(source) return _pytest._code.compile(source)
@ -199,7 +203,7 @@ class TestSourceParsingAndCompiling:
source2 = inspect.getsource(co2) source2 = inspect.getsource(co2)
assert "ValueError" in source2 assert "ValueError" in source2
def test_getstatement(self): def test_getstatement(self) -> None:
# print str(self.source) # print str(self.source)
ass = str(self.source[1:]) ass = str(self.source[1:])
for i in range(1, 4): for i in range(1, 4):
@ -208,7 +212,7 @@ class TestSourceParsingAndCompiling:
# x = s.deindent() # x = s.deindent()
assert str(s) == ass assert str(s) == ass
def test_getstatementrange_triple_quoted(self): def test_getstatementrange_triple_quoted(self) -> None:
# print str(self.source) # print str(self.source)
source = Source( source = Source(
"""hello(''' """hello('''
@ -219,7 +223,7 @@ class TestSourceParsingAndCompiling:
s = source.getstatement(1) s = source.getstatement(1)
assert s == str(source) assert s == str(source)
def test_getstatementrange_within_constructs(self): def test_getstatementrange_within_constructs(self) -> None:
source = Source( source = Source(
"""\ """\
try: try:
@ -241,7 +245,7 @@ class TestSourceParsingAndCompiling:
# assert source.getstatementrange(5) == (0, 7) # assert source.getstatementrange(5) == (0, 7)
assert source.getstatementrange(6) == (6, 7) assert source.getstatementrange(6) == (6, 7)
def test_getstatementrange_bug(self): def test_getstatementrange_bug(self) -> None:
source = Source( source = Source(
"""\ """\
try: try:
@ -255,7 +259,7 @@ class TestSourceParsingAndCompiling:
assert len(source) == 6 assert len(source) == 6
assert source.getstatementrange(2) == (1, 4) assert source.getstatementrange(2) == (1, 4)
def test_getstatementrange_bug2(self): def test_getstatementrange_bug2(self) -> None:
source = Source( source = Source(
"""\ """\
assert ( assert (
@ -272,7 +276,7 @@ class TestSourceParsingAndCompiling:
assert len(source) == 9 assert len(source) == 9
assert source.getstatementrange(5) == (0, 9) assert source.getstatementrange(5) == (0, 9)
def test_getstatementrange_ast_issue58(self): def test_getstatementrange_ast_issue58(self) -> None:
source = Source( source = Source(
"""\ """\
@ -286,38 +290,44 @@ class TestSourceParsingAndCompiling:
assert getstatement(2, source).lines == source.lines[2:3] assert getstatement(2, source).lines == source.lines[2:3]
assert getstatement(3, source).lines == source.lines[3:4] assert getstatement(3, source).lines == source.lines[3:4]
def test_getstatementrange_out_of_bounds_py3(self): def test_getstatementrange_out_of_bounds_py3(self) -> None:
source = Source("if xxx:\n from .collections import something") source = Source("if xxx:\n from .collections import something")
r = source.getstatementrange(1) r = source.getstatementrange(1)
assert r == (1, 2) assert r == (1, 2)
def test_getstatementrange_with_syntaxerror_issue7(self): def test_getstatementrange_with_syntaxerror_issue7(self) -> None:
source = Source(":") source = Source(":")
pytest.raises(SyntaxError, lambda: source.getstatementrange(0)) pytest.raises(SyntaxError, lambda: source.getstatementrange(0))
def test_compile_to_ast(self): def test_compile_to_ast(self) -> None:
source = Source("x = 4") source = Source("x = 4")
mod = source.compile(flag=ast.PyCF_ONLY_AST) mod = source.compile(flag=ast.PyCF_ONLY_AST)
assert isinstance(mod, ast.Module) assert isinstance(mod, ast.Module)
compile(mod, "<filename>", "exec") compile(mod, "<filename>", "exec")
def test_compile_and_getsource(self): def test_compile_and_getsource(self) -> None:
co = self.source.compile() co = self.source.compile()
exec(co, globals()) exec(co, globals())
f(7) f(7) # type: ignore
excinfo = pytest.raises(AssertionError, f, 6) excinfo = pytest.raises(AssertionError, f, 6) # type: ignore
assert excinfo is not None
frame = excinfo.traceback[-1].frame frame = excinfo.traceback[-1].frame
assert isinstance(frame.code.fullsource, Source)
stmt = frame.code.fullsource.getstatement(frame.lineno) stmt = frame.code.fullsource.getstatement(frame.lineno)
assert str(stmt).strip().startswith("assert") assert str(stmt).strip().startswith("assert")
@pytest.mark.parametrize("name", ["", None, "my"]) @pytest.mark.parametrize("name", ["", None, "my"])
def test_compilefuncs_and_path_sanity(self, name): def test_compilefuncs_and_path_sanity(self, name: Optional[str]) -> None:
def check(comp, name): def check(comp, name):
co = comp(self.source, name) co = comp(self.source, name)
if not name: if not name:
expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2) expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2) # type: ignore
else: else:
expected = "codegen %r %s:%d>" % (name, mypath, mylineno + 2 + 2) expected = "codegen %r %s:%d>" % (
name,
mypath, # type: ignore
mylineno + 2 + 2, # type: ignore
) # type: ignore
fn = co.co_filename fn = co.co_filename
assert fn.endswith(expected) assert fn.endswith(expected)
@ -332,9 +342,9 @@ class TestSourceParsingAndCompiling:
pytest.raises(SyntaxError, _pytest._code.compile, "lambda a,a: 0", mode="eval") pytest.raises(SyntaxError, _pytest._code.compile, "lambda a,a: 0", mode="eval")
def test_getstartingblock_singleline(): def test_getstartingblock_singleline() -> None:
class A: class A:
def __init__(self, *args): def __init__(self, *args) -> None:
frame = sys._getframe(1) frame = sys._getframe(1)
self.source = _pytest._code.Frame(frame).statement self.source = _pytest._code.Frame(frame).statement
@ -344,22 +354,22 @@ def test_getstartingblock_singleline():
assert len(values) == 1 assert len(values) == 1
def test_getline_finally(): def test_getline_finally() -> None:
def c(): def c() -> None:
pass pass
with pytest.raises(TypeError) as excinfo: with pytest.raises(TypeError) as excinfo:
teardown = None teardown = None
try: try:
c(1) c(1) # type: ignore
finally: finally:
if teardown: if teardown:
teardown() teardown()
source = excinfo.traceback[-1].statement source = excinfo.traceback[-1].statement
assert str(source).strip() == "c(1)" assert str(source).strip() == "c(1) # type: ignore"
def test_getfuncsource_dynamic(): def test_getfuncsource_dynamic() -> None:
source = """ source = """
def f(): def f():
raise ValueError raise ValueError
@ -368,11 +378,13 @@ def test_getfuncsource_dynamic():
""" """
co = _pytest._code.compile(source) co = _pytest._code.compile(source)
exec(co, globals()) exec(co, globals())
assert str(_pytest._code.Source(f)).strip() == "def f():\n raise ValueError" f_source = _pytest._code.Source(f) # type: ignore
assert str(_pytest._code.Source(g)).strip() == "def g(): pass" g_source = _pytest._code.Source(g) # type: ignore
assert str(f_source).strip() == "def f():\n raise ValueError"
assert str(g_source).strip() == "def g(): pass"
def test_getfuncsource_with_multine_string(): def test_getfuncsource_with_multine_string() -> None:
def f(): def f():
c = """while True: c = """while True:
pass pass
@ -387,7 +399,7 @@ def test_getfuncsource_with_multine_string():
assert str(_pytest._code.Source(f)) == expected.rstrip() assert str(_pytest._code.Source(f)) == expected.rstrip()
def test_deindent(): def test_deindent() -> None:
from _pytest._code.source import deindent as deindent from _pytest._code.source import deindent as deindent
assert deindent(["\tfoo", "\tbar"]) == ["foo", "bar"] assert deindent(["\tfoo", "\tbar"]) == ["foo", "bar"]
@ -401,7 +413,7 @@ def test_deindent():
assert lines == ["def f():", " def g():", " pass"] assert lines == ["def f():", " def g():", " pass"]
def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot): def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot) -> None:
# this test fails because the implicit inspect.getsource(A) below # this test fails because the implicit inspect.getsource(A) below
# does not return the "x = 1" last line. # does not return the "x = 1" last line.
source = _pytest._code.Source( source = _pytest._code.Source(
@ -423,7 +435,7 @@ if True:
pass pass
def test_getsource_fallback(): def test_getsource_fallback() -> None:
from _pytest._code.source import getsource from _pytest._code.source import getsource
expected = """def x(): expected = """def x():
@ -432,7 +444,7 @@ def test_getsource_fallback():
assert src == expected assert src == expected
def test_idem_compile_and_getsource(): def test_idem_compile_and_getsource() -> None:
from _pytest._code.source import getsource from _pytest._code.source import getsource
expected = "def x(): pass" expected = "def x(): pass"
@ -441,15 +453,16 @@ def test_idem_compile_and_getsource():
assert src == expected assert src == expected
def test_findsource_fallback(): def test_findsource_fallback() -> None:
from _pytest._code.source import findsource from _pytest._code.source import findsource
src, lineno = findsource(x) src, lineno = findsource(x)
assert src is not None
assert "test_findsource_simple" in str(src) assert "test_findsource_simple" in str(src)
assert src[lineno] == " def x():" assert src[lineno] == " def x():"
def test_findsource(): def test_findsource() -> None:
from _pytest._code.source import findsource from _pytest._code.source import findsource
co = _pytest._code.compile( co = _pytest._code.compile(
@ -460,19 +473,21 @@ def test_findsource():
) )
src, lineno = findsource(co) src, lineno = findsource(co)
assert src is not None
assert "if 1:" in str(src) assert "if 1:" in str(src)
d = {} d = {} # type: Dict[str, Any]
eval(co, d) eval(co, d)
src, lineno = findsource(d["x"]) src, lineno = findsource(d["x"])
assert src is not None
assert "if 1:" in str(src) assert "if 1:" in str(src)
assert src[lineno] == " def x():" assert src[lineno] == " def x():"
def test_getfslineno(): def test_getfslineno() -> None:
from _pytest._code import getfslineno from _pytest._code import getfslineno
def f(x): def f(x) -> None:
pass pass
fspath, lineno = getfslineno(f) fspath, lineno = getfslineno(f)
@ -498,40 +513,40 @@ def test_getfslineno():
assert getfslineno(B)[1] == -1 assert getfslineno(B)[1] == -1
def test_code_of_object_instance_with_call(): def test_code_of_object_instance_with_call() -> None:
class A: class A:
pass pass
pytest.raises(TypeError, lambda: _pytest._code.Source(A())) pytest.raises(TypeError, lambda: _pytest._code.Source(A()))
class WithCall: class WithCall:
def __call__(self): def __call__(self) -> None:
pass pass
code = _pytest._code.Code(WithCall()) code = _pytest._code.Code(WithCall())
assert "pass" in str(code.source()) assert "pass" in str(code.source())
class Hello: class Hello:
def __call__(self): def __call__(self) -> None:
pass pass
pytest.raises(TypeError, lambda: _pytest._code.Code(Hello)) pytest.raises(TypeError, lambda: _pytest._code.Code(Hello))
def getstatement(lineno, source): def getstatement(lineno: int, source) -> Source:
from _pytest._code.source import getstatementrange_ast from _pytest._code.source import getstatementrange_ast
source = _pytest._code.Source(source, deindent=False) src = _pytest._code.Source(source, deindent=False)
ast, start, end = getstatementrange_ast(lineno, source) ast, start, end = getstatementrange_ast(lineno, src)
return source[start:end] return src[start:end]
def test_oneline(): def test_oneline() -> None:
source = getstatement(0, "raise ValueError") source = getstatement(0, "raise ValueError")
assert str(source) == "raise ValueError" assert str(source) == "raise ValueError"
def test_comment_and_no_newline_at_end(): def test_comment_and_no_newline_at_end() -> None:
from _pytest._code.source import getstatementrange_ast from _pytest._code.source import getstatementrange_ast
source = Source( source = Source(
@ -545,12 +560,12 @@ def test_comment_and_no_newline_at_end():
assert end == 2 assert end == 2
def test_oneline_and_comment(): def test_oneline_and_comment() -> None:
source = getstatement(0, "raise ValueError\n#hello") source = getstatement(0, "raise ValueError\n#hello")
assert str(source) == "raise ValueError" assert str(source) == "raise ValueError"
def test_comments(): def test_comments() -> None:
source = '''def test(): source = '''def test():
"comment 1" "comment 1"
x = 1 x = 1
@ -576,7 +591,7 @@ comment 4
assert str(getstatement(line, source)) == '"""\ncomment 4\n"""' assert str(getstatement(line, source)) == '"""\ncomment 4\n"""'
def test_comment_in_statement(): def test_comment_in_statement() -> None:
source = """test(foo=1, source = """test(foo=1,
# comment 1 # comment 1
bar=2) bar=2)
@ -588,17 +603,17 @@ def test_comment_in_statement():
) )
def test_single_line_else(): def test_single_line_else() -> None:
source = getstatement(1, "if False: 2\nelse: 3") source = getstatement(1, "if False: 2\nelse: 3")
assert str(source) == "else: 3" assert str(source) == "else: 3"
def test_single_line_finally(): def test_single_line_finally() -> None:
source = getstatement(1, "try: 1\nfinally: 3") source = getstatement(1, "try: 1\nfinally: 3")
assert str(source) == "finally: 3" assert str(source) == "finally: 3"
def test_issue55(): def test_issue55() -> None:
source = ( source = (
"def round_trip(dinp):\n assert 1 == dinp\n" "def round_trip(dinp):\n assert 1 == dinp\n"
'def test_rt():\n round_trip("""\n""")\n' 'def test_rt():\n round_trip("""\n""")\n'
@ -607,7 +622,7 @@ def test_issue55():
assert str(s) == ' round_trip("""\n""")' assert str(s) == ' round_trip("""\n""")'
def test_multiline(): def test_multiline() -> None:
source = getstatement( source = getstatement(
0, 0,
"""\ """\
@ -621,7 +636,7 @@ x = 3
class TestTry: class TestTry:
def setup_class(self): def setup_class(self) -> None:
self.source = """\ self.source = """\
try: try:
raise ValueError raise ValueError
@ -631,25 +646,25 @@ else:
raise KeyError() raise KeyError()
""" """
def test_body(self): def test_body(self) -> None:
source = getstatement(1, self.source) source = getstatement(1, self.source)
assert str(source) == " raise ValueError" assert str(source) == " raise ValueError"
def test_except_line(self): def test_except_line(self) -> None:
source = getstatement(2, self.source) source = getstatement(2, self.source)
assert str(source) == "except Something:" assert str(source) == "except Something:"
def test_except_body(self): def test_except_body(self) -> None:
source = getstatement(3, self.source) source = getstatement(3, self.source)
assert str(source) == " raise IndexError(1)" assert str(source) == " raise IndexError(1)"
def test_else(self): def test_else(self) -> None:
source = getstatement(5, self.source) source = getstatement(5, self.source)
assert str(source) == " raise KeyError()" assert str(source) == " raise KeyError()"
class TestTryFinally: class TestTryFinally:
def setup_class(self): def setup_class(self) -> None:
self.source = """\ self.source = """\
try: try:
raise ValueError raise ValueError
@ -657,17 +672,17 @@ finally:
raise IndexError(1) raise IndexError(1)
""" """
def test_body(self): def test_body(self) -> None:
source = getstatement(1, self.source) source = getstatement(1, self.source)
assert str(source) == " raise ValueError" assert str(source) == " raise ValueError"
def test_finally(self): def test_finally(self) -> None:
source = getstatement(3, self.source) source = getstatement(3, self.source)
assert str(source) == " raise IndexError(1)" assert str(source) == " raise IndexError(1)"
class TestIf: class TestIf:
def setup_class(self): def setup_class(self) -> None:
self.source = """\ self.source = """\
if 1: if 1:
y = 3 y = 3
@ -677,24 +692,24 @@ else:
y = 7 y = 7
""" """
def test_body(self): def test_body(self) -> None:
source = getstatement(1, self.source) source = getstatement(1, self.source)
assert str(source) == " y = 3" assert str(source) == " y = 3"
def test_elif_clause(self): def test_elif_clause(self) -> None:
source = getstatement(2, self.source) source = getstatement(2, self.source)
assert str(source) == "elif False:" assert str(source) == "elif False:"
def test_elif(self): def test_elif(self) -> None:
source = getstatement(3, self.source) source = getstatement(3, self.source)
assert str(source) == " y = 5" assert str(source) == " y = 5"
def test_else(self): def test_else(self) -> None:
source = getstatement(5, self.source) source = getstatement(5, self.source)
assert str(source) == " y = 7" assert str(source) == " y = 7"
def test_semicolon(): def test_semicolon() -> None:
s = """\ s = """\
hello ; pytest.skip() hello ; pytest.skip()
""" """
@ -702,7 +717,7 @@ hello ; pytest.skip()
assert str(source) == s.strip() assert str(source) == s.strip()
def test_def_online(): def test_def_online() -> None:
s = """\ s = """\
def func(): raise ValueError(42) def func(): raise ValueError(42)
@ -713,7 +728,7 @@ def something():
assert str(source) == "def func(): raise ValueError(42)" assert str(source) == "def func(): raise ValueError(42)"
def XXX_test_expression_multiline(): def XXX_test_expression_multiline() -> None:
source = """\ source = """\
something something
''' '''
@ -722,7 +737,7 @@ something
assert str(result) == "'''\n'''" assert str(result) == "'''\n'''"
def test_getstartingblock_multiline(): def test_getstartingblock_multiline() -> None:
class A: class A:
def __init__(self, *args): def __init__(self, *args):
frame = sys._getframe(1) frame = sys._getframe(1)