From 6f5e1e386a685cc5975762f0a0457738a3522b13 Mon Sep 17 00:00:00 2001 From: Kale Kundert Date: Sun, 6 Mar 2016 18:53:48 -0800 Subject: [PATCH] Add a convenient and correct way to compare floats. --- _pytest/python.py | 102 ++++++++++++++++++++++++++++++++++++++- testing/python/approx.py | 27 +++++++++++ 2 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 testing/python/approx.py diff --git a/_pytest/python.py b/_pytest/python.py index ec346f587..3d458ed82 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -261,7 +261,8 @@ def pytest_namespace(): return { 'fixture': fixture, 'yield_fixture': yield_fixture, - 'raises' : raises, + 'raises': raises, + 'approx': approx, 'collect': { 'Module': Module, 'Class': Class, 'Instance': Instance, 'Function': Function, 'Generator': Generator, @@ -1336,6 +1337,105 @@ class RaisesContext(object): self.excinfo.__init__(tp) return issubclass(self.excinfo.type, self.expected_exception) +# builtin pytest.approx helper + +class approx: + """ assert that two numbers (or two sets of numbers) are equal to each + other within some margin. + + Due to the intricacies of floating-point arithmetic, numbers that we would + intuitively expect to be the same are not always so:: + + >>> 0.1 + 0.2 == 0.3 + False + + This problem is commonly encountered when writing tests, e.g. to make sure + that a floating-point function returns the expected values. The best way + to deal with this problem is to assert that two floating point numbers are + equal to within some appropriate margin:: + + >>> abs((0.1 + 0.2) - 0.3) < 1e-6 + True + + However, comparisons like this are tedious to write and difficult to + understand. Furthermore, absolute comparisons like the one above are + usually discouraged in favor of relative comparisons, which can't even be + easily written on one line. The ``approx`` class provides a way to make + floating point comparisons that solves both these problems:: + + >>> from pytest import approx + >>> 0.1 + 0.2 == approx(0.3) + True + + ``approx`` also makes is easy to compare ordered sets of numbers, which + would otherwise be very tedious:: + + >>> (0.1 + 0.2, 0.2 + 0.4) == approx((0.3, 0.6)) + True + + By default, ``approx`` considers two numbers to be equal if the relative + error between them is less than one part in a million (e.g. 1e-6). + Relative error is defined as ``abs(x - a) / x`` where ``x`` is the value + you're expecting and ``a`` is the value you're comparing to. This + definition breaks down when the numbers being compared get very close to + zero, so ``approx`` will also consider two numbers to be equal if the + absolute difference between them is less than 1e-100. + + Both the relative and absolute error thresholds can be changed by passing + arguments to the ``approx`` constructor:: + + >>> 1.0001 == approx(1) + False + >>> 1.0001 == approx(1, rel=1e-3) + True + >>> 1.0001 == approx(1, abs=1e-3) + True + + Note that if you specify ``abs`` but not ``rel``, the comparison will not + consider the relative error between the two values at all. In other words, + two number that are within the default relative error threshold of 1e-6 + will still be considered unequal if they exceed the specified absolute + error threshold:: + + >>> 0.1 + 0.2 == approx(0.3, abs=1e-100) + False + """ + + def __init__(self, expected, rel=None, abs=None): + self.expected = expected + self.max_relative_error = rel + self.max_absolute_error = abs + + def __repr__(self): + from collections import Iterable + plus_minus = lambda x: '{}\u00B1{}'.format(x, self._margin(x)) + + if isinstance(self.expected, Iterable): + return str([plus_minus(x) for x in self.expected]) + else: + plus_minus(self.expected) + + def __eq__(self, actual): + from collections import Iterable + expected = self.expected + almost_eq = lambda a, x: abs(x - a) < self._margin(x) + + if isinstance(actual, Iterable) and isinstance(expected, Iterable): + return all(almost_eq(a, x) for a, x in zip(actual, expected)) + else: + return almost_eq(actual, expected) + + def _margin(self, x): + margin = self.max_absolute_error or 1e-100 + + if self.max_relative_error is None: + if self.max_absolute_error is not None: + return margin + + return max(margin, x * (self.max_relative_error or 1e-6)) + + + # # the basic pytest Function item # diff --git a/testing/python/approx.py b/testing/python/approx.py new file mode 100644 index 000000000..76064114d --- /dev/null +++ b/testing/python/approx.py @@ -0,0 +1,27 @@ +import pytest +import doctest + +class MyDocTestRunner(doctest.DocTestRunner): + + def __init__(self): + doctest.DocTestRunner.__init__(self) + + def report_failure(self, out, test, example, got): + raise AssertionError("'{}' evaluates to '{}', not '{}'".format( + example.source.strip(), got.strip(), example.want.strip())) + + +class TestApprox: + + def test_approx(self): + parser = doctest.DocTestParser() + test = parser.get_doctest( + pytest.approx.__doc__, + {'approx': pytest.approx}, + pytest.approx.__name__, + None, None, + ) + runner = MyDocTestRunner() + runner.run(test) + +