diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index 1d9fffd34..46e578188 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -1,9 +1,15 @@ """Utilities for assertion debugging""" +import collections.abc import pprint -from collections.abc import Sequence +from typing import AbstractSet +from typing import Any from typing import Callable +from typing import Iterable from typing import List +from typing import Mapping from typing import Optional +from typing import Sequence +from typing import Tuple import _pytest._code from _pytest import outcomes @@ -22,7 +28,7 @@ _reprcompare = None # type: Optional[Callable[[str, object, object], Optional[s _assertion_pass = None # type: Optional[Callable[[int, str, str], None]] -def format_explanation(explanation): +def format_explanation(explanation: str) -> str: """This formats an explanation Normally all embedded newlines are escaped, however there are @@ -38,7 +44,7 @@ def format_explanation(explanation): return "\n".join(result) -def _split_explanation(explanation): +def _split_explanation(explanation: str) -> List[str]: """Return a list of individual lines in the explanation This will return a list of lines split on '\n{', '\n}' and '\n~'. @@ -55,7 +61,7 @@ def _split_explanation(explanation): return lines -def _format_lines(lines): +def _format_lines(lines: Sequence[str]) -> List[str]: """Format the individual lines This will replace the '{', '}' and '~' characters of our mini @@ -64,7 +70,7 @@ def _format_lines(lines): Return a list of formatted lines. """ - result = lines[:1] + result = list(lines[:1]) stack = [0] stackcnt = [0] for line in lines[1:]: @@ -90,31 +96,31 @@ def _format_lines(lines): return result -def issequence(x): - return isinstance(x, Sequence) and not isinstance(x, str) +def issequence(x: Any) -> bool: + return isinstance(x, collections.abc.Sequence) and not isinstance(x, str) -def istext(x): +def istext(x: Any) -> bool: return isinstance(x, str) -def isdict(x): +def isdict(x: Any) -> bool: return isinstance(x, dict) -def isset(x): +def isset(x: Any) -> bool: return isinstance(x, (set, frozenset)) -def isdatacls(obj): +def isdatacls(obj: Any) -> bool: return getattr(obj, "__dataclass_fields__", None) is not None -def isattrs(obj): +def isattrs(obj: Any) -> bool: return getattr(obj, "__attrs_attrs__", None) is not None -def isiterable(obj): +def isiterable(obj: Any) -> bool: try: iter(obj) return not istext(obj) @@ -122,7 +128,7 @@ def isiterable(obj): return False -def assertrepr_compare(config, op, left, right): +def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[str]]: """Return specialised explanations for some operators/operands""" verbose = config.getoption("verbose") if verbose > 1: @@ -180,7 +186,7 @@ def assertrepr_compare(config, op, left, right): return [summary] + explanation -def _diff_text(left, right, verbose=0): +def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: """Return the explanation for the diff between text. Unless --verbose is used this will skip leading and trailing @@ -226,7 +232,7 @@ def _diff_text(left, right, verbose=0): return explanation -def _compare_eq_verbose(left, right): +def _compare_eq_verbose(left: Any, right: Any) -> List[str]: keepends = True left_lines = repr(left).splitlines(keepends) right_lines = repr(right).splitlines(keepends) @@ -238,7 +244,7 @@ def _compare_eq_verbose(left, right): return explanation -def _surrounding_parens_on_own_lines(lines): # type: (List) -> None +def _surrounding_parens_on_own_lines(lines: List[str]) -> None: """Move opening/closing parenthesis/bracket to own lines.""" opening = lines[0][:1] if opening in ["(", "[", "{"]: @@ -250,7 +256,9 @@ def _surrounding_parens_on_own_lines(lines): # type: (List) -> None lines[:] = lines + [closing] -def _compare_eq_iterable(left, right, verbose=0): +def _compare_eq_iterable( + left: Iterable[Any], right: Iterable[Any], verbose: int = 0 +) -> List[str]: if not verbose: return ["Use -v to get the full diff"] # dynamic import to speedup pytest @@ -283,7 +291,9 @@ def _compare_eq_iterable(left, right, verbose=0): return explanation -def _compare_eq_sequence(left, right, verbose=0): +def _compare_eq_sequence( + left: Sequence[Any], right: Sequence[Any], verbose: int = 0 +) -> List[str]: comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes) explanation = [] # type: List[str] len_left = len(left) @@ -337,7 +347,9 @@ def _compare_eq_sequence(left, right, verbose=0): return explanation -def _compare_eq_set(left, right, verbose=0): +def _compare_eq_set( + left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0 +) -> List[str]: explanation = [] diff_left = left - right diff_right = right - left @@ -352,7 +364,9 @@ def _compare_eq_set(left, right, verbose=0): return explanation -def _compare_eq_dict(left, right, verbose=0): +def _compare_eq_dict( + left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0 +) -> List[str]: explanation = [] # type: List[str] set_left = set(left) set_right = set(right) @@ -391,7 +405,12 @@ def _compare_eq_dict(left, right, verbose=0): return explanation -def _compare_eq_cls(left, right, verbose, type_fns): +def _compare_eq_cls( + left: Any, + right: Any, + verbose: int, + type_fns: Tuple[Callable[[Any], bool], Callable[[Any], bool]], +) -> List[str]: isdatacls, isattrs = type_fns if isdatacls(left): all_fields = left.__dataclass_fields__ @@ -425,7 +444,7 @@ def _compare_eq_cls(left, right, verbose, type_fns): return explanation -def _notin_text(term, text, verbose=0): +def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]: index = text.find(term) head = text[:index] tail = text[index + len(term) :]