python_api: type annotate some parts of pytest.approx()

This commit is contained in:
Ran Benita 2020-06-25 15:02:04 +03:00
parent 142d8963e6
commit 8f8f472379
2 changed files with 49 additions and 41 deletions

View File

@ -33,7 +33,7 @@ if TYPE_CHECKING:
BASE_TYPE = (type, STRING_TYPES) BASE_TYPE = (type, STRING_TYPES)
def _non_numeric_type_error(value, at): def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
at_str = " at {}".format(at) if at else "" at_str = " at {}".format(at) if at else ""
return TypeError( return TypeError(
"cannot make approximate comparisons to non-numeric values: {!r} {}".format( "cannot make approximate comparisons to non-numeric values: {!r} {}".format(
@ -55,7 +55,7 @@ class ApproxBase:
__array_ufunc__ = None __array_ufunc__ = None
__array_priority__ = 100 __array_priority__ = 100
def __init__(self, expected, rel=None, abs=None, nan_ok=False): def __init__(self, expected, rel=None, abs=None, nan_ok: bool = False) -> None:
__tracebackhide__ = True __tracebackhide__ = True
self.expected = expected self.expected = expected
self.abs = abs self.abs = abs
@ -63,10 +63,10 @@ class ApproxBase:
self.nan_ok = nan_ok self.nan_ok = nan_ok
self._check_type() self._check_type()
def __repr__(self): def __repr__(self) -> str:
raise NotImplementedError raise NotImplementedError
def __eq__(self, actual): def __eq__(self, actual) -> bool:
return all( return all(
a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual) a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual)
) )
@ -74,10 +74,10 @@ class ApproxBase:
# Ignore type because of https://github.com/python/mypy/issues/4266. # Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore __hash__ = None # type: ignore
def __ne__(self, actual): def __ne__(self, actual) -> bool:
return not (actual == self) return not (actual == self)
def _approx_scalar(self, x): def _approx_scalar(self, x) -> "ApproxScalar":
return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok) return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
def _yield_comparisons(self, actual): def _yield_comparisons(self, actual):
@ -87,7 +87,7 @@ class ApproxBase:
""" """
raise NotImplementedError raise NotImplementedError
def _check_type(self): def _check_type(self) -> None:
""" """
Raise a TypeError if the expected value is not a valid type. Raise a TypeError if the expected value is not a valid type.
""" """
@ -111,11 +111,11 @@ class ApproxNumpy(ApproxBase):
Perform approximate comparisons where the expected value is numpy array. Perform approximate comparisons where the expected value is numpy array.
""" """
def __repr__(self): def __repr__(self) -> str:
list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist()) list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
return "approx({!r})".format(list_scalars) return "approx({!r})".format(list_scalars)
def __eq__(self, actual): def __eq__(self, actual) -> bool:
import numpy as np import numpy as np
# self.expected is supposed to always be an array here # self.expected is supposed to always be an array here
@ -154,12 +154,12 @@ class ApproxMapping(ApproxBase):
numeric values (the keys can be anything). numeric values (the keys can be anything).
""" """
def __repr__(self): def __repr__(self) -> str:
return "approx({!r})".format( return "approx({!r})".format(
{k: self._approx_scalar(v) for k, v in self.expected.items()} {k: self._approx_scalar(v) for k, v in self.expected.items()}
) )
def __eq__(self, actual): def __eq__(self, actual) -> bool:
if set(actual.keys()) != set(self.expected.keys()): if set(actual.keys()) != set(self.expected.keys()):
return False return False
@ -169,7 +169,7 @@ class ApproxMapping(ApproxBase):
for k in self.expected.keys(): for k in self.expected.keys():
yield actual[k], self.expected[k] yield actual[k], self.expected[k]
def _check_type(self): def _check_type(self) -> None:
__tracebackhide__ = True __tracebackhide__ = True
for key, value in self.expected.items(): for key, value in self.expected.items():
if isinstance(value, type(self.expected)): if isinstance(value, type(self.expected)):
@ -185,7 +185,7 @@ class ApproxSequencelike(ApproxBase):
numbers. numbers.
""" """
def __repr__(self): def __repr__(self) -> str:
seq_type = type(self.expected) seq_type = type(self.expected)
if seq_type not in (tuple, list, set): if seq_type not in (tuple, list, set):
seq_type = list seq_type = list
@ -193,7 +193,7 @@ class ApproxSequencelike(ApproxBase):
seq_type(self._approx_scalar(x) for x in self.expected) seq_type(self._approx_scalar(x) for x in self.expected)
) )
def __eq__(self, actual): def __eq__(self, actual) -> bool:
if len(actual) != len(self.expected): if len(actual) != len(self.expected):
return False return False
return ApproxBase.__eq__(self, actual) return ApproxBase.__eq__(self, actual)
@ -201,7 +201,7 @@ class ApproxSequencelike(ApproxBase):
def _yield_comparisons(self, actual): def _yield_comparisons(self, actual):
return zip(actual, self.expected) return zip(actual, self.expected)
def _check_type(self): def _check_type(self) -> None:
__tracebackhide__ = True __tracebackhide__ = True
for index, x in enumerate(self.expected): for index, x in enumerate(self.expected):
if isinstance(x, type(self.expected)): if isinstance(x, type(self.expected)):
@ -223,7 +223,7 @@ class ApproxScalar(ApproxBase):
DEFAULT_ABSOLUTE_TOLERANCE = 1e-12 # type: Union[float, Decimal] DEFAULT_ABSOLUTE_TOLERANCE = 1e-12 # type: Union[float, Decimal]
DEFAULT_RELATIVE_TOLERANCE = 1e-6 # type: Union[float, Decimal] DEFAULT_RELATIVE_TOLERANCE = 1e-6 # type: Union[float, Decimal]
def __repr__(self): def __repr__(self) -> str:
""" """
Return a string communicating both the expected value and the tolerance Return a string communicating both the expected value and the tolerance
for the comparison being made, e.g. '1.0 ± 1e-6', '(3+4j) ± 5e-6 ∠ ±180°'. for the comparison being made, e.g. '1.0 ± 1e-6', '(3+4j) ± 5e-6 ∠ ±180°'.
@ -245,7 +245,7 @@ class ApproxScalar(ApproxBase):
return "{} ± {}".format(self.expected, vetted_tolerance) return "{} ± {}".format(self.expected, vetted_tolerance)
def __eq__(self, actual): def __eq__(self, actual) -> bool:
""" """
Return true if the given value is equal to the expected value within Return true if the given value is equal to the expected value within
the pre-specified tolerance. the pre-specified tolerance.
@ -275,7 +275,8 @@ class ApproxScalar(ApproxBase):
return False return False
# Return true if the two numbers are within the tolerance. # Return true if the two numbers are within the tolerance.
return abs(self.expected - actual) <= self.tolerance result = abs(self.expected - actual) <= self.tolerance # type: bool
return result
# Ignore type because of https://github.com/python/mypy/issues/4266. # Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore __hash__ = None # type: ignore
@ -337,7 +338,7 @@ class ApproxDecimal(ApproxScalar):
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6") DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
def approx(expected, rel=None, abs=None, nan_ok=False): def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
""" """
Assert that two numbers (or two sets of numbers) are equal to each other Assert that two numbers (or two sets of numbers) are equal to each other
within some tolerance. within some tolerance.
@ -527,7 +528,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
return cls(expected, rel, abs, nan_ok) return cls(expected, rel, abs, nan_ok)
def _is_numpy_array(obj): def _is_numpy_array(obj: object) -> bool:
""" """
Return true if the given object is a numpy array. Make a special effort to Return true if the given object is a numpy array. Make a special effort to
avoid importing numpy unless it's really necessary. avoid importing numpy unless it's really necessary.

