Merge pull request #6205 from bluetech/type-annotations-8

Add type annotations to _pytest.compat and _pytest._code.code
This commit is contained in:
Ran Benita 2019-11-17 09:45:32 +02:00 committed by GitHub
commit fa578d7329
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 380 additions and 325 deletions

View File

@ -7,13 +7,17 @@ from inspect import CO_VARKEYWORDS
from io import StringIO
from traceback import format_exception_only
from types import CodeType
from types import FrameType
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import List
from typing import Optional
from typing import Pattern
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TypeVar
@ -27,9 +31,16 @@ import py
import _pytest
from _pytest._io.saferepr import safeformat
from _pytest._io.saferepr import saferepr
from _pytest.compat import overload
if False: # TYPE_CHECKING
from typing import Type
from typing_extensions import Literal
from weakref import ReferenceType # noqa: F401
from _pytest._code import Source
_TracebackStyle = Literal["long", "short", "no", "native"]
class Code:
@ -38,13 +49,12 @@ class Code:
def __init__(self, rawcode) -> None:
if not hasattr(rawcode, "co_filename"):
rawcode = getrawcode(rawcode)
try:
self.filename = rawcode.co_filename
self.firstlineno = rawcode.co_firstlineno - 1
self.name = rawcode.co_name
except AttributeError:
if not isinstance(rawcode, CodeType):
raise TypeError("not a code object: {!r}".format(rawcode))
self.raw = rawcode # type: CodeType
self.filename = rawcode.co_filename
self.firstlineno = rawcode.co_firstlineno - 1
self.name = rawcode.co_name
self.raw = rawcode
def __eq__(self, other):
return self.raw == other.raw
@ -72,7 +82,7 @@ class Code:
return p
@property
def fullsource(self):
def fullsource(self) -> Optional["Source"]:
""" return a _pytest._code.Source object for the full source file of the code
"""
from _pytest._code import source
@ -80,7 +90,7 @@ class Code:
full, _ = source.findsource(self.raw)
return full
def source(self):
def source(self) -> "Source":
""" return a _pytest._code.Source object for the code object's source only
"""
# return source only for that part of code
@ -88,7 +98,7 @@ class Code:
return _pytest._code.Source(self.raw)
def getargs(self, var=False):
def getargs(self, var: bool = False) -> Tuple[str, ...]:
""" return a tuple with the argument names for the code object
if 'var' is set True also return the names of the variable and
@ -107,7 +117,7 @@ class Frame:
"""Wrapper around a Python frame holding f_locals and f_globals
in which expressions can be evaluated."""
def __init__(self, frame):
def __init__(self, frame: FrameType) -> None:
self.lineno = frame.f_lineno - 1
self.f_globals = frame.f_globals
self.f_locals = frame.f_locals
@ -115,7 +125,7 @@ class Frame:
self.code = Code(frame.f_code)
@property
def statement(self):
def statement(self) -> "Source":
""" statement this frame is at """
import _pytest._code
@ -134,7 +144,7 @@ class Frame:
f_locals.update(vars)
return eval(code, self.f_globals, f_locals)
def exec_(self, code, **vars):
def exec_(self, code, **vars) -> None:
""" exec 'code' in the frame
'vars' are optional; additional local variables
@ -143,7 +153,7 @@ class Frame:
f_locals.update(vars)
exec(code, self.f_globals, f_locals)
def repr(self, object):
def repr(self, object: object) -> str:
""" return a 'safe' (non-recursive, one-line) string repr for 'object'
"""
return saferepr(object)
@ -151,7 +161,7 @@ class Frame:
def is_true(self, object):
return object
def getargs(self, var=False):
def getargs(self, var: bool = False):
""" return a list of tuples (name, value) for all arguments
if 'var' is set True also include the variable and keyword
@ -169,35 +179,34 @@ class Frame:
class TracebackEntry:
""" a single entry in a traceback """
_repr_style = None
_repr_style = None # type: Optional[Literal["short", "long"]]
exprinfo = None
def __init__(self, rawentry, excinfo=None):
def __init__(self, rawentry: TracebackType, excinfo=None) -> None:
self._excinfo = excinfo
self._rawentry = rawentry
self.lineno = rawentry.tb_lineno - 1
def set_repr_style(self, mode):
def set_repr_style(self, mode: "Literal['short', 'long']") -> None:
assert mode in ("short", "long")
self._repr_style = mode
@property
def frame(self):
import _pytest._code
return _pytest._code.Frame(self._rawentry.tb_frame)
def frame(self) -> Frame:
return Frame(self._rawentry.tb_frame)
@property
def relline(self):
def relline(self) -> int:
return self.lineno - self.frame.code.firstlineno
def __repr__(self):
def __repr__(self) -> str:
return "<TracebackEntry %s:%d>" % (self.frame.code.path, self.lineno + 1)
@property
def statement(self):
def statement(self) -> "Source":
""" _pytest._code.Source object for the current statement """
source = self.frame.code.fullsource
assert source is not None
return source.getstatement(self.lineno)
@property
@ -206,14 +215,14 @@ class TracebackEntry:
return self.frame.code.path
@property
def locals(self):
def locals(self) -> Dict[str, Any]:
""" locals of underlying frame """
return self.frame.f_locals
def getfirstlinesource(self):
def getfirstlinesource(self) -> int:
return self.frame.code.firstlineno
def getsource(self, astcache=None):
def getsource(self, astcache=None) -> Optional["Source"]:
""" return failing source code. """
# we use the passed in astcache to not reparse asttrees
# within exception info printing
@ -258,7 +267,7 @@ class TracebackEntry:
return tbh(None if self._excinfo is None else self._excinfo())
return tbh
def __str__(self):
def __str__(self) -> str:
try:
fn = str(self.path)
except py.error.Error:
@ -273,33 +282,42 @@ class TracebackEntry:
return " File %r:%d in %s\n %s\n" % (fn, self.lineno + 1, name, line)
@property
def name(self):
def name(self) -> str:
""" co_name of underlying code """
return self.frame.code.raw.co_name
class Traceback(list):
class Traceback(List[TracebackEntry]):
""" Traceback objects encapsulate and offer higher level
access to Traceback entries.
"""
Entry = TracebackEntry
def __init__(self, tb, excinfo=None):
def __init__(
self,
tb: Union[TracebackType, Iterable[TracebackEntry]],
excinfo: Optional["ReferenceType[ExceptionInfo]"] = None,
) -> None:
""" initialize from given python traceback object and ExceptionInfo """
self._excinfo = excinfo
if hasattr(tb, "tb_next"):
if isinstance(tb, TracebackType):
def f(cur):
while cur is not None:
yield self.Entry(cur, excinfo=excinfo)
cur = cur.tb_next
def f(cur: TracebackType) -> Iterable[TracebackEntry]:
cur_ = cur # type: Optional[TracebackType]
while cur_ is not None:
yield TracebackEntry(cur_, excinfo=excinfo)
cur_ = cur_.tb_next
list.__init__(self, f(tb))
super().__init__(f(tb))
else:
list.__init__(self, tb)
super().__init__(tb)
def cut(self, path=None, lineno=None, firstlineno=None, excludepath=None):
def cut(
self,
path=None,
lineno: Optional[int] = None,
firstlineno: Optional[int] = None,
excludepath=None,
) -> "Traceback":
""" return a Traceback instance wrapping part of this Traceback
by providing any combination of path, lineno and firstlineno, the
@ -325,13 +343,25 @@ class Traceback(list):
return Traceback(x._rawentry, self._excinfo)
return self
def __getitem__(self, key):
val = super().__getitem__(key)
if isinstance(key, type(slice(0))):
val = self.__class__(val)
return val
@overload
def __getitem__(self, key: int) -> TracebackEntry:
raise NotImplementedError()
def filter(self, fn=lambda x: not x.ishidden()):
@overload # noqa: F811
def __getitem__(self, key: slice) -> "Traceback": # noqa: F811
raise NotImplementedError()
def __getitem__( # noqa: F811
self, key: Union[int, slice]
) -> Union[TracebackEntry, "Traceback"]:
if isinstance(key, slice):
return self.__class__(super().__getitem__(key))
else:
return super().__getitem__(key)
def filter(
self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden()
) -> "Traceback":
""" return a Traceback instance with certain items removed
fn is a function that gets a single argument, a TracebackEntry
@ -343,7 +373,7 @@ class Traceback(list):
"""
return Traceback(filter(fn, self), self._excinfo)
def getcrashentry(self):
def getcrashentry(self) -> TracebackEntry:
""" return last non-hidden traceback entry that lead
to the exception of a traceback.
"""
@ -353,7 +383,7 @@ class Traceback(list):
return entry
return self[-1]
def recursionindex(self):
def recursionindex(self) -> Optional[int]:
""" return the index of the frame/TracebackEntry where recursion
originates if appropriate, None if no recursion occurred
"""
@ -543,7 +573,7 @@ class ExceptionInfo(Generic[_E]):
def getrepr(
self,
showlocals: bool = False,
style: str = "long",
style: "_TracebackStyle" = "long",
abspath: bool = False,
tbfilter: bool = True,
funcargs: bool = False,
@ -621,16 +651,16 @@ class FormattedExcinfo:
flow_marker = ">"
fail_marker = "E"
showlocals = attr.ib(default=False)
style = attr.ib(default="long")
abspath = attr.ib(default=True)
tbfilter = attr.ib(default=True)
funcargs = attr.ib(default=False)
truncate_locals = attr.ib(default=True)
chain = attr.ib(default=True)
showlocals = attr.ib(type=bool, default=False)
style = attr.ib(type="_TracebackStyle", default="long")
abspath = attr.ib(type=bool, default=True)
tbfilter = attr.ib(type=bool, default=True)
funcargs = attr.ib(type=bool, default=False)
truncate_locals = attr.ib(type=bool, default=True)
chain = attr.ib(type=bool, default=True)
astcache = attr.ib(default=attr.Factory(dict), init=False, repr=False)
def _getindent(self, source):
def _getindent(self, source: "Source") -> int:
# figure out indent for given source
try:
s = str(source.getstatement(len(source) - 1))
@ -645,20 +675,27 @@ class FormattedExcinfo:
return 0
return 4 + (len(s) - len(s.lstrip()))
def _getentrysource(self, entry):
def _getentrysource(self, entry: TracebackEntry) -> Optional["Source"]:
source = entry.getsource(self.astcache)
if source is not None:
source = source.deindent()
return source
def repr_args(self, entry):
def repr_args(self, entry: TracebackEntry) -> Optional["ReprFuncArgs"]:
if self.funcargs:
args = []
for argname, argvalue in entry.frame.getargs(var=True):
args.append((argname, saferepr(argvalue)))
return ReprFuncArgs(args)
return None
def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]:
def get_source(
self,
source: "Source",
line_index: int = -1,
excinfo: Optional[ExceptionInfo] = None,
short: bool = False,
) -> List[str]:
""" return formatted and marked up source lines. """
import _pytest._code
@ -682,19 +719,21 @@ class FormattedExcinfo:
lines.extend(self.get_exconly(excinfo, indent=indent, markall=True))
return lines
def get_exconly(self, excinfo, indent=4, markall=False):
def get_exconly(
self, excinfo: ExceptionInfo, indent: int = 4, markall: bool = False
) -> List[str]:
lines = []
indent = " " * indent
indentstr = " " * indent
# get the real exception information out
exlines = excinfo.exconly(tryshort=True).split("\n")
failindent = self.fail_marker + indent[1:]
failindent = self.fail_marker + indentstr[1:]
for line in exlines:
lines.append(failindent + line)
if not markall:
failindent = indent
failindent = indentstr
return lines
def repr_locals(self, locals):
def repr_locals(self, locals: Dict[str, object]) -> Optional["ReprLocals"]:
if self.showlocals:
lines = []
keys = [loc for loc in locals if loc[0] != "@"]
@ -719,8 +758,11 @@ class FormattedExcinfo:
# # XXX
# pprint.pprint(value, stream=self.excinfowriter)
return ReprLocals(lines)
return None
def repr_traceback_entry(self, entry, excinfo=None):
def repr_traceback_entry(
self, entry: TracebackEntry, excinfo: Optional[ExceptionInfo] = None
) -> "ReprEntry":
import _pytest._code
source = self._getentrysource(entry)
@ -731,9 +773,7 @@ class FormattedExcinfo:
line_index = entry.lineno - entry.getfirstlinesource()
lines = [] # type: List[str]
style = entry._repr_style
if style is None:
style = self.style
style = entry._repr_style if entry._repr_style is not None else self.style
if style in ("short", "long"):
short = style == "short"
reprargs = self.repr_args(entry) if not short else None
@ -763,7 +803,7 @@ class FormattedExcinfo:
path = np
return path
def repr_traceback(self, excinfo):
def repr_traceback(self, excinfo: ExceptionInfo) -> "ReprTraceback":
traceback = excinfo.traceback
if self.tbfilter:
traceback = traceback.filter()
@ -781,7 +821,9 @@ class FormattedExcinfo:
entries.append(reprentry)
return ReprTraceback(entries, extraline, style=self.style)
def _truncate_recursive_traceback(self, traceback):
def _truncate_recursive_traceback(
self, traceback: Traceback
) -> Tuple[Traceback, Optional[str]]:
"""
Truncate the given recursive traceback trying to find the starting point
of the recursion.
@ -808,7 +850,9 @@ class FormattedExcinfo:
max_frames=max_frames,
total=len(traceback),
) # type: Optional[str]
traceback = traceback[:max_frames] + traceback[-max_frames:]
# Type ignored because adding two instaces of a List subtype
# currently incorrectly has type List instead of the subtype.
traceback = traceback[:max_frames] + traceback[-max_frames:] # type: ignore
else:
if recursionindex is not None:
extraline = "!!! Recursion detected (same locals & position)"
@ -865,7 +909,7 @@ class FormattedExcinfo:
class TerminalRepr:
def __str__(self):
def __str__(self) -> str:
# FYI this is called from pytest-xdist's serialization of exception
# information.
io = StringIO()
@ -873,7 +917,7 @@ class TerminalRepr:
self.toterminal(tw)
return io.getvalue().strip()
def __repr__(self):
def __repr__(self) -> str:
return "<{} instance at {:0x}>".format(self.__class__, id(self))
def toterminal(self, tw) -> None:
@ -884,7 +928,7 @@ class ExceptionRepr(TerminalRepr):
def __init__(self) -> None:
self.sections = [] # type: List[Tuple[str, str, str]]
def addsection(self, name, content, sep="-"):
def addsection(self, name: str, content: str, sep: str = "-") -> None:
self.sections.append((name, content, sep))
def toterminal(self, tw) -> None:
@ -894,7 +938,12 @@ class ExceptionRepr(TerminalRepr):
class ExceptionChainRepr(ExceptionRepr):
def __init__(self, chain):
def __init__(
self,
chain: Sequence[
Tuple["ReprTraceback", Optional["ReprFileLocation"], Optional[str]]
],
) -> None:
super().__init__()
self.chain = chain
# reprcrash and reprtraceback of the outermost (the newest) exception
@ -912,7 +961,9 @@ class ExceptionChainRepr(ExceptionRepr):
class ReprExceptionInfo(ExceptionRepr):
def __init__(self, reprtraceback, reprcrash):
def __init__(
self, reprtraceback: "ReprTraceback", reprcrash: "ReprFileLocation"
) -> None:
super().__init__()
self.reprtraceback = reprtraceback
self.reprcrash = reprcrash
@ -925,7 +976,12 @@ class ReprExceptionInfo(ExceptionRepr):
class ReprTraceback(TerminalRepr):
entrysep = "_ "
def __init__(self, reprentries, extraline, style):
def __init__(
self,
reprentries: Sequence[Union["ReprEntry", "ReprEntryNative"]],
extraline: Optional[str],
style: "_TracebackStyle",
) -> None:
self.reprentries = reprentries
self.extraline = extraline
self.style = style
@ -950,16 +1006,16 @@ class ReprTraceback(TerminalRepr):
class ReprTracebackNative(ReprTraceback):
def __init__(self, tblines):
def __init__(self, tblines: Sequence[str]) -> None:
self.style = "native"
self.reprentries = [ReprEntryNative(tblines)]
self.extraline = None
class ReprEntryNative(TerminalRepr):
style = "native"
style = "native" # type: _TracebackStyle
def __init__(self, tblines):
def __init__(self, tblines: Sequence[str]) -> None:
self.lines = tblines
def toterminal(self, tw) -> None:
@ -967,7 +1023,14 @@ class ReprEntryNative(TerminalRepr):
class ReprEntry(TerminalRepr):
def __init__(self, lines, reprfuncargs, reprlocals, filelocrepr, style):
def __init__(
self,
lines: Sequence[str],
reprfuncargs: Optional["ReprFuncArgs"],
reprlocals: Optional["ReprLocals"],
filelocrepr: Optional["ReprFileLocation"],
style: "_TracebackStyle",
) -> None:
self.lines = lines
self.reprfuncargs = reprfuncargs
self.reprlocals = reprlocals
@ -976,6 +1039,7 @@ class ReprEntry(TerminalRepr):
def toterminal(self, tw) -> None:
if self.style == "short":
assert self.reprfileloc is not None
self.reprfileloc.toterminal(tw)
for line in self.lines:
red = line.startswith("E ")
@ -994,14 +1058,14 @@ class ReprEntry(TerminalRepr):
tw.line("")
self.reprfileloc.toterminal(tw)
def __str__(self):
def __str__(self) -> str:
return "{}\n{}\n{}".format(
"\n".join(self.lines), self.reprlocals, self.reprfileloc
)
class ReprFileLocation(TerminalRepr):
def __init__(self, path, lineno, message):
def __init__(self, path, lineno: int, message: str) -> None:
self.path = str(path)
self.lineno = lineno
self.message = message
@ -1018,7 +1082,7 @@ class ReprFileLocation(TerminalRepr):
class ReprLocals(TerminalRepr):
def __init__(self, lines):
def __init__(self, lines: Sequence[str]) -> None:
self.lines = lines
def toterminal(self, tw) -> None:
@ -1027,7 +1091,7 @@ class ReprLocals(TerminalRepr):
class ReprFuncArgs(TerminalRepr):
def __init__(self, args):
def __init__(self, args: Sequence[Tuple[str, object]]) -> None:
self.args = args
def toterminal(self, tw) -> None:
@ -1049,13 +1113,11 @@ class ReprFuncArgs(TerminalRepr):
tw.line("")
def getrawcode(obj, trycall=True):
def getrawcode(obj, trycall: bool = True):
""" return code object for given function. """
try:
return obj.__code__
except AttributeError:
obj = getattr(obj, "im_func", obj)
obj = getattr(obj, "func_code", obj)
obj = getattr(obj, "f_code", obj)
obj = getattr(obj, "__code__", obj)
if trycall and not hasattr(obj, "co_firstlineno"):
@ -1079,7 +1141,7 @@ _PYTEST_DIR = py.path.local(_pytest.__file__).dirpath()
_PY_DIR = py.path.local(py.__file__).dirpath()
def filter_traceback(entry):
def filter_traceback(entry: TracebackEntry) -> bool:
"""Return True if a TracebackEntry instance should be removed from tracebacks:
* dynamically generated code (no code to show up for it);
* internal traceback from pytest or its internal libraries, py and pluggy.

