Add a convenient and correct way to compare floats.
This commit is contained in:
parent
c2b9196a7c
commit
6f5e1e386a
|
@ -262,6 +262,7 @@ def pytest_namespace():
|
|||
'fixture': fixture,
|
||||
'yield_fixture': yield_fixture,
|
||||
'raises': raises,
|
||||
'approx': approx,
|
||||
'collect': {
|
||||
'Module': Module, 'Class': Class, 'Instance': Instance,
|
||||
'Function': Function, 'Generator': Generator,
|
||||
|
@ -1336,6 +1337,105 @@ class RaisesContext(object):
|
|||
self.excinfo.__init__(tp)
|
||||
return issubclass(self.excinfo.type, self.expected_exception)
|
||||
|
||||
# builtin pytest.approx helper
|
||||
|
||||
class approx:
|
||||
""" assert that two numbers (or two sets of numbers) are equal to each
|
||||
other within some margin.
|
||||
|
||||
Due to the intricacies of floating-point arithmetic, numbers that we would
|
||||
intuitively expect to be the same are not always so::
|
||||
|
||||
>>> 0.1 + 0.2 == 0.3
|
||||
False
|
||||
|
||||
This problem is commonly encountered when writing tests, e.g. to make sure
|
||||
that a floating-point function returns the expected values. The best way
|
||||
to deal with this problem is to assert that two floating point numbers are
|
||||
equal to within some appropriate margin::
|
||||
|
||||
>>> abs((0.1 + 0.2) - 0.3) < 1e-6
|
||||
True
|
||||
|
||||
However, comparisons like this are tedious to write and difficult to
|
||||
understand. Furthermore, absolute comparisons like the one above are
|
||||
usually discouraged in favor of relative comparisons, which can't even be
|
||||
easily written on one line. The ``approx`` class provides a way to make
|
||||
floating point comparisons that solves both these problems::
|
||||
|
||||
>>> from pytest import approx
|
||||
>>> 0.1 + 0.2 == approx(0.3)
|
||||
True
|
||||
|
||||
``approx`` also makes is easy to compare ordered sets of numbers, which
|
||||
would otherwise be very tedious::
|
||||
|
||||
>>> (0.1 + 0.2, 0.2 + 0.4) == approx((0.3, 0.6))
|
||||
True
|
||||
|
||||
By default, ``approx`` considers two numbers to be equal if the relative
|
||||
error between them is less than one part in a million (e.g. 1e-6).
|
||||
Relative error is defined as ``abs(x - a) / x`` where ``x`` is the value
|
||||
you're expecting and ``a`` is the value you're comparing to. This
|
||||
definition breaks down when the numbers being compared get very close to
|
||||
zero, so ``approx`` will also consider two numbers to be equal if the
|
||||
absolute difference between them is less than 1e-100.
|
||||
|
||||
Both the relative and absolute error thresholds can be changed by passing
|
||||
arguments to the ``approx`` constructor::
|
||||
|
||||
>>> 1.0001 == approx(1)
|
||||
False
|
||||
>>> 1.0001 == approx(1, rel=1e-3)
|
||||
True
|
||||
>>> 1.0001 == approx(1, abs=1e-3)
|
||||
True
|
||||
|
||||
Note that if you specify ``abs`` but not ``rel``, the comparison will not
|
||||
consider the relative error between the two values at all. In other words,
|
||||
two number that are within the default relative error threshold of 1e-6
|
||||
will still be considered unequal if they exceed the specified absolute
|
||||
error threshold::
|
||||
|
||||
>>> 0.1 + 0.2 == approx(0.3, abs=1e-100)
|
||||
False
|
||||
"""
|
||||
|
||||
def __init__(self, expected, rel=None, abs=None):
|
||||
self.expected = expected
|
||||
self.max_relative_error = rel
|
||||
self.max_absolute_error = abs
|
||||
|
||||
def __repr__(self):
|
||||
from collections import Iterable
|
||||
plus_minus = lambda x: '{}\u00B1{}'.format(x, self._margin(x))
|
||||
|
||||
if isinstance(self.expected, Iterable):
|
||||
return str([plus_minus(x) for x in self.expected])
|
||||
else:
|
||||
plus_minus(self.expected)
|
||||
|
||||
def __eq__(self, actual):
|
||||
from collections import Iterable
|
||||
expected = self.expected
|
||||
almost_eq = lambda a, x: abs(x - a) < self._margin(x)
|
||||
|
||||
if isinstance(actual, Iterable) and isinstance(expected, Iterable):
|
||||
return all(almost_eq(a, x) for a, x in zip(actual, expected))
|
||||
else:
|
||||
return almost_eq(actual, expected)
|
||||
|
||||
def _margin(self, x):
|
||||
margin = self.max_absolute_error or 1e-100
|
||||
|
||||
if self.max_relative_error is None:
|
||||
if self.max_absolute_error is not None:
|
||||
return margin
|
||||
|
||||
return max(margin, x * (self.max_relative_error or 1e-6))
|
||||
|
||||
|
||||
|
||||
#
|
||||
# the basic pytest Function item
|
||||
#
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
import pytest
|
||||
import doctest
|
||||
|
||||
class MyDocTestRunner(doctest.DocTestRunner):
|
||||
|
||||
def __init__(self):
|
||||
doctest.DocTestRunner.__init__(self)
|
||||
|
||||
def report_failure(self, out, test, example, got):
|
||||
raise AssertionError("'{}' evaluates to '{}', not '{}'".format(
|
||||
example.source.strip(), got.strip(), example.want.strip()))
|
||||
|
||||
|
||||
class TestApprox:
|
||||
|
||||
def test_approx(self):
|
||||
parser = doctest.DocTestParser()
|
||||
test = parser.get_doctest(
|
||||
pytest.approx.__doc__,
|
||||
{'approx': pytest.approx},
|
||||
pytest.approx.__name__,
|
||||
None, None,
|
||||
)
|
||||
runner = MyDocTestRunner()
|
||||
runner.run(test)
|
||||
|
||||
|
Loading…
Reference in New Issue