From 161d4e5fe4730f46e1fb803595198068b20c94c5 Mon Sep 17 00:00:00 2001 From: Tadeu Manoel Date: Wed, 14 Mar 2018 16:29:04 -0300 Subject: [PATCH] Add support for pytest.approx comparisons between scalar and array (inverted order) --- _pytest/python_api.py | 15 ++++++++++----- testing/python/approx.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/_pytest/python_api.py b/_pytest/python_api.py index 4b428322a..af4d77644 100644 --- a/_pytest/python_api.py +++ b/_pytest/python_api.py @@ -88,12 +88,14 @@ class ApproxNumpy(ApproxBase): def __eq__(self, actual): import numpy as np - try: - actual = np.asarray(actual) - except: # noqa - raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual)) + if not np.isscalar(actual): + try: + actual = np.asarray(actual) + except: # noqa + raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual)) - if not np.isscalar(self.expected) and actual.shape != self.expected.shape: + if (not np.isscalar(self.expected) and not np.isscalar(actual) + and actual.shape != self.expected.shape): return False return ApproxBase.__eq__(self, actual) @@ -108,6 +110,9 @@ class ApproxNumpy(ApproxBase): if np.isscalar(self.expected): for i in np.ndindex(actual.shape): yield actual[i], self.expected + elif np.isscalar(actual): + for i in np.ndindex(self.expected.shape): + yield actual, self.expected[i] else: for i in np.ndindex(self.expected.shape): yield actual[i], self.expected[i] diff --git a/testing/python/approx.py b/testing/python/approx.py index b9d28aadb..9ca21bdf8 100644 --- a/testing/python/approx.py +++ b/testing/python/approx.py @@ -402,3 +402,14 @@ class TestApprox(object): assert actual != approx(expected, rel=5e-8, abs=0) assert approx(expected, rel=5e-7, abs=0) == actual assert approx(expected, rel=5e-8, abs=0) != actual + + def test_numpy_scalar_with_array(self): + np = pytest.importorskip('numpy') + + actual = 1.0 + expected = np.array([1 + 1e-7, 1 - 1e-8]) + + assert actual == approx(expected, rel=5e-7, abs=0) + assert actual != approx(expected, rel=5e-8, abs=0) + assert approx(expected, rel=5e-7, abs=0) == actual + assert approx(expected, rel=5e-8, abs=0) != actual