Fix some check_untyped_defs mypy errors in terminal

This commit is contained in:
Ran Benita 2019-09-13 23:09:08 +03:00
parent 5dca7a2f4f
commit 0267b25c66
1 changed files with 26 additions and 16 deletions

View File

@ -9,6 +9,12 @@ import platform
import sys import sys
import time import time
from functools import partial from functools import partial
from typing import Callable
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Set
import attr import attr
import pluggy import pluggy
@ -195,8 +201,8 @@ class WarningReport:
file system location of the source of the warning (see ``get_location``). file system location of the source of the warning (see ``get_location``).
""" """
message = attr.ib() message = attr.ib(type=str)
nodeid = attr.ib(default=None) nodeid = attr.ib(type=Optional[str], default=None)
fslocation = attr.ib(default=None) fslocation = attr.ib(default=None)
count_towards_summary = True count_towards_summary = True
@ -240,7 +246,7 @@ class TerminalReporter:
self.reportchars = getreportopt(config) self.reportchars = getreportopt(config)
self.hasmarkup = self._tw.hasmarkup self.hasmarkup = self._tw.hasmarkup
self.isatty = file.isatty() self.isatty = file.isatty()
self._progress_nodeids_reported = set() self._progress_nodeids_reported = set() # type: Set[str]
self._show_progress_info = self._determine_show_progress_info() self._show_progress_info = self._determine_show_progress_info()
self._collect_report_last_write = None self._collect_report_last_write = None
@ -619,7 +625,7 @@ class TerminalReporter:
# because later versions are going to get rid of them anyway # because later versions are going to get rid of them anyway
if self.config.option.verbose < 0: if self.config.option.verbose < 0:
if self.config.option.verbose < -1: if self.config.option.verbose < -1:
counts = {} counts = {} # type: Dict[str, int]
for item in items: for item in items:
name = item.nodeid.split("::", 1)[0] name = item.nodeid.split("::", 1)[0]
counts[name] = counts.get(name, 0) + 1 counts[name] = counts.get(name, 0) + 1
@ -750,7 +756,9 @@ class TerminalReporter:
def summary_warnings(self): def summary_warnings(self):
if self.hasopt("w"): if self.hasopt("w"):
all_warnings = self.stats.get("warnings") all_warnings = self.stats.get(
"warnings"
) # type: Optional[List[WarningReport]]
if not all_warnings: if not all_warnings:
return return
@ -763,7 +771,9 @@ class TerminalReporter:
if not warning_reports: if not warning_reports:
return return
reports_grouped_by_message = collections.OrderedDict() reports_grouped_by_message = (
collections.OrderedDict()
) # type: collections.OrderedDict[str, List[WarningReport]]
for wr in warning_reports: for wr in warning_reports:
reports_grouped_by_message.setdefault(wr.message, []).append(wr) reports_grouped_by_message.setdefault(wr.message, []).append(wr)
@ -900,11 +910,11 @@ class TerminalReporter:
else: else:
self.write_line(msg, **main_markup) self.write_line(msg, **main_markup)
def short_test_summary(self): def short_test_summary(self) -> None:
if not self.reportchars: if not self.reportchars:
return return
def show_simple(stat, lines): def show_simple(stat, lines: List[str]) -> None:
failed = self.stats.get(stat, []) failed = self.stats.get(stat, [])
if not failed: if not failed:
return return
@ -914,7 +924,7 @@ class TerminalReporter:
line = _get_line_with_reprcrash_message(config, rep, termwidth) line = _get_line_with_reprcrash_message(config, rep, termwidth)
lines.append(line) lines.append(line)
def show_xfailed(lines): def show_xfailed(lines: List[str]) -> None:
xfailed = self.stats.get("xfailed", []) xfailed = self.stats.get("xfailed", [])
for rep in xfailed: for rep in xfailed:
verbose_word = rep._get_verbose_word(self.config) verbose_word = rep._get_verbose_word(self.config)
@ -924,7 +934,7 @@ class TerminalReporter:
if reason: if reason:
lines.append(" " + str(reason)) lines.append(" " + str(reason))
def show_xpassed(lines): def show_xpassed(lines: List[str]) -> None:
xpassed = self.stats.get("xpassed", []) xpassed = self.stats.get("xpassed", [])
for rep in xpassed: for rep in xpassed:
verbose_word = rep._get_verbose_word(self.config) verbose_word = rep._get_verbose_word(self.config)
@ -932,7 +942,7 @@ class TerminalReporter:
reason = rep.wasxfail reason = rep.wasxfail
lines.append("{} {} {}".format(verbose_word, pos, reason)) lines.append("{} {} {}".format(verbose_word, pos, reason))
def show_skipped(lines): def show_skipped(lines: List[str]) -> None:
skipped = self.stats.get("skipped", []) skipped = self.stats.get("skipped", [])
fskips = _folded_skips(skipped) if skipped else [] fskips = _folded_skips(skipped) if skipped else []
if not fskips: if not fskips:
@ -958,9 +968,9 @@ class TerminalReporter:
"S": show_skipped, "S": show_skipped,
"p": partial(show_simple, "passed"), "p": partial(show_simple, "passed"),
"E": partial(show_simple, "error"), "E": partial(show_simple, "error"),
} } # type: Mapping[str, Callable[[List[str]], None]]
lines = [] lines = [] # type: List[str]
for char in self.reportchars: for char in self.reportchars:
action = REPORTCHAR_ACTIONS.get(char) action = REPORTCHAR_ACTIONS.get(char)
if action: # skipping e.g. "P" (passed with output) here. if action: # skipping e.g. "P" (passed with output) here.
@ -1084,8 +1094,8 @@ def build_summary_stats_line(stats):
return parts, main_color return parts, main_color
def _plugin_nameversions(plugininfo): def _plugin_nameversions(plugininfo) -> List[str]:
values = [] values = [] # type: List[str]
for plugin, dist in plugininfo: for plugin, dist in plugininfo:
# gets us name and version! # gets us name and version!
name = "{dist.project_name}-{dist.version}".format(dist=dist) name = "{dist.project_name}-{dist.version}".format(dist=dist)
@ -1099,7 +1109,7 @@ def _plugin_nameversions(plugininfo):
return values return values
def format_session_duration(seconds): def format_session_duration(seconds: float) -> str:
"""Format the given seconds in a human readable manner to show in the final summary""" """Format the given seconds in a human readable manner to show in the final summary"""
if seconds < 60: if seconds < 60:
return "{:.2f}s".format(seconds) return "{:.2f}s".format(seconds)