Merge pull request #2606 from kalekundert/simplify-numpy
Make approx more compatible with numpy
This commit is contained in:
commit
1b732fe361
|
@ -46,60 +46,13 @@ class ApproxBase(object):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ApproxNumpyBase(ApproxBase):
|
class ApproxNumpy(ApproxBase):
|
||||||
"""
|
"""
|
||||||
Perform approximate comparisons for numpy arrays.
|
Perform approximate comparisons for numpy arrays.
|
||||||
|
|
||||||
This class should not be used directly. Instead, the `inherit_ndarray()`
|
|
||||||
class method should be used to make a subclass that also inherits from
|
|
||||||
`np.ndarray`. This indirection is necessary because the object doing the
|
|
||||||
approximate comparison must inherit from `np.ndarray`, or it will only work
|
|
||||||
on the left side of the `==` operator. But importing numpy is relatively
|
|
||||||
expensive, so we also want to avoid that unless we actually have a numpy
|
|
||||||
array to compare.
|
|
||||||
|
|
||||||
The reason why the approx object needs to inherit from `np.ndarray` has to
|
|
||||||
do with how python decides whether to call `a.__eq__()` or `b.__eq__()`
|
|
||||||
when it parses `a == b`. If `a` and `b` are not related by inheritance,
|
|
||||||
`a` gets priority. So as long as `a.__eq__` is defined, it will be called.
|
|
||||||
Because most implementations of `a.__eq__` end up calling `b.__eq__`, this
|
|
||||||
detail usually doesn't matter. However, `np.ndarray.__eq__` treats the
|
|
||||||
approx object as a scalar and builds a new array by comparing it to each
|
|
||||||
item in the original array. `b.__eq__` is called to compare against each
|
|
||||||
individual element in the array, but it has no way (that I can see) to
|
|
||||||
prevent the return value from being an boolean array, and boolean arrays
|
|
||||||
can't be used with assert because "the truth value of an array with more
|
|
||||||
than one element is ambiguous."
|
|
||||||
|
|
||||||
The trick is that the priority rules change if `a` and `b` are related
|
|
||||||
by inheritance. Specifically, `b.__eq__` gets priority if `b` is a
|
|
||||||
subclass of `a`. So by inheriting from `np.ndarray`, we can guarantee that
|
|
||||||
`ApproxNumpy.__eq__` gets called no matter which side of the `==` operator
|
|
||||||
it appears on.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
subclass = None
|
# Tell numpy to use our `__eq__` operator instead of its.
|
||||||
|
__array_priority__ = 100
|
||||||
@classmethod
|
|
||||||
def inherit_ndarray(cls):
|
|
||||||
import numpy as np
|
|
||||||
assert not isinstance(cls, np.ndarray)
|
|
||||||
|
|
||||||
if cls.subclass is None:
|
|
||||||
cls.subclass = type('ApproxNumpy', (cls, np.ndarray), {})
|
|
||||||
|
|
||||||
return cls.subclass
|
|
||||||
|
|
||||||
def __new__(cls, expected, rel=None, abs=None, nan_ok=False):
|
|
||||||
"""
|
|
||||||
Numpy uses __new__ (rather than __init__) to initialize objects.
|
|
||||||
|
|
||||||
The `expected` argument must be a numpy array. This should be
|
|
||||||
ensured by the approx() delegator function.
|
|
||||||
"""
|
|
||||||
obj = super(ApproxNumpyBase, cls).__new__(cls, ())
|
|
||||||
obj.__init__(expected, rel, abs, nan_ok)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# It might be nice to rewrite this function to account for the
|
# It might be nice to rewrite this function to account for the
|
||||||
|
@ -113,7 +66,7 @@ class ApproxNumpyBase(ApproxBase):
|
||||||
try:
|
try:
|
||||||
actual = np.asarray(actual)
|
actual = np.asarray(actual)
|
||||||
except:
|
except:
|
||||||
raise ValueError("cannot cast '{0}' to numpy.ndarray".format(actual))
|
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
|
||||||
|
|
||||||
if actual.shape != self.expected.shape:
|
if actual.shape != self.expected.shape:
|
||||||
return False
|
return False
|
||||||
|
@ -157,6 +110,9 @@ class ApproxSequence(ApproxBase):
|
||||||
Perform approximate comparisons for sequences of numbers.
|
Perform approximate comparisons for sequences of numbers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Tell numpy to use our `__eq__` operator instead of its.
|
||||||
|
__array_priority__ = 100
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
seq_type = type(self.expected)
|
seq_type = type(self.expected)
|
||||||
if seq_type not in (tuple, list, set):
|
if seq_type not in (tuple, list, set):
|
||||||
|
@ -422,9 +378,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
|
||||||
# their keys, which is probably not what most people would expect.
|
# their keys, which is probably not what most people would expect.
|
||||||
|
|
||||||
if _is_numpy_array(expected):
|
if _is_numpy_array(expected):
|
||||||
# Create the delegate class on the fly. This allow us to inherit from
|
cls = ApproxNumpy
|
||||||
# ``np.ndarray`` while still not importing numpy unless we need to.
|
|
||||||
cls = ApproxNumpyBase.inherit_ndarray()
|
|
||||||
elif isinstance(expected, Mapping):
|
elif isinstance(expected, Mapping):
|
||||||
cls = ApproxMapping
|
cls = ApproxMapping
|
||||||
elif isinstance(expected, Sequence) and not isinstance(expected, String):
|
elif isinstance(expected, Sequence) and not isinstance(expected, String):
|
||||||
|
|
Loading…
Reference in New Issue