Merge pull request #7418 from bluetech/typing-3

More typing work
This commit is contained in:
Ran Benita 2020-06-27 10:54:29 +03:00 committed by GitHub
commit 7450b6dd95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 196 additions and 146 deletions

View File

@ -0,0 +1,2 @@
Remove the `pytest_doctest_prepare_content` hook specification. This hook
hasn't been triggered by pytest for at least 10 years.

View File

@ -924,9 +924,7 @@ def _teardown_yield_fixture(fixturefunc, it) -> None:
except StopIteration: except StopIteration:
pass pass
else: else:
fail_fixturefunc( fail_fixturefunc(fixturefunc, "fixture function has more than one 'yield'")
fixturefunc, "yield_fixture function has more than one 'yield'"
)
def _eval_scope_callable( def _eval_scope_callable(

View File

@ -2,9 +2,13 @@
Provides a function to report all internal modules for using freezing tools Provides a function to report all internal modules for using freezing tools
pytest pytest
""" """
import types
from typing import Iterator
from typing import List
from typing import Union
def freeze_includes(): def freeze_includes() -> List[str]:
""" """
Returns a list of module names used by pytest that should be Returns a list of module names used by pytest that should be
included by cx_freeze. included by cx_freeze.
@ -17,7 +21,9 @@ def freeze_includes():
return result return result
def _iter_all_modules(package, prefix=""): def _iter_all_modules(
package: Union[str, types.ModuleType], prefix: str = "",
) -> Iterator[str]:
""" """
Iterates over the names of all modules that can be found in the given Iterates over the names of all modules that can be found in the given
package, recursively. package, recursively.
@ -29,10 +35,13 @@ def _iter_all_modules(package, prefix=""):
import os import os
import pkgutil import pkgutil
if type(package) is not str: if isinstance(package, str):
path, prefix = package.__path__[0], package.__name__ + "."
else:
path = package path = package
else:
# Type ignored because typeshed doesn't define ModuleType.__path__
# (only defined on packages).
package_path = package.__path__ # type: ignore[attr-defined]
path, prefix = package_path[0], package.__name__ + "."
for _, name, is_package in pkgutil.iter_modules([path]): for _, name, is_package in pkgutil.iter_modules([path]):
if is_package: if is_package:
for m in _iter_all_modules(os.path.join(path, name), prefix=name + "."): for m in _iter_all_modules(os.path.join(path, name), prefix=name + "."):

View File

@ -1,5 +1,6 @@
""" hook specifications for pytest plugins, invoked from main.py and builtin plugins. """ """ hook specifications for pytest plugins, invoked from main.py and builtin plugins. """
from typing import Any from typing import Any
from typing import Dict
from typing import List from typing import List
from typing import Mapping from typing import Mapping
from typing import Optional from typing import Optional
@ -37,7 +38,6 @@ if TYPE_CHECKING:
from _pytest.python import Metafunc from _pytest.python import Metafunc
from _pytest.python import Module from _pytest.python import Module
from _pytest.python import PyCollector from _pytest.python import PyCollector
from _pytest.reports import BaseReport
from _pytest.reports import CollectReport from _pytest.reports import CollectReport
from _pytest.reports import TestReport from _pytest.reports import TestReport
from _pytest.runner import CallInfo from _pytest.runner import CallInfo
@ -172,7 +172,7 @@ def pytest_cmdline_preparse(config: "Config", args: List[str]) -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_cmdline_main(config: "Config") -> "Optional[Union[ExitCode, int]]": def pytest_cmdline_main(config: "Config") -> Optional[Union["ExitCode", int]]:
""" called for performing the main command line action. The default """ called for performing the main command line action. The default
implementation will invoke the configure hooks and runtest_mainloop. implementation will invoke the configure hooks and runtest_mainloop.
@ -206,7 +206,7 @@ def pytest_load_initial_conftests(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_collection(session: "Session") -> Optional[Any]: def pytest_collection(session: "Session") -> Optional[object]:
"""Perform the collection protocol for the given session. """Perform the collection protocol for the given session.
Stops at first non-None result, see :ref:`firstresult`. Stops at first non-None result, see :ref:`firstresult`.
@ -242,20 +242,21 @@ def pytest_collection_modifyitems(
""" """
def pytest_collection_finish(session: "Session"): def pytest_collection_finish(session: "Session") -> None:
""" called after collection has been performed and modified. """Called after collection has been performed and modified.
:param _pytest.main.Session session: the pytest session object :param _pytest.main.Session session: the pytest session object
""" """
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_ignore_collect(path, config: "Config"): def pytest_ignore_collect(path: py.path.local, config: "Config") -> Optional[bool]:
""" return True to prevent considering this path for collection. """Return True to prevent considering this path for collection.
This hook is consulted for all files and directories prior to calling This hook is consulted for all files and directories prior to calling
more specific hooks. more specific hooks.
Stops at first non-None result, see :ref:`firstresult` Stops at first non-None result, see :ref:`firstresult`.
:param path: a :py:class:`py.path.local` - the path to analyze :param path: a :py:class:`py.path.local` - the path to analyze
:param _pytest.config.Config config: pytest config object :param _pytest.config.Config config: pytest config object
@ -263,18 +264,19 @@ def pytest_ignore_collect(path, config: "Config"):
@hookspec(firstresult=True, warn_on_impl=COLLECT_DIRECTORY_HOOK) @hookspec(firstresult=True, warn_on_impl=COLLECT_DIRECTORY_HOOK)
def pytest_collect_directory(path, parent): def pytest_collect_directory(path: py.path.local, parent) -> Optional[object]:
""" called before traversing a directory for collection files. """Called before traversing a directory for collection files.
Stops at first non-None result, see :ref:`firstresult` Stops at first non-None result, see :ref:`firstresult`.
:param path: a :py:class:`py.path.local` - the path to analyze :param path: a :py:class:`py.path.local` - the path to analyze
""" """
def pytest_collect_file(path: py.path.local, parent) -> "Optional[Collector]": def pytest_collect_file(path: py.path.local, parent) -> "Optional[Collector]":
""" return collection Node or None for the given path. Any new node """Return collection Node or None for the given path.
needs to have the specified ``parent`` as a parent.
Any new node needs to have the specified ``parent`` as a parent.
:param path: a :py:class:`py.path.local` - the path to collect :param path: a :py:class:`py.path.local` - the path to collect
""" """
@ -287,16 +289,16 @@ def pytest_collectstart(collector: "Collector") -> None:
""" collector starts collecting. """ """ collector starts collecting. """
def pytest_itemcollected(item): def pytest_itemcollected(item: "Item") -> None:
""" we just collected a test item. """ """We just collected a test item."""
def pytest_collectreport(report: "CollectReport") -> None: def pytest_collectreport(report: "CollectReport") -> None:
""" collector finished collecting. """ """ collector finished collecting. """
def pytest_deselected(items): def pytest_deselected(items: Sequence["Item"]) -> None:
""" called for test items deselected, e.g. by keyword. """ """Called for deselected test items, e.g. by keyword."""
@hookspec(firstresult=True) @hookspec(firstresult=True)
@ -312,13 +314,14 @@ def pytest_make_collect_report(collector: "Collector") -> "Optional[CollectRepor
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Optional[Module]": def pytest_pycollect_makemodule(path: py.path.local, parent) -> Optional["Module"]:
""" return a Module collector or None for the given path. """Return a Module collector or None for the given path.
This hook will be called for each matching test module path. This hook will be called for each matching test module path.
The pytest_collect_file hook needs to be used if you want to The pytest_collect_file hook needs to be used if you want to
create test modules for files that do not match as a test module. create test modules for files that do not match as a test module.
Stops at first non-None result, see :ref:`firstresult` Stops at first non-None result, see :ref:`firstresult`.
:param path: a :py:class:`py.path.local` - the path of module to collect :param path: a :py:class:`py.path.local` - the path of module to collect
""" """
@ -326,11 +329,12 @@ def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Optional[Module
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_pycollect_makeitem( def pytest_pycollect_makeitem(
collector: "PyCollector", name: str, obj collector: "PyCollector", name: str, obj: object
) -> "Union[None, Item, Collector, List[Union[Item, Collector]]]": ) -> Union[None, "Item", "Collector", List[Union["Item", "Collector"]]]:
""" return custom item/collector for a python object in a module, or None. """Return a custom item/collector for a Python object in a module, or None.
Stops at first non-None result, see :ref:`firstresult` """ Stops at first non-None result, see :ref:`firstresult`.
"""
@hookspec(firstresult=True) @hookspec(firstresult=True)
@ -466,7 +470,7 @@ def pytest_runtest_call(item: "Item") -> None:
""" """
def pytest_runtest_teardown(item: "Item", nextitem: "Optional[Item]") -> None: def pytest_runtest_teardown(item: "Item", nextitem: Optional["Item"]) -> None:
"""Called to perform the teardown phase for a test item. """Called to perform the teardown phase for a test item.
The default implementation runs the finalizers and calls ``teardown()`` The default implementation runs the finalizers and calls ``teardown()``
@ -505,7 +509,9 @@ def pytest_runtest_logreport(report: "TestReport") -> None:
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_report_to_serializable(config: "Config", report: "BaseReport"): def pytest_report_to_serializable(
config: "Config", report: Union["CollectReport", "TestReport"],
) -> Optional[Dict[str, Any]]:
""" """
Serializes the given report object into a data structure suitable for sending Serializes the given report object into a data structure suitable for sending
over the wire, e.g. converted to JSON. over the wire, e.g. converted to JSON.
@ -513,7 +519,9 @@ def pytest_report_to_serializable(config: "Config", report: "BaseReport"):
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_report_from_serializable(config: "Config", data): def pytest_report_from_serializable(
config: "Config", data: Dict[str, Any],
) -> Optional[Union["CollectReport", "TestReport"]]:
""" """
Restores a report object previously serialized with pytest_report_to_serializable(). Restores a report object previously serialized with pytest_report_to_serializable().
""" """
@ -528,11 +536,11 @@ def pytest_report_from_serializable(config: "Config", data):
def pytest_fixture_setup( def pytest_fixture_setup(
fixturedef: "FixtureDef", request: "SubRequest" fixturedef: "FixtureDef", request: "SubRequest"
) -> Optional[object]: ) -> Optional[object]:
""" performs fixture setup execution. """Performs fixture setup execution.
:return: The return value of the call to the fixture function :return: The return value of the call to the fixture function.
Stops at first non-None result, see :ref:`firstresult` Stops at first non-None result, see :ref:`firstresult`.
.. note:: .. note::
If the fixture function returns None, other implementations of If the fixture function returns None, other implementations of
@ -555,7 +563,7 @@ def pytest_fixture_post_finalizer(
def pytest_sessionstart(session: "Session") -> None: def pytest_sessionstart(session: "Session") -> None:
""" called after the ``Session`` object has been created and before performing collection """Called after the ``Session`` object has been created and before performing collection
and entering the run test loop. and entering the run test loop.
:param _pytest.main.Session session: the pytest session object :param _pytest.main.Session session: the pytest session object
@ -563,9 +571,9 @@ def pytest_sessionstart(session: "Session") -> None:
def pytest_sessionfinish( def pytest_sessionfinish(
session: "Session", exitstatus: "Union[int, ExitCode]" session: "Session", exitstatus: Union[int, "ExitCode"],
) -> None: ) -> None:
""" called after whole test run finished, right before returning the exit status to the system. """Called after whole test run finished, right before returning the exit status to the system.
:param _pytest.main.Session session: the pytest session object :param _pytest.main.Session session: the pytest session object
:param int exitstatus: the status which pytest will return to the system :param int exitstatus: the status which pytest will return to the system
@ -573,7 +581,7 @@ def pytest_sessionfinish(
def pytest_unconfigure(config: "Config") -> None: def pytest_unconfigure(config: "Config") -> None:
""" called before test process is exited. """Called before test process is exited.
:param _pytest.config.Config config: pytest config object :param _pytest.config.Config config: pytest config object
""" """
@ -587,7 +595,7 @@ def pytest_unconfigure(config: "Config") -> None:
def pytest_assertrepr_compare( def pytest_assertrepr_compare(
config: "Config", op: str, left: object, right: object config: "Config", op: str, left: object, right: object
) -> Optional[List[str]]: ) -> Optional[List[str]]:
"""return explanation for comparisons in failing assert expressions. """Return explanation for comparisons in failing assert expressions.
Return None for no custom explanation, otherwise return a list Return None for no custom explanation, otherwise return a list
of strings. The strings will be joined by newlines but any newlines of strings. The strings will be joined by newlines but any newlines
@ -598,7 +606,7 @@ def pytest_assertrepr_compare(
""" """
def pytest_assertion_pass(item, lineno: int, orig: str, expl: str) -> None: def pytest_assertion_pass(item: "Item", lineno: int, orig: str, expl: str) -> None:
""" """
**(Experimental)** **(Experimental)**
@ -665,12 +673,12 @@ def pytest_report_header(
def pytest_report_collectionfinish( def pytest_report_collectionfinish(
config: "Config", startdir: py.path.local, items: "Sequence[Item]" config: "Config", startdir: py.path.local, items: Sequence["Item"],
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
""" """
.. versionadded:: 3.2 .. versionadded:: 3.2
return a string or list of strings to be displayed after collection has finished successfully. Return a string or list of strings to be displayed after collection has finished successfully.
These strings will be displayed after the standard "collected X items" message. These strings will be displayed after the standard "collected X items" message.
@ -689,7 +697,7 @@ def pytest_report_collectionfinish(
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_report_teststatus( def pytest_report_teststatus(
report: "BaseReport", config: "Config" report: Union["CollectReport", "TestReport"], config: "Config"
) -> Tuple[ ) -> Tuple[
str, str, Union[str, Mapping[str, bool]], str, str, Union[str, Mapping[str, bool]],
]: ]:
@ -734,7 +742,7 @@ def pytest_terminal_summary(
def pytest_warning_captured( def pytest_warning_captured(
warning_message: "warnings.WarningMessage", warning_message: "warnings.WarningMessage",
when: "Literal['config', 'collect', 'runtest']", when: "Literal['config', 'collect', 'runtest']",
item: "Optional[Item]", item: Optional["Item"],
location: Optional[Tuple[str, int, str]], location: Optional[Tuple[str, int, str]],
) -> None: ) -> None:
"""(**Deprecated**) Process a warning captured by the internal pytest warnings plugin. """(**Deprecated**) Process a warning captured by the internal pytest warnings plugin.
@ -797,18 +805,6 @@ def pytest_warning_recorded(
""" """
# -------------------------------------------------------------------------
# doctest hooks
# -------------------------------------------------------------------------
@hookspec(firstresult=True)
def pytest_doctest_prepare_content(content):
""" return processed content for a given doctest
Stops at first non-None result, see :ref:`firstresult` """
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# error handling and internal debugging hooks # error handling and internal debugging hooks
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -831,7 +827,9 @@ def pytest_keyboard_interrupt(
def pytest_exception_interact( def pytest_exception_interact(
node: "Node", call: "CallInfo[object]", report: "Union[CollectReport, TestReport]" node: "Node",
call: "CallInfo[object]",
report: Union["CollectReport", "TestReport"],
) -> None: ) -> None:
"""Called when an exception was raised which can potentially be """Called when an exception was raised which can potentially be
interactively handled. interactively handled.

View File

@ -302,8 +302,8 @@ def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
return None return None
def pytest_collection(session: "Session") -> Sequence[nodes.Item]: def pytest_collection(session: "Session") -> None:
return session.perform_collect() session.perform_collect()
def pytest_runtestloop(session: "Session") -> bool: def pytest_runtestloop(session: "Session") -> bool:
@ -343,9 +343,7 @@ def _in_venv(path: py.path.local) -> bool:
return any([fname.basename in activates for fname in bindir.listdir()]) return any([fname.basename in activates for fname in bindir.listdir()])
def pytest_ignore_collect( def pytest_ignore_collect(path: py.path.local, config: Config) -> Optional[bool]:
path: py.path.local, config: Config
) -> "Optional[Literal[True]]":
ignore_paths = config._getconftest_pathlist("collect_ignore", path=path.dirpath()) ignore_paths = config._getconftest_pathlist("collect_ignore", path=path.dirpath())
ignore_paths = ignore_paths or [] ignore_paths = ignore_paths or []
excludeopt = config.getoption("ignore") excludeopt = config.getoption("ignore")

View File

@ -422,7 +422,7 @@ class PyCollector(PyobjMixin, nodes.Collector):
return values return values
def _makeitem( def _makeitem(
self, name: str, obj self, name: str, obj: object
) -> Union[ ) -> Union[
None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]] None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]]
]: ]:

View File

@ -33,7 +33,7 @@ if TYPE_CHECKING:
BASE_TYPE = (type, STRING_TYPES) BASE_TYPE = (type, STRING_TYPES)
def _non_numeric_type_error(value, at): def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
at_str = " at {}".format(at) if at else "" at_str = " at {}".format(at) if at else ""
return TypeError( return TypeError(
"cannot make approximate comparisons to non-numeric values: {!r} {}".format( "cannot make approximate comparisons to non-numeric values: {!r} {}".format(
@ -55,7 +55,7 @@ class ApproxBase:
__array_ufunc__ = None __array_ufunc__ = None
__array_priority__ = 100 __array_priority__ = 100
def __init__(self, expected, rel=None, abs=None, nan_ok=False): def __init__(self, expected, rel=None, abs=None, nan_ok: bool = False) -> None:
__tracebackhide__ = True __tracebackhide__ = True
self.expected = expected self.expected = expected
self.abs = abs self.abs = abs
@ -63,10 +63,10 @@ class ApproxBase:
self.nan_ok = nan_ok self.nan_ok = nan_ok
self._check_type() self._check_type()
def __repr__(self): def __repr__(self) -> str:
raise NotImplementedError raise NotImplementedError
def __eq__(self, actual): def __eq__(self, actual) -> bool:
return all( return all(
a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual) a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual)
) )
@ -74,10 +74,10 @@ class ApproxBase:
# Ignore type because of https://github.com/python/mypy/issues/4266. # Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore __hash__ = None # type: ignore
def __ne__(self, actual): def __ne__(self, actual) -> bool:
return not (actual == self) return not (actual == self)
def _approx_scalar(self, x): def _approx_scalar(self, x) -> "ApproxScalar":
return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
def _yield_comparisons(self, actual): def _yield_comparisons(self, actual):
@ -87,7 +87,7 @@ class ApproxBase:
""" """
raise NotImplementedError raise NotImplementedError
def _check_type(self): def _check_type(self) -> None:
""" """
Raise a TypeError if the expected value is not a valid type. Raise a TypeError if the expected value is not a valid type.
""" """
@ -111,11 +111,11 @@ class ApproxNumpy(ApproxBase):
Perform approximate comparisons where the expected value is numpy array. Perform approximate comparisons where the expected value is numpy array.
""" """
def __repr__(self): def __repr__(self) -> str:
list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist()) list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
return "approx({!r})".format(list_scalars) return "approx({!r})".format(list_scalars)
def __eq__(self, actual): def __eq__(self, actual) -> bool:
import numpy as np import numpy as np
# self.expected is supposed to always be an array here # self.expected is supposed to always be an array here
@ -154,12 +154,12 @@ class ApproxMapping(ApproxBase):
numeric values (the keys can be anything). numeric values (the keys can be anything).
""" """
def __repr__(self): def __repr__(self) -> str:
return "approx({!r})".format( return "approx({!r})".format(
{k: self._approx_scalar(v) for k, v in self.expected.items()} {k: self._approx_scalar(v) for k, v in self.expected.items()}
) )
def __eq__(self, actual): def __eq__(self, actual) -> bool:
if set(actual.keys()) != set(self.expected.keys()): if set(actual.keys()) != set(self.expected.keys()):
return False return False
@ -169,7 +169,7 @@ class ApproxMapping(ApproxBase):
for k in self.expected.keys(): for k in self.expected.keys():
yield actual[k], self.expected[k] yield actual[k], self.expected[k]
def _check_type(self): def _check_type(self) -> None:
__tracebackhide__ = True __tracebackhide__ = True
for key, value in self.expected.items(): for key, value in self.expected.items():
if isinstance(value, type(self.expected)): if isinstance(value, type(self.expected)):
@ -185,7 +185,7 @@ class ApproxSequencelike(ApproxBase):
numbers. numbers.
""" """
def __repr__(self): def __repr__(self) -> str:
seq_type = type(self.expected) seq_type = type(self.expected)
if seq_type not in (tuple, list, set): if seq_type not in (tuple, list, set):
seq_type = list seq_type = list
@ -193,7 +193,7 @@ class ApproxSequencelike(ApproxBase):
seq_type(self._approx_scalar(x) for x in self.expected) seq_type(self._approx_scalar(x) for x in self.expected)
) )
def __eq__(self, actual): def __eq__(self, actual) -> bool:
if len(actual) != len(self.expected): if len(actual) != len(self.expected):
return False return False
return ApproxBase.__eq__(self, actual) return ApproxBase.__eq__(self, actual)
@ -201,7 +201,7 @@ class ApproxSequencelike(ApproxBase):
def _yield_comparisons(self, actual): def _yield_comparisons(self, actual):
return zip(actual, self.expected) return zip(actual, self.expected)
def _check_type(self): def _check_type(self) -> None:
__tracebackhide__ = True __tracebackhide__ = True
for index, x in enumerate(self.expected): for index, x in enumerate(self.expected):
if isinstance(x, type(self.expected)): if isinstance(x, type(self.expected)):
@ -223,7 +223,7 @@ class ApproxScalar(ApproxBase):
DEFAULT_ABSOLUTE_TOLERANCE = 1e-12 # type: Union[float, Decimal] DEFAULT_ABSOLUTE_TOLERANCE = 1e-12 # type: Union[float, Decimal]
DEFAULT_RELATIVE_TOLERANCE = 1e-6 # type: Union[float, Decimal] DEFAULT_RELATIVE_TOLERANCE = 1e-6 # type: Union[float, Decimal]
def __repr__(self): def __repr__(self) -> str:
""" """
Return a string communicating both the expected value and the tolerance Return a string communicating both the expected value and the tolerance
for the comparison being made, e.g. '1.0 ± 1e-6', '(3+4j) ± 5e-6 ∠ ±180°'. for the comparison being made, e.g. '1.0 ± 1e-6', '(3+4j) ± 5e-6 ∠ ±180°'.
@ -245,7 +245,7 @@ class ApproxScalar(ApproxBase):
return "{} ± {}".format(self.expected, vetted_tolerance) return "{} ± {}".format(self.expected, vetted_tolerance)
def __eq__(self, actual): def __eq__(self, actual) -> bool:
""" """
Return true if the given value is equal to the expected value within Return true if the given value is equal to the expected value within
the pre-specified tolerance. the pre-specified tolerance.
@ -275,7 +275,8 @@ class ApproxScalar(ApproxBase):
return False return False
# Return true if the two numbers are within the tolerance. # Return true if the two numbers are within the tolerance.
return abs(self.expected - actual) <= self.tolerance result = abs(self.expected - actual) <= self.tolerance # type: bool
return result
# Ignore type because of https://github.com/python/mypy/issues/4266. # Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore __hash__ = None # type: ignore
@ -337,7 +338,7 @@ class ApproxDecimal(ApproxScalar):
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6") DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
def approx(expected, rel=None, abs=None, nan_ok=False): def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
""" """
Assert that two numbers (or two sets of numbers) are equal to each other Assert that two numbers (or two sets of numbers) are equal to each other
within some tolerance. within some tolerance.
@ -527,7 +528,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
return cls(expected, rel, abs, nan_ok) return cls(expected, rel, abs, nan_ok)
def _is_numpy_array(obj): def _is_numpy_array(obj: object) -> bool:
""" """
Return true if the given object is a numpy array. Make a special effort to Return true if the given object is a numpy array. Make a special effort to
avoid importing numpy unless it's really necessary. avoid importing numpy unless it's really necessary.

View File

@ -4,24 +4,29 @@ import warnings
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Generator
from typing import Iterator from typing import Iterator
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Pattern from typing import Pattern
from typing import Tuple from typing import Tuple
from typing import TypeVar
from typing import Union from typing import Union
from _pytest.compat import overload from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING from _pytest.compat import TYPE_CHECKING
from _pytest.fixtures import yield_fixture from _pytest.fixtures import fixture
from _pytest.outcomes import fail from _pytest.outcomes import fail
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Type from typing import Type
@yield_fixture T = TypeVar("T")
def recwarn():
@fixture
def recwarn() -> Generator["WarningsRecorder", None, None]:
"""Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions. """Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions.
See http://docs.python.org/library/warnings.html for information See http://docs.python.org/library/warnings.html for information
@ -33,9 +38,26 @@ def recwarn():
yield wrec yield wrec
def deprecated_call(func=None, *args, **kwargs): @overload
"""context manager that can be used to ensure a block of code triggers a def deprecated_call(
``DeprecationWarning`` or ``PendingDeprecationWarning``:: *, match: Optional[Union[str, "Pattern"]] = ...
) -> "WarningsRecorder":
raise NotImplementedError()
@overload # noqa: F811
def deprecated_call( # noqa: F811
func: Callable[..., T], *args: Any, **kwargs: Any
) -> T:
raise NotImplementedError()
def deprecated_call( # noqa: F811
func: Optional[Callable] = None, *args: Any, **kwargs: Any
) -> Union["WarningsRecorder", Any]:
"""Assert that code produces a ``DeprecationWarning`` or ``PendingDeprecationWarning``.
This function can be used as a context manager::
>>> import warnings >>> import warnings
>>> def api_call_v2(): >>> def api_call_v2():
@ -45,9 +67,15 @@ def deprecated_call(func=None, *args, **kwargs):
>>> with deprecated_call(): >>> with deprecated_call():
... assert api_call_v2() == 200 ... assert api_call_v2() == 200
``deprecated_call`` can also be used by passing a function and ``*args`` and ``*kwargs``, It can also be used by passing a function and ``*args`` and ``**kwargs``,
in which case it will ensure calling ``func(*args, **kwargs)`` produces one of the warnings in which case it will ensure calling ``func(*args, **kwargs)`` produces one of
types above. the warnings types above. The return value is the return value of the function.
In the context manager form you may use the keyword argument ``match`` to assert
that the warning matches a text or regex.
The context manager produces a list of :class:`warnings.WarningMessage` objects,
one for each warning raised.
""" """
__tracebackhide__ = True __tracebackhide__ = True
if func is not None: if func is not None:
@ -67,11 +95,10 @@ def warns(
@overload # noqa: F811 @overload # noqa: F811
def warns( # noqa: F811 def warns( # noqa: F811
expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]], expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
func: Callable, func: Callable[..., T],
*args: Any, *args: Any,
match: Optional[Union[str, "Pattern"]] = ...,
**kwargs: Any **kwargs: Any
) -> Union[Any]: ) -> T:
raise NotImplementedError() raise NotImplementedError()
@ -97,7 +124,7 @@ def warns( # noqa: F811
... warnings.warn("my warning", RuntimeWarning) ... warnings.warn("my warning", RuntimeWarning)
In the context manager form you may use the keyword argument ``match`` to assert In the context manager form you may use the keyword argument ``match`` to assert
that the exception matches a text or regex:: that the warning matches a text or regex::
>>> with warns(UserWarning, match='must be 0 or None'): >>> with warns(UserWarning, match='must be 0 or None'):
... warnings.warn("value must be 0 or None", UserWarning) ... warnings.warn("value must be 0 or None", UserWarning)

View File

@ -1,6 +1,7 @@
from io import StringIO from io import StringIO
from pprint import pprint from pprint import pprint
from typing import Any from typing import Any
from typing import Dict
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List from typing import List
@ -69,7 +70,7 @@ class BaseReport:
def __getattr__(self, key: str) -> Any: def __getattr__(self, key: str) -> Any:
raise NotImplementedError() raise NotImplementedError()
def toterminal(self, out) -> None: def toterminal(self, out: TerminalWriter) -> None:
if hasattr(self, "node"): if hasattr(self, "node"):
out.line(getworkerinfoline(self.node)) out.line(getworkerinfoline(self.node))
@ -187,7 +188,7 @@ class BaseReport:
) )
return verbose return verbose
def _to_json(self): def _to_json(self) -> Dict[str, Any]:
""" """
This was originally the serialize_report() function from xdist (ca03269). This was originally the serialize_report() function from xdist (ca03269).
@ -199,7 +200,7 @@ class BaseReport:
return _report_to_json(self) return _report_to_json(self)
@classmethod @classmethod
def _from_json(cls: "Type[_R]", reportdict) -> _R: def _from_json(cls: "Type[_R]", reportdict: Dict[str, object]) -> _R:
""" """
This was originally the serialize_report() function from xdist (ca03269). This was originally the serialize_report() function from xdist (ca03269).
@ -382,11 +383,13 @@ class CollectErrorRepr(TerminalRepr):
def __init__(self, msg) -> None: def __init__(self, msg) -> None:
self.longrepr = msg self.longrepr = msg
def toterminal(self, out) -> None: def toterminal(self, out: TerminalWriter) -> None:
out.line(self.longrepr, red=True) out.line(self.longrepr, red=True)
def pytest_report_to_serializable(report: BaseReport): def pytest_report_to_serializable(
report: Union[CollectReport, TestReport]
) -> Optional[Dict[str, Any]]:
if isinstance(report, (TestReport, CollectReport)): if isinstance(report, (TestReport, CollectReport)):
data = report._to_json() data = report._to_json()
data["$report_type"] = report.__class__.__name__ data["$report_type"] = report.__class__.__name__
@ -394,7 +397,9 @@ def pytest_report_to_serializable(report: BaseReport):
return None return None
def pytest_report_from_serializable(data) -> Optional[BaseReport]: def pytest_report_from_serializable(
data: Dict[str, Any],
) -> Optional[Union[CollectReport, TestReport]]:
if "$report_type" in data: if "$report_type" in data:
if data["$report_type"] == "TestReport": if data["$report_type"] == "TestReport":
return TestReport._from_json(data) return TestReport._from_json(data)
@ -406,7 +411,7 @@ def pytest_report_from_serializable(data) -> Optional[BaseReport]:
return None return None
def _report_to_json(report: BaseReport): def _report_to_json(report: BaseReport) -> Dict[str, Any]:
""" """
This was originally the serialize_report() function from xdist (ca03269). This was originally the serialize_report() function from xdist (ca03269).
@ -414,7 +419,9 @@ def _report_to_json(report: BaseReport):
serialization. serialization.
""" """
def serialize_repr_entry(entry: Union[ReprEntry, ReprEntryNative]): def serialize_repr_entry(
entry: Union[ReprEntry, ReprEntryNative]
) -> Dict[str, Any]:
data = attr.asdict(entry) data = attr.asdict(entry)
for key, value in data.items(): for key, value in data.items():
if hasattr(value, "__dict__"): if hasattr(value, "__dict__"):
@ -422,25 +429,28 @@ def _report_to_json(report: BaseReport):
entry_data = {"type": type(entry).__name__, "data": data} entry_data = {"type": type(entry).__name__, "data": data}
return entry_data return entry_data
def serialize_repr_traceback(reprtraceback: ReprTraceback): def serialize_repr_traceback(reprtraceback: ReprTraceback) -> Dict[str, Any]:
result = attr.asdict(reprtraceback) result = attr.asdict(reprtraceback)
result["reprentries"] = [ result["reprentries"] = [
serialize_repr_entry(x) for x in reprtraceback.reprentries serialize_repr_entry(x) for x in reprtraceback.reprentries
] ]
return result return result
def serialize_repr_crash(reprcrash: Optional[ReprFileLocation]): def serialize_repr_crash(
reprcrash: Optional[ReprFileLocation],
) -> Optional[Dict[str, Any]]:
if reprcrash is not None: if reprcrash is not None:
return attr.asdict(reprcrash) return attr.asdict(reprcrash)
else: else:
return None return None
def serialize_longrepr(rep): def serialize_longrepr(rep: BaseReport) -> Dict[str, Any]:
assert rep.longrepr is not None
result = { result = {
"reprcrash": serialize_repr_crash(rep.longrepr.reprcrash), "reprcrash": serialize_repr_crash(rep.longrepr.reprcrash),
"reprtraceback": serialize_repr_traceback(rep.longrepr.reprtraceback), "reprtraceback": serialize_repr_traceback(rep.longrepr.reprtraceback),
"sections": rep.longrepr.sections, "sections": rep.longrepr.sections,
} } # type: Dict[str, Any]
if isinstance(rep.longrepr, ExceptionChainRepr): if isinstance(rep.longrepr, ExceptionChainRepr):
result["chain"] = [] result["chain"] = []
for repr_traceback, repr_crash, description in rep.longrepr.chain: for repr_traceback, repr_crash, description in rep.longrepr.chain:
@ -473,7 +483,7 @@ def _report_to_json(report: BaseReport):
return d return d
def _report_kwargs_from_json(reportdict): def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]:
""" """
This was originally the serialize_report() function from xdist (ca03269). This was originally the serialize_report() function from xdist (ca03269).

View File

@ -473,9 +473,9 @@ class TerminalReporter:
def line(self, msg: str, **kw: bool) -> None: def line(self, msg: str, **kw: bool) -> None:
self._tw.line(msg, **kw) self._tw.line(msg, **kw)
def _add_stats(self, category: str, items: List) -> None: def _add_stats(self, category: str, items: Sequence) -> None:
set_main_color = category not in self.stats set_main_color = category not in self.stats
self.stats.setdefault(category, []).extend(items[:]) self.stats.setdefault(category, []).extend(items)
if set_main_color: if set_main_color:
self._set_main_color() self._set_main_color()
@ -505,7 +505,7 @@ class TerminalReporter:
# which garbles our output if we use self.write_line # which garbles our output if we use self.write_line
self.write_line(msg) self.write_line(msg)
def pytest_deselected(self, items) -> None: def pytest_deselected(self, items: Sequence[Item]) -> None:
self._add_stats("deselected", items) self._add_stats("deselected", items)
def pytest_runtest_logstart( def pytest_runtest_logstart(

View File

@ -44,7 +44,7 @@ if TYPE_CHECKING:
def pytest_pycollect_makeitem( def pytest_pycollect_makeitem(
collector: PyCollector, name: str, obj collector: PyCollector, name: str, obj: object
) -> Optional["UnitTestCase"]: ) -> Optional["UnitTestCase"]:
# has unittest been imported and is obj a subclass of its TestCase? # has unittest been imported and is obj a subclass of its TestCase?
try: try:

View File

@ -33,13 +33,13 @@ def checked_order():
] ]
@pytest.yield_fixture(scope="module") @pytest.fixture(scope="module")
def fix1(request, arg1, checked_order): def fix1(request, arg1, checked_order):
checked_order.append((request.node.name, "fix1", arg1)) checked_order.append((request.node.name, "fix1", arg1))
yield "fix1-" + arg1 yield "fix1-" + arg1
@pytest.yield_fixture(scope="function") @pytest.fixture(scope="function")
def fix2(request, fix1, arg2, checked_order): def fix2(request, fix1, arg2, checked_order):
checked_order.append((request.node.name, "fix2", arg2)) checked_order.append((request.node.name, "fix2", arg2))
yield "fix2-" + arg2 + fix1 yield "fix2-" + arg2 + fix1

View File

@ -3,6 +3,7 @@ from decimal import Decimal
from fractions import Fraction from fractions import Fraction
from operator import eq from operator import eq
from operator import ne from operator import ne
from typing import Optional
import pytest import pytest
from pytest import approx from pytest import approx
@ -121,18 +122,22 @@ class TestApprox:
assert a == approx(x, rel=5e-1, abs=0.0) assert a == approx(x, rel=5e-1, abs=0.0)
assert a != approx(x, rel=5e-2, abs=0.0) assert a != approx(x, rel=5e-2, abs=0.0)
def test_negative_tolerance(self): @pytest.mark.parametrize(
("rel", "abs"),
[
(-1e100, None),
(None, -1e100),
(1e100, -1e100),
(-1e100, 1e100),
(-1e100, -1e100),
],
)
def test_negative_tolerance(
self, rel: Optional[float], abs: Optional[float]
) -> None:
# Negative tolerances are not allowed. # Negative tolerances are not allowed.
illegal_kwargs = [
dict(rel=-1e100),
dict(abs=-1e100),
dict(rel=1e100, abs=-1e100),
dict(rel=-1e100, abs=1e100),
dict(rel=-1e100, abs=-1e100),
]
for kwargs in illegal_kwargs:
with pytest.raises(ValueError): with pytest.raises(ValueError):
1.1 == approx(1, **kwargs) 1.1 == approx(1, rel, abs)
def test_inf_tolerance(self): def test_inf_tolerance(self):
# Everything should be equal if the tolerance is infinite. # Everything should be equal if the tolerance is infinite.
@ -143,19 +148,21 @@ class TestApprox:
assert a == approx(x, rel=0.0, abs=inf) assert a == approx(x, rel=0.0, abs=inf)
assert a == approx(x, rel=inf, abs=inf) assert a == approx(x, rel=inf, abs=inf)
def test_inf_tolerance_expecting_zero(self): def test_inf_tolerance_expecting_zero(self) -> None:
# If the relative tolerance is zero but the expected value is infinite, # If the relative tolerance is zero but the expected value is infinite,
# the actual tolerance is a NaN, which should be an error. # the actual tolerance is a NaN, which should be an error.
illegal_kwargs = [dict(rel=inf, abs=0.0), dict(rel=inf, abs=inf)]
for kwargs in illegal_kwargs:
with pytest.raises(ValueError): with pytest.raises(ValueError):
1 == approx(0, **kwargs) 1 == approx(0, rel=inf, abs=0.0)
with pytest.raises(ValueError):
1 == approx(0, rel=inf, abs=inf)
def test_nan_tolerance(self): def test_nan_tolerance(self) -> None:
illegal_kwargs = [dict(rel=nan), dict(abs=nan), dict(rel=nan, abs=nan)]
for kwargs in illegal_kwargs:
with pytest.raises(ValueError): with pytest.raises(ValueError):
1.1 == approx(1, **kwargs) 1.1 == approx(1, rel=nan)
with pytest.raises(ValueError):
1.1 == approx(1, abs=nan)
with pytest.raises(ValueError):
1.1 == approx(1, rel=nan, abs=nan)
def test_reasonable_defaults(self): def test_reasonable_defaults(self):
# Whatever the defaults are, they should work for numbers close to 1 # Whatever the defaults are, they should work for numbers close to 1

View File

@ -1315,7 +1315,7 @@ class TestFixtureUsages:
DB_INITIALIZED = None DB_INITIALIZED = None
@pytest.yield_fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def db(): def db():
global DB_INITIALIZED global DB_INITIALIZED
DB_INITIALIZED = True DB_INITIALIZED = True
@ -2960,8 +2960,7 @@ class TestFixtureMarker:
""" """
import pytest import pytest
@pytest.yield_fixture(params=[object(), object()], @pytest.fixture(params=[object(), object()], ids=['alpha', 'beta'])
ids=['alpha', 'beta'])
def fix(request): def fix(request):
yield request.param yield request.param

View File

@ -1176,7 +1176,7 @@ class TestDoctestAutoUseFixtures:
import pytest import pytest
import sys import sys
@pytest.yield_fixture(autouse=True, scope='session') @pytest.fixture(autouse=True, scope='session')
def myfixture(): def myfixture():
assert not hasattr(sys, 'pytest_session_data') assert not hasattr(sys, 'pytest_session_data')
sys.pytest_session_data = 1 sys.pytest_session_data = 1

View File

@ -91,7 +91,7 @@ class TestImportPath:
Having our own pyimport-like function is inline with removing py.path dependency in the future. Having our own pyimport-like function is inline with removing py.path dependency in the future.
""" """
@pytest.yield_fixture(scope="session") @pytest.fixture(scope="session")
def path1(self, tmpdir_factory): def path1(self, tmpdir_factory):
path = tmpdir_factory.mktemp("path") path = tmpdir_factory.mktemp("path")
self.setuptestfs(path) self.setuptestfs(path)

View File

@ -370,13 +370,14 @@ class TestWarns:
@pytest.mark.filterwarnings("ignore") @pytest.mark.filterwarnings("ignore")
def test_can_capture_previously_warned(self) -> None: def test_can_capture_previously_warned(self) -> None:
def f(): def f() -> int:
warnings.warn(UserWarning("ohai")) warnings.warn(UserWarning("ohai"))
return 10 return 10
assert f() == 10 assert f() == 10
assert pytest.warns(UserWarning, f) == 10 assert pytest.warns(UserWarning, f) == 10
assert pytest.warns(UserWarning, f) == 10 assert pytest.warns(UserWarning, f) == 10
assert pytest.warns(UserWarning, f) != "10" # type: ignore[comparison-overlap]
def test_warns_context_manager_with_kwargs(self) -> None: def test_warns_context_manager_with_kwargs(self) -> None:
with pytest.raises(TypeError) as excinfo: with pytest.raises(TypeError) as excinfo: