Add type annotations to _pytest.compat

This commit is contained in:
Ran Benita 2019-11-15 16:26:46 +02:00
parent a649f157de
commit 562d4811d5
1 changed files with 29 additions and 18 deletions

View File

@ -10,11 +10,14 @@ import sys
from contextlib import contextmanager
from inspect import Parameter
from inspect import signature
from typing import Any
from typing import Callable
from typing import Generic
from typing import Optional
from typing import overload
from typing import Tuple
from typing import TypeVar
from typing import Union
import attr
import py
@ -46,7 +49,7 @@ else:
import importlib_metadata # noqa: F401
def _format_args(func):
def _format_args(func: Callable[..., Any]) -> str:
return str(signature(func))
@ -67,12 +70,12 @@ else:
fspath = os.fspath
def is_generator(func):
def is_generator(func: object) -> bool:
genfunc = inspect.isgeneratorfunction(func)
return genfunc and not iscoroutinefunction(func)
def iscoroutinefunction(func):
def iscoroutinefunction(func: object) -> bool:
"""
Return True if func is a coroutine function (a function defined with async
def syntax, and doesn't contain yield), or a function decorated with
@ -85,7 +88,7 @@ def iscoroutinefunction(func):
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
def getlocation(function, curdir=None):
def getlocation(function, curdir=None) -> str:
function = get_real_func(function)
fn = py.path.local(inspect.getfile(function))
lineno = function.__code__.co_firstlineno
@ -94,7 +97,7 @@ def getlocation(function, curdir=None):
return "%s:%d" % (fn, lineno + 1)
def num_mock_patch_args(function):
def num_mock_patch_args(function) -> int:
""" return number of arguments used up by mock arguments (if any) """
patchings = getattr(function, "patchings", None)
if not patchings:
@ -113,7 +116,13 @@ def num_mock_patch_args(function):
)
def getfuncargnames(function, *, name: str = "", is_method=False, cls=None):
def getfuncargnames(
function: Callable[..., Any],
*,
name: str = "",
is_method: bool = False,
cls: Optional[type] = None
) -> Tuple[str, ...]:
"""Returns the names of a function's mandatory arguments.
This should return the names of all function arguments that:
@ -181,7 +190,7 @@ else:
from contextlib import nullcontext # noqa
def get_default_arg_names(function):
def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]:
# Note: this code intentionally mirrors the code at the beginning of getfuncargnames,
# to get the arguments which were excluded from its result because they had default values
return tuple(
@ -200,18 +209,18 @@ _non_printable_ascii_translate_table.update(
)
def _translate_non_printable(s):
def _translate_non_printable(s: str) -> str:
return s.translate(_non_printable_ascii_translate_table)
STRING_TYPES = bytes, str
def _bytes_to_ascii(val):
def _bytes_to_ascii(val: bytes) -> str:
return val.decode("ascii", "backslashreplace")
def ascii_escaped(val):
def ascii_escaped(val: Union[bytes, str]):
"""If val is pure ascii, returns it as a str(). Otherwise, escapes
bytes objects into a sequence of escaped bytes:
@ -308,7 +317,7 @@ def getimfunc(func):
return func
def safe_getattr(object, name, default):
def safe_getattr(object: Any, name: str, default: Any) -> Any:
""" Like getattr but return default upon any Exception or any OutcomeException.
Attribute access can potentially fail for 'evil' Python objects.
@ -322,7 +331,7 @@ def safe_getattr(object, name, default):
return default
def safe_isclass(obj):
def safe_isclass(obj: object) -> bool:
"""Ignore any exception via isinstance on Python 3."""
try:
return inspect.isclass(obj)
@ -343,21 +352,23 @@ COLLECT_FAKEMODULE_ATTRIBUTES = (
)
def _setup_collect_fakemodule():
def _setup_collect_fakemodule() -> None:
from types import ModuleType
import pytest
pytest.collect = ModuleType("pytest.collect")
pytest.collect.__all__ = [] # used for setns
# Types ignored because the module is created dynamically.
pytest.collect = ModuleType("pytest.collect") # type: ignore
pytest.collect.__all__ = [] # type: ignore # used for setns
for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES:
setattr(pytest.collect, attr_name, getattr(pytest, attr_name))
setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) # type: ignore
class CaptureIO(io.TextIOWrapper):
def __init__(self):
def __init__(self) -> None:
super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
def getvalue(self):
def getvalue(self) -> str:
assert isinstance(self.buffer, io.BytesIO)
return self.buffer.getvalue().decode("UTF-8")