Fix some check_untyped_defs mypy errors in terminal

This commit is contained in:
Ran Benita 2019-09-13 23:09:08 +03:00
parent 5dca7a2f4f
commit 0267b25c66
1 changed files with 26 additions and 16 deletions

View File

@ -9,6 +9,12 @@ import platform
import sys
import time
from functools import partial
from typing import Callable
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Set
import attr
import pluggy
@ -195,8 +201,8 @@ class WarningReport:
file system location of the source of the warning (see ``get_location``).
"""
message = attr.ib()
nodeid = attr.ib(default=None)
message = attr.ib(type=str)
nodeid = attr.ib(type=Optional[str], default=None)
fslocation = attr.ib(default=None)
count_towards_summary = True
@ -240,7 +246,7 @@ class TerminalReporter:
self.reportchars = getreportopt(config)
self.hasmarkup = self._tw.hasmarkup
self.isatty = file.isatty()
self._progress_nodeids_reported = set()
self._progress_nodeids_reported = set() # type: Set[str]
self._show_progress_info = self._determine_show_progress_info()
self._collect_report_last_write = None
@ -619,7 +625,7 @@ class TerminalReporter:
# because later versions are going to get rid of them anyway
if self.config.option.verbose < 0:
if self.config.option.verbose < -1:
counts = {}
counts = {} # type: Dict[str, int]
for item in items:
name = item.nodeid.split("::", 1)[0]
counts[name] = counts.get(name, 0) + 1
@ -750,7 +756,9 @@ class TerminalReporter:
def summary_warnings(self):
if self.hasopt("w"):
all_warnings = self.stats.get("warnings")
all_warnings = self.stats.get(
"warnings"
) # type: Optional[List[WarningReport]]
if not all_warnings:
return
@ -763,7 +771,9 @@ class TerminalReporter:
if not warning_reports:
return
reports_grouped_by_message = collections.OrderedDict()
reports_grouped_by_message = (
collections.OrderedDict()
) # type: collections.OrderedDict[str, List[WarningReport]]
for wr in warning_reports:
reports_grouped_by_message.setdefault(wr.message, []).append(wr)
@ -900,11 +910,11 @@ class TerminalReporter:
else:
self.write_line(msg, **main_markup)
def short_test_summary(self):
def short_test_summary(self) -> None:
if not self.reportchars:
return
def show_simple(stat, lines):
def show_simple(stat, lines: List[str]) -> None:
failed = self.stats.get(stat, [])
if not failed:
return
@ -914,7 +924,7 @@ class TerminalReporter:
line = _get_line_with_reprcrash_message(config, rep, termwidth)
lines.append(line)
def show_xfailed(lines):
def show_xfailed(lines: List[str]) -> None:
xfailed = self.stats.get("xfailed", [])
for rep in xfailed:
verbose_word = rep._get_verbose_word(self.config)
@ -924,7 +934,7 @@ class TerminalReporter:
if reason:
lines.append(" " + str(reason))
def show_xpassed(lines):
def show_xpassed(lines: List[str]) -> None:
xpassed = self.stats.get("xpassed", [])
for rep in xpassed:
verbose_word = rep._get_verbose_word(self.config)
@ -932,7 +942,7 @@ class TerminalReporter:
reason = rep.wasxfail
lines.append("{} {} {}".format(verbose_word, pos, reason))
def show_skipped(lines):
def show_skipped(lines: List[str]) -> None:
skipped = self.stats.get("skipped", [])
fskips = _folded_skips(skipped) if skipped else []
if not fskips:
@ -958,9 +968,9 @@ class TerminalReporter:
"S": show_skipped,
"p": partial(show_simple, "passed"),
"E": partial(show_simple, "error"),
}
} # type: Mapping[str, Callable[[List[str]], None]]
lines = []
lines = [] # type: List[str]
for char in self.reportchars:
action = REPORTCHAR_ACTIONS.get(char)
if action: # skipping e.g. "P" (passed with output) here.
@ -1084,8 +1094,8 @@ def build_summary_stats_line(stats):
return parts, main_color
def _plugin_nameversions(plugininfo):
values = []
def _plugin_nameversions(plugininfo) -> List[str]:
values = [] # type: List[str]
for plugin, dist in plugininfo:
# gets us name and version!
name = "{dist.project_name}-{dist.version}".format(dist=dist)
@ -1099,7 +1109,7 @@ def _plugin_nameversions(plugininfo):
return values
def format_session_duration(seconds):
def format_session_duration(seconds: float) -> str:
"""Format the given seconds in a human readable manner to show in the final summary"""
if seconds < 60:
return "{:.2f}s".format(seconds)