Add fixes to `numpy.approx` array-scalar comparisons (from PR suggestions)

This commit is contained in:
Tadeu Manoel 2018-03-15 13:41:58 -03:00
parent 97f9a8bfdf
commit 42c84f4f30
1 changed files with 8 additions and 4 deletions

View File

@ -76,8 +76,10 @@ class ApproxNumpy(ApproxBase):
def __repr__(self): def __repr__(self):
# It might be nice to rewrite this function to account for the # It might be nice to rewrite this function to account for the
# shape of the array... # shape of the array...
import numpy as np
return "approx({0!r})".format(list( return "approx({0!r})".format(list(
self._approx_scalar(x) for x in self.expected)) self._approx_scalar(x) for x in np.asarray(self.expected)))
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
__cmp__ = _cmp_raises_type_error __cmp__ = _cmp_raises_type_error
@ -100,9 +102,11 @@ class ApproxNumpy(ApproxBase):
def _yield_comparisons(self, actual): def _yield_comparisons(self, actual):
import numpy as np import numpy as np
# We can be sure that `actual` is a numpy array, because it's # For both `actual` and `self.expected`, they can independently be
# casted in `__eq__` before being passed to `ApproxBase.__eq__`, # either a `numpy.array` or a scalar (but both can't be scalar,
# which is the only method that calls this one. # in this case an `ApproxScalar` is used).
# They are treated in `__eq__` before being passed to
# `ApproxBase.__eq__`, which is the only method that calls this one.
if np.isscalar(self.expected): if np.isscalar(self.expected):
for i in np.ndindex(actual.shape): for i in np.ndindex(actual.shape):