diff --git a/src/_pytest/compat.py b/src/_pytest/compat.py index 09e621c5d..fc810b3e5 100644 --- a/src/_pytest/compat.py +++ b/src/_pytest/compat.py @@ -10,11 +10,14 @@ import sys from contextlib import contextmanager from inspect import Parameter from inspect import signature +from typing import Any from typing import Callable from typing import Generic from typing import Optional from typing import overload +from typing import Tuple from typing import TypeVar +from typing import Union import attr import py @@ -46,7 +49,7 @@ else: import importlib_metadata # noqa: F401 -def _format_args(func): +def _format_args(func: Callable[..., Any]) -> str: return str(signature(func)) @@ -67,12 +70,12 @@ else: fspath = os.fspath -def is_generator(func): +def is_generator(func: object) -> bool: genfunc = inspect.isgeneratorfunction(func) return genfunc and not iscoroutinefunction(func) -def iscoroutinefunction(func): +def iscoroutinefunction(func: object) -> bool: """ Return True if func is a coroutine function (a function defined with async def syntax, and doesn't contain yield), or a function decorated with @@ -85,7 +88,7 @@ def iscoroutinefunction(func): return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False) -def getlocation(function, curdir=None): +def getlocation(function, curdir=None) -> str: function = get_real_func(function) fn = py.path.local(inspect.getfile(function)) lineno = function.__code__.co_firstlineno @@ -94,7 +97,7 @@ def getlocation(function, curdir=None): return "%s:%d" % (fn, lineno + 1) -def num_mock_patch_args(function): +def num_mock_patch_args(function) -> int: """ return number of arguments used up by mock arguments (if any) """ patchings = getattr(function, "patchings", None) if not patchings: @@ -113,7 +116,13 @@ def num_mock_patch_args(function): ) -def getfuncargnames(function, *, name: str = "", is_method=False, cls=None): +def getfuncargnames( + function: Callable[..., Any], + *, + name: str = "", + is_method: bool = False, + cls: Optional[type] = None +) -> Tuple[str, ...]: """Returns the names of a function's mandatory arguments. This should return the names of all function arguments that: @@ -181,7 +190,7 @@ else: from contextlib import nullcontext # noqa -def get_default_arg_names(function): +def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]: # Note: this code intentionally mirrors the code at the beginning of getfuncargnames, # to get the arguments which were excluded from its result because they had default values return tuple( @@ -200,18 +209,18 @@ _non_printable_ascii_translate_table.update( ) -def _translate_non_printable(s): +def _translate_non_printable(s: str) -> str: return s.translate(_non_printable_ascii_translate_table) STRING_TYPES = bytes, str -def _bytes_to_ascii(val): +def _bytes_to_ascii(val: bytes) -> str: return val.decode("ascii", "backslashreplace") -def ascii_escaped(val): +def ascii_escaped(val: Union[bytes, str]): """If val is pure ascii, returns it as a str(). Otherwise, escapes bytes objects into a sequence of escaped bytes: @@ -308,7 +317,7 @@ def getimfunc(func): return func -def safe_getattr(object, name, default): +def safe_getattr(object: Any, name: str, default: Any) -> Any: """ Like getattr but return default upon any Exception or any OutcomeException. Attribute access can potentially fail for 'evil' Python objects. @@ -322,7 +331,7 @@ def safe_getattr(object, name, default): return default -def safe_isclass(obj): +def safe_isclass(obj: object) -> bool: """Ignore any exception via isinstance on Python 3.""" try: return inspect.isclass(obj) @@ -343,21 +352,23 @@ COLLECT_FAKEMODULE_ATTRIBUTES = ( ) -def _setup_collect_fakemodule(): +def _setup_collect_fakemodule() -> None: from types import ModuleType import pytest - pytest.collect = ModuleType("pytest.collect") - pytest.collect.__all__ = [] # used for setns + # Types ignored because the module is created dynamically. + pytest.collect = ModuleType("pytest.collect") # type: ignore + pytest.collect.__all__ = [] # type: ignore # used for setns for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES: - setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) + setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) # type: ignore class CaptureIO(io.TextIOWrapper): - def __init__(self): + def __init__(self) -> None: super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True) - def getvalue(self): + def getvalue(self) -> str: + assert isinstance(self.buffer, io.BytesIO) return self.buffer.getvalue().decode("UTF-8")