Add a few missing type annotations in _pytest._code
These are more "dirty" than the previous batch (that's why they were left out). The trouble is that `compile` can return either a code object or an AST depending on a flag, so we need to add an overload to make the common case Union free. But it's still worthwhile.
This commit is contained in:
parent
3e6f0f34ff
commit
0c247be769
|
@ -67,7 +67,7 @@ class Code:
|
||||||
return not self == other
|
return not self == other
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self):
|
def path(self) -> Union[py.path.local, str]:
|
||||||
""" return a path object pointing to source code (note that it
|
""" return a path object pointing to source code (note that it
|
||||||
might not point to an actually existing file). """
|
might not point to an actually existing file). """
|
||||||
try:
|
try:
|
||||||
|
@ -335,7 +335,7 @@ class Traceback(List[TracebackEntry]):
|
||||||
(path is None or codepath == path)
|
(path is None or codepath == path)
|
||||||
and (
|
and (
|
||||||
excludepath is None
|
excludepath is None
|
||||||
or not hasattr(codepath, "relto")
|
or not isinstance(codepath, py.path.local)
|
||||||
or not codepath.relto(excludepath)
|
or not codepath.relto(excludepath)
|
||||||
)
|
)
|
||||||
and (lineno is None or x.lineno == lineno)
|
and (lineno is None or x.lineno == lineno)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import textwrap
|
||||||
import tokenize
|
import tokenize
|
||||||
import warnings
|
import warnings
|
||||||
from bisect import bisect_right
|
from bisect import bisect_right
|
||||||
|
from types import CodeType
|
||||||
from types import FrameType
|
from types import FrameType
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
from typing import List
|
from typing import List
|
||||||
|
@ -17,6 +18,10 @@ from typing import Union
|
||||||
import py
|
import py
|
||||||
|
|
||||||
from _pytest.compat import overload
|
from _pytest.compat import overload
|
||||||
|
from _pytest.compat import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
class Source:
|
class Source:
|
||||||
|
@ -120,7 +125,7 @@ class Source:
|
||||||
start, end = self.getstatementrange(lineno)
|
start, end = self.getstatementrange(lineno)
|
||||||
return self[start:end]
|
return self[start:end]
|
||||||
|
|
||||||
def getstatementrange(self, lineno: int):
|
def getstatementrange(self, lineno: int) -> Tuple[int, int]:
|
||||||
""" return (start, end) tuple which spans the minimal
|
""" return (start, end) tuple which spans the minimal
|
||||||
statement region which containing the given lineno.
|
statement region which containing the given lineno.
|
||||||
"""
|
"""
|
||||||
|
@ -158,14 +163,36 @@ class Source:
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return "\n".join(self.lines)
|
return "\n".join(self.lines)
|
||||||
|
|
||||||
|
@overload
|
||||||
def compile(
|
def compile(
|
||||||
self,
|
self,
|
||||||
filename=None,
|
filename: Optional[str] = ...,
|
||||||
mode="exec",
|
mode: str = ...,
|
||||||
|
flag: "Literal[0]" = ...,
|
||||||
|
dont_inherit: int = ...,
|
||||||
|
_genframe: Optional[FrameType] = ...,
|
||||||
|
) -> CodeType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@overload # noqa: F811
|
||||||
|
def compile( # noqa: F811
|
||||||
|
self,
|
||||||
|
filename: Optional[str] = ...,
|
||||||
|
mode: str = ...,
|
||||||
|
flag: int = ...,
|
||||||
|
dont_inherit: int = ...,
|
||||||
|
_genframe: Optional[FrameType] = ...,
|
||||||
|
) -> Union[CodeType, ast.AST]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def compile( # noqa: F811
|
||||||
|
self,
|
||||||
|
filename: Optional[str] = None,
|
||||||
|
mode: str = "exec",
|
||||||
flag: int = 0,
|
flag: int = 0,
|
||||||
dont_inherit: int = 0,
|
dont_inherit: int = 0,
|
||||||
_genframe: Optional[FrameType] = None,
|
_genframe: Optional[FrameType] = None,
|
||||||
):
|
) -> Union[CodeType, ast.AST]:
|
||||||
""" return compiled code object. if filename is None
|
""" return compiled code object. if filename is None
|
||||||
invent an artificial filename which displays
|
invent an artificial filename which displays
|
||||||
the source/line position of the caller frame.
|
the source/line position of the caller frame.
|
||||||
|
@ -196,7 +223,9 @@ class Source:
|
||||||
raise newex
|
raise newex
|
||||||
else:
|
else:
|
||||||
if flag & ast.PyCF_ONLY_AST:
|
if flag & ast.PyCF_ONLY_AST:
|
||||||
|
assert isinstance(co, ast.AST)
|
||||||
return co
|
return co
|
||||||
|
assert isinstance(co, CodeType)
|
||||||
lines = [(x + "\n") for x in self.lines]
|
lines = [(x + "\n") for x in self.lines]
|
||||||
# Type ignored because linecache.cache is private.
|
# Type ignored because linecache.cache is private.
|
||||||
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
|
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
|
||||||
|
@ -208,7 +237,35 @@ class Source:
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0):
|
@overload
|
||||||
|
def compile_(
|
||||||
|
source: Union[str, bytes, ast.mod, ast.AST],
|
||||||
|
filename: Optional[str] = ...,
|
||||||
|
mode: str = ...,
|
||||||
|
flags: "Literal[0]" = ...,
|
||||||
|
dont_inherit: int = ...,
|
||||||
|
) -> CodeType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@overload # noqa: F811
|
||||||
|
def compile_( # noqa: F811
|
||||||
|
source: Union[str, bytes, ast.mod, ast.AST],
|
||||||
|
filename: Optional[str] = ...,
|
||||||
|
mode: str = ...,
|
||||||
|
flags: int = ...,
|
||||||
|
dont_inherit: int = ...,
|
||||||
|
) -> Union[CodeType, ast.AST]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def compile_( # noqa: F811
|
||||||
|
source: Union[str, bytes, ast.mod, ast.AST],
|
||||||
|
filename: Optional[str] = None,
|
||||||
|
mode: str = "exec",
|
||||||
|
flags: int = 0,
|
||||||
|
dont_inherit: int = 0,
|
||||||
|
) -> Union[CodeType, ast.AST]:
|
||||||
""" compile the given source to a raw code object,
|
""" compile the given source to a raw code object,
|
||||||
and maintain an internal cache which allows later
|
and maintain an internal cache which allows later
|
||||||
retrieval of the source code for the code object
|
retrieval of the source code for the code object
|
||||||
|
@ -216,14 +273,16 @@ def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: i
|
||||||
"""
|
"""
|
||||||
if isinstance(source, ast.AST):
|
if isinstance(source, ast.AST):
|
||||||
# XXX should Source support having AST?
|
# XXX should Source support having AST?
|
||||||
return compile(source, filename, mode, flags, dont_inherit)
|
assert filename is not None
|
||||||
|
co = compile(source, filename, mode, flags, dont_inherit)
|
||||||
|
assert isinstance(co, (CodeType, ast.AST))
|
||||||
|
return co
|
||||||
_genframe = sys._getframe(1) # the caller
|
_genframe = sys._getframe(1) # the caller
|
||||||
s = Source(source)
|
s = Source(source)
|
||||||
co = s.compile(filename, mode, flags, _genframe=_genframe)
|
return s.compile(filename, mode, flags, _genframe=_genframe)
|
||||||
return co
|
|
||||||
|
|
||||||
|
|
||||||
def getfslineno(obj):
|
def getfslineno(obj) -> Tuple[Union[str, py.path.local], int]:
|
||||||
""" Return source location (path, lineno) for the given object.
|
""" Return source location (path, lineno) for the given object.
|
||||||
If the source cannot be determined return ("", -1).
|
If the source cannot be determined return ("", -1).
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,13 @@
|
||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
|
from types import CodeType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import py
|
||||||
|
|
||||||
import _pytest._code
|
import _pytest._code
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest._code import Source
|
from _pytest._code import Source
|
||||||
|
@ -147,6 +150,10 @@ class TestAccesses:
|
||||||
assert len(x.lines) == 2
|
assert len(x.lines) == 2
|
||||||
assert str(x) == "def f(x):\n pass"
|
assert str(x) == "def f(x):\n pass"
|
||||||
|
|
||||||
|
def test_getrange_step_not_supported(self) -> None:
|
||||||
|
with pytest.raises(IndexError, match=r"step"):
|
||||||
|
self.source[::2]
|
||||||
|
|
||||||
def test_getline(self) -> None:
|
def test_getline(self) -> None:
|
||||||
x = self.source[0]
|
x = self.source[0]
|
||||||
assert x == "def f(x):"
|
assert x == "def f(x):"
|
||||||
|
@ -449,6 +456,14 @@ def test_idem_compile_and_getsource() -> None:
|
||||||
assert src == expected
|
assert src == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_compile_ast() -> None:
|
||||||
|
# We don't necessarily want to support this.
|
||||||
|
# This test was added just for coverage.
|
||||||
|
stmt = ast.parse("def x(): pass")
|
||||||
|
co = _pytest._code.compile(stmt, filename="foo.py")
|
||||||
|
assert isinstance(co, CodeType)
|
||||||
|
|
||||||
|
|
||||||
def test_findsource_fallback() -> None:
|
def test_findsource_fallback() -> None:
|
||||||
from _pytest._code.source import findsource
|
from _pytest._code.source import findsource
|
||||||
|
|
||||||
|
@ -488,6 +503,7 @@ def test_getfslineno() -> None:
|
||||||
|
|
||||||
fspath, lineno = getfslineno(f)
|
fspath, lineno = getfslineno(f)
|
||||||
|
|
||||||
|
assert isinstance(fspath, py.path.local)
|
||||||
assert fspath.basename == "test_source.py"
|
assert fspath.basename == "test_source.py"
|
||||||
assert lineno == f.__code__.co_firstlineno - 1 # see findsource
|
assert lineno == f.__code__.co_firstlineno - 1 # see findsource
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue