python_api: handle array-like args in approx() (#8137)
This commit is contained in:
parent
6c899a0afa
commit
9ccbf5b899
|
@ -0,0 +1,10 @@
|
||||||
|
Fixed regression in ``approx``: in 6.2.0 ``approx`` no longer raises
|
||||||
|
``TypeError`` when dealing with non-numeric types, falling back to normal comparison.
|
||||||
|
Before 6.2.0, array types like tf.DeviceArray fell through to the scalar case,
|
||||||
|
and happened to compare correctly to a scalar if they had only one element.
|
||||||
|
After 6.2.0, these types began failing, because they inherited neither from
|
||||||
|
standard Python number hierarchy nor from ``numpy.ndarray``.
|
||||||
|
|
||||||
|
``approx`` now converts arguments to ``numpy.ndarray`` if they expose the array
|
||||||
|
protocol and are not scalars. This treats array-like objects like numpy arrays,
|
||||||
|
regardless of size.
|
|
@ -15,9 +15,14 @@ from typing import overload
|
||||||
from typing import Pattern
|
from typing import Pattern
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from numpy import ndarray
|
||||||
|
|
||||||
|
|
||||||
import _pytest._code
|
import _pytest._code
|
||||||
from _pytest.compat import final
|
from _pytest.compat import final
|
||||||
from _pytest.compat import STRING_TYPES
|
from _pytest.compat import STRING_TYPES
|
||||||
|
@ -232,10 +237,11 @@ class ApproxScalar(ApproxBase):
|
||||||
def __eq__(self, actual) -> bool:
|
def __eq__(self, actual) -> bool:
|
||||||
"""Return whether the given value is equal to the expected value
|
"""Return whether the given value is equal to the expected value
|
||||||
within the pre-specified tolerance."""
|
within the pre-specified tolerance."""
|
||||||
if _is_numpy_array(actual):
|
asarray = _as_numpy_array(actual)
|
||||||
|
if asarray is not None:
|
||||||
# Call ``__eq__()`` manually to prevent infinite-recursion with
|
# Call ``__eq__()`` manually to prevent infinite-recursion with
|
||||||
# numpy<1.13. See #3748.
|
# numpy<1.13. See #3748.
|
||||||
return all(self.__eq__(a) for a in actual.flat)
|
return all(self.__eq__(a) for a in asarray.flat)
|
||||||
|
|
||||||
# Short-circuit exact equality.
|
# Short-circuit exact equality.
|
||||||
if actual == self.expected:
|
if actual == self.expected:
|
||||||
|
@ -521,6 +527,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
|
||||||
elif isinstance(expected, Mapping):
|
elif isinstance(expected, Mapping):
|
||||||
cls = ApproxMapping
|
cls = ApproxMapping
|
||||||
elif _is_numpy_array(expected):
|
elif _is_numpy_array(expected):
|
||||||
|
expected = _as_numpy_array(expected)
|
||||||
cls = ApproxNumpy
|
cls = ApproxNumpy
|
||||||
elif (
|
elif (
|
||||||
isinstance(expected, Iterable)
|
isinstance(expected, Iterable)
|
||||||
|
@ -536,16 +543,30 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
|
||||||
|
|
||||||
|
|
||||||
def _is_numpy_array(obj: object) -> bool:
|
def _is_numpy_array(obj: object) -> bool:
|
||||||
"""Return true if the given object is a numpy array.
|
"""
|
||||||
|
Return true if the given object is implicitly convertible to ndarray,
|
||||||
|
and numpy is already imported.
|
||||||
|
"""
|
||||||
|
return _as_numpy_array(obj) is not None
|
||||||
|
|
||||||
A special effort is made to avoid importing numpy unless it's really necessary.
|
|
||||||
|
def _as_numpy_array(obj: object) -> Optional["ndarray"]:
|
||||||
|
"""
|
||||||
|
Return an ndarray if the given object is implicitly convertible to ndarray,
|
||||||
|
and numpy is already imported, otherwise None.
|
||||||
"""
|
"""
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
np: Any = sys.modules.get("numpy")
|
np: Any = sys.modules.get("numpy")
|
||||||
if np is not None:
|
if np is not None:
|
||||||
return isinstance(obj, np.ndarray)
|
# avoid infinite recursion on numpy scalars, which have __array__
|
||||||
return False
|
if np.isscalar(obj):
|
||||||
|
return None
|
||||||
|
elif isinstance(obj, np.ndarray):
|
||||||
|
return obj
|
||||||
|
elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"):
|
||||||
|
return np.asarray(obj)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# builtin pytest.raises helper
|
# builtin pytest.raises helper
|
||||||
|
|
|
@ -447,6 +447,36 @@ class TestApprox:
|
||||||
assert a12 != approx(a21)
|
assert a12 != approx(a21)
|
||||||
assert a21 != approx(a12)
|
assert a21 != approx(a12)
|
||||||
|
|
||||||
|
def test_numpy_array_protocol(self):
|
||||||
|
"""
|
||||||
|
array-like objects such as tensorflow's DeviceArray are handled like ndarray.
|
||||||
|
See issue #8132
|
||||||
|
"""
|
||||||
|
np = pytest.importorskip("numpy")
|
||||||
|
|
||||||
|
class DeviceArray:
|
||||||
|
def __init__(self, value, size):
|
||||||
|
self.value = value
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __array__(self):
|
||||||
|
return self.value * np.ones(self.size)
|
||||||
|
|
||||||
|
class DeviceScalar:
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __array__(self):
|
||||||
|
return np.array(self.value)
|
||||||
|
|
||||||
|
expected = 1
|
||||||
|
actual = 1 + 1e-6
|
||||||
|
assert approx(expected) == DeviceArray(actual, size=1)
|
||||||
|
assert approx(expected) == DeviceArray(actual, size=2)
|
||||||
|
assert approx(expected) == DeviceScalar(actual)
|
||||||
|
assert approx(DeviceScalar(expected)) == actual
|
||||||
|
assert approx(DeviceScalar(expected)) == DeviceScalar(actual)
|
||||||
|
|
||||||
def test_doctests(self, mocked_doctest_runner) -> None:
|
def test_doctests(self, mocked_doctest_runner) -> None:
|
||||||
import doctest
|
import doctest
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue