pprint: Type annotate the module

This will make it easier to refactor
This commit is contained in:
Benjamin Schubert 2023-11-21 22:05:37 +00:00
parent 64e72b79f6
commit e5a448cd5f
1 changed files with 217 additions and 36 deletions

View File

@ -21,7 +21,10 @@ from typing import Any
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import IO from typing import IO
from typing import Iterator
from typing import List from typing import List
from typing import Optional
from typing import Tuple
class _safe_key: class _safe_key:
@ -57,13 +60,13 @@ def _safe_tuple(t):
class PrettyPrinter: class PrettyPrinter:
def __init__( def __init__(
self, self,
indent=4, indent: int = 4,
width=80, width: int = 80,
depth=None, depth: Optional[int] = None,
*, *,
sort_dicts=True, sort_dicts: bool = True,
underscore_numbers=False, underscore_numbers: bool = False,
): ) -> None:
"""Handle pretty printing operations onto a stream using a set of """Handle pretty printing operations onto a stream using a set of
configured parameters. configured parameters.
@ -99,7 +102,15 @@ class PrettyPrinter:
self._format(object, sio, 0, 0, {}, 0) self._format(object, sio, 0, 0, {}, 0)
return sio.getvalue() return sio.getvalue()
def _format(self, object, stream, indent, allowance, context, level): def _format(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
objid = id(object) objid = id(object)
if objid in context: if objid in context:
stream.write(_recursion(object)) stream.write(_recursion(object))
@ -129,7 +140,15 @@ class PrettyPrinter:
else: else:
stream.write(self._repr(object, context, level)) stream.write(self._repr(object, context, level))
def _pprint_dataclass(self, object, stream, indent, allowance, context, level): def _pprint_dataclass(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
cls_name = object.__class__.__name__ cls_name = object.__class__.__name__
items = [ items = [
(f.name, getattr(object, f.name)) (f.name, getattr(object, f.name))
@ -142,10 +161,18 @@ class PrettyPrinter:
_dispatch: Dict[ _dispatch: Dict[
Callable[..., str], Callable[..., str],
Callable[["PrettyPrinter", Any, IO[str], int, int, Dict[int, int], int], str], Callable[["PrettyPrinter", Any, IO[str], int, int, Dict[int, int], int], None],
] = {} ] = {}
def _pprint_dict(self, object, stream, indent, allowance, context, level): def _pprint_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
write = stream.write write = stream.write
write("{") write("{")
if self._sort_dicts: if self._sort_dicts:
@ -157,7 +184,15 @@ class PrettyPrinter:
_dispatch[dict.__repr__] = _pprint_dict _dispatch[dict.__repr__] = _pprint_dict
def _pprint_ordered_dict(self, object, stream, indent, allowance, context, level): def _pprint_ordered_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not len(object): if not len(object):
stream.write(repr(object)) stream.write(repr(object))
return return
@ -168,21 +203,45 @@ class PrettyPrinter:
_dispatch[_collections.OrderedDict.__repr__] = _pprint_ordered_dict _dispatch[_collections.OrderedDict.__repr__] = _pprint_ordered_dict
def _pprint_list(self, object, stream, indent, allowance, context, level): def _pprint_list(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
stream.write("[") stream.write("[")
self._format_items(object, stream, indent, allowance, context, level) self._format_items(object, stream, indent, allowance, context, level)
stream.write("]") stream.write("]")
_dispatch[list.__repr__] = _pprint_list _dispatch[list.__repr__] = _pprint_list
def _pprint_tuple(self, object, stream, indent, allowance, context, level): def _pprint_tuple(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
stream.write("(") stream.write("(")
self._format_items(object, stream, indent, allowance, context, level) self._format_items(object, stream, indent, allowance, context, level)
stream.write(")") stream.write(")")
_dispatch[tuple.__repr__] = _pprint_tuple _dispatch[tuple.__repr__] = _pprint_tuple
def _pprint_set(self, object, stream, indent, allowance, context, level): def _pprint_set(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not len(object): if not len(object):
stream.write(repr(object)) stream.write(repr(object))
return return
@ -200,7 +259,15 @@ class PrettyPrinter:
_dispatch[set.__repr__] = _pprint_set _dispatch[set.__repr__] = _pprint_set
_dispatch[frozenset.__repr__] = _pprint_set _dispatch[frozenset.__repr__] = _pprint_set
def _pprint_str(self, object, stream, indent, allowance, context, level): def _pprint_str(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
write = stream.write write = stream.write
if not len(object): if not len(object):
write(repr(object)) write(repr(object))
@ -251,7 +318,15 @@ class PrettyPrinter:
_dispatch[str.__repr__] = _pprint_str _dispatch[str.__repr__] = _pprint_str
def _pprint_bytes(self, object, stream, indent, allowance, context, level): def _pprint_bytes(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
write = stream.write write = stream.write
if len(object) <= 4: if len(object) <= 4:
write(repr(object)) write(repr(object))
@ -272,7 +347,15 @@ class PrettyPrinter:
_dispatch[bytes.__repr__] = _pprint_bytes _dispatch[bytes.__repr__] = _pprint_bytes
def _pprint_bytearray(self, object, stream, indent, allowance, context, level): def _pprint_bytearray(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
write = stream.write write = stream.write
write("bytearray(") write("bytearray(")
self._pprint_bytes( self._pprint_bytes(
@ -282,7 +365,15 @@ class PrettyPrinter:
_dispatch[bytearray.__repr__] = _pprint_bytearray _dispatch[bytearray.__repr__] = _pprint_bytearray
def _pprint_mappingproxy(self, object, stream, indent, allowance, context, level): def _pprint_mappingproxy(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
stream.write("mappingproxy(") stream.write("mappingproxy(")
self._format(object.copy(), stream, indent, allowance, context, level) self._format(object.copy(), stream, indent, allowance, context, level)
stream.write(")") stream.write(")")
@ -290,8 +381,14 @@ class PrettyPrinter:
_dispatch[_types.MappingProxyType.__repr__] = _pprint_mappingproxy _dispatch[_types.MappingProxyType.__repr__] = _pprint_mappingproxy
def _pprint_simplenamespace( def _pprint_simplenamespace(
self, object, stream, indent, allowance, context, level self,
): object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if type(object) is _types.SimpleNamespace: if type(object) is _types.SimpleNamespace:
# The SimpleNamespace repr is "namespace" instead of the class # The SimpleNamespace repr is "namespace" instead of the class
# name, so we do the same here. For subclasses; use the class name. # name, so we do the same here. For subclasses; use the class name.
@ -305,7 +402,15 @@ class PrettyPrinter:
_dispatch[_types.SimpleNamespace.__repr__] = _pprint_simplenamespace _dispatch[_types.SimpleNamespace.__repr__] = _pprint_simplenamespace
def _format_dict_items(self, items, stream, indent, allowance, context, level): def _format_dict_items(
self,
items: List[Tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not items: if not items:
return return
@ -321,7 +426,15 @@ class PrettyPrinter:
write("\n" + " " * indent) write("\n" + " " * indent)
def _format_namespace_items(self, items, stream, indent, allowance, context, level): def _format_namespace_items(
self,
items: List[Tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not items: if not items:
return return
@ -350,7 +463,15 @@ class PrettyPrinter:
write("\n" + " " * indent) write("\n" + " " * indent)
def _format_items(self, items, stream, indent, allowance, context, level): def _format_items(
self,
items: List[Any],
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not items: if not items:
return return
@ -365,7 +486,7 @@ class PrettyPrinter:
write("\n" + " " * indent) write("\n" + " " * indent)
def _repr(self, object, context, level): def _repr(self, object: Any, context: Dict[int, int], level: int) -> str:
repr, readable, recursive = self.format( repr, readable, recursive = self.format(
object, context.copy(), self._depth, level object, context.copy(), self._depth, level
) )
@ -375,14 +496,24 @@ class PrettyPrinter:
self._recursive = True self._recursive = True
return repr return repr
def format(self, object, context, maxlevels, level): def format(
self, object: Any, context: Dict[int, int], maxlevels: Optional[int], level: int
) -> Tuple[str, bool, bool]:
"""Format object for a specific context, returning a string """Format object for a specific context, returning a string
and flags indicating whether the representation is 'readable' and flags indicating whether the representation is 'readable'
and whether the object represents a recursive construct. and whether the object represents a recursive construct.
""" """
return self._safe_repr(object, context, maxlevels, level) return self._safe_repr(object, context, maxlevels, level)
def _pprint_default_dict(self, object, stream, indent, allowance, context, level): def _pprint_default_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
rdf = self._repr(object.default_factory, context, level) rdf = self._repr(object.default_factory, context, level)
stream.write(f"{object.__class__.__name__}({rdf}, ") stream.write(f"{object.__class__.__name__}({rdf}, ")
self._pprint_dict(object, stream, indent, allowance, context, level) self._pprint_dict(object, stream, indent, allowance, context, level)
@ -390,7 +521,15 @@ class PrettyPrinter:
_dispatch[_collections.defaultdict.__repr__] = _pprint_default_dict _dispatch[_collections.defaultdict.__repr__] = _pprint_default_dict
def _pprint_counter(self, object, stream, indent, allowance, context, level): def _pprint_counter(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(") stream.write(object.__class__.__name__ + "(")
if object: if object:
@ -403,7 +542,15 @@ class PrettyPrinter:
_dispatch[_collections.Counter.__repr__] = _pprint_counter _dispatch[_collections.Counter.__repr__] = _pprint_counter
def _pprint_chain_map(self, object, stream, indent, allowance, context, level): def _pprint_chain_map(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])): if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])):
stream.write(repr(object)) stream.write(repr(object))
return return
@ -414,7 +561,15 @@ class PrettyPrinter:
_dispatch[_collections.ChainMap.__repr__] = _pprint_chain_map _dispatch[_collections.ChainMap.__repr__] = _pprint_chain_map
def _pprint_deque(self, object, stream, indent, allowance, context, level): def _pprint_deque(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(") stream.write(object.__class__.__name__ + "(")
if object.maxlen is not None: if object.maxlen is not None:
stream.write("maxlen=%d, " % object.maxlen) stream.write("maxlen=%d, " % object.maxlen)
@ -425,22 +580,48 @@ class PrettyPrinter:
_dispatch[_collections.deque.__repr__] = _pprint_deque _dispatch[_collections.deque.__repr__] = _pprint_deque
def _pprint_user_dict(self, object, stream, indent, allowance, context, level): def _pprint_user_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1) self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserDict.__repr__] = _pprint_user_dict _dispatch[_collections.UserDict.__repr__] = _pprint_user_dict
def _pprint_user_list(self, object, stream, indent, allowance, context, level): def _pprint_user_list(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1) self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserList.__repr__] = _pprint_user_list _dispatch[_collections.UserList.__repr__] = _pprint_user_list
def _pprint_user_string(self, object, stream, indent, allowance, context, level): def _pprint_user_string(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1) self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserString.__repr__] = _pprint_user_string _dispatch[_collections.UserString.__repr__] = _pprint_user_string
def _safe_repr(self, object, context, maxlevels, level): def _safe_repr(
self, object: Any, context: Dict[int, int], maxlevels: Optional[int], level: int
) -> Tuple[str, bool, bool]:
# Return triple (repr_string, isreadable, isrecursive). # Return triple (repr_string, isreadable, isrecursive).
typ = type(object) typ = type(object)
if typ in _builtin_scalars: if typ in _builtin_scalars:
@ -517,17 +698,17 @@ class PrettyPrinter:
return format % ", ".join(components), readable, recursive return format % ", ".join(components), readable, recursive
rep = repr(object) rep = repr(object)
return rep, (rep and not rep.startswith("<")), False return rep, bool(rep and not rep.startswith("<")), False
_builtin_scalars = frozenset({str, bytes, bytearray, float, complex, bool, type(None)}) _builtin_scalars = frozenset({str, bytes, bytearray, float, complex, bool, type(None)})
def _recursion(object): def _recursion(object: Any) -> str:
return f"<Recursion on {type(object).__name__} with id={id(object)}>" return f"<Recursion on {type(object).__name__} with id={id(object)}>"
def _wrap_bytes_repr(object, width, allowance): def _wrap_bytes_repr(object: Any, width: int, allowance: int) -> Iterator[str]:
current = b"" current = b""
last = len(object) // 4 * 4 last = len(object) // 4 * 4
for i in range(0, len(object), 4): for i in range(0, len(object), 4):