python_api: type annotate some parts of pytest.approx()
This commit is contained in:
parent
142d8963e6
commit
8f8f472379
|
@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||||
BASE_TYPE = (type, STRING_TYPES)
|
BASE_TYPE = (type, STRING_TYPES)
|
||||||
|
|
||||||
|
|
||||||
def _non_numeric_type_error(value, at):
|
def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
|
||||||
at_str = " at {}".format(at) if at else ""
|
at_str = " at {}".format(at) if at else ""
|
||||||
return TypeError(
|
return TypeError(
|
||||||
"cannot make approximate comparisons to non-numeric values: {!r} {}".format(
|
"cannot make approximate comparisons to non-numeric values: {!r} {}".format(
|
||||||
|
@ -55,7 +55,7 @@ class ApproxBase:
|
||||||
__array_ufunc__ = None
|
__array_ufunc__ = None
|
||||||
__array_priority__ = 100
|
__array_priority__ = 100
|
||||||
|
|
||||||
def __init__(self, expected, rel=None, abs=None, nan_ok=False):
|
def __init__(self, expected, rel=None, abs=None, nan_ok: bool = False) -> None:
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
self.expected = expected
|
self.expected = expected
|
||||||
self.abs = abs
|
self.abs = abs
|
||||||
|
@ -63,10 +63,10 @@ class ApproxBase:
|
||||||
self.nan_ok = nan_ok
|
self.nan_ok = nan_ok
|
||||||
self._check_type()
|
self._check_type()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __eq__(self, actual):
|
def __eq__(self, actual) -> bool:
|
||||||
return all(
|
return all(
|
||||||
a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual)
|
a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual)
|
||||||
)
|
)
|
||||||
|
@ -74,10 +74,10 @@ class ApproxBase:
|
||||||
# Ignore type because of https://github.com/python/mypy/issues/4266.
|
# Ignore type because of https://github.com/python/mypy/issues/4266.
|
||||||
__hash__ = None # type: ignore
|
__hash__ = None # type: ignore
|
||||||
|
|
||||||
def __ne__(self, actual):
|
def __ne__(self, actual) -> bool:
|
||||||
return not (actual == self)
|
return not (actual == self)
|
||||||
|
|
||||||
def _approx_scalar(self, x):
|
def _approx_scalar(self, x) -> "ApproxScalar":
|
||||||
return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
|
return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
|
||||||
|
|
||||||
def _yield_comparisons(self, actual):
|
def _yield_comparisons(self, actual):
|
||||||
|
@ -87,7 +87,7 @@ class ApproxBase:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _check_type(self):
|
def _check_type(self) -> None:
|
||||||
"""
|
"""
|
||||||
Raise a TypeError if the expected value is not a valid type.
|
Raise a TypeError if the expected value is not a valid type.
|
||||||
"""
|
"""
|
||||||
|
@ -111,11 +111,11 @@ class ApproxNumpy(ApproxBase):
|
||||||
Perform approximate comparisons where the expected value is numpy array.
|
Perform approximate comparisons where the expected value is numpy array.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
|
list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
|
||||||
return "approx({!r})".format(list_scalars)
|
return "approx({!r})".format(list_scalars)
|
||||||
|
|
||||||
def __eq__(self, actual):
|
def __eq__(self, actual) -> bool:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# self.expected is supposed to always be an array here
|
# self.expected is supposed to always be an array here
|
||||||
|
@ -154,12 +154,12 @@ class ApproxMapping(ApproxBase):
|
||||||
numeric values (the keys can be anything).
|
numeric values (the keys can be anything).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return "approx({!r})".format(
|
return "approx({!r})".format(
|
||||||
{k: self._approx_scalar(v) for k, v in self.expected.items()}
|
{k: self._approx_scalar(v) for k, v in self.expected.items()}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __eq__(self, actual):
|
def __eq__(self, actual) -> bool:
|
||||||
if set(actual.keys()) != set(self.expected.keys()):
|
if set(actual.keys()) != set(self.expected.keys()):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -169,7 +169,7 @@ class ApproxMapping(ApproxBase):
|
||||||
for k in self.expected.keys():
|
for k in self.expected.keys():
|
||||||
yield actual[k], self.expected[k]
|
yield actual[k], self.expected[k]
|
||||||
|
|
||||||
def _check_type(self):
|
def _check_type(self) -> None:
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
for key, value in self.expected.items():
|
for key, value in self.expected.items():
|
||||||
if isinstance(value, type(self.expected)):
|
if isinstance(value, type(self.expected)):
|
||||||
|
@ -185,7 +185,7 @@ class ApproxSequencelike(ApproxBase):
|
||||||
numbers.
|
numbers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
seq_type = type(self.expected)
|
seq_type = type(self.expected)
|
||||||
if seq_type not in (tuple, list, set):
|
if seq_type not in (tuple, list, set):
|
||||||
seq_type = list
|
seq_type = list
|
||||||
|
@ -193,7 +193,7 @@ class ApproxSequencelike(ApproxBase):
|
||||||
seq_type(self._approx_scalar(x) for x in self.expected)
|
seq_type(self._approx_scalar(x) for x in self.expected)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __eq__(self, actual):
|
def __eq__(self, actual) -> bool:
|
||||||
if len(actual) != len(self.expected):
|
if len(actual) != len(self.expected):
|
||||||
return False
|
return False
|
||||||
return ApproxBase.__eq__(self, actual)
|
return ApproxBase.__eq__(self, actual)
|
||||||
|
@ -201,7 +201,7 @@ class ApproxSequencelike(ApproxBase):
|
||||||
def _yield_comparisons(self, actual):
|
def _yield_comparisons(self, actual):
|
||||||
return zip(actual, self.expected)
|
return zip(actual, self.expected)
|
||||||
|
|
||||||
def _check_type(self):
|
def _check_type(self) -> None:
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
for index, x in enumerate(self.expected):
|
for index, x in enumerate(self.expected):
|
||||||
if isinstance(x, type(self.expected)):
|
if isinstance(x, type(self.expected)):
|
||||||
|
@ -223,7 +223,7 @@ class ApproxScalar(ApproxBase):
|
||||||
DEFAULT_ABSOLUTE_TOLERANCE = 1e-12 # type: Union[float, Decimal]
|
DEFAULT_ABSOLUTE_TOLERANCE = 1e-12 # type: Union[float, Decimal]
|
||||||
DEFAULT_RELATIVE_TOLERANCE = 1e-6 # type: Union[float, Decimal]
|
DEFAULT_RELATIVE_TOLERANCE = 1e-6 # type: Union[float, Decimal]
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
"""
|
"""
|
||||||
Return a string communicating both the expected value and the tolerance
|
Return a string communicating both the expected value and the tolerance
|
||||||
for the comparison being made, e.g. '1.0 ± 1e-6', '(3+4j) ± 5e-6 ∠ ±180°'.
|
for the comparison being made, e.g. '1.0 ± 1e-6', '(3+4j) ± 5e-6 ∠ ±180°'.
|
||||||
|
@ -245,7 +245,7 @@ class ApproxScalar(ApproxBase):
|
||||||
|
|
||||||
return "{} ± {}".format(self.expected, vetted_tolerance)
|
return "{} ± {}".format(self.expected, vetted_tolerance)
|
||||||
|
|
||||||
def __eq__(self, actual):
|
def __eq__(self, actual) -> bool:
|
||||||
"""
|
"""
|
||||||
Return true if the given value is equal to the expected value within
|
Return true if the given value is equal to the expected value within
|
||||||
the pre-specified tolerance.
|
the pre-specified tolerance.
|
||||||
|
@ -275,7 +275,8 @@ class ApproxScalar(ApproxBase):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Return true if the two numbers are within the tolerance.
|
# Return true if the two numbers are within the tolerance.
|
||||||
return abs(self.expected - actual) <= self.tolerance
|
result = abs(self.expected - actual) <= self.tolerance # type: bool
|
||||||
|
return result
|
||||||
|
|
||||||
# Ignore type because of https://github.com/python/mypy/issues/4266.
|
# Ignore type because of https://github.com/python/mypy/issues/4266.
|
||||||
__hash__ = None # type: ignore
|
__hash__ = None # type: ignore
|
||||||
|
@ -337,7 +338,7 @@ class ApproxDecimal(ApproxScalar):
|
||||||
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
|
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
|
||||||
|
|
||||||
|
|
||||||
def approx(expected, rel=None, abs=None, nan_ok=False):
|
def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
|
||||||
"""
|
"""
|
||||||
Assert that two numbers (or two sets of numbers) are equal to each other
|
Assert that two numbers (or two sets of numbers) are equal to each other
|
||||||
within some tolerance.
|
within some tolerance.
|
||||||
|
@ -527,7 +528,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
|
||||||
return cls(expected, rel, abs, nan_ok)
|
return cls(expected, rel, abs, nan_ok)
|
||||||
|
|
||||||
|
|
||||||
def _is_numpy_array(obj):
|
def _is_numpy_array(obj: object) -> bool:
|
||||||
"""
|
"""
|
||||||
Return true if the given object is a numpy array. Make a special effort to
|
Return true if the given object is a numpy array. Make a special effort to
|
||||||
avoid importing numpy unless it's really necessary.
|
avoid importing numpy unless it's really necessary.
|
||||||
|
|
|
@ -3,6 +3,7 @@ from decimal import Decimal
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from operator import eq
|
from operator import eq
|
||||||
from operator import ne
|
from operator import ne
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest import approx
|
from pytest import approx
|
||||||
|
@ -121,18 +122,22 @@ class TestApprox:
|
||||||
assert a == approx(x, rel=5e-1, abs=0.0)
|
assert a == approx(x, rel=5e-1, abs=0.0)
|
||||||
assert a != approx(x, rel=5e-2, abs=0.0)
|
assert a != approx(x, rel=5e-2, abs=0.0)
|
||||||
|
|
||||||
def test_negative_tolerance(self):
|
@pytest.mark.parametrize(
|
||||||
|
("rel", "abs"),
|
||||||
|
[
|
||||||
|
(-1e100, None),
|
||||||
|
(None, -1e100),
|
||||||
|
(1e100, -1e100),
|
||||||
|
(-1e100, 1e100),
|
||||||
|
(-1e100, -1e100),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_negative_tolerance(
|
||||||
|
self, rel: Optional[float], abs: Optional[float]
|
||||||
|
) -> None:
|
||||||
# Negative tolerances are not allowed.
|
# Negative tolerances are not allowed.
|
||||||
illegal_kwargs = [
|
with pytest.raises(ValueError):
|
||||||
dict(rel=-1e100),
|
1.1 == approx(1, rel, abs)
|
||||||
dict(abs=-1e100),
|
|
||||||
dict(rel=1e100, abs=-1e100),
|
|
||||||
dict(rel=-1e100, abs=1e100),
|
|
||||||
dict(rel=-1e100, abs=-1e100),
|
|
||||||
]
|
|
||||||
for kwargs in illegal_kwargs:
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
1.1 == approx(1, **kwargs)
|
|
||||||
|
|
||||||
def test_inf_tolerance(self):
|
def test_inf_tolerance(self):
|
||||||
# Everything should be equal if the tolerance is infinite.
|
# Everything should be equal if the tolerance is infinite.
|
||||||
|
@ -143,19 +148,21 @@ class TestApprox:
|
||||||
assert a == approx(x, rel=0.0, abs=inf)
|
assert a == approx(x, rel=0.0, abs=inf)
|
||||||
assert a == approx(x, rel=inf, abs=inf)
|
assert a == approx(x, rel=inf, abs=inf)
|
||||||
|
|
||||||
def test_inf_tolerance_expecting_zero(self):
|
def test_inf_tolerance_expecting_zero(self) -> None:
|
||||||
# If the relative tolerance is zero but the expected value is infinite,
|
# If the relative tolerance is zero but the expected value is infinite,
|
||||||
# the actual tolerance is a NaN, which should be an error.
|
# the actual tolerance is a NaN, which should be an error.
|
||||||
illegal_kwargs = [dict(rel=inf, abs=0.0), dict(rel=inf, abs=inf)]
|
with pytest.raises(ValueError):
|
||||||
for kwargs in illegal_kwargs:
|
1 == approx(0, rel=inf, abs=0.0)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
1 == approx(0, **kwargs)
|
1 == approx(0, rel=inf, abs=inf)
|
||||||
|
|
||||||
def test_nan_tolerance(self):
|
def test_nan_tolerance(self) -> None:
|
||||||
illegal_kwargs = [dict(rel=nan), dict(abs=nan), dict(rel=nan, abs=nan)]
|
with pytest.raises(ValueError):
|
||||||
for kwargs in illegal_kwargs:
|
1.1 == approx(1, rel=nan)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
1.1 == approx(1, **kwargs)
|
1.1 == approx(1, abs=nan)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
1.1 == approx(1, rel=nan, abs=nan)
|
||||||
|
|
||||||
def test_reasonable_defaults(self):
|
def test_reasonable_defaults(self):
|
||||||
# Whatever the defaults are, they should work for numbers close to 1
|
# Whatever the defaults are, they should work for numbers close to 1
|
||||||
|
|
Loading…
Reference in New Issue