Add support for pytest.approx comparisons between array and scalar

This commit is contained in:
Tadeu Manoel 2018-03-14 15:29:40 -03:00
parent cbb2c55dea
commit c34dde7a3f
3 changed files with 33 additions and 4 deletions

View File

@ -31,6 +31,10 @@ class ApproxBase(object):
or sequences of numbers. or sequences of numbers.
""" """
# Tell numpy to use our `__eq__` operator instead of its when left side in a numpy array but right side is
# an instance of ApproxBase
__array_ufunc__ = None
def __init__(self, expected, rel=None, abs=None, nan_ok=False): def __init__(self, expected, rel=None, abs=None, nan_ok=False):
self.expected = expected self.expected = expected
self.abs = abs self.abs = abs
@ -89,7 +93,7 @@ class ApproxNumpy(ApproxBase):
except: # noqa except: # noqa
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual)) raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
if actual.shape != self.expected.shape: if not np.isscalar(self.expected) and actual.shape != self.expected.shape:
return False return False
return ApproxBase.__eq__(self, actual) return ApproxBase.__eq__(self, actual)
@ -100,8 +104,13 @@ class ApproxNumpy(ApproxBase):
# We can be sure that `actual` is a numpy array, because it's # We can be sure that `actual` is a numpy array, because it's
# casted in `__eq__` before being passed to `ApproxBase.__eq__`, # casted in `__eq__` before being passed to `ApproxBase.__eq__`,
# which is the only method that calls this one. # which is the only method that calls this one.
for i in np.ndindex(self.expected.shape):
yield actual[i], self.expected[i] if np.isscalar(self.expected):
for i in np.ndindex(actual.shape):
yield actual[i], self.expected
else:
for i in np.ndindex(self.expected.shape):
yield actual[i], self.expected[i]
class ApproxMapping(ApproxBase): class ApproxMapping(ApproxBase):
@ -189,6 +198,8 @@ class ApproxScalar(ApproxBase):
Return true if the given value is equal to the expected value within Return true if the given value is equal to the expected value within
the pre-specified tolerance. the pre-specified tolerance.
""" """
if _is_numpy_array(actual):
return actual == ApproxNumpy(self.expected, self.abs, self.rel, self.nan_ok)
# Short-circuit exact equality. # Short-circuit exact equality.
if actual == self.expected: if actual == self.expected:
@ -308,12 +319,18 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
>>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6}) >>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
True True
And ``numpy`` arrays:: ``numpy`` arrays::
>>> import numpy as np # doctest: +SKIP >>> import numpy as np # doctest: +SKIP
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) # doctest: +SKIP >>> np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) # doctest: +SKIP
True True
And for a ``numpy`` array against a scalar::
>>> import numpy as np # doctest: +SKIP
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.1]) == approx(0.3) # doctest: +SKIP
True
By default, ``approx`` considers numbers within a relative tolerance of By default, ``approx`` considers numbers within a relative tolerance of
``1e-6`` (i.e. one part in a million) of its expected value to be equal. ``1e-6`` (i.e. one part in a million) of its expected value to be equal.
This treatment would lead to surprising results if the expected value was This treatment would lead to surprising results if the expected value was

1
changelog/3312.feature Normal file
View File

@ -0,0 +1 @@
``pytest.approx`` now accepts comparing a numpy array with a scalar.

View File

@ -391,3 +391,14 @@ class TestApprox(object):
""" """
with pytest.raises(TypeError): with pytest.raises(TypeError):
op(1, approx(1, rel=1e-6, abs=1e-12)) op(1, approx(1, rel=1e-6, abs=1e-12))
def test_numpy_array_with_scalar(self):
np = pytest.importorskip('numpy')
actual = np.array([1 + 1e-7, 1 - 1e-8])
expected = 1.0
assert actual == approx(expected, rel=5e-7, abs=0)
assert actual != approx(expected, rel=5e-8, abs=0)
assert approx(expected, rel=5e-7, abs=0) == actual
assert approx(expected, rel=5e-8, abs=0) != actual