Simplify how comparisons with numpy arrays work.

Previously I was subverting the natural order of operations by
subclassing from `ndarray`, but it turns out that you can tell just
numpy to call your operator instead of its by setting the
`__array_priority__` attribute on your class.  This is much simpler, and
it turns out the be a little more robust, too.
This commit is contained in:
Kale Kundert 2017-07-22 07:52:03 -07:00
parent 6461dc9fc6
commit 495f731760
No known key found for this signature in database
GPG Key ID: C6238221D17CAFAE
1 changed files with 8 additions and 54 deletions

View File

@ -46,60 +46,13 @@ class ApproxBase(object):
raise NotImplementedError
class ApproxNumpyBase(ApproxBase):
class ApproxNumpy(ApproxBase):
"""
Perform approximate comparisons for numpy arrays.
This class should not be used directly. Instead, the `inherit_ndarray()`
class method should be used to make a subclass that also inherits from
`np.ndarray`. This indirection is necessary because the object doing the
approximate comparison must inherit from `np.ndarray`, or it will only work
on the left side of the `==` operator. But importing numpy is relatively
expensive, so we also want to avoid that unless we actually have a numpy
array to compare.
The reason why the approx object needs to inherit from `np.ndarray` has to
do with how python decides whether to call `a.__eq__()` or `b.__eq__()`
when it parses `a == b`. If `a` and `b` are not related by inheritance,
`a` gets priority. So as long as `a.__eq__` is defined, it will be called.
Because most implementations of `a.__eq__` end up calling `b.__eq__`, this
detail usually doesn't matter. However, `np.ndarray.__eq__` treats the
approx object as a scalar and builds a new array by comparing it to each
item in the original array. `b.__eq__` is called to compare against each
individual element in the array, but it has no way (that I can see) to
prevent the return value from being an boolean array, and boolean arrays
can't be used with assert because "the truth value of an array with more
than one element is ambiguous."
The trick is that the priority rules change if `a` and `b` are related
by inheritance. Specifically, `b.__eq__` gets priority if `b` is a
subclass of `a`. So by inheriting from `np.ndarray`, we can guarantee that
`ApproxNumpy.__eq__` gets called no matter which side of the `==` operator
it appears on.
"""
subclass = None
@classmethod
def inherit_ndarray(cls):
import numpy as np
assert not isinstance(cls, np.ndarray)
if cls.subclass is None:
cls.subclass = type('ApproxNumpy', (cls, np.ndarray), {})
return cls.subclass
def __new__(cls, expected, rel=None, abs=None, nan_ok=False):
"""
Numpy uses __new__ (rather than __init__) to initialize objects.
The `expected` argument must be a numpy array. This should be
ensured by the approx() delegator function.
"""
obj = super(ApproxNumpyBase, cls).__new__(cls, ())
obj.__init__(expected, rel, abs, nan_ok)
return obj
# Tell numpy to use our `__eq__` operator instead of its.
__array_priority__ = 100
def __repr__(self):
# It might be nice to rewrite this function to account for the
@ -113,7 +66,7 @@ class ApproxNumpyBase(ApproxBase):
try:
actual = np.asarray(actual)
except:
raise ValueError("cannot cast '{0}' to numpy.ndarray".format(actual))
raise ValueError("cannot compare '{0}' to numpy.ndarray".format(actual))
if actual.shape != self.expected.shape:
return False
@ -157,6 +110,9 @@ class ApproxSequence(ApproxBase):
Perform approximate comparisons for sequences of numbers.
"""
# Tell numpy to use our `__eq__` operator instead of its.
__array_priority__ = 100
def __repr__(self):
seq_type = type(self.expected)
if seq_type not in (tuple, list, set):
@ -422,9 +378,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
# their keys, which is probably not what most people would expect.
if _is_numpy_array(expected):
# Create the delegate class on the fly. This allow us to inherit from
# ``np.ndarray`` while still not importing numpy unless we need to.
cls = ApproxNumpyBase.inherit_ndarray()
cls = ApproxNumpy
elif isinstance(expected, Mapping):
cls = ApproxMapping
elif isinstance(expected, Sequence) and not isinstance(expected, String):