terminalwriter: add type annotations

This commit is contained in:
Ran Benita 2020-04-29 17:08:18 +03:00
parent f6564a548a
commit e8fc5f99fa
1 changed files with 28 additions and 22 deletions

View File

@ -4,7 +4,9 @@ import shutil
import sys import sys
import unicodedata import unicodedata
from functools import lru_cache from functools import lru_cache
from typing import Optional
from typing import Sequence from typing import Sequence
from typing import TextIO
# This code was initially copied from py 1.8.1, file _io/terminalwriter.py. # This code was initially copied from py 1.8.1, file _io/terminalwriter.py.
@ -26,12 +28,12 @@ def char_width(c: str) -> int:
return 2 if unicodedata.east_asian_width(c) in ("F", "W") else 1 return 2 if unicodedata.east_asian_width(c) in ("F", "W") else 1
def get_line_width(text): def get_line_width(text: str) -> int:
text = unicodedata.normalize("NFC", text) text = unicodedata.normalize("NFC", text)
return sum(char_width(c) for c in text) return sum(char_width(c) for c in text)
def should_do_markup(file): def should_do_markup(file: TextIO) -> bool:
if os.environ.get("PY_COLORS") == "1": if os.environ.get("PY_COLORS") == "1":
return True return True
if os.environ.get("PY_COLORS") == "0": if os.environ.get("PY_COLORS") == "0":
@ -68,7 +70,7 @@ class TerminalWriter:
invert=7, invert=7,
) )
def __init__(self, file=None): def __init__(self, file: Optional[TextIO] = None) -> None:
if file is None: if file is None:
file = sys.stdout file = sys.stdout
if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32": if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32":
@ -78,38 +80,33 @@ class TerminalWriter:
pass pass
else: else:
file = colorama.AnsiToWin32(file).stream file = colorama.AnsiToWin32(file).stream
assert file is not None
self._file = file self._file = file
self.hasmarkup = should_do_markup(file) self.hasmarkup = should_do_markup(file)
self._chars_on_current_line = 0 self._chars_on_current_line = 0
self._width_of_current_line = 0 self._width_of_current_line = 0
@property @property
def fullwidth(self): def fullwidth(self) -> int:
if hasattr(self, "_terminal_width"): if hasattr(self, "_terminal_width"):
return self._terminal_width return self._terminal_width
return get_terminal_width() return get_terminal_width()
@fullwidth.setter @fullwidth.setter
def fullwidth(self, value): def fullwidth(self, value: int) -> None:
self._terminal_width = value self._terminal_width = value
@property @property
def chars_on_current_line(self): def chars_on_current_line(self) -> int:
"""Return the number of characters written so far in the current line. """Return the number of characters written so far in the current line."""
:rtype: int
"""
return self._chars_on_current_line return self._chars_on_current_line
@property @property
def width_of_current_line(self): def width_of_current_line(self) -> int:
"""Return an estimate of the width so far in the current line. """Return an estimate of the width so far in the current line."""
:rtype: int
"""
return self._width_of_current_line return self._width_of_current_line
def markup(self, text, **kw): def markup(self, text: str, **kw: bool) -> str:
esc = [] esc = []
for name in kw: for name in kw:
if name not in self._esctable: if name not in self._esctable:
@ -120,7 +117,13 @@ class TerminalWriter:
text = "".join("\x1b[%sm" % cod for cod in esc) + text + "\x1b[0m" text = "".join("\x1b[%sm" % cod for cod in esc) + text + "\x1b[0m"
return text return text
def sep(self, sepchar, title=None, fullwidth=None, **kw): def sep(
self,
sepchar: str,
title: Optional[str] = None,
fullwidth: Optional[int] = None,
**kw: bool
) -> None:
if fullwidth is None: if fullwidth is None:
fullwidth = self.fullwidth fullwidth = self.fullwidth
# the goal is to have the line be as long as possible # the goal is to have the line be as long as possible
@ -151,7 +154,7 @@ class TerminalWriter:
self.line(line, **kw) self.line(line, **kw)
def write(self, msg: str, **kw) -> None: def write(self, msg: str, **kw: bool) -> None:
if msg: if msg:
self._update_chars_on_current_line(msg) self._update_chars_on_current_line(msg)
@ -171,7 +174,7 @@ class TerminalWriter:
self._chars_on_current_line += len(current_line) self._chars_on_current_line += len(current_line)
self._width_of_current_line += get_line_width(current_line) self._width_of_current_line += get_line_width(current_line)
def line(self, s: str = "", **kw): def line(self, s: str = "", **kw: bool) -> None:
self.write(s, **kw) self.write(s, **kw)
self.write("\n") self.write("\n")
@ -195,8 +198,8 @@ class TerminalWriter:
for indent, new_line in zip(indents, new_lines): for indent, new_line in zip(indents, new_lines):
self.line(indent + new_line) self.line(indent + new_line)
def _highlight(self, source): def _highlight(self, source: str) -> str:
"""Highlight the given source code if we have markup support""" """Highlight the given source code if we have markup support."""
if not self.hasmarkup: if not self.hasmarkup:
return source return source
try: try:
@ -206,4 +209,7 @@ class TerminalWriter:
except ImportError: except ImportError:
return source return source
else: else:
return highlight(source, PythonLexer(), TerminalFormatter(bg="dark")) highlighted = highlight(
source, PythonLexer(), TerminalFormatter(bg="dark")
) # type: str
return highlighted