Avoid making multiple ApproxNumpy types.

This commit is contained in:
Kale Kundert 2017-07-03 22:45:24 -07:00
parent 8524a57075
commit c111e9dac3
No known key found for this signature in database
GPG Key ID: C6238221D17CAFAE
1 changed files with 16 additions and 9 deletions

View File

@ -49,13 +49,9 @@ class ApproxNumpyBase(ApproxBase):
""" """
Perform approximate comparisons for numpy arrays. Perform approximate comparisons for numpy arrays.
This class should not be used directly. Instead, it should be used to make This class should not be used directly. Instead, the `inherit_ndarray()`
a subclass that also inherits from `np.ndarray`, e.g.:: class method should be used to make a subclass that also inherits from
`np.ndarray`. This indirection is necessary because the object doing the
import numpy as np
ApproxNumpy = type('ApproxNumpy', (ApproxNumpyBase, np.ndarray), {})
This bizarre invocation is necessary because the object doing the
approximate comparison must inherit from `np.ndarray`, or it will only work approximate comparison must inherit from `np.ndarray`, or it will only work
on the left side of the `==` operator. But importing numpy is relatively on the left side of the `==` operator. But importing numpy is relatively
expensive, so we also want to avoid that unless we actually have a numpy expensive, so we also want to avoid that unless we actually have a numpy
@ -81,6 +77,18 @@ class ApproxNumpyBase(ApproxBase):
it appears on. it appears on.
""" """
subclass = None
@classmethod
def inherit_ndarray(cls):
import numpy as np
assert not isinstance(cls, np.ndarray)
if cls.subclass is None:
cls.subclass = type('ApproxNumpy', (ApproxNumpyBase, np.ndarray), {})
return cls.subclass
def __new__(cls, expected, rel=None, abs=None, nan_ok=False): def __new__(cls, expected, rel=None, abs=None, nan_ok=False):
""" """
Numpy uses __new__ (rather than __init__) to initialize objects. Numpy uses __new__ (rather than __init__) to initialize objects.
@ -416,8 +424,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
if _is_numpy_array(expected): if _is_numpy_array(expected):
# Create the delegate class on the fly. This allow us to inherit from # Create the delegate class on the fly. This allow us to inherit from
# ``np.ndarray`` while still not importing numpy unless we need to. # ``np.ndarray`` while still not importing numpy unless we need to.
import numpy as np cls = ApproxNumpyBase.inherit_ndarray()
cls = type('ApproxNumpy', (ApproxNumpyBase, np.ndarray), {})
elif isinstance(expected, Mapping): elif isinstance(expected, Mapping):
cls = ApproxMapping cls = ApproxMapping
elif isinstance(expected, Sequence) and not isinstance(expected, String): elif isinstance(expected, Sequence) and not isinstance(expected, String):