python_api: type annotate some parts of pytest.approx()

This commit is contained in:
Ran Benita 2020-06-25 15:02:04 +03:00
parent 142d8963e6
commit 8f8f472379
2 changed files with 49 additions and 41 deletions

View File

@ -33,7 +33,7 @@ if TYPE_CHECKING:
BASE_TYPE = (type, STRING_TYPES)
def _non_numeric_type_error(value, at):
def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
at_str = " at {}".format(at) if at else ""
return TypeError(
"cannot make approximate comparisons to non-numeric values: {!r} {}".format(
@ -55,7 +55,7 @@ class ApproxBase:
__array_ufunc__ = None
__array_priority__ = 100
def __init__(self, expected, rel=None, abs=None, nan_ok=False):
def __init__(self, expected, rel=None, abs=None, nan_ok: bool = False) -> None:
__tracebackhide__ = True
self.expected = expected
self.abs = abs
@ -63,10 +63,10 @@ class ApproxBase:
self.nan_ok = nan_ok
self._check_type()
def __repr__(self):
def __repr__(self) -> str:
raise NotImplementedError
def __eq__(self, actual):
def __eq__(self, actual) -> bool:
return all(
a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual)
)
@ -74,10 +74,10 @@ class ApproxBase:
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
def __ne__(self, actual):
def __ne__(self, actual) -> bool:
return not (actual == self)
def _approx_scalar(self, x):
def _approx_scalar(self, x) -> "ApproxScalar":
return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
def _yield_comparisons(self, actual):
@ -87,7 +87,7 @@ class ApproxBase:
"""
raise NotImplementedError
def _check_type(self):
def _check_type(self) -> None:
"""
Raise a TypeError if the expected value is not a valid type.
"""
@ -111,11 +111,11 @@ class ApproxNumpy(ApproxBase):
Perform approximate comparisons where the expected value is numpy array.
"""
def __repr__(self):
def __repr__(self) -> str:
list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
return "approx({!r})".format(list_scalars)
def __eq__(self, actual):
def __eq__(self, actual) -> bool:
import numpy as np
# self.expected is supposed to always be an array here
@ -154,12 +154,12 @@ class ApproxMapping(ApproxBase):
numeric values (the keys can be anything).
"""
def __repr__(self):
def __repr__(self) -> str:
return "approx({!r})".format(
{k: self._approx_scalar(v) for k, v in self.expected.items()}
)
def __eq__(self, actual):
def __eq__(self, actual) -> bool:
if set(actual.keys()) != set(self.expected.keys()):
return False
@ -169,7 +169,7 @@ class ApproxMapping(ApproxBase):
for k in self.expected.keys():
yield actual[k], self.expected[k]
def _check_type(self):
def _check_type(self) -> None:
__tracebackhide__ = True
for key, value in self.expected.items():
if isinstance(value, type(self.expected)):
@ -185,7 +185,7 @@ class ApproxSequencelike(ApproxBase):
numbers.
"""
def __repr__(self):
def __repr__(self) -> str:
seq_type = type(self.expected)
if seq_type not in (tuple, list, set):
seq_type = list
@ -193,7 +193,7 @@ class ApproxSequencelike(ApproxBase):
seq_type(self._approx_scalar(x) for x in self.expected)
)
def __eq__(self, actual):
def __eq__(self, actual) -> bool:
if len(actual) != len(self.expected):
return False
return ApproxBase.__eq__(self, actual)
@ -201,7 +201,7 @@ class ApproxSequencelike(ApproxBase):
def _yield_comparisons(self, actual):
return zip(actual, self.expected)
def _check_type(self):
def _check_type(self) -> None:
__tracebackhide__ = True
for index, x in enumerate(self.expected):
if isinstance(x, type(self.expected)):
@ -223,7 +223,7 @@ class ApproxScalar(ApproxBase):
DEFAULT_ABSOLUTE_TOLERANCE = 1e-12 # type: Union[float, Decimal]
DEFAULT_RELATIVE_TOLERANCE = 1e-6 # type: Union[float, Decimal]
def __repr__(self):
def __repr__(self) -> str:
"""
Return a string communicating both the expected value and the tolerance
for the comparison being made, e.g. '1.0 ± 1e-6', '(3+4j) ± 5e-6 ∠ ±180°'.
@ -245,7 +245,7 @@ class ApproxScalar(ApproxBase):
return "{} ± {}".format(self.expected, vetted_tolerance)
def __eq__(self, actual):
def __eq__(self, actual) -> bool:
"""
Return true if the given value is equal to the expected value within
the pre-specified tolerance.
@ -275,7 +275,8 @@ class ApproxScalar(ApproxBase):
return False
# Return true if the two numbers are within the tolerance.
return abs(self.expected - actual) <= self.tolerance
result = abs(self.expected - actual) <= self.tolerance # type: bool
return result
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
@ -337,7 +338,7 @@ class ApproxDecimal(ApproxScalar):
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
def approx(expected, rel=None, abs=None, nan_ok=False):
def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
"""
Assert that two numbers (or two sets of numbers) are equal to each other
within some tolerance.
@ -527,7 +528,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
return cls(expected, rel, abs, nan_ok)
def _is_numpy_array(obj):
def _is_numpy_array(obj: object) -> bool:
"""
Return true if the given object is a numpy array. Make a special effort to
avoid importing numpy unless it's really necessary.