View File

@ -8,6 +8,7 @@ import warnings
from ast import PyCF_ONLY_AST as _AST_FLAG
from bisect import bisect_right
from types import FrameType
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
@ -60,7 +61,7 @@ class Source:
raise NotImplementedError()
@overload # noqa: F811
def __getitem__(self, key: slice) -> "Source":
def __getitem__(self, key: slice) -> "Source": # noqa: F811
raise NotImplementedError()
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811
@ -73,6 +74,9 @@ class Source:
newsource.lines = self.lines[key.start : key.stop]
return newsource
def __iter__(self) -> Iterator[str]:
return iter(self.lines)
def __len__(self) -> int:
return len(self.lines)

View File

@ -1074,13 +1074,14 @@ def try_makedirs(cache_dir) -> bool:
def get_cache_dir(file_path: Path) -> Path:
"""Returns the cache directory to write .pyc files for the given .py file path"""
if sys.version_info >= (3, 8) and sys.pycache_prefix:
# Type ignored until added in next mypy release.
if sys.version_info >= (3, 8) and sys.pycache_prefix: # type: ignore
# given:
# prefix = '/tmp/pycs'
# path = '/home/user/proj/test_app.py'
# we want:
# '/tmp/pycs/home/user/proj'
return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1]) # type: ignore
else:
# classic pycache directory
return file_path.parent / "__pycache__"

