diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index 14428c885..55c9e9100 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -67,7 +67,7 @@ class Code: return not self == other @property - def path(self): + def path(self) -> Union[py.path.local, str]: """ return a path object pointing to source code (note that it might not point to an actually existing file). """ try: @@ -335,7 +335,7 @@ class Traceback(List[TracebackEntry]): (path is None or codepath == path) and ( excludepath is None - or not hasattr(codepath, "relto") + or not isinstance(codepath, py.path.local) or not codepath.relto(excludepath) ) and (lineno is None or x.lineno == lineno) diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index ee3f7cb14..67c74143f 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -6,6 +6,7 @@ import textwrap import tokenize import warnings from bisect import bisect_right +from types import CodeType from types import FrameType from typing import Iterator from typing import List @@ -17,6 +18,10 @@ from typing import Union import py from _pytest.compat import overload +from _pytest.compat import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Literal class Source: @@ -120,7 +125,7 @@ class Source: start, end = self.getstatementrange(lineno) return self[start:end] - def getstatementrange(self, lineno: int): + def getstatementrange(self, lineno: int) -> Tuple[int, int]: """ return (start, end) tuple which spans the minimal statement region which containing the given lineno. """ @@ -158,14 +163,36 @@ class Source: def __str__(self) -> str: return "\n".join(self.lines) + @overload def compile( self, - filename=None, - mode="exec", + filename: Optional[str] = ..., + mode: str = ..., + flag: "Literal[0]" = ..., + dont_inherit: int = ..., + _genframe: Optional[FrameType] = ..., + ) -> CodeType: + raise NotImplementedError() + + @overload # noqa: F811 + def compile( # noqa: F811 + self, + filename: Optional[str] = ..., + mode: str = ..., + flag: int = ..., + dont_inherit: int = ..., + _genframe: Optional[FrameType] = ..., + ) -> Union[CodeType, ast.AST]: + raise NotImplementedError() + + def compile( # noqa: F811 + self, + filename: Optional[str] = None, + mode: str = "exec", flag: int = 0, dont_inherit: int = 0, _genframe: Optional[FrameType] = None, - ): + ) -> Union[CodeType, ast.AST]: """ return compiled code object. if filename is None invent an artificial filename which displays the source/line position of the caller frame. @@ -196,7 +223,9 @@ class Source: raise newex else: if flag & ast.PyCF_ONLY_AST: + assert isinstance(co, ast.AST) return co + assert isinstance(co, CodeType) lines = [(x + "\n") for x in self.lines] # Type ignored because linecache.cache is private. linecache.cache[filename] = (1, None, lines, filename) # type: ignore @@ -208,7 +237,35 @@ class Source: # -def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0): +@overload +def compile_( + source: Union[str, bytes, ast.mod, ast.AST], + filename: Optional[str] = ..., + mode: str = ..., + flags: "Literal[0]" = ..., + dont_inherit: int = ..., +) -> CodeType: + raise NotImplementedError() + + +@overload # noqa: F811 +def compile_( # noqa: F811 + source: Union[str, bytes, ast.mod, ast.AST], + filename: Optional[str] = ..., + mode: str = ..., + flags: int = ..., + dont_inherit: int = ..., +) -> Union[CodeType, ast.AST]: + raise NotImplementedError() + + +def compile_( # noqa: F811 + source: Union[str, bytes, ast.mod, ast.AST], + filename: Optional[str] = None, + mode: str = "exec", + flags: int = 0, + dont_inherit: int = 0, +) -> Union[CodeType, ast.AST]: """ compile the given source to a raw code object, and maintain an internal cache which allows later retrieval of the source code for the code object @@ -216,14 +273,16 @@ def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: i """ if isinstance(source, ast.AST): # XXX should Source support having AST? - return compile(source, filename, mode, flags, dont_inherit) + assert filename is not None + co = compile(source, filename, mode, flags, dont_inherit) + assert isinstance(co, (CodeType, ast.AST)) + return co _genframe = sys._getframe(1) # the caller s = Source(source) - co = s.compile(filename, mode, flags, _genframe=_genframe) - return co + return s.compile(filename, mode, flags, _genframe=_genframe) -def getfslineno(obj): +def getfslineno(obj) -> Tuple[Union[str, py.path.local], int]: """ Return source location (path, lineno) for the given object. If the source cannot be determined return ("", -1). diff --git a/testing/code/test_source.py b/testing/code/test_source.py index 1390d8b0a..030e60676 100644 --- a/testing/code/test_source.py +++ b/testing/code/test_source.py @@ -4,10 +4,13 @@ import ast import inspect import sys +from types import CodeType from typing import Any from typing import Dict from typing import Optional +import py + import _pytest._code import pytest from _pytest._code import Source @@ -147,6 +150,10 @@ class TestAccesses: assert len(x.lines) == 2 assert str(x) == "def f(x):\n pass" + def test_getrange_step_not_supported(self) -> None: + with pytest.raises(IndexError, match=r"step"): + self.source[::2] + def test_getline(self) -> None: x = self.source[0] assert x == "def f(x):" @@ -449,6 +456,14 @@ def test_idem_compile_and_getsource() -> None: assert src == expected +def test_compile_ast() -> None: + # We don't necessarily want to support this. + # This test was added just for coverage. + stmt = ast.parse("def x(): pass") + co = _pytest._code.compile(stmt, filename="foo.py") + assert isinstance(co, CodeType) + + def test_findsource_fallback() -> None: from _pytest._code.source import findsource @@ -488,6 +503,7 @@ def test_getfslineno() -> None: fspath, lineno = getfslineno(f) + assert isinstance(fspath, py.path.local) assert fspath.basename == "test_source.py" assert lineno == f.__code__.co_firstlineno - 1 # see findsource