pprint: Type annotate the module

This will make it easier to refactor
This commit is contained in:
Benjamin Schubert 2023-11-21 22:05:37 +00:00
parent 64e72b79f6
commit e5a448cd5f
1 changed files with 217 additions and 36 deletions

View File

@ -21,7 +21,10 @@ from typing import Any
from typing import Callable
from typing import Dict
from typing import IO
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple
class _safe_key:
@ -57,13 +60,13 @@ def _safe_tuple(t):
class PrettyPrinter:
def __init__(
self,
indent=4,
width=80,
depth=None,
indent: int = 4,
width: int = 80,
depth: Optional[int] = None,
*,
sort_dicts=True,
underscore_numbers=False,
):
sort_dicts: bool = True,
underscore_numbers: bool = False,
) -> None:
"""Handle pretty printing operations onto a stream using a set of
configured parameters.
@ -99,7 +102,15 @@ class PrettyPrinter:
self._format(object, sio, 0, 0, {}, 0)
return sio.getvalue()
def _format(self, object, stream, indent, allowance, context, level):
def _format(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
objid = id(object)
if objid in context:
stream.write(_recursion(object))
@ -129,7 +140,15 @@ class PrettyPrinter:
else:
stream.write(self._repr(object, context, level))
def _pprint_dataclass(self, object, stream, indent, allowance, context, level):
def _pprint_dataclass(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
cls_name = object.__class__.__name__
items = [
(f.name, getattr(object, f.name))
@ -142,10 +161,18 @@ class PrettyPrinter:
_dispatch: Dict[
Callable[..., str],
Callable[["PrettyPrinter", Any, IO[str], int, int, Dict[int, int], int], str],
Callable[["PrettyPrinter", Any, IO[str], int, int, Dict[int, int], int], None],
] = {}
def _pprint_dict(self, object, stream, indent, allowance, context, level):
def _pprint_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
write = stream.write
write("{")
if self._sort_dicts:
@ -157,7 +184,15 @@ class PrettyPrinter:
_dispatch[dict.__repr__] = _pprint_dict
def _pprint_ordered_dict(self, object, stream, indent, allowance, context, level):
def _pprint_ordered_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not len(object):
stream.write(repr(object))
return
@ -168,21 +203,45 @@ class PrettyPrinter:
_dispatch[_collections.OrderedDict.__repr__] = _pprint_ordered_dict
def _pprint_list(self, object, stream, indent, allowance, context, level):
def _pprint_list(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
stream.write("[")
self._format_items(object, stream, indent, allowance, context, level)
stream.write("]")
_dispatch[list.__repr__] = _pprint_list
def _pprint_tuple(self, object, stream, indent, allowance, context, level):
def _pprint_tuple(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
stream.write("(")
self._format_items(object, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[tuple.__repr__] = _pprint_tuple
def _pprint_set(self, object, stream, indent, allowance, context, level):
def _pprint_set(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not len(object):
stream.write(repr(object))
return
@ -200,7 +259,15 @@ class PrettyPrinter:
_dispatch[set.__repr__] = _pprint_set
_dispatch[frozenset.__repr__] = _pprint_set
def _pprint_str(self, object, stream, indent, allowance, context, level):
def _pprint_str(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
write = stream.write
if not len(object):
write(repr(object))
@ -251,7 +318,15 @@ class PrettyPrinter:
_dispatch[str.__repr__] = _pprint_str
def _pprint_bytes(self, object, stream, indent, allowance, context, level):
def _pprint_bytes(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
write = stream.write
if len(object) <= 4:
write(repr(object))
@ -272,7 +347,15 @@ class PrettyPrinter:
_dispatch[bytes.__repr__] = _pprint_bytes
def _pprint_bytearray(self, object, stream, indent, allowance, context, level):
def _pprint_bytearray(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
write = stream.write
write("bytearray(")
self._pprint_bytes(
@ -282,7 +365,15 @@ class PrettyPrinter:
_dispatch[bytearray.__repr__] = _pprint_bytearray
def _pprint_mappingproxy(self, object, stream, indent, allowance, context, level):
def _pprint_mappingproxy(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
stream.write("mappingproxy(")
self._format(object.copy(), stream, indent, allowance, context, level)
stream.write(")")
@ -290,8 +381,14 @@ class PrettyPrinter:
_dispatch[_types.MappingProxyType.__repr__] = _pprint_mappingproxy
def _pprint_simplenamespace(
self, object, stream, indent, allowance, context, level
):
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if type(object) is _types.SimpleNamespace:
# The SimpleNamespace repr is "namespace" instead of the class
# name, so we do the same here. For subclasses; use the class name.
@ -305,7 +402,15 @@ class PrettyPrinter:
_dispatch[_types.SimpleNamespace.__repr__] = _pprint_simplenamespace
def _format_dict_items(self, items, stream, indent, allowance, context, level):
def _format_dict_items(
self,
items: List[Tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not items:
return
@ -321,7 +426,15 @@ class PrettyPrinter:
write("\n" + " " * indent)
def _format_namespace_items(self, items, stream, indent, allowance, context, level):
def _format_namespace_items(
self,
items: List[Tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not items:
return
@ -350,7 +463,15 @@ class PrettyPrinter:
write("\n" + " " * indent)
def _format_items(self, items, stream, indent, allowance, context, level):
def _format_items(
self,
items: List[Any],
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not items:
return
@ -365,7 +486,7 @@ class PrettyPrinter:
write("\n" + " " * indent)
def _repr(self, object, context, level):
def _repr(self, object: Any, context: Dict[int, int], level: int) -> str:
repr, readable, recursive = self.format(
object, context.copy(), self._depth, level
)
@ -375,14 +496,24 @@ class PrettyPrinter:
self._recursive = True
return repr
def format(self, object, context, maxlevels, level):
def format(
self, object: Any, context: Dict[int, int], maxlevels: Optional[int], level: int
) -> Tuple[str, bool, bool]:
"""Format object for a specific context, returning a string
and flags indicating whether the representation is 'readable'
and whether the object represents a recursive construct.
"""
return self._safe_repr(object, context, maxlevels, level)
def _pprint_default_dict(self, object, stream, indent, allowance, context, level):
def _pprint_default_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
rdf = self._repr(object.default_factory, context, level)
stream.write(f"{object.__class__.__name__}({rdf}, ")
self._pprint_dict(object, stream, indent, allowance, context, level)
@ -390,7 +521,15 @@ class PrettyPrinter:
_dispatch[_collections.defaultdict.__repr__] = _pprint_default_dict
def _pprint_counter(self, object, stream, indent, allowance, context, level):
def _pprint_counter(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(")
if object:
@ -403,7 +542,15 @@ class PrettyPrinter:
_dispatch[_collections.Counter.__repr__] = _pprint_counter
def _pprint_chain_map(self, object, stream, indent, allowance, context, level):
def _pprint_chain_map(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])):
stream.write(repr(object))
return
@ -414,7 +561,15 @@ class PrettyPrinter:
_dispatch[_collections.ChainMap.__repr__] = _pprint_chain_map
def _pprint_deque(self, object, stream, indent, allowance, context, level):
def _pprint_deque(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(")
if object.maxlen is not None:
stream.write("maxlen=%d, " % object.maxlen)
@ -425,22 +580,48 @@ class PrettyPrinter:
_dispatch[_collections.deque.__repr__] = _pprint_deque
def _pprint_user_dict(self, object, stream, indent, allowance, context, level):
def _pprint_user_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserDict.__repr__] = _pprint_user_dict
def _pprint_user_list(self, object, stream, indent, allowance, context, level):
def _pprint_user_list(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserList.__repr__] = _pprint_user_list
def _pprint_user_string(self, object, stream, indent, allowance, context, level):
def _pprint_user_string(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserString.__repr__] = _pprint_user_string
def _safe_repr(self, object, context, maxlevels, level):
def _safe_repr(
self, object: Any, context: Dict[int, int], maxlevels: Optional[int], level: int
) -> Tuple[str, bool, bool]:
# Return triple (repr_string, isreadable, isrecursive).
typ = type(object)
if typ in _builtin_scalars:
@ -517,17 +698,17 @@ class PrettyPrinter:
return format % ", ".join(components), readable, recursive
rep = repr(object)
return rep, (rep and not rep.startswith("<")), False
return rep, bool(rep and not rep.startswith("<")), False
_builtin_scalars = frozenset({str, bytes, bytearray, float, complex, bool, type(None)})
def _recursion(object):
def _recursion(object: Any) -> str:
return f"<Recursion on {type(object).__name__} with id={id(object)}>"
def _wrap_bytes_repr(object, width, allowance):
def _wrap_bytes_repr(object: Any, width: int, allowance: int) -> Iterator[str]:
current = b""
last = len(object) // 4 * 4
for i in range(0, len(object), 4):