View File

@ -10,11 +10,14 @@ import sys
from contextlib import contextmanager
from inspect import Parameter
from inspect import signature
from typing import Any
from typing import Callable
from typing import Generic
from typing import Optional
from typing import overload
from typing import Tuple
from typing import TypeVar
from typing import Union
import attr
import py
@ -40,12 +43,13 @@ MODULE_NOT_FOUND_ERROR = (
if sys.version_info >= (3, 8):
from importlib import metadata as importlib_metadata # noqa: F401
# Type ignored until next mypy release.
from importlib import metadata as importlib_metadata # type: ignore
else:
import importlib_metadata # noqa: F401
def _format_args(func):
def _format_args(func: Callable[..., Any]) -> str:
return str(signature(func))
@ -66,12 +70,12 @@ else:
fspath = os.fspath
def is_generator(func):
def is_generator(func: object) -> bool:
genfunc = inspect.isgeneratorfunction(func)
return genfunc and not iscoroutinefunction(func)
def iscoroutinefunction(func):
def iscoroutinefunction(func: object) -> bool:
"""
Return True if func is a coroutine function (a function defined with async
def syntax, and doesn't contain yield), or a function decorated with
@ -84,7 +88,7 @@ def iscoroutinefunction(func):
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
def getlocation(function, curdir=None):
def getlocation(function, curdir=None) -> str:
function = get_real_func(function)
fn = py.path.local(inspect.getfile(function))
lineno = function.__code__.co_firstlineno
@ -93,7 +97,7 @@ def getlocation(function, curdir=None):
return "%s:%d" % (fn, lineno + 1)
def num_mock_patch_args(function):
def num_mock_patch_args(function) -> int:
""" return number of arguments used up by mock arguments (if any) """
patchings = getattr(function, "patchings", None)
if not patchings:
@ -112,7 +116,13 @@ def num_mock_patch_args(function):
)
def getfuncargnames(function, *, name: str = "", is_method=False, cls=None):
def getfuncargnames(
function: Callable[..., Any],
*,
name: str = "",
is_method: bool = False,
cls: Optional[type] = None
) -> Tuple[str, ...]:
"""Returns the names of a function's mandatory arguments.
This should return the names of all function arguments that:
@ -180,7 +190,7 @@ else:
from contextlib import nullcontext # noqa
def get_default_arg_names(function):
def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]:
# Note: this code intentionally mirrors the code at the beginning of getfuncargnames,
# to get the arguments which were excluded from its result because they had default values
return tuple(
@ -199,18 +209,18 @@ _non_printable_ascii_translate_table.update(
)
def _translate_non_printable(s):
def _translate_non_printable(s: str) -> str:
return s.translate(_non_printable_ascii_translate_table)
STRING_TYPES = bytes, str
def _bytes_to_ascii(val):
def _bytes_to_ascii(val: bytes) -> str:
return val.decode("ascii", "backslashreplace")
def ascii_escaped(val):
def ascii_escaped(val: Union[bytes, str]):
"""If val is pure ascii, returns it as a str(). Otherwise, escapes
bytes objects into a sequence of escaped bytes:
@ -307,7 +317,7 @@ def getimfunc(func):
return func
def safe_getattr(object, name, default):
def safe_getattr(object: Any, name: str, default: Any) -> Any:
""" Like getattr but return default upon any Exception or any OutcomeException.
Attribute access can potentially fail for 'evil' Python objects.
@ -321,7 +331,7 @@ def safe_getattr(object, name, default):
return default
def safe_isclass(obj):
def safe_isclass(obj: object) -> bool:
"""Ignore any exception via isinstance on Python 3."""
try:
return inspect.isclass(obj)
@ -342,39 +352,26 @@ COLLECT_FAKEMODULE_ATTRIBUTES = (
)
def _setup_collect_fakemodule():
def _setup_collect_fakemodule() -> None:
from types import ModuleType
import pytest
pytest.collect = ModuleType("pytest.collect")
pytest.collect.__all__ = [] # used for setns
# Types ignored because the module is created dynamically.
pytest.collect = ModuleType("pytest.collect") # type: ignore
pytest.collect.__all__ = [] # type: ignore # used for setns
for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES:
setattr(pytest.collect, attr_name, getattr(pytest, attr_name))
setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) # type: ignore
class CaptureIO(io.TextIOWrapper):
def __init__(self):
def __init__(self) -> None:
super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
def getvalue(self):
def getvalue(self) -> str:
assert isinstance(self.buffer, io.BytesIO)
return self.buffer.getvalue().decode("UTF-8")
class FuncargnamesCompatAttr:
""" helper class so that Metafunc, Function and FixtureRequest
don't need to each define the "funcargnames" compatibility attribute.
"""
@property
def funcargnames(self):
""" alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
import warnings
from _pytest.deprecated import FUNCARGNAMES
warnings.warn(FUNCARGNAMES, stacklevel=2)
return self.fixturenames
if sys.version_info < (3, 5, 2): # pragma: no cover
def overload(f): # noqa: F811
@ -407,7 +404,9 @@ else:
raise NotImplementedError()
@overload # noqa: F811
def __get__(self, instance: _S, owner: Optional["Type[_S]"] = ...) -> _T:
def __get__( # noqa: F811
self, instance: _S, owner: Optional["Type[_S]"] = ...
) -> _T:
raise NotImplementedError()
def __get__(self, instance, owner=None): # noqa: F811

View File

@ -18,7 +18,6 @@ from _pytest._code.code import FormattedExcinfo
from _pytest._code.code import TerminalRepr
from _pytest.compat import _format_args
from _pytest.compat import _PytestWrapper
from _pytest.compat import FuncargnamesCompatAttr
from _pytest.compat import get_real_func
from _pytest.compat import get_real_method
from _pytest.compat import getfslineno
@ -29,6 +28,7 @@ from _pytest.compat import is_generator
from _pytest.compat import NOTSET
from _pytest.compat import safe_getattr
from _pytest.deprecated import FIXTURE_POSITIONAL_ARGUMENTS
from _pytest.deprecated import FUNCARGNAMES
from _pytest.outcomes import fail
from _pytest.outcomes import TEST_OUTCOME
@ -336,7 +336,7 @@ class FuncFixtureInfo:
self.names_closure[:] = sorted(closure, key=self.names_closure.index)
class FixtureRequest(FuncargnamesCompatAttr):
class FixtureRequest:
""" A request for a fixture from a test or fixture function.
A request object gives access to the requesting test context
@ -363,6 +363,12 @@ class FixtureRequest(FuncargnamesCompatAttr):
result.extend(set(self._fixture_defs).difference(result))
return result
@property
def funcargnames(self):
""" alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
warnings.warn(FUNCARGNAMES, stacklevel=2)
return self.fixturenames
@property
def node(self):
""" underlying collection node (depends on current request scope)"""

View File

@ -31,6 +31,7 @@ from _pytest.compat import safe_getattr
from _pytest.compat import safe_isclass
from _pytest.compat import STRING_TYPES
from _pytest.config import hookimpl
from _pytest.deprecated import FUNCARGNAMES
from _pytest.main import FSHookProxy
from _pytest.mark import MARK_GEN
from _pytest.mark.structures import get_unpacked_marks
@ -882,7 +883,7 @@ class CallSpec2:
self.marks.extend(normalize_mark_list(marks))
class Metafunc(fixtures.FuncargnamesCompatAttr):
class Metafunc:
"""
Metafunc objects are passed to the :func:`pytest_generate_tests <_pytest.hookspec.pytest_generate_tests>` hook.
They help to inspect a test function and to generate tests according to
@ -916,6 +917,12 @@ class Metafunc(fixtures.FuncargnamesCompatAttr):
self._ids = set()
self._arg2fixturedefs = fixtureinfo.name2fixturedefs
@property
def funcargnames(self):
""" alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
warnings.warn(FUNCARGNAMES, stacklevel=2)
return self.fixturenames
def parametrize(self, argnames, argvalues, indirect=False, ids=None, scope=None):
""" Add new invocations to the underlying test function using the list
of argvalues for the given argnames. Parametrization is performed
@ -1333,7 +1340,7 @@ def write_docstring(tw, doc, indent=" "):
tw.write(indent + line + "\n")
class Function(FunctionMixin, nodes.Item, fixtures.FuncargnamesCompatAttr):
class Function(FunctionMixin, nodes.Item):
""" a Function Item is responsible for setting up and executing a
Python test function.
"""
@ -1420,6 +1427,12 @@ class Function(FunctionMixin, nodes.Item, fixtures.FuncargnamesCompatAttr):
"(compatonly) for code expecting pytest-2.2 style request objects"
return self
@property
def funcargnames(self):
""" alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
warnings.warn(FUNCARGNAMES, stacklevel=2)
return self.fixturenames
def runtest(self):
""" execute the underlying test function. """
self.ihook.pytest_pyfunc_call(pyfuncitem=self)

View File

@ -552,7 +552,7 @@ def raises(
@overload # noqa: F811
def raises(
def raises( # noqa: F811
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
func: Callable,
*args: Any,

View File

@ -60,18 +60,18 @@ def warns(
*,
match: "Optional[Union[str, Pattern]]" = ...
) -> "WarningsChecker":
... # pragma: no cover
raise NotImplementedError()
@overload # noqa: F811
def warns(
def warns( # noqa: F811
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
func: Callable,
*args: Any,
match: Optional[Union[str, "Pattern"]] = ...,
**kwargs: Any
) -> Union[Any]:
... # pragma: no cover
raise NotImplementedError()
def warns( # noqa: F811

View File

@ -1,18 +1,19 @@
import sys
from types import FrameType
from unittest import mock
import _pytest._code
import pytest
def test_ne():
def test_ne() -> None:
code1 = _pytest._code.Code(compile('foo = "bar"', "", "exec"))
assert code1 == code1
code2 = _pytest._code.Code(compile('foo = "baz"', "", "exec"))
assert code2 != code1
def test_code_gives_back_name_for_not_existing_file():
def test_code_gives_back_name_for_not_existing_file() -> None:
name = "abc-123"
co_code = compile("pass\n", name, "exec")
assert co_code.co_filename == name
@ -21,68 +22,67 @@ def test_code_gives_back_name_for_not_existing_file():
assert code.fullsource is None
def test_code_with_class():
def test_code_with_class() -> None:
class A:
pass
pytest.raises(TypeError, _pytest._code.Code, A)
def x():
def x() -> None:
raise NotImplementedError()
def test_code_fullsource():
def test_code_fullsource() -> None:
code = _pytest._code.Code(x)
full = code.fullsource
assert "test_code_fullsource()" in str(full)
def test_code_source():
def test_code_source() -> None:
code = _pytest._code.Code(x)
src = code.source()
expected = """def x():
expected = """def x() -> None:
raise NotImplementedError()"""
assert str(src) == expected
def test_frame_getsourcelineno_myself():
def func():
def test_frame_getsourcelineno_myself() -> None:
def func() -> FrameType:
return sys._getframe(0)
f = func()
f = _pytest._code.Frame(f)
f = _pytest._code.Frame(func())
source, lineno = f.code.fullsource, f.lineno
assert source is not None
assert source[lineno].startswith(" return sys._getframe(0)")
def test_getstatement_empty_fullsource():
def func():
def test_getstatement_empty_fullsource() -> None:
def func() -> FrameType:
return sys._getframe(0)
f = func()
f = _pytest._code.Frame(f)
f = _pytest._code.Frame(func())
with mock.patch.object(f.code.__class__, "fullsource", None):
assert f.statement == ""
def test_code_from_func():
def test_code_from_func() -> None:
co = _pytest._code.Code(test_frame_getsourcelineno_myself)
assert co.firstlineno
assert co.path
def test_unicode_handling():
def test_unicode_handling() -> None:
value = "ąć".encode()
def f():
def f() -> None:
raise Exception(value)
excinfo = pytest.raises(Exception, f)
str(excinfo)
def test_code_getargs():
def test_code_getargs() -> None:
def f1(x):
raise NotImplementedError()
@ -108,26 +108,26 @@ def test_code_getargs():
assert c4.getargs(var=True) == ("x", "y", "z")
def test_frame_getargs():
def f1(x):
def test_frame_getargs() -> None:
def f1(x) -> FrameType:
return sys._getframe(0)
fr1 = _pytest._code.Frame(f1("a"))
assert fr1.getargs(var=True) == [("x", "a")]
def f2(x, *y):
def f2(x, *y) -> FrameType:
return sys._getframe(0)
fr2 = _pytest._code.Frame(f2("a", "b", "c"))
assert fr2.getargs(var=True) == [("x", "a"), ("y", ("b", "c"))]
def f3(x, **z):
def f3(x, **z) -> FrameType:
return sys._getframe(0)
fr3 = _pytest._code.Frame(f3("a", b="c"))
assert fr3.getargs(var=True) == [("x", "a"), ("z", {"b": "c"})]
def f4(x, *y, **z):
def f4(x, *y, **z) -> FrameType:
return sys._getframe(0)
fr4 = _pytest._code.Frame(f4("a", "b", c="d"))
@ -135,7 +135,7 @@ def test_frame_getargs():
class TestExceptionInfo:
def test_bad_getsource(self):
def test_bad_getsource(self) -> None:
try:
if False:
pass
@ -145,13 +145,13 @@ class TestExceptionInfo:
exci = _pytest._code.ExceptionInfo.from_current()
assert exci.getrepr()
def test_from_current_with_missing(self):
def test_from_current_with_missing(self) -> None:
with pytest.raises(AssertionError, match="no current exception"):
_pytest._code.ExceptionInfo.from_current()
class TestTracebackEntry:
def test_getsource(self):
def test_getsource(self) -> None:
try:
if False:
pass
@ -161,12 +161,13 @@ class TestTracebackEntry:
exci = _pytest._code.ExceptionInfo.from_current()
entry = exci.traceback[0]
source = entry.getsource()
assert source is not None
assert len(source) == 6
assert "assert False" in source[5]
class TestReprFuncArgs:
def test_not_raise_exception_with_mixed_encoding(self, tw_mock):
def test_not_raise_exception_with_mixed_encoding(self, tw_mock) -> None:
from _pytest._code.code import ReprFuncArgs
args = [("unicode_string", "São Paulo"), ("utf8_string", b"S\xc3\xa3o Paulo")]