View File

@ -3,6 +3,7 @@ from decimal import Decimal
from fractions import Fraction
from operator import eq
from operator import ne
from typing import Optional
import pytest
from pytest import approx
@ -121,18 +122,22 @@ class TestApprox:
assert a == approx(x, rel=5e-1, abs=0.0)
assert a != approx(x, rel=5e-2, abs=0.0)
def test_negative_tolerance(self):
@pytest.mark.parametrize(
("rel", "abs"),
[
(-1e100, None),
(None, -1e100),
(1e100, -1e100),
(-1e100, 1e100),
(-1e100, -1e100),
],
)
def test_negative_tolerance(
self, rel: Optional[float], abs: Optional[float]
) -> None:
# Negative tolerances are not allowed.
illegal_kwargs = [
dict(rel=-1e100),
dict(abs=-1e100),
dict(rel=1e100, abs=-1e100),
dict(rel=-1e100, abs=1e100),
dict(rel=-1e100, abs=-1e100),
]
for kwargs in illegal_kwargs:
with pytest.raises(ValueError):
1.1 == approx(1, **kwargs)
with pytest.raises(ValueError):
1.1 == approx(1, rel, abs)
def test_inf_tolerance(self):
# Everything should be equal if the tolerance is infinite.
@ -143,19 +148,21 @@ class TestApprox:
assert a == approx(x, rel=0.0, abs=inf)
assert a == approx(x, rel=inf, abs=inf)
def test_inf_tolerance_expecting_zero(self):
def test_inf_tolerance_expecting_zero(self) -> None:
# If the relative tolerance is zero but the expected value is infinite,
# the actual tolerance is a NaN, which should be an error.
illegal_kwargs = [dict(rel=inf, abs=0.0), dict(rel=inf, abs=inf)]
for kwargs in illegal_kwargs:
with pytest.raises(ValueError):
1 == approx(0, **kwargs)
with pytest.raises(ValueError):
1 == approx(0, rel=inf, abs=0.0)
with pytest.raises(ValueError):
1 == approx(0, rel=inf, abs=inf)
def test_nan_tolerance(self):
illegal_kwargs = [dict(rel=nan), dict(abs=nan), dict(rel=nan, abs=nan)]
for kwargs in illegal_kwargs:
with pytest.raises(ValueError):
1.1 == approx(1, **kwargs)
def test_nan_tolerance(self) -> None:
with pytest.raises(ValueError):
1.1 == approx(1, rel=nan)
with pytest.raises(ValueError):
1.1 == approx(1, abs=nan)
with pytest.raises(ValueError):
1.1 == approx(1, rel=nan, abs=nan)
def test_reasonable_defaults(self):
# Whatever the defaults are, they should work for numbers close to 1