Add support for numpy arrays (and dicts) to approx.

This fixes #1994.  It turned out to require a lot of refactoring because
subclassing numpy.ndarray was necessary to coerce python into calling
the right `__eq__` operator.
This commit is contained in:
Kale Kundert 2017-06-11 19:27:41 -07:00
parent 467c526307
commit 9f3122fec6
No known key found for this signature in database
GPG Key ID: C6238221D17CAFAE
2 changed files with 408 additions and 171 deletions

View File

@ -1117,7 +1117,6 @@ def raises(expected_exception, *args, **kwargs):
... ...
Failed: Expecting ZeroDivisionError Failed: Expecting ZeroDivisionError
.. note:: .. note::
When using ``pytest.raises`` as a context manager, it's worthwhile to When using ``pytest.raises`` as a context manager, it's worthwhile to
@ -1150,7 +1149,6 @@ def raises(expected_exception, *args, **kwargs):
>>> with raises(ValueError, match=r'must be \d+$'): >>> with raises(ValueError, match=r'must be \d+$'):
... raise ValueError("value must be 42") ... raise ValueError("value must be 42")
Or you can specify a callable by passing a to-be-called lambda:: Or you can specify a callable by passing a to-be-called lambda::
>>> raises(ZeroDivisionError, lambda: 1/0) >>> raises(ZeroDivisionError, lambda: 1/0)
@ -1230,10 +1228,8 @@ def raises(expected_exception, *args, **kwargs):
return _pytest._code.ExceptionInfo() return _pytest._code.ExceptionInfo()
fail(message) fail(message)
raises.Exception = fail.Exception raises.Exception = fail.Exception
class RaisesContext(object): class RaisesContext(object):
def __init__(self, expected_exception, message, match_expr): def __init__(self, expected_exception, message, match_expr):
self.expected_exception = expected_exception self.expected_exception = expected_exception
@ -1265,9 +1261,271 @@ class RaisesContext(object):
return suppress_exception return suppress_exception
# builtin pytest.approx helper # builtin pytest.approx helper
class approx(object): class ApproxBase(object):
"""
Provide shared utilities for making approximate comparisons between numbers
or sequences of numbers.
"""
def __init__(self, expected, rel=None, abs=None):
self.expected = expected
self.abs = abs
self.rel = rel
def __repr__(self):
return ', '.join(
repr(self._approx_scalar(x))
for x in self._yield_expected())
def __eq__(self, actual):
return all(
a == self._approx_scalar(x)
for a, x in self._yield_comparisons(actual))
__hash__ = None
def __ne__(self, actual):
return not (actual == self)
def _approx_scalar(self, x):
return ApproxScalar(x, rel=self.rel, abs=self.abs)
def _yield_expected(self, actual):
"""
Yield all the expected values associated with this object. This is
used to implement the `__repr__` method.
"""
raise NotImplementedError
def _yield_comparisons(self, actual):
"""
Yield all the pairs of numbers to be compared. This is used to
implement the `__eq__` method.
"""
raise NotImplementedError
try:
import numpy as np
class ApproxNumpy(ApproxBase, np.ndarray):
"""
Perform approximate comparisons for numpy arrays.
This class must inherit from numpy.ndarray in order to allow the approx
to be on either side of the `==` operator. The reason for this has to
do with how python decides whether to call `a.__eq__()` or `b.__eq__()`
when it encounters `a == b`.
If `a` and `b` are not related by inheritance, `a` gets priority. So
as long as `a.__eq__` is defined, it will be called. Because most
implementations of `a.__eq__` end up calling `b.__eq__`, this detail
usually doesn't matter. However, `numpy.ndarray.__eq__` raises an
error complaining that "the truth value of an array with more than
one element is ambiguous. Use a.any() or a.all()" when compared with a
custom class, so `b.__eq__` never gets called.
The trick is that the priority rules change if `a` and `b` are related
by inheritance. Specifically, `b.__eq__` gets priority if `b` is a
subclass of `a`. So we can guarantee that `ApproxNumpy.__eq__` gets
called by inheriting from `numpy.ndarray`.
"""
def __new__(cls, expected, rel=None, abs=None):
"""
Numpy uses __new__ (rather than __init__) to initialize objects.
The `expected` argument must be a numpy array. This should be
ensured by the approx() delegator function.
"""
assert isinstance(expected, np.ndarray)
obj = super(ApproxNumpy, cls).__new__(cls, expected.shape)
obj.__init__(expected, rel, abs)
return obj
def __repr__(self):
# It might be nice to rewrite this function to account for the
# shape of the array...
return '[' + ApproxBase.__repr__(self) + ']'
def __eq__(self, actual):
try:
actual = np.array(actual)
except:
raise ValueError("cannot cast '{0}' to numpy.ndarray".format(actual))
if actual.shape != self.expected.shape:
return False
return ApproxBase.__eq__(self, actual)
def _yield_expected(self):
for x in self.expected:
yield x
def _yield_comparisons(self, actual):
# We can be sure that `actual` is a numpy array, because it's
# casted in `__eq__` before being passed to `ApproxBase.__eq__`,
# which is the only method that calls this one.
for i in np.ndindex(self.expected.shape):
yield actual[i], self.expected[i]
except ImportError:
np = None
class ApproxMapping(ApproxBase):
"""
Perform approximate comparisons for mappings where the values are numbers
(the keys can be anything).
"""
def __repr__(self):
item = lambda k, v: "'{0}': {1}".format(k, self._approx_scalar(v))
return '{' + ', '.join(item(k,v) for k,v in self.expected.items()) + '}'
def __eq__(self, actual):
if actual.keys() != self.expected.keys():
return False
return ApproxBase.__eq__(self, actual)
def _yield_comparisons(self, actual):
for k in self.expected.keys():
yield actual[k], self.expected[k]
class ApproxSequence(ApproxBase):
"""
Perform approximate comparisons for sequences of numbers.
"""
def __repr__(self):
open, close = '()' if isinstance(self.expected, tuple) else '[]'
return open + ApproxBase.__repr__(self) + close
def __eq__(self, actual):
if len(actual) != len(self.expected):
return False
return ApproxBase.__eq__(self, actual)
def _yield_expected(self):
return iter(self.expected)
def _yield_comparisons(self, actual):
return zip(actual, self.expected)
class ApproxScalar(ApproxBase):
"""
Perform approximate comparisons for single numbers only.
"""
def __repr__(self):
"""
Return a string communicating both the expected value and the tolerance
for the comparison being made, e.g. '1.0 +- 1e-6'. Use the unicode
plus/minus symbol if this is python3 (it's too hard to get right for
python2).
"""
if isinstance(self.expected, complex):
return str(self.expected)
# Infinities aren't compared using tolerances, so don't show a
# tolerance.
if math.isinf(self.expected):
return str(self.expected)
# If a sensible tolerance can't be calculated, self.tolerance will
# raise a ValueError. In this case, display '???'.
try:
vetted_tolerance = '{:.1e}'.format(self.tolerance)
except ValueError:
vetted_tolerance = '???'
if sys.version_info[0] == 2:
return '{0} +- {1}'.format(self.expected, vetted_tolerance)
else:
return u'{0} \u00b1 {1}'.format(self.expected, vetted_tolerance)
def __eq__(self, actual):
"""
Return true if the given value is equal to the expected value within
the pre-specified tolerance.
"""
from numbers import Number
# Give a good error message we get values to compare that aren't
# numbers, rather than choking on them later on.
if not isinstance(actual, Number):
raise ValueError("approx can only compare numbers, not '{0}'".format(actual))
if not isinstance(self.expected, Number):
raise ValueError("approx can only compare numbers, not '{0}'".format(self.expected))
# Short-circuit exact equality.
if actual == self.expected:
return True
# Infinity shouldn't be approximately equal to anything but itself, but
# if there's a relative tolerance, it will be infinite and infinity
# will seem approximately equal to everything. The equal-to-itself
# case would have been short circuited above, so here we can just
# return false if the expected value is infinite. The abs() call is
# for compatibility with complex numbers.
if math.isinf(abs(self.expected)):
return False
# Return true if the two numbers are within the tolerance.
return abs(self.expected - actual) <= self.tolerance
__hash__ = None
@property
def tolerance(self):
"""
Return the tolerance for the comparison. This could be either an
absolute tolerance or a relative tolerance, depending on what the user
specified or which would be larger.
"""
set_default = lambda x, default: x if x is not None else default
# Figure out what the absolute tolerance should be. ``self.abs`` is
# either None or a value specified by the user.
absolute_tolerance = set_default(self.abs, 1e-12)
if absolute_tolerance < 0:
raise ValueError("absolute tolerance can't be negative: {}".format(absolute_tolerance))
if math.isnan(absolute_tolerance):
raise ValueError("absolute tolerance can't be NaN.")
# If the user specified an absolute tolerance but not a relative one,
# just return the absolute tolerance.
if self.rel is None:
if self.abs is not None:
return absolute_tolerance
# Figure out what the relative tolerance should be. ``self.rel`` is
# either None or a value specified by the user. This is done after
# we've made sure the user didn't ask for an absolute tolerance only,
# because we don't want to raise errors about the relative tolerance if
# we aren't even going to use it.
relative_tolerance = set_default(self.rel, 1e-6) * abs(self.expected)
if relative_tolerance < 0:
raise ValueError("relative tolerance can't be negative: {}".format(absolute_tolerance))
if math.isnan(relative_tolerance):
raise ValueError("relative tolerance can't be NaN.")
# Return the larger of the relative and absolute tolerances.
return max(relative_tolerance, absolute_tolerance)
def approx(expected, rel=None, abs=None):
""" """
Assert that two numbers (or two sets of numbers) are equal to each other Assert that two numbers (or two sets of numbers) are equal to each other
within some tolerance. within some tolerance.
@ -1307,6 +1565,8 @@ class approx(object):
>>> (0.1 + 0.2, 0.2 + 0.4) == approx((0.3, 0.6)) >>> (0.1 + 0.2, 0.2 + 0.4) == approx((0.3, 0.6))
True True
>>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
True
By default, ``approx`` considers numbers within a relative tolerance of By default, ``approx`` considers numbers within a relative tolerance of
``1e-6`` (i.e. one part in a million) of its expected value to be equal. ``1e-6`` (i.e. one part in a million) of its expected value to be equal.
@ -1380,139 +1640,37 @@ class approx(object):
special case that you explicitly specify an absolute tolerance but not a special case that you explicitly specify an absolute tolerance but not a
relative tolerance, only the absolute tolerance is considered. relative tolerance, only the absolute tolerance is considered.
""" """
from collections import Mapping, Sequence
try:
String = basestring # python2
except NameError:
String = str, bytes # python3
def __init__(self, expected, rel=None, abs=None): # Delegate the comparison to a class that knows how to deal with the type
self.expected = expected # of the expected value (e.g. int, float, list, dict, numpy.array, etc).
self.abs = abs #
self.rel = rel # This architecture is really driven by the need to support numpy arrays.
# The only way to override `==` for arrays without requiring that approx be
# the left operand is to inherit the approx object from `numpy.ndarray`.
# But that can't be a general solution, because it requires (1) numpy to be
# installed and (2) the expected value to be a numpy array. So the general
# solution is to delegate each type of expected value to a different class.
#
# This has the advantage that it made it easy to support mapping types
# (i.e. dict). The old code accepted mapping types, but would only compare
# their keys, which is probably not what most people would expect.
def __repr__(self): if np and isinstance(expected, np.ndarray):
return ', '.join(repr(x) for x in self.expected) cls = ApproxNumpy
elif isinstance(expected, Mapping):
cls = ApproxMapping
elif isinstance(expected, Sequence) and not isinstance(expected, String):
cls = ApproxSequence
else:
cls = ApproxScalar
def __eq__(self, actual): return cls(expected, rel, abs)
from collections import Iterable
if not isinstance(actual, Iterable):
actual = [actual]
if len(actual) != len(self.expected):
return False
return all(a == x for a, x in zip(actual, self.expected))
__hash__ = None
def __ne__(self, actual):
return not (actual == self)
@property
def expected(self):
# Regardless of whether the user-specified expected value is a number
# or a sequence of numbers, return a list of ApproxNotIterable objects
# that can be compared against.
from collections import Iterable
approx_non_iter = lambda x: ApproxNonIterable(x, self.rel, self.abs)
if isinstance(self._expected, Iterable):
return [approx_non_iter(x) for x in self._expected]
else:
return [approx_non_iter(self._expected)]
@expected.setter
def expected(self, expected):
self._expected = expected
class ApproxNonIterable(object):
"""
Perform approximate comparisons for single numbers only.
In other words, the ``expected`` attribute for objects of this class must
be some sort of number. This is in contrast to the ``approx`` class, where
the ``expected`` attribute can either be a number of a sequence of numbers.
This class is responsible for making comparisons, while ``approx`` is
responsible for abstracting the difference between numbers and sequences of
numbers. Although this class can stand on its own, it's only meant to be
used within ``approx``.
"""
def __init__(self, expected, rel=None, abs=None):
self.expected = expected
self.abs = abs
self.rel = rel
def __repr__(self):
if isinstance(self.expected, complex):
return str(self.expected)
# Infinities aren't compared using tolerances, so don't show a
# tolerance.
if math.isinf(self.expected):
return str(self.expected)
# If a sensible tolerance can't be calculated, self.tolerance will
# raise a ValueError. In this case, display '???'.
try:
vetted_tolerance = '{:.1e}'.format(self.tolerance)
except ValueError:
vetted_tolerance = '???'
if sys.version_info[0] == 2:
return '{0} +- {1}'.format(self.expected, vetted_tolerance)
else:
return u'{0} \u00b1 {1}'.format(self.expected, vetted_tolerance)
def __eq__(self, actual):
# Short-circuit exact equality.
if actual == self.expected:
return True
# Infinity shouldn't be approximately equal to anything but itself, but
# if there's a relative tolerance, it will be infinite and infinity
# will seem approximately equal to everything. The equal-to-itself
# case would have been short circuited above, so here we can just
# return false if the expected value is infinite. The abs() call is
# for compatibility with complex numbers.
if math.isinf(abs(self.expected)):
return False
# Return true if the two numbers are within the tolerance.
return abs(self.expected - actual) <= self.tolerance
__hash__ = None
def __ne__(self, actual):
return not (actual == self)
@property
def tolerance(self):
set_default = lambda x, default: x if x is not None else default
# Figure out what the absolute tolerance should be. ``self.abs`` is
# either None or a value specified by the user.
absolute_tolerance = set_default(self.abs, 1e-12)
if absolute_tolerance < 0:
raise ValueError("absolute tolerance can't be negative: {}".format(absolute_tolerance))
if math.isnan(absolute_tolerance):
raise ValueError("absolute tolerance can't be NaN.")
# If the user specified an absolute tolerance but not a relative one,
# just return the absolute tolerance.
if self.rel is None:
if self.abs is not None:
return absolute_tolerance
# Figure out what the relative tolerance should be. ``self.rel`` is
# either None or a value specified by the user. This is done after
# we've made sure the user didn't ask for an absolute tolerance only,
# because we don't want to raise errors about the relative tolerance if
# we aren't even going to use it.
relative_tolerance = set_default(self.rel, 1e-6) * abs(self.expected)
if relative_tolerance < 0:
raise ValueError("relative tolerance can't be negative: {}".format(absolute_tolerance))
if math.isnan(relative_tolerance):
raise ValueError("relative tolerance can't be NaN.")
# Return the larger of the relative and absolute tolerances.
return max(relative_tolerance, absolute_tolerance)
# #