View File

@ -3,6 +3,7 @@ import os
import queue
import sys
import textwrap
from typing import Union
import py
@ -59,9 +60,9 @@ def test_excinfo_getstatement():
except ValueError:
excinfo = _pytest._code.ExceptionInfo.from_current()
linenumbers = [
_pytest._code.getrawcode(f).co_firstlineno - 1 + 4,
_pytest._code.getrawcode(f).co_firstlineno - 1 + 1,
_pytest._code.getrawcode(g).co_firstlineno - 1 + 1,
f.__code__.co_firstlineno - 1 + 4,
f.__code__.co_firstlineno - 1 + 1,
g.__code__.co_firstlineno - 1 + 1,
]
values = list(excinfo.traceback)
foundlinenumbers = [x.lineno for x in values]
@ -224,23 +225,25 @@ class TestTraceback_f_g_h:
repr = excinfo.getrepr()
assert "RuntimeError: hello" in str(repr.reprcrash)
def test_traceback_no_recursion_index(self):
def do_stuff():
def test_traceback_no_recursion_index(self) -> None:
def do_stuff() -> None:
raise RuntimeError
def reraise_me():
def reraise_me() -> None:
import sys
exc, val, tb = sys.exc_info()
assert val is not None
raise val.with_traceback(tb)
def f(n):
def f(n: int) -> None:
try:
do_stuff()
except: # noqa
reraise_me()
excinfo = pytest.raises(RuntimeError, f, 8)
assert excinfo is not None
traceback = excinfo.traceback
recindex = traceback.recursionindex()
assert recindex is None
@ -502,65 +505,18 @@ raise ValueError()
assert repr.reprtraceback.reprentries[1].lines[0] == "> ???"
assert repr.chain[0][0].reprentries[1].lines[0] == "> ???"
def test_repr_source_failing_fullsource(self):
def test_repr_source_failing_fullsource(self, monkeypatch) -> None:
pr = FormattedExcinfo()
class FakeCode:
class raw:
co_filename = "?"
try:
1 / 0
except ZeroDivisionError:
excinfo = ExceptionInfo.from_current()
path = "?"
firstlineno = 5
with monkeypatch.context() as m:
m.setattr(_pytest._code.Code, "fullsource", property(lambda self: None))
repr = pr.repr_excinfo(excinfo)
def fullsource(self):
return None
fullsource = property(fullsource)
class FakeFrame:
code = FakeCode()
f_locals = {}
f_globals = {}
class FakeTracebackEntry(_pytest._code.Traceback.Entry):
def __init__(self, tb, excinfo=None):
self.lineno = 5 + 3
@property
def frame(self):
return FakeFrame()
class Traceback(_pytest._code.Traceback):
Entry = FakeTracebackEntry
class FakeExcinfo(_pytest._code.ExceptionInfo):
typename = "Foo"
value = Exception()
def __init__(self):
pass
def exconly(self, tryshort):
return "EXC"
def errisinstance(self, cls):
return False
excinfo = FakeExcinfo()
class FakeRawTB:
tb_next = None
tb = FakeRawTB()
excinfo.traceback = Traceback(tb)
fail = IOError()
repr = pr.repr_excinfo(excinfo)
assert repr.reprtraceback.reprentries[0].lines[0] == "> ???"
assert repr.chain[0][0].reprentries[0].lines[0] == "> ???"
fail = py.error.ENOENT # noqa
repr = pr.repr_excinfo(excinfo)
assert repr.reprtraceback.reprentries[0].lines[0] == "> ???"
assert repr.chain[0][0].reprentries[0].lines[0] == "> ???"
@ -643,7 +599,6 @@ raise ValueError()
assert lines[3] == "E world"
assert not lines[4:]
loc = repr_entry.reprlocals is not None
loc = repr_entry.reprfileloc
assert loc.path == mod.__file__
assert loc.lineno == 3
@ -1333,9 +1288,10 @@ raise ValueError()
@pytest.mark.parametrize("style", ["short", "long"])
@pytest.mark.parametrize("encoding", [None, "utf8", "utf16"])
def test_repr_traceback_with_unicode(style, encoding):
msg = ""
if encoding is not None:
msg = msg.encode(encoding)
if encoding is None:
msg = "" # type: Union[str, bytes]
else:
msg = "".encode(encoding)
try:
raise RuntimeError(msg)
except RuntimeError:

