Add a few missing type annotations in _pytest._code

These are more "dirty" than the previous batch (that's why they were
left out). The trouble is that `compile` can return either a code object
or an AST depending on a flag, so we need to add an overload to make the
common case Union free. But it's still worthwhile.
This commit is contained in:
Ran Benita 2019-11-25 17:20:54 +02:00
parent 3e6f0f34ff
commit 0c247be769
3 changed files with 86 additions and 11 deletions

View File

@ -67,7 +67,7 @@ class Code:
return not self == other return not self == other
@property @property
def path(self): def path(self) -> Union[py.path.local, str]:
""" return a path object pointing to source code (note that it """ return a path object pointing to source code (note that it
might not point to an actually existing file). """ might not point to an actually existing file). """
try: try:
@ -335,7 +335,7 @@ class Traceback(List[TracebackEntry]):
(path is None or codepath == path) (path is None or codepath == path)
and ( and (
excludepath is None excludepath is None
or not hasattr(codepath, "relto") or not isinstance(codepath, py.path.local)
or not codepath.relto(excludepath) or not codepath.relto(excludepath)
) )
and (lineno is None or x.lineno == lineno) and (lineno is None or x.lineno == lineno)

View File

@ -6,6 +6,7 @@ import textwrap
import tokenize import tokenize
import warnings import warnings
from bisect import bisect_right from bisect import bisect_right
from types import CodeType
from types import FrameType from types import FrameType
from typing import Iterator from typing import Iterator
from typing import List from typing import List
@ -17,6 +18,10 @@ from typing import Union
import py import py
from _pytest.compat import overload from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING
if TYPE_CHECKING:
from typing_extensions import Literal
class Source: class Source:
@ -120,7 +125,7 @@ class Source:
start, end = self.getstatementrange(lineno) start, end = self.getstatementrange(lineno)
return self[start:end] return self[start:end]
def getstatementrange(self, lineno: int): def getstatementrange(self, lineno: int) -> Tuple[int, int]:
""" return (start, end) tuple which spans the minimal """ return (start, end) tuple which spans the minimal
statement region which containing the given lineno. statement region which containing the given lineno.
""" """
@ -158,14 +163,36 @@ class Source:
def __str__(self) -> str: def __str__(self) -> str:
return "\n".join(self.lines) return "\n".join(self.lines)
@overload
def compile( def compile(
self, self,
filename=None, filename: Optional[str] = ...,
mode="exec", mode: str = ...,
flag: "Literal[0]" = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> CodeType:
raise NotImplementedError()
@overload # noqa: F811
def compile( # noqa: F811
self,
filename: Optional[str] = ...,
mode: str = ...,
flag: int = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()
def compile( # noqa: F811
self,
filename: Optional[str] = None,
mode: str = "exec",
flag: int = 0, flag: int = 0,
dont_inherit: int = 0, dont_inherit: int = 0,
_genframe: Optional[FrameType] = None, _genframe: Optional[FrameType] = None,
): ) -> Union[CodeType, ast.AST]:
""" return compiled code object. if filename is None """ return compiled code object. if filename is None
invent an artificial filename which displays invent an artificial filename which displays
the source/line position of the caller frame. the source/line position of the caller frame.
@ -196,7 +223,9 @@ class Source:
raise newex raise newex
else: else:
if flag & ast.PyCF_ONLY_AST: if flag & ast.PyCF_ONLY_AST:
assert isinstance(co, ast.AST)
return co return co
assert isinstance(co, CodeType)
lines = [(x + "\n") for x in self.lines] lines = [(x + "\n") for x in self.lines]
# Type ignored because linecache.cache is private. # Type ignored because linecache.cache is private.
linecache.cache[filename] = (1, None, lines, filename) # type: ignore linecache.cache[filename] = (1, None, lines, filename) # type: ignore
@ -208,7 +237,35 @@ class Source:
# #
def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0): @overload
def compile_(
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: "Literal[0]" = ...,
dont_inherit: int = ...,
) -> CodeType:
raise NotImplementedError()
@overload # noqa: F811
def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: int = ...,
dont_inherit: int = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()
def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = None,
mode: str = "exec",
flags: int = 0,
dont_inherit: int = 0,
) -> Union[CodeType, ast.AST]:
""" compile the given source to a raw code object, """ compile the given source to a raw code object,
and maintain an internal cache which allows later and maintain an internal cache which allows later
retrieval of the source code for the code object retrieval of the source code for the code object
@ -216,14 +273,16 @@ def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: i
""" """
if isinstance(source, ast.AST): if isinstance(source, ast.AST):
# XXX should Source support having AST? # XXX should Source support having AST?
return compile(source, filename, mode, flags, dont_inherit) assert filename is not None
co = compile(source, filename, mode, flags, dont_inherit)
assert isinstance(co, (CodeType, ast.AST))
return co
_genframe = sys._getframe(1) # the caller _genframe = sys._getframe(1) # the caller
s = Source(source) s = Source(source)
co = s.compile(filename, mode, flags, _genframe=_genframe) return s.compile(filename, mode, flags, _genframe=_genframe)
return co
def getfslineno(obj): def getfslineno(obj) -> Tuple[Union[str, py.path.local], int]:
""" Return source location (path, lineno) for the given object. """ Return source location (path, lineno) for the given object.
If the source cannot be determined return ("", -1). If the source cannot be determined return ("", -1).

View File

@ -4,10 +4,13 @@
import ast import ast
import inspect import inspect
import sys import sys
from types import CodeType
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
import py
import _pytest._code import _pytest._code
import pytest import pytest
from _pytest._code import Source from _pytest._code import Source
@ -147,6 +150,10 @@ class TestAccesses:
assert len(x.lines) == 2 assert len(x.lines) == 2
assert str(x) == "def f(x):\n pass" assert str(x) == "def f(x):\n pass"
def test_getrange_step_not_supported(self) -> None:
with pytest.raises(IndexError, match=r"step"):
self.source[::2]
def test_getline(self) -> None: def test_getline(self) -> None:
x = self.source[0] x = self.source[0]
assert x == "def f(x):" assert x == "def f(x):"
@ -449,6 +456,14 @@ def test_idem_compile_and_getsource() -> None:
assert src == expected assert src == expected
def test_compile_ast() -> None:
# We don't necessarily want to support this.
# This test was added just for coverage.
stmt = ast.parse("def x(): pass")
co = _pytest._code.compile(stmt, filename="foo.py")
assert isinstance(co, CodeType)
def test_findsource_fallback() -> None: def test_findsource_fallback() -> None:
from _pytest._code.source import findsource from _pytest._code.source import findsource
@ -488,6 +503,7 @@ def test_getfslineno() -> None:
fspath, lineno = getfslineno(f) fspath, lineno = getfslineno(f)
assert isinstance(fspath, py.path.local)
assert fspath.basename == "test_source.py" assert fspath.basename == "test_source.py"
assert lineno == f.__code__.co_firstlineno - 1 # see findsource assert lineno == f.__code__.co_firstlineno - 1 # see findsource