View File

@ -9,7 +9,6 @@ from decimal import Decimal
from fractions import Fraction from fractions import Fraction
inf, nan = float('inf'), float('nan') inf, nan = float('inf'), float('nan')
class MyDocTestRunner(doctest.DocTestRunner): class MyDocTestRunner(doctest.DocTestRunner):
def __init__(self): def __init__(self):
@ -29,12 +28,19 @@ class TestApprox(object):
if sys.version_info[:2] == (2, 6): if sys.version_info[:2] == (2, 6):
tol1, tol2, infr = '???', '???', '???' tol1, tol2, infr = '???', '???', '???'
assert repr(approx(1.0)) == '1.0 {pm} {tol1}'.format(pm=plus_minus, tol1=tol1) assert repr(approx(1.0)) == '1.0 {pm} {tol1}'.format(pm=plus_minus, tol1=tol1)
assert repr(approx([1.0, 2.0])) == '1.0 {pm} {tol1}, 2.0 {pm} {tol2}'.format(pm=plus_minus, tol1=tol1, tol2=tol2) assert repr(approx([1.0, 2.0])) == '[1.0 {pm} {tol1}, 2.0 {pm} {tol2}]'.format(pm=plus_minus, tol1=tol1, tol2=tol2)
assert repr(approx((1.0, 2.0))) == '(1.0 {pm} {tol1}, 2.0 {pm} {tol2})'.format(pm=plus_minus, tol1=tol1, tol2=tol2)
assert repr(approx(inf)) == 'inf' assert repr(approx(inf)) == 'inf'
assert repr(approx(1.0, rel=nan)) == '1.0 {pm} ???'.format(pm=plus_minus) assert repr(approx(1.0, rel=nan)) == '1.0 {pm} ???'.format(pm=plus_minus)
assert repr(approx(1.0, rel=inf)) == '1.0 {pm} {infr}'.format(pm=plus_minus, infr=infr) assert repr(approx(1.0, rel=inf)) == '1.0 {pm} {infr}'.format(pm=plus_minus, infr=infr)
assert repr(approx(1.0j, rel=inf)) == '1j' assert repr(approx(1.0j, rel=inf)) == '1j'
# Dictionaries aren't ordered, so we need to check both orders.
assert repr(approx({'a': 1.0, 'b': 2.0})) in (
"{{'a': 1.0 {pm} {tol1}, 'b': 2.0 {pm} {tol2}}}".format(pm=plus_minus, tol1=tol1, tol2=tol2),
"{{'b': 2.0 {pm} {tol2}, 'a': 1.0 {pm} {tol1}}}".format(pm=plus_minus, tol1=tol1, tol2=tol2),
)
def test_operator_overloading(self): def test_operator_overloading(self):
assert 1 == approx(1, rel=1e-6, abs=1e-12) assert 1 == approx(1, rel=1e-6, abs=1e-12)
assert not (1 != approx(1, rel=1e-6, abs=1e-12)) assert not (1 != approx(1, rel=1e-6, abs=1e-12))
@ -228,18 +234,38 @@ class TestApprox(object):
# tolerance, so only an absolute tolerance is calculated. # tolerance, so only an absolute tolerance is calculated.
assert a != approx(x, abs=inf) assert a != approx(x, abs=inf)
def test_expecting_sequence(self): def test_int(self):
within_1e8 = [ within_1e6 = [
(1e8 + 1e0, 1e8), (1000001, 1000000),
(1e0 + 1e-8, 1e0), (-1000001, -1000000),
(1e-8 + 1e-16, 1e-8),
] ]
actual, expected = zip(*within_1e8) for a, x in within_1e6:
assert actual == approx(expected, rel=5e-8, abs=0.0) assert a == approx(x, rel=5e-6, abs=0)
assert a != approx(x, rel=5e-7, abs=0)
assert approx(x, rel=5e-6, abs=0) == a
assert approx(x, rel=5e-7, abs=0) != a
def test_expecting_sequence_wrong_len(self): def test_decimal(self):
assert [1, 2] != approx([1]) within_1e6 = [
assert [1, 2] != approx([1,2,3]) (Decimal('1.000001'), Decimal('1.0')),
(Decimal('-1.000001'), Decimal('-1.0')),
]
for a, x in within_1e6:
assert a == approx(x, rel=Decimal('5e-6'), abs=0)
assert a != approx(x, rel=Decimal('5e-7'), abs=0)
assert approx(x, rel=Decimal('5e-6'), abs=0) == a
assert approx(x, rel=Decimal('5e-7'), abs=0) != a
def test_fraction(self):
within_1e6 = [
(1 + Fraction(1, 1000000), Fraction(1)),
(-1 - Fraction(-1, 1000000), Fraction(-1)),
]
for a, x in within_1e6:
assert a == approx(x, rel=5e-6, abs=0)
assert a != approx(x, rel=5e-7, abs=0)
assert approx(x, rel=5e-6, abs=0) == a
assert approx(x, rel=5e-7, abs=0) != a
def test_complex(self): def test_complex(self):
within_1e6 = [ within_1e6 = [
@ -251,33 +277,86 @@ class TestApprox(object):
for a, x in within_1e6: for a, x in within_1e6:
assert a == approx(x, rel=5e-6, abs=0) assert a == approx(x, rel=5e-6, abs=0)
assert a != approx(x, rel=5e-7, abs=0) assert a != approx(x, rel=5e-7, abs=0)
assert approx(x, rel=5e-6, abs=0) == a
assert approx(x, rel=5e-7, abs=0) != a
def test_int(self): def test_list(self):
within_1e6 = [ actual = [1 + 1e-7, 2 + 1e-8]
(1000001, 1000000), expected = [1, 2]
(-1000001, -1000000),
]
for a, x in within_1e6:
assert a == approx(x, rel=5e-6, abs=0)
assert a != approx(x, rel=5e-7, abs=0)
def test_decimal(self): # Return false if any element is outside the tolerance.
within_1e6 = [ assert actual == approx(expected, rel=5e-7, abs=0)
(Decimal('1.000001'), Decimal('1.0')), assert actual != approx(expected, rel=5e-8, abs=0)
(Decimal('-1.000001'), Decimal('-1.0')), assert approx(expected, rel=5e-7, abs=0) == actual
] assert approx(expected, rel=5e-8, abs=0) != actual
for a, x in within_1e6:
assert a == approx(x, rel=Decimal('5e-6'), abs=0)
assert a != approx(x, rel=Decimal('5e-7'), abs=0)
def test_fraction(self): def test_list_wrong_len(self):
within_1e6 = [ assert [1, 2] != approx([1])
(1 + Fraction(1, 1000000), Fraction(1)), assert [1, 2] != approx([1,2,3])
(-1 - Fraction(-1, 1000000), Fraction(-1)),
] def test_tuple(self):
for a, x in within_1e6: actual = (1 + 1e-7, 2 + 1e-8)
assert a == approx(x, rel=5e-6, abs=0) expected = (1, 2)
assert a != approx(x, rel=5e-7, abs=0)
# Return false if any element is outside the tolerance.
assert actual == approx(expected, rel=5e-7, abs=0)
assert actual != approx(expected, rel=5e-8, abs=0)
assert approx(expected, rel=5e-7, abs=0) == actual
assert approx(expected, rel=5e-8, abs=0) != actual
def test_tuple_wrong_len(self):
assert (1, 2) != approx((1,))
assert (1, 2) != approx((1,2,3))
def test_dict(self):
actual = {'a': 1 + 1e-7, 'b': 2 + 1e-8}
expected = {'b': 2, 'a': 1} # Dictionaries became ordered in python3.6,
# so make sure the order doesn't matter
# Return false if any element is outside the tolerance.
assert actual == approx(expected, rel=5e-7, abs=0)
assert actual != approx(expected, rel=5e-8, abs=0)
assert approx(expected, rel=5e-7, abs=0) == actual
assert approx(expected, rel=5e-8, abs=0) != actual
def test_dict_wrong_len(self):
assert {'a': 1, 'b': 2} != approx({'a': 1})
assert {'a': 1, 'b': 2} != approx({'a': 1, 'c': 2})
assert {'a': 1, 'b': 2} != approx({'a': 1, 'b': 2, 'c': 3})
def test_numpy_array(self):
try:
import numpy as np
except ImportError:
pytest.skip("numpy not installed")
actual = np.array([1 + 1e-7, 2 + 1e-8])
expected = np.array([1, 2])
# Return false if any element is outside the tolerance.
assert actual == approx(expected, rel=5e-7, abs=0)
assert actual != approx(expected, rel=5e-8, abs=0)
assert approx(expected, rel=5e-7, abs=0) == expected
assert approx(expected, rel=5e-8, abs=0) != actual
def test_numpy_array_wrong_shape(self):
try:
import numpy as np
except ImportError:
pytest.skip("numpy not installed")
import numpy as np
a12 = np.array([[1, 2]])
a21 = np.array([[1],[2]])
assert a12 != approx(a21)
assert a21 != approx(a12)
def test_non_number(self):
with pytest.raises(ValueError):
1 == approx("1")
with pytest.raises(ValueError):
"1" == approx(1)
def test_doctests(self): def test_doctests(self):
parser = doctest.DocTestParser() parser = doctest.DocTestParser()