View File

@ -4,13 +4,16 @@
import ast
import inspect
import sys
from typing import Any
from typing import Dict
from typing import Optional
import _pytest._code
import pytest
from _pytest._code import Source
def test_source_str_function():
def test_source_str_function() -> None:
x = Source("3")
assert str(x) == "3"
@ -25,7 +28,7 @@ def test_source_str_function():
assert str(x) == "\n3"
def test_unicode():
def test_unicode() -> None:
x = Source("4")
assert str(x) == "4"
co = _pytest._code.compile('"å"', mode="eval")
@ -33,12 +36,12 @@ def test_unicode():
assert isinstance(val, str)
def test_source_from_function():
def test_source_from_function() -> None:
source = _pytest._code.Source(test_source_str_function)
assert str(source).startswith("def test_source_str_function():")
assert str(source).startswith("def test_source_str_function() -> None:")
def test_source_from_method():
def test_source_from_method() -> None:
class TestClass:
def test_method(self):
pass
@ -47,13 +50,13 @@ def test_source_from_method():
assert source.lines == ["def test_method(self):", " pass"]
def test_source_from_lines():
def test_source_from_lines() -> None:
lines = ["a \n", "b\n", "c"]
source = _pytest._code.Source(lines)
assert source.lines == ["a ", "b", "c"]
def test_source_from_inner_function():
def test_source_from_inner_function() -> None:
def f():
pass
@ -63,7 +66,7 @@ def test_source_from_inner_function():
assert str(source).startswith("def f():")
def test_source_putaround_simple():
def test_source_putaround_simple() -> None:
source = Source("raise ValueError")
source = source.putaround(
"try:",
@ -85,7 +88,7 @@ else:
)
def test_source_putaround():
def test_source_putaround() -> None:
source = Source()
source = source.putaround(
"""
@ -96,28 +99,29 @@ def test_source_putaround():
assert str(source).strip() == "if 1:\n x=1"
def test_source_strips():
def test_source_strips() -> None:
source = Source("")
assert source == Source()
assert str(source) == ""
assert source.strip() == source
def test_source_strip_multiline():
def test_source_strip_multiline() -> None:
source = Source()
source.lines = ["", " hello", " "]
source2 = source.strip()
assert source2.lines == [" hello"]
def test_syntaxerror_rerepresentation():
def test_syntaxerror_rerepresentation() -> None:
ex = pytest.raises(SyntaxError, _pytest._code.compile, "xyz xyz")
assert ex is not None
assert ex.value.lineno == 1
assert ex.value.offset in {5, 7} # cpython: 7, pypy3.6 7.1.1: 5
assert ex.value.text.strip(), "x x"
assert ex.value.text == "xyz xyz\n"
def test_isparseable():
def test_isparseable() -> None:
assert Source("hello").isparseable()
assert Source("if 1:\n pass").isparseable()
assert Source(" \nif 1:\n pass").isparseable()
@ -127,7 +131,7 @@ def test_isparseable():
class TestAccesses:
def setup_class(self):
def setup_class(self) -> None:
self.source = Source(
"""\
def f(x):
@ -137,26 +141,26 @@ class TestAccesses:
"""
)
def test_getrange(self):
def test_getrange(self) -> None:
x = self.source[0:2]
assert x.isparseable()
assert len(x.lines) == 2
assert str(x) == "def f(x):\n pass"
def test_getline(self):
def test_getline(self) -> None:
x = self.source[0]
assert x == "def f(x):"
def test_len(self):
def test_len(self) -> None:
assert len(self.source) == 4
def test_iter(self):
def test_iter(self) -> None:
values = [x for x in self.source]
assert len(values) == 4
class TestSourceParsingAndCompiling:
def setup_class(self):
def setup_class(self) -> None:
self.source = Source(
"""\
def f(x):
@ -166,19 +170,19 @@ class TestSourceParsingAndCompiling:
"""
).strip()
def test_compile(self):
def test_compile(self) -> None:
co = _pytest._code.compile("x=3")
d = {}
d = {} # type: Dict[str, Any]
exec(co, d)
assert d["x"] == 3
def test_compile_and_getsource_simple(self):
def test_compile_and_getsource_simple(self) -> None:
co = _pytest._code.compile("x=3")
exec(co)
source = _pytest._code.Source(co)
assert str(source) == "x=3"
def test_compile_and_getsource_through_same_function(self):
def test_compile_and_getsource_through_same_function(self) -> None:
def gensource(source):
return _pytest._code.compile(source)
@ -199,7 +203,7 @@ class TestSourceParsingAndCompiling:
source2 = inspect.getsource(co2)
assert "ValueError" in source2
def test_getstatement(self):
def test_getstatement(self) -> None:
# print str(self.source)
ass = str(self.source[1:])
for i in range(1, 4):
@ -208,7 +212,7 @@ class TestSourceParsingAndCompiling:
# x = s.deindent()
assert str(s) == ass
def test_getstatementrange_triple_quoted(self):
def test_getstatementrange_triple_quoted(self) -> None:
# print str(self.source)
source = Source(
"""hello('''
@ -219,7 +223,7 @@ class TestSourceParsingAndCompiling:
s = source.getstatement(1)
assert s == str(source)
def test_getstatementrange_within_constructs(self):
def test_getstatementrange_within_constructs(self) -> None:
source = Source(
"""\
try:
@ -241,7 +245,7 @@ class TestSourceParsingAndCompiling:
# assert source.getstatementrange(5) == (0, 7)
assert source.getstatementrange(6) == (6, 7)
def test_getstatementrange_bug(self):
def test_getstatementrange_bug(self) -> None:
source = Source(
"""\
try:
@ -255,7 +259,7 @@ class TestSourceParsingAndCompiling:
assert len(source) == 6
assert source.getstatementrange(2) == (1, 4)
def test_getstatementrange_bug2(self):
def test_getstatementrange_bug2(self) -> None:
source = Source(
"""\
assert (
@ -272,7 +276,7 @@ class TestSourceParsingAndCompiling:
assert len(source) == 9
assert source.getstatementrange(5) == (0, 9)
def test_getstatementrange_ast_issue58(self):
def test_getstatementrange_ast_issue58(self) -> None:
source = Source(
"""\
@ -286,38 +290,44 @@ class TestSourceParsingAndCompiling:
assert getstatement(2, source).lines == source.lines[2:3]
assert getstatement(3, source).lines == source.lines[3:4]
def test_getstatementrange_out_of_bounds_py3(self):
def test_getstatementrange_out_of_bounds_py3(self) -> None:
source = Source("if xxx:\n from .collections import something")
r = source.getstatementrange(1)
assert r == (1, 2)
def test_getstatementrange_with_syntaxerror_issue7(self):
def test_getstatementrange_with_syntaxerror_issue7(self) -> None:
source = Source(":")
pytest.raises(SyntaxError, lambda: source.getstatementrange(0))
def test_compile_to_ast(self):
def test_compile_to_ast(self) -> None:
source = Source("x = 4")
mod = source.compile(flag=ast.PyCF_ONLY_AST)
assert isinstance(mod, ast.Module)
compile(mod, "<filename>", "exec")
def test_compile_and_getsource(self):
def test_compile_and_getsource(self) -> None:
co = self.source.compile()
exec(co, globals())
f(7)
excinfo = pytest.raises(AssertionError, f, 6)
f(7) # type: ignore
excinfo = pytest.raises(AssertionError, f, 6) # type: ignore
assert excinfo is not None
frame = excinfo.traceback[-1].frame
assert isinstance(frame.code.fullsource, Source)
stmt = frame.code.fullsource.getstatement(frame.lineno)
assert str(stmt).strip().startswith("assert")
@pytest.mark.parametrize("name", ["", None, "my"])
def test_compilefuncs_and_path_sanity(self, name):
def test_compilefuncs_and_path_sanity(self, name: Optional[str]) -> None:
def check(comp, name):
co = comp(self.source, name)
if not name:
expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2)
expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2) # type: ignore
else:
expected = "codegen %r %s:%d>" % (name, mypath, mylineno + 2 + 2)
expected = "codegen %r %s:%d>" % (
name,
mypath, # type: ignore
mylineno + 2 + 2, # type: ignore
) # type: ignore
fn = co.co_filename
assert fn.endswith(expected)
@ -332,9 +342,9 @@ class TestSourceParsingAndCompiling:
pytest.raises(SyntaxError, _pytest._code.compile, "lambda a,a: 0", mode="eval")
def test_getstartingblock_singleline():
def test_getstartingblock_singleline() -> None:
class A:
def __init__(self, *args):
def __init__(self, *args) -> None:
frame = sys._getframe(1)
self.source = _pytest._code.Frame(frame).statement
@ -344,22 +354,22 @@ def test_getstartingblock_singleline():
assert len(values) == 1
def test_getline_finally():
def c():
def test_getline_finally() -> None:
def c() -> None:
pass
with pytest.raises(TypeError) as excinfo:
teardown = None
try:
c(1)
c(1) # type: ignore
finally:
if teardown:
teardown()
source = excinfo.traceback[-1].statement
assert str(source).strip() == "c(1)"
assert str(source).strip() == "c(1) # type: ignore"
def test_getfuncsource_dynamic():
def test_getfuncsource_dynamic() -> None:
source = """
def f():
raise ValueError
@ -368,11 +378,13 @@ def test_getfuncsource_dynamic():
"""
co = _pytest._code.compile(source)
exec(co, globals())
assert str(_pytest._code.Source(f)).strip() == "def f():\n raise ValueError"
assert str(_pytest._code.Source(g)).strip() == "def g(): pass"
f_source = _pytest._code.Source(f) # type: ignore
g_source = _pytest._code.Source(g) # type: ignore
assert str(f_source).strip() == "def f():\n raise ValueError"
assert str(g_source).strip() == "def g(): pass"
def test_getfuncsource_with_multine_string():
def test_getfuncsource_with_multine_string() -> None:
def f():
c = """while True:
pass
@ -387,7 +399,7 @@ def test_getfuncsource_with_multine_string():
assert str(_pytest._code.Source(f)) == expected.rstrip()
def test_deindent():
def test_deindent() -> None:
from _pytest._code.source import deindent as deindent
assert deindent(["\tfoo", "\tbar"]) == ["foo", "bar"]
@ -401,7 +413,7 @@ def test_deindent():
assert lines == ["def f():", " def g():", " pass"]
def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot):
def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot) -> None:
# this test fails because the implicit inspect.getsource(A) below
# does not return the "x = 1" last line.
source = _pytest._code.Source(
@ -423,7 +435,7 @@ if True:
pass
def test_getsource_fallback():
def test_getsource_fallback() -> None:
from _pytest._code.source import getsource
expected = """def x():
@ -432,7 +444,7 @@ def test_getsource_fallback():
assert src == expected
def test_idem_compile_and_getsource():
def test_idem_compile_and_getsource() -> None:
from _pytest._code.source import getsource
expected = "def x(): pass"
@ -441,15 +453,16 @@ def test_idem_compile_and_getsource():
assert src == expected
def test_findsource_fallback():
def test_findsource_fallback() -> None:
from _pytest._code.source import findsource
src, lineno = findsource(x)
assert src is not None
assert "test_findsource_simple" in str(src)
assert src[lineno] == " def x():"
def test_findsource():
def test_findsource() -> None:
from _pytest._code.source import findsource
co = _pytest._code.compile(
@ -460,25 +473,27 @@ def test_findsource():
)
src, lineno = findsource(co)
assert src is not None
assert "if 1:" in str(src)
d = {}
d = {} # type: Dict[str, Any]
eval(co, d)
src, lineno = findsource(d["x"])
assert src is not None
assert "if 1:" in str(src)
assert src[lineno] == " def x():"
def test_getfslineno():
def test_getfslineno() -> None:
from _pytest._code import getfslineno
def f(x):
def f(x) -> None:
pass
fspath, lineno = getfslineno(f)
assert fspath.basename == "test_source.py"
assert lineno == _pytest._code.getrawcode(f).co_firstlineno - 1 # see findsource
assert lineno == f.__code__.co_firstlineno - 1 # see findsource
class A:
pass
@ -498,40 +513,40 @@ def test_getfslineno():
assert getfslineno(B)[1] == -1
def test_code_of_object_instance_with_call():
def test_code_of_object_instance_with_call() -> None:
class A:
pass
pytest.raises(TypeError, lambda: _pytest._code.Source(A()))
class WithCall:
def __call__(self):
def __call__(self) -> None:
pass
code = _pytest._code.Code(WithCall())
assert "pass" in str(code.source())
class Hello:
def __call__(self):
def __call__(self) -> None:
pass
pytest.raises(TypeError, lambda: _pytest._code.Code(Hello))
def getstatement(lineno, source):
def getstatement(lineno: int, source) -> Source:
from _pytest._code.source import getstatementrange_ast
source = _pytest._code.Source(source, deindent=False)
ast, start, end = getstatementrange_ast(lineno, source)
return source[start:end]
src = _pytest._code.Source(source, deindent=False)
ast, start, end = getstatementrange_ast(lineno, src)
return src[start:end]
def test_oneline():
def test_oneline() -> None:
source = getstatement(0, "raise ValueError")
assert str(source) == "raise ValueError"
def test_comment_and_no_newline_at_end():
def test_comment_and_no_newline_at_end() -> None:
from _pytest._code.source import getstatementrange_ast
source = Source(
@ -545,12 +560,12 @@ def test_comment_and_no_newline_at_end():
assert end == 2
def test_oneline_and_comment():
def test_oneline_and_comment() -> None:
source = getstatement(0, "raise ValueError\n#hello")
assert str(source) == "raise ValueError"
def test_comments():
def test_comments() -> None:
source = '''def test():
"comment 1"
x = 1
@ -576,7 +591,7 @@ comment 4
assert str(getstatement(line, source)) == '"""\ncomment 4\n"""'
def test_comment_in_statement():
def test_comment_in_statement() -> None:
source = """test(foo=1,
# comment 1
bar=2)
@ -588,17 +603,17 @@ def test_comment_in_statement():
)
def test_single_line_else():
def test_single_line_else() -> None:
source = getstatement(1, "if False: 2\nelse: 3")
assert str(source) == "else: 3"
def test_single_line_finally():
def test_single_line_finally() -> None:
source = getstatement(1, "try: 1\nfinally: 3")
assert str(source) == "finally: 3"
def test_issue55():
def test_issue55() -> None:
source = (
"def round_trip(dinp):\n assert 1 == dinp\n"
'def test_rt():\n round_trip("""\n""")\n'
@ -607,7 +622,7 @@ def test_issue55():
assert str(s) == ' round_trip("""\n""")'
def test_multiline():
def test_multiline() -> None:
source = getstatement(
0,
"""\
@ -621,7 +636,7 @@ x = 3
class TestTry:
def setup_class(self):
def setup_class(self) -> None:
self.source = """\
try:
raise ValueError
@ -631,25 +646,25 @@ else:
raise KeyError()
"""
def test_body(self):
def test_body(self) -> None:
source = getstatement(1, self.source)
assert str(source) == " raise ValueError"
def test_except_line(self):
def test_except_line(self) -> None:
source = getstatement(2, self.source)
assert str(source) == "except Something:"
def test_except_body(self):
def test_except_body(self) -> None:
source = getstatement(3, self.source)
assert str(source) == " raise IndexError(1)"
def test_else(self):
def test_else(self) -> None:
source = getstatement(5, self.source)
assert str(source) == " raise KeyError()"
class TestTryFinally:
def setup_class(self):
def setup_class(self) -> None:
self.source = """\
try:
raise ValueError
@ -657,17 +672,17 @@ finally:
raise IndexError(1)
"""
def test_body(self):
def test_body(self) -> None:
source = getstatement(1, self.source)
assert str(source) == " raise ValueError"
def test_finally(self):
def test_finally(self) -> None:
source = getstatement(3, self.source)
assert str(source) == " raise IndexError(1)"
class TestIf:
def setup_class(self):
def setup_class(self) -> None:
self.source = """\
if 1:
y = 3
@ -677,24 +692,24 @@ else:
y = 7
"""
def test_body(self):
def test_body(self) -> None:
source = getstatement(1, self.source)
assert str(source) == " y = 3"
def test_elif_clause(self):
def test_elif_clause(self) -> None:
source = getstatement(2, self.source)
assert str(source) == "elif False:"
def test_elif(self):
def test_elif(self) -> None:
source = getstatement(3, self.source)
assert str(source) == " y = 5"
def test_else(self):
def test_else(self) -> None:
source = getstatement(5, self.source)
assert str(source) == " y = 7"
def test_semicolon():
def test_semicolon() -> None:
s = """\
hello ; pytest.skip()
"""
@ -702,7 +717,7 @@ hello ; pytest.skip()
assert str(source) == s.strip()
def test_def_online():
def test_def_online() -> None:
s = """\
def func(): raise ValueError(42)
@ -713,7 +728,7 @@ def something():
assert str(source) == "def func(): raise ValueError(42)"
def XXX_test_expression_multiline():
def XXX_test_expression_multiline() -> None:
source = """\
something
'''
@ -722,7 +737,7 @@ something
assert str(result) == "'''\n'''"
def test_getstartingblock_multiline():
def test_getstartingblock_multiline() -> None:
class A:
def __init__(self, *args):
frame = sys._getframe(1)

View File

@ -92,8 +92,6 @@ class TestCaptureManager:
@pytest.mark.parametrize("method", ["fd", "sys"])
def test_capturing_unicode(testdir, method):
if hasattr(sys, "pypy_version_info") and sys.pypy_version_info < (2, 2):
pytest.xfail("does not work on pypy < 2.2")
obj = "'b\u00f6y'"
testdir.makepyfile(
"""\