Merge pull request #2606 from kalekundert/simplify-numpy
Make approx more compatible with numpy
This commit is contained in:
commit
1b732fe361
|
@ -46,60 +46,13 @@ class ApproxBase(object):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class ApproxNumpyBase(ApproxBase):
|
||||
class ApproxNumpy(ApproxBase):
|
||||
"""
|
||||
Perform approximate comparisons for numpy arrays.
|
||||
|
||||
This class should not be used directly. Instead, the `inherit_ndarray()`
|
||||
class method should be used to make a subclass that also inherits from
|
||||
`np.ndarray`. This indirection is necessary because the object doing the
|
||||
approximate comparison must inherit from `np.ndarray`, or it will only work
|
||||
on the left side of the `==` operator. But importing numpy is relatively
|
||||
expensive, so we also want to avoid that unless we actually have a numpy
|
||||
array to compare.
|
||||
|
||||
The reason why the approx object needs to inherit from `np.ndarray` has to
|
||||
do with how python decides whether to call `a.__eq__()` or `b.__eq__()`
|
||||
when it parses `a == b`. If `a` and `b` are not related by inheritance,
|
||||
`a` gets priority. So as long as `a.__eq__` is defined, it will be called.
|
||||
Because most implementations of `a.__eq__` end up calling `b.__eq__`, this
|
||||
detail usually doesn't matter. However, `np.ndarray.__eq__` treats the
|
||||
approx object as a scalar and builds a new array by comparing it to each
|
||||
item in the original array. `b.__eq__` is called to compare against each
|
||||
individual element in the array, but it has no way (that I can see) to
|
||||
prevent the return value from being an boolean array, and boolean arrays
|
||||
can't be used with assert because "the truth value of an array with more
|
||||
than one element is ambiguous."
|
||||
|
||||
The trick is that the priority rules change if `a` and `b` are related
|
||||
by inheritance. Specifically, `b.__eq__` gets priority if `b` is a
|
||||
subclass of `a`. So by inheriting from `np.ndarray`, we can guarantee that
|
||||
`ApproxNumpy.__eq__` gets called no matter which side of the `==` operator
|
||||
it appears on.
|
||||
"""
|
||||
|
||||
subclass = None
|
||||
|
||||
@classmethod
|
||||
def inherit_ndarray(cls):
|
||||
import numpy as np
|
||||
assert not isinstance(cls, np.ndarray)
|
||||
|
||||
if cls.subclass is None:
|
||||
cls.subclass = type('ApproxNumpy', (cls, np.ndarray), {})
|
||||
|
||||
return cls.subclass
|
||||
|
||||
def __new__(cls, expected, rel=None, abs=None, nan_ok=False):
|
||||
"""
|
||||
Numpy uses __new__ (rather than __init__) to initialize objects.
|
||||
|
||||
The `expected` argument must be a numpy array. This should be
|
||||
ensured by the approx() delegator function.
|
||||
"""
|
||||
obj = super(ApproxNumpyBase, cls).__new__(cls, ())
|
||||
obj.__init__(expected, rel, abs, nan_ok)
|
||||
return obj
|
||||
# Tell numpy to use our `__eq__` operator instead of its.
|
||||
__array_priority__ = 100
|
||||
|
||||
def __repr__(self):
|
||||
# It might be nice to rewrite this function to account for the
|
||||
|
@ -113,7 +66,7 @@ class ApproxNumpyBase(ApproxBase):
|
|||
try:
|
||||
actual = np.asarray(actual)
|
||||
except:
|
||||
raise ValueError("cannot cast '{0}' to numpy.ndarray".format(actual))
|
||||
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
|
||||
|
||||
if actual.shape != self.expected.shape:
|
||||
return False
|
||||
|
@ -157,6 +110,9 @@ class ApproxSequence(ApproxBase):
|
|||
Perform approximate comparisons for sequences of numbers.
|
||||
"""
|
||||
|
||||
# Tell numpy to use our `__eq__` operator instead of its.
|
||||
__array_priority__ = 100
|
||||
|
||||
def __repr__(self):
|
||||
seq_type = type(self.expected)
|
||||
if seq_type not in (tuple, list, set):
|
||||
|
@ -422,9 +378,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
|
|||
# their keys, which is probably not what most people would expect.
|
||||
|
||||
if _is_numpy_array(expected):
|
||||
# Create the delegate class on the fly. This allow us to inherit from
|
||||
# ``np.ndarray`` while still not importing numpy unless we need to.
|
||||
cls = ApproxNumpyBase.inherit_ndarray()
|
||||
cls = ApproxNumpy
|
||||
elif isinstance(expected, Mapping):
|
||||
cls = ApproxMapping
|
||||
elif isinstance(expected, Sequence) and not isinstance(expected, String):
|
||||
|
|
Loading…
Reference in New Issue