diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index c225eff5f..4e7db8369 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -78,7 +78,8 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder): # there's nothing to rewrite there # python3.5 - python3.6: `namespace` # python3.7+: `None` - or spec.origin in {None, "namespace"} + or spec.origin == "namespace" + or spec.origin is None # we can only rewrite source files or not isinstance(spec.loader, importlib.machinery.SourceFileLoader) # if the file doesn't exist, we can't rewrite it @@ -743,8 +744,7 @@ class AssertionRewriter(ast.NodeVisitor): from _pytest.warning_types import PytestAssertRewriteWarning import warnings - # Ignore type: typeshed bug https://github.com/python/typeshed/pull/3121 - warnings.warn_explicit( # type: ignore + warnings.warn_explicit( PytestAssertRewriteWarning( "assertion is always true, perhaps remove parentheses?" ), diff --git a/src/_pytest/cacheprovider.py b/src/_pytest/cacheprovider.py index 7a5deaa39..dad76f13f 100755 --- a/src/_pytest/cacheprovider.py +++ b/src/_pytest/cacheprovider.py @@ -7,6 +7,7 @@ ignores the external pytest-cache import json import os from collections import OrderedDict +from typing import List import attr import py @@ -15,6 +16,9 @@ import pytest from .pathlib import Path from .pathlib import resolve_from_str from .pathlib import rm_rf +from _pytest import nodes +from _pytest.config import Config +from _pytest.main import Session README_CONTENT = """\ # pytest cache directory # @@ -263,10 +267,12 @@ class NFPlugin: self.active = config.option.newfirst self.cached_nodeids = config.cache.get("cache/nodeids", []) - def pytest_collection_modifyitems(self, session, config, items): + def pytest_collection_modifyitems( + self, session: Session, config: Config, items: List[nodes.Item] + ) -> None: if self.active: - new_items = OrderedDict() - other_items = OrderedDict() + new_items = OrderedDict() # type: OrderedDict[str, nodes.Item] + other_items = OrderedDict() # type: OrderedDict[str, nodes.Item] for item in items: if item.nodeid not in self.cached_nodeids: new_items[item.nodeid] = item diff --git a/src/_pytest/capture.py b/src/_pytest/capture.py index c4099e6b0..56707822d 100644 --- a/src/_pytest/capture.py +++ b/src/_pytest/capture.py @@ -12,6 +12,7 @@ from tempfile import TemporaryFile import pytest from _pytest.compat import CaptureIO +from _pytest.fixtures import FixtureRequest patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"} @@ -241,13 +242,12 @@ class CaptureManager: capture_fixtures = {"capfd", "capfdbinary", "capsys", "capsysbinary"} -def _ensure_only_one_capture_fixture(request, name): - fixtures = set(request.fixturenames) & capture_fixtures - {name} +def _ensure_only_one_capture_fixture(request: FixtureRequest, name): + fixtures = sorted(set(request.fixturenames) & capture_fixtures - {name}) if fixtures: - fixtures = sorted(fixtures) - fixtures = fixtures[0] if len(fixtures) == 1 else fixtures + arg = fixtures[0] if len(fixtures) == 1 else fixtures raise request.raiseerror( - "cannot use {} and {} at the same time".format(fixtures, name) + "cannot use {} and {} at the same time".format(arg, name) ) diff --git a/src/_pytest/doctest.py b/src/_pytest/doctest.py index db1de1986..48c934e3a 100644 --- a/src/_pytest/doctest.py +++ b/src/_pytest/doctest.py @@ -6,8 +6,12 @@ import sys import traceback import warnings from contextlib import contextmanager +from typing import Dict +from typing import List +from typing import Optional from typing import Sequence from typing import Tuple +from typing import Union import pytest from _pytest import outcomes @@ -20,6 +24,10 @@ from _pytest.outcomes import Skipped from _pytest.python_api import approx from _pytest.warning_types import PytestWarning +if False: # TYPE_CHECKING + import doctest + from typing import Type + DOCTEST_REPORT_CHOICE_NONE = "none" DOCTEST_REPORT_CHOICE_CDIFF = "cdiff" DOCTEST_REPORT_CHOICE_NDIFF = "ndiff" @@ -36,6 +44,8 @@ DOCTEST_REPORT_CHOICES = ( # Lazy definition of runner class RUNNER_CLASS = None +# Lazy definition of output checker class +CHECKER_CLASS = None # type: Optional[Type[doctest.OutputChecker]] def pytest_addoption(parser): @@ -139,7 +149,7 @@ class MultipleDoctestFailures(Exception): self.failures = failures -def _init_runner_class(): +def _init_runner_class() -> "Type[doctest.DocTestRunner]": import doctest class PytestDoctestRunner(doctest.DebugRunner): @@ -177,12 +187,19 @@ def _init_runner_class(): return PytestDoctestRunner -def _get_runner(checker=None, verbose=None, optionflags=0, continue_on_failure=True): +def _get_runner( + checker: Optional["doctest.OutputChecker"] = None, + verbose: Optional[bool] = None, + optionflags: int = 0, + continue_on_failure: bool = True, +) -> "doctest.DocTestRunner": # We need this in order to do a lazy import on doctest global RUNNER_CLASS if RUNNER_CLASS is None: RUNNER_CLASS = _init_runner_class() - return RUNNER_CLASS( + # Type ignored because the continue_on_failure argument is only defined on + # PytestDoctestRunner, which is lazily defined so can't be used as a type. + return RUNNER_CLASS( # type: ignore checker=checker, verbose=verbose, optionflags=optionflags, @@ -211,7 +228,7 @@ class DoctestItem(pytest.Item): def runtest(self): _check_all_skipped(self.dtest) self._disable_output_capturing_for_darwin() - failures = [] + failures = [] # type: List[doctest.DocTestFailure] self.runner.run(self.dtest, out=failures) if failures: raise MultipleDoctestFailures(failures) @@ -232,7 +249,9 @@ class DoctestItem(pytest.Item): def repr_failure(self, excinfo): import doctest - failures = None + failures = ( + None + ) # type: Optional[List[Union[doctest.DocTestFailure, doctest.UnexpectedException]]] if excinfo.errisinstance((doctest.DocTestFailure, doctest.UnexpectedException)): failures = [excinfo.value] elif excinfo.errisinstance(MultipleDoctestFailures): @@ -255,8 +274,10 @@ class DoctestItem(pytest.Item): self.config.getoption("doctestreport") ) if lineno is not None: + assert failure.test.docstring is not None lines = failure.test.docstring.splitlines(False) # add line numbers to the left of the error message + assert test.lineno is not None lines = [ "%03d %s" % (i + test.lineno + 1, x) for (i, x) in enumerate(lines) @@ -288,7 +309,7 @@ class DoctestItem(pytest.Item): return self.fspath, self.dtest.lineno, "[doctest] %s" % self.name -def _get_flag_lookup(): +def _get_flag_lookup() -> Dict[str, int]: import doctest return dict( @@ -340,7 +361,7 @@ class DoctestTextfile(pytest.Module): optionflags = get_optionflags(self) runner = _get_runner( - verbose=0, + verbose=False, optionflags=optionflags, checker=_get_checker(), continue_on_failure=_get_continue_on_failure(self.config), @@ -419,7 +440,8 @@ class DoctestModule(pytest.Module): return with _patch_unwrap_mock_aware(): - doctest.DocTestFinder._find( + # Type ignored because this is a private function. + doctest.DocTestFinder._find( # type: ignore self, tests, obj, name, module, source_lines, globs, seen ) @@ -437,7 +459,7 @@ class DoctestModule(pytest.Module): finder = MockAwareDocTestFinder() optionflags = get_optionflags(self) runner = _get_runner( - verbose=0, + verbose=False, optionflags=optionflags, checker=_get_checker(), continue_on_failure=_get_continue_on_failure(self.config), @@ -466,24 +488,7 @@ def _setup_fixtures(doctest_item): return fixture_request -def _get_checker(): - """ - Returns a doctest.OutputChecker subclass that supports some - additional options: - - * ALLOW_UNICODE and ALLOW_BYTES options to ignore u'' and b'' - prefixes (respectively) in string literals. Useful when the same - doctest should run in Python 2 and Python 3. - - * NUMBER to ignore floating-point differences smaller than the - precision of the literal number in the doctest. - - An inner class is used to avoid importing "doctest" at the module - level. - """ - if hasattr(_get_checker, "LiteralsOutputChecker"): - return _get_checker.LiteralsOutputChecker() - +def _init_checker_class() -> "Type[doctest.OutputChecker]": import doctest import re @@ -573,11 +578,31 @@ def _get_checker(): offset += w.end() - w.start() - (g.end() - g.start()) return got - _get_checker.LiteralsOutputChecker = LiteralsOutputChecker - return _get_checker.LiteralsOutputChecker() + return LiteralsOutputChecker -def _get_allow_unicode_flag(): +def _get_checker() -> "doctest.OutputChecker": + """ + Returns a doctest.OutputChecker subclass that supports some + additional options: + + * ALLOW_UNICODE and ALLOW_BYTES options to ignore u'' and b'' + prefixes (respectively) in string literals. Useful when the same + doctest should run in Python 2 and Python 3. + + * NUMBER to ignore floating-point differences smaller than the + precision of the literal number in the doctest. + + An inner class is used to avoid importing "doctest" at the module + level. + """ + global CHECKER_CLASS + if CHECKER_CLASS is None: + CHECKER_CLASS = _init_checker_class() + return CHECKER_CLASS() + + +def _get_allow_unicode_flag() -> int: """ Registers and returns the ALLOW_UNICODE flag. """ @@ -586,7 +611,7 @@ def _get_allow_unicode_flag(): return doctest.register_optionflag("ALLOW_UNICODE") -def _get_allow_bytes_flag(): +def _get_allow_bytes_flag() -> int: """ Registers and returns the ALLOW_BYTES flag. """ @@ -595,7 +620,7 @@ def _get_allow_bytes_flag(): return doctest.register_optionflag("ALLOW_BYTES") -def _get_number_flag(): +def _get_number_flag() -> int: """ Registers and returns the NUMBER flag. """ @@ -604,7 +629,7 @@ def _get_number_flag(): return doctest.register_optionflag("NUMBER") -def _get_report_choice(key): +def _get_report_choice(key: str) -> int: """ This function returns the actual `doctest` module flag value, we want to do it as late as possible to avoid importing `doctest` and all its dependencies when parsing options, as it adds overhead and breaks tests. diff --git a/src/_pytest/logging.py b/src/_pytest/logging.py index 054bfc866..c72f76118 100644 --- a/src/_pytest/logging.py +++ b/src/_pytest/logging.py @@ -2,6 +2,10 @@ import logging import re from contextlib import contextmanager +from typing import AbstractSet +from typing import Dict +from typing import List +from typing import Mapping import py @@ -32,14 +36,15 @@ class ColoredLevelFormatter(logging.Formatter): logging.INFO: {"green"}, logging.DEBUG: {"purple"}, logging.NOTSET: set(), - } + } # type: Mapping[int, AbstractSet[str]] LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*s)") - def __init__(self, terminalwriter, *args, **kwargs): + def __init__(self, terminalwriter, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._original_fmt = self._style._fmt - self._level_to_fmt_mapping = {} + self._level_to_fmt_mapping = {} # type: Dict[int, str] + assert self._fmt is not None levelname_fmt_match = self.LEVELNAME_FMT_REGEX.search(self._fmt) if not levelname_fmt_match: return @@ -216,17 +221,17 @@ def catching_logs(handler, formatter=None, level=None): class LogCaptureHandler(logging.StreamHandler): """A logging handler that stores log records and the log text.""" - def __init__(self): + def __init__(self) -> None: """Creates a new log handler.""" logging.StreamHandler.__init__(self, py.io.TextIO()) - self.records = [] + self.records = [] # type: List[logging.LogRecord] - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: """Keep the log records in a list in addition to the log text.""" self.records.append(record) logging.StreamHandler.emit(self, record) - def reset(self): + def reset(self) -> None: self.records = [] self.stream = py.io.TextIO() @@ -234,13 +239,13 @@ class LogCaptureHandler(logging.StreamHandler): class LogCaptureFixture: """Provides access and control of log capturing.""" - def __init__(self, item): + def __init__(self, item) -> None: """Creates a new funcarg.""" self._item = item # dict of log name -> log level - self._initial_log_levels = {} # Dict[str, int] + self._initial_log_levels = {} # type: Dict[str, int] - def _finalize(self): + def _finalize(self) -> None: """Finalizes the fixture. This restores the log levels changed by :meth:`set_level`. @@ -453,7 +458,7 @@ class LoggingPlugin: ): formatter = ColoredLevelFormatter( create_terminal_writer(self._config), log_format, log_date_format - ) + ) # type: logging.Formatter else: formatter = logging.Formatter(log_format, log_date_format) diff --git a/src/_pytest/nodes.py b/src/_pytest/nodes.py index e6dee1547..71036dc7e 100644 --- a/src/_pytest/nodes.py +++ b/src/_pytest/nodes.py @@ -139,8 +139,7 @@ class Node: ) ) path, lineno = get_fslocation_from_item(self) - # Type ignored: https://github.com/python/typeshed/pull/3121 - warnings.warn_explicit( # type: ignore + warnings.warn_explicit( warning, category=None, filename=str(path), diff --git a/src/_pytest/terminal.py b/src/_pytest/terminal.py index fd30d8572..35f6d324b 100644 --- a/src/_pytest/terminal.py +++ b/src/_pytest/terminal.py @@ -9,6 +9,12 @@ import platform import sys import time from functools import partial +from typing import Callable +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Set import attr import pluggy @@ -195,8 +201,8 @@ class WarningReport: file system location of the source of the warning (see ``get_location``). """ - message = attr.ib() - nodeid = attr.ib(default=None) + message = attr.ib(type=str) + nodeid = attr.ib(type=Optional[str], default=None) fslocation = attr.ib(default=None) count_towards_summary = True @@ -240,7 +246,7 @@ class TerminalReporter: self.reportchars = getreportopt(config) self.hasmarkup = self._tw.hasmarkup self.isatty = file.isatty() - self._progress_nodeids_reported = set() + self._progress_nodeids_reported = set() # type: Set[str] self._show_progress_info = self._determine_show_progress_info() self._collect_report_last_write = None @@ -619,7 +625,7 @@ class TerminalReporter: # because later versions are going to get rid of them anyway if self.config.option.verbose < 0: if self.config.option.verbose < -1: - counts = {} + counts = {} # type: Dict[str, int] for item in items: name = item.nodeid.split("::", 1)[0] counts[name] = counts.get(name, 0) + 1 @@ -750,7 +756,9 @@ class TerminalReporter: def summary_warnings(self): if self.hasopt("w"): - all_warnings = self.stats.get("warnings") + all_warnings = self.stats.get( + "warnings" + ) # type: Optional[List[WarningReport]] if not all_warnings: return @@ -763,7 +771,9 @@ class TerminalReporter: if not warning_reports: return - reports_grouped_by_message = collections.OrderedDict() + reports_grouped_by_message = ( + collections.OrderedDict() + ) # type: collections.OrderedDict[str, List[WarningReport]] for wr in warning_reports: reports_grouped_by_message.setdefault(wr.message, []).append(wr) @@ -900,11 +910,11 @@ class TerminalReporter: else: self.write_line(msg, **main_markup) - def short_test_summary(self): + def short_test_summary(self) -> None: if not self.reportchars: return - def show_simple(stat, lines): + def show_simple(stat, lines: List[str]) -> None: failed = self.stats.get(stat, []) if not failed: return @@ -914,7 +924,7 @@ class TerminalReporter: line = _get_line_with_reprcrash_message(config, rep, termwidth) lines.append(line) - def show_xfailed(lines): + def show_xfailed(lines: List[str]) -> None: xfailed = self.stats.get("xfailed", []) for rep in xfailed: verbose_word = rep._get_verbose_word(self.config) @@ -924,7 +934,7 @@ class TerminalReporter: if reason: lines.append(" " + str(reason)) - def show_xpassed(lines): + def show_xpassed(lines: List[str]) -> None: xpassed = self.stats.get("xpassed", []) for rep in xpassed: verbose_word = rep._get_verbose_word(self.config) @@ -932,7 +942,7 @@ class TerminalReporter: reason = rep.wasxfail lines.append("{} {} {}".format(verbose_word, pos, reason)) - def show_skipped(lines): + def show_skipped(lines: List[str]) -> None: skipped = self.stats.get("skipped", []) fskips = _folded_skips(skipped) if skipped else [] if not fskips: @@ -958,9 +968,9 @@ class TerminalReporter: "S": show_skipped, "p": partial(show_simple, "passed"), "E": partial(show_simple, "error"), - } + } # type: Mapping[str, Callable[[List[str]], None]] - lines = [] + lines = [] # type: List[str] for char in self.reportchars: action = REPORTCHAR_ACTIONS.get(char) if action: # skipping e.g. "P" (passed with output) here. @@ -1084,8 +1094,8 @@ def build_summary_stats_line(stats): return parts, main_color -def _plugin_nameversions(plugininfo): - values = [] +def _plugin_nameversions(plugininfo) -> List[str]: + values = [] # type: List[str] for plugin, dist in plugininfo: # gets us name and version! name = "{dist.project_name}-{dist.version}".format(dist=dist) @@ -1099,7 +1109,7 @@ def _plugin_nameversions(plugininfo): return values -def format_session_duration(seconds): +def format_session_duration(seconds: float) -> str: """Format the given seconds in a human readable manner to show in the final summary""" if seconds < 60: return "{:.2f}s".format(seconds) diff --git a/src/_pytest/warnings.py b/src/_pytest/warnings.py index d817a5cfa..8fdb61c2b 100644 --- a/src/_pytest/warnings.py +++ b/src/_pytest/warnings.py @@ -66,6 +66,8 @@ def catch_warnings_for_item(config, ihook, when, item): cmdline_filters = config.getoption("pythonwarnings") or [] inifilters = config.getini("filterwarnings") with warnings.catch_warnings(record=True) as log: + # mypy can't infer that record=True means log is not None; help it. + assert log is not None if not sys.warnoptions: # if user is not explicitly configuring warning filters, show deprecation warnings by default (#2908) @@ -145,6 +147,8 @@ def _issue_warning_captured(warning, hook, stacklevel): with warnings.catch_warnings(record=True) as records: warnings.simplefilter("always", type(warning)) warnings.warn(warning, stacklevel=stacklevel) + # Mypy can't infer that record=True means records is not None; help it. + assert records is not None hook.pytest_warning_captured.call_historic( kwargs=dict(warning_message=records[0], when="config", item=None) ) diff --git a/testing/test_capture.py b/testing/test_capture.py index 0f7db4b8e..180637db6 100644 --- a/testing/test_capture.py +++ b/testing/test_capture.py @@ -6,6 +6,8 @@ import subprocess import sys import textwrap from io import UnsupportedOperation +from typing import List +from typing import TextIO import py @@ -857,8 +859,8 @@ def tmpfile(testdir): @needsosdup -def test_dupfile(tmpfile): - flist = [] +def test_dupfile(tmpfile) -> None: + flist = [] # type: List[TextIO] for i in range(5): nf = capture.safe_text_dupfile(tmpfile, "wb") assert nf != tmpfile diff --git a/testing/test_doctest.py b/testing/test_doctest.py index 755f26286..37b3988f7 100644 --- a/testing/test_doctest.py +++ b/testing/test_doctest.py @@ -839,7 +839,8 @@ class TestLiterals: reprec = testdir.inline_run() reprec.assertoutcome(failed=1) - def test_number_re(self): + def test_number_re(self) -> None: + _number_re = _get_checker()._number_re # type: ignore for s in [ "1.", "+1.", @@ -861,12 +862,12 @@ class TestLiterals: "-1.2e-3", ]: print(s) - m = _get_checker()._number_re.match(s) + m = _number_re.match(s) assert m is not None assert float(m.group()) == pytest.approx(float(s)) for s in ["1", "abc"]: print(s) - assert _get_checker()._number_re.match(s) is None + assert _number_re.match(s) is None @pytest.mark.parametrize("config_mode", ["ini", "comment"]) def test_number_precision(self, testdir, config_mode):