View File

@ -3,6 +3,7 @@ from decimal import Decimal
from fractions import Fraction from fractions import Fraction
from operator import eq from operator import eq
from operator import ne from operator import ne
from typing import Optional
import pytest import pytest
from pytest import approx from pytest import approx
@ -121,18 +122,22 @@ class TestApprox:
assert a == approx(x, rel=5e-1, abs=0.0) assert a == approx(x, rel=5e-1, abs=0.0)
assert a != approx(x, rel=5e-2, abs=0.0) assert a != approx(x, rel=5e-2, abs=0.0)
def test_negative_tolerance(self): @pytest.mark.parametrize(
("rel", "abs"),
[
(-1e100, None),
(None, -1e100),
(1e100, -1e100),
(-1e100, 1e100),
(-1e100, -1e100),
],
)
def test_negative_tolerance(
self, rel: Optional[float], abs: Optional[float]
) -> None:
# Negative tolerances are not allowed. # Negative tolerances are not allowed.
illegal_kwargs = [ with pytest.raises(ValueError):
dict(rel=-1e100), 1.1 == approx(1, rel, abs)
dict(abs=-1e100),
dict(rel=1e100, abs=-1e100),
dict(rel=-1e100, abs=1e100),
dict(rel=-1e100, abs=-1e100),
]
for kwargs in illegal_kwargs:
with pytest.raises(ValueError):
1.1 == approx(1, **kwargs)
def test_inf_tolerance(self): def test_inf_tolerance(self):
# Everything should be equal if the tolerance is infinite. # Everything should be equal if the tolerance is infinite.
@ -143,19 +148,21 @@ class TestApprox:
assert a == approx(x, rel=0.0, abs=inf) assert a == approx(x, rel=0.0, abs=inf)
assert a == approx(x, rel=inf, abs=inf) assert a == approx(x, rel=inf, abs=inf)
def test_inf_tolerance_expecting_zero(self): def test_inf_tolerance_expecting_zero(self) -> None:
# If the relative tolerance is zero but the expected value is infinite, # If the relative tolerance is zero but the expected value is infinite,
# the actual tolerance is a NaN, which should be an error. # the actual tolerance is a NaN, which should be an error.
illegal_kwargs = [dict(rel=inf, abs=0.0), dict(rel=inf, abs=inf)] with pytest.raises(ValueError):
for kwargs in illegal_kwargs: 1 == approx(0, rel=inf, abs=0.0)
with pytest.raises(ValueError): with pytest.raises(ValueError):
1 == approx(0, **kwargs) 1 == approx(0, rel=inf, abs=inf)
def test_nan_tolerance(self): def test_nan_tolerance(self) -> None:
illegal_kwargs = [dict(rel=nan), dict(abs=nan), dict(rel=nan, abs=nan)] with pytest.raises(ValueError):
for kwargs in illegal_kwargs: 1.1 == approx(1, rel=nan)
with pytest.raises(ValueError): with pytest.raises(ValueError):
1.1 == approx(1, **kwargs) 1.1 == approx(1, abs=nan)
with pytest.raises(ValueError):
1.1 == approx(1, rel=nan, abs=nan)
def test_reasonable_defaults(self): def test_reasonable_defaults(self):
# Whatever the defaults are, they should work for numbers close to 1 # Whatever the defaults are, they should work for numbers close to 1