python_api: handle array-like args in approx() (#8137)

This commit is contained in:
Jakob van Santen 2020-12-15 12:49:29 +01:00 committed by GitHub
parent 6c899a0afa
commit 9ccbf5b899
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 67 additions and 6 deletions

10
changelog/8132.bugfix.rst Normal file
View File

@ -0,0 +1,10 @@
Fixed regression in ``approx``: in 6.2.0 ``approx`` no longer raises
``TypeError`` when dealing with non-numeric types, falling back to normal comparison.
Before 6.2.0, array types like tf.DeviceArray fell through to the scalar case,
and happened to compare correctly to a scalar if they had only one element.
After 6.2.0, these types began failing, because they inherited neither from
standard Python number hierarchy nor from ``numpy.ndarray``.
``approx`` now converts arguments to ``numpy.ndarray`` if they expose the array
protocol and are not scalars. This treats array-like objects like numpy arrays,
regardless of size.

View File

@ -15,9 +15,14 @@ from typing import overload
from typing import Pattern from typing import Pattern
from typing import Tuple from typing import Tuple
from typing import Type from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
from typing import Union from typing import Union
if TYPE_CHECKING:
from numpy import ndarray
import _pytest._code import _pytest._code
from _pytest.compat import final from _pytest.compat import final
from _pytest.compat import STRING_TYPES from _pytest.compat import STRING_TYPES
@ -232,10 +237,11 @@ class ApproxScalar(ApproxBase):
def __eq__(self, actual) -> bool: def __eq__(self, actual) -> bool:
"""Return whether the given value is equal to the expected value """Return whether the given value is equal to the expected value
within the pre-specified tolerance.""" within the pre-specified tolerance."""
if _is_numpy_array(actual): asarray = _as_numpy_array(actual)
if asarray is not None:
# Call ``__eq__()`` manually to prevent infinite-recursion with # Call ``__eq__()`` manually to prevent infinite-recursion with
# numpy<1.13. See #3748. # numpy<1.13. See #3748.
return all(self.__eq__(a) for a in actual.flat) return all(self.__eq__(a) for a in asarray.flat)
# Short-circuit exact equality. # Short-circuit exact equality.
if actual == self.expected: if actual == self.expected:
@ -521,6 +527,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
elif isinstance(expected, Mapping): elif isinstance(expected, Mapping):
cls = ApproxMapping cls = ApproxMapping
elif _is_numpy_array(expected): elif _is_numpy_array(expected):
expected = _as_numpy_array(expected)
cls = ApproxNumpy cls = ApproxNumpy
elif ( elif (
isinstance(expected, Iterable) isinstance(expected, Iterable)
@ -536,16 +543,30 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
def _is_numpy_array(obj: object) -> bool: def _is_numpy_array(obj: object) -> bool:
"""Return true if the given object is a numpy array. """
Return true if the given object is implicitly convertible to ndarray,
and numpy is already imported.
"""
return _as_numpy_array(obj) is not None
A special effort is made to avoid importing numpy unless it's really necessary.
def _as_numpy_array(obj: object) -> Optional["ndarray"]:
"""
Return an ndarray if the given object is implicitly convertible to ndarray,
and numpy is already imported, otherwise None.
""" """
import sys import sys
np: Any = sys.modules.get("numpy") np: Any = sys.modules.get("numpy")
if np is not None: if np is not None:
return isinstance(obj, np.ndarray) # avoid infinite recursion on numpy scalars, which have __array__
return False if np.isscalar(obj):
return None
elif isinstance(obj, np.ndarray):
return obj
elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"):
return np.asarray(obj)
return None
# builtin pytest.raises helper # builtin pytest.raises helper

View File

@ -447,6 +447,36 @@ class TestApprox:
assert a12 != approx(a21) assert a12 != approx(a21)
assert a21 != approx(a12) assert a21 != approx(a12)
def test_numpy_array_protocol(self):
"""
array-like objects such as tensorflow's DeviceArray are handled like ndarray.
See issue #8132
"""
np = pytest.importorskip("numpy")
class DeviceArray:
def __init__(self, value, size):
self.value = value
self.size = size
def __array__(self):
return self.value * np.ones(self.size)
class DeviceScalar:
def __init__(self, value):
self.value = value
def __array__(self):
return np.array(self.value)
expected = 1
actual = 1 + 1e-6
assert approx(expected) == DeviceArray(actual, size=1)
assert approx(expected) == DeviceArray(actual, size=2)
assert approx(expected) == DeviceScalar(actual)
assert approx(DeviceScalar(expected)) == actual
assert approx(DeviceScalar(expected)) == DeviceScalar(actual)
def test_doctests(self, mocked_doctest_runner) -> None: def test_doctests(self, mocked_doctest_runner) -> None:
import doctest import doctest