From 7d3ce374d28d7eadd85c59ed1e59319556e61635 Mon Sep 17 00:00:00 2001
From: Ran Benita <ran@unusedvar.com>
Date: Sun, 3 Nov 2019 16:57:14 +0200
Subject: [PATCH] Add type annotations to _pytest.assertion.util

---
 src/_pytest/assertion/util.py | 65 ++++++++++++++++++++++-------------
 1 file changed, 42 insertions(+), 23 deletions(-)

diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py
index 1d9fffd34..46e578188 100644
--- a/src/_pytest/assertion/util.py
+++ b/src/_pytest/assertion/util.py
@@ -1,9 +1,15 @@
 """Utilities for assertion debugging"""
+import collections.abc
 import pprint
-from collections.abc import Sequence
+from typing import AbstractSet
+from typing import Any
 from typing import Callable
+from typing import Iterable
 from typing import List
+from typing import Mapping
 from typing import Optional
+from typing import Sequence
+from typing import Tuple
 
 import _pytest._code
 from _pytest import outcomes
@@ -22,7 +28,7 @@ _reprcompare = None  # type: Optional[Callable[[str, object, object], Optional[s
 _assertion_pass = None  # type: Optional[Callable[[int, str, str], None]]
 
 
-def format_explanation(explanation):
+def format_explanation(explanation: str) -> str:
     """This formats an explanation
 
     Normally all embedded newlines are escaped, however there are
@@ -38,7 +44,7 @@ def format_explanation(explanation):
     return "\n".join(result)
 
 
-def _split_explanation(explanation):
+def _split_explanation(explanation: str) -> List[str]:
     """Return a list of individual lines in the explanation
 
     This will return a list of lines split on '\n{', '\n}' and '\n~'.
@@ -55,7 +61,7 @@ def _split_explanation(explanation):
     return lines
 
 
-def _format_lines(lines):
+def _format_lines(lines: Sequence[str]) -> List[str]:
     """Format the individual lines
 
     This will replace the '{', '}' and '~' characters of our mini
@@ -64,7 +70,7 @@ def _format_lines(lines):
 
     Return a list of formatted lines.
     """
-    result = lines[:1]
+    result = list(lines[:1])
     stack = [0]
     stackcnt = [0]
     for line in lines[1:]:
@@ -90,31 +96,31 @@ def _format_lines(lines):
     return result
 
 
-def issequence(x):
-    return isinstance(x, Sequence) and not isinstance(x, str)
+def issequence(x: Any) -> bool:
+    return isinstance(x, collections.abc.Sequence) and not isinstance(x, str)
 
 
-def istext(x):
+def istext(x: Any) -> bool:
     return isinstance(x, str)
 
 
-def isdict(x):
+def isdict(x: Any) -> bool:
     return isinstance(x, dict)
 
 
-def isset(x):
+def isset(x: Any) -> bool:
     return isinstance(x, (set, frozenset))
 
 
-def isdatacls(obj):
+def isdatacls(obj: Any) -> bool:
     return getattr(obj, "__dataclass_fields__", None) is not None
 
 
-def isattrs(obj):
+def isattrs(obj: Any) -> bool:
     return getattr(obj, "__attrs_attrs__", None) is not None
 
 
-def isiterable(obj):
+def isiterable(obj: Any) -> bool:
     try:
         iter(obj)
         return not istext(obj)
@@ -122,7 +128,7 @@ def isiterable(obj):
         return False
 
 
-def assertrepr_compare(config, op, left, right):
+def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[str]]:
     """Return specialised explanations for some operators/operands"""
     verbose = config.getoption("verbose")
     if verbose > 1:
@@ -180,7 +186,7 @@ def assertrepr_compare(config, op, left, right):
     return [summary] + explanation
 
 
-def _diff_text(left, right, verbose=0):
+def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
     """Return the explanation for the diff between text.
 
     Unless --verbose is used this will skip leading and trailing
@@ -226,7 +232,7 @@ def _diff_text(left, right, verbose=0):
     return explanation
 
 
-def _compare_eq_verbose(left, right):
+def _compare_eq_verbose(left: Any, right: Any) -> List[str]:
     keepends = True
     left_lines = repr(left).splitlines(keepends)
     right_lines = repr(right).splitlines(keepends)
@@ -238,7 +244,7 @@ def _compare_eq_verbose(left, right):
     return explanation
 
 
-def _surrounding_parens_on_own_lines(lines):  # type: (List) -> None
+def _surrounding_parens_on_own_lines(lines: List[str]) -> None:
     """Move opening/closing parenthesis/bracket to own lines."""
     opening = lines[0][:1]
     if opening in ["(", "[", "{"]:
@@ -250,7 +256,9 @@ def _surrounding_parens_on_own_lines(lines):  # type: (List) -> None
         lines[:] = lines + [closing]
 
 
-def _compare_eq_iterable(left, right, verbose=0):
+def _compare_eq_iterable(
+    left: Iterable[Any], right: Iterable[Any], verbose: int = 0
+) -> List[str]:
     if not verbose:
         return ["Use -v to get the full diff"]
     # dynamic import to speedup pytest
@@ -283,7 +291,9 @@ def _compare_eq_iterable(left, right, verbose=0):
     return explanation
 
 
-def _compare_eq_sequence(left, right, verbose=0):
+def _compare_eq_sequence(
+    left: Sequence[Any], right: Sequence[Any], verbose: int = 0
+) -> List[str]:
     comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
     explanation = []  # type: List[str]
     len_left = len(left)
@@ -337,7 +347,9 @@ def _compare_eq_sequence(left, right, verbose=0):
     return explanation
 
 
-def _compare_eq_set(left, right, verbose=0):
+def _compare_eq_set(
+    left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0
+) -> List[str]:
     explanation = []
     diff_left = left - right
     diff_right = right - left
@@ -352,7 +364,9 @@ def _compare_eq_set(left, right, verbose=0):
     return explanation
 
 
-def _compare_eq_dict(left, right, verbose=0):
+def _compare_eq_dict(
+    left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0
+) -> List[str]:
     explanation = []  # type: List[str]
     set_left = set(left)
     set_right = set(right)
@@ -391,7 +405,12 @@ def _compare_eq_dict(left, right, verbose=0):
     return explanation
 
 
-def _compare_eq_cls(left, right, verbose, type_fns):
+def _compare_eq_cls(
+    left: Any,
+    right: Any,
+    verbose: int,
+    type_fns: Tuple[Callable[[Any], bool], Callable[[Any], bool]],
+) -> List[str]:
     isdatacls, isattrs = type_fns
     if isdatacls(left):
         all_fields = left.__dataclass_fields__
@@ -425,7 +444,7 @@ def _compare_eq_cls(left, right, verbose, type_fns):
     return explanation
 
 
-def _notin_text(term, text, verbose=0):
+def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]:
     index = text.find(term)
     head = text[:index]
     tail = text[index + len(term) :]