From 42a7e0488da56f2cb0e5468af95bb1f124b05e65 Mon Sep 17 00:00:00 2001 From: Kale Kundert Date: Fri, 11 Mar 2016 08:49:26 -0800 Subject: [PATCH] Properly handle inf, nan, and built-in numeric types. This commit also: - Dramatically increases the number of unit tests , mostly by borrowing from the standard library's unit tests for math.isclose(). - Refactors approx() into two classes, one of which handles comparing individual numbers (ApproxNonIterable) and another which uses the first to compare individual numbers or sequences of numbers. --- _pytest/python.py | 126 +++++++++++++++--- testing/python/approx.py | 268 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 364 insertions(+), 30 deletions(-) diff --git a/_pytest/python.py b/_pytest/python.py index 1bed796d1..fe1758204 100644 --- a/_pytest/python.py +++ b/_pytest/python.py @@ -6,6 +6,7 @@ import inspect import re import types import sys +import math import py import pytest @@ -1412,37 +1413,120 @@ class approx(object): def __init__(self, expected, rel=None, abs=None): self.expected = expected - self.max_relative_error = rel - self.max_absolute_error = abs + self.abs = abs + self.rel = rel def __repr__(self): - from collections import Iterable - utf_8 = lambda s: s.encode('utf-8') if sys.version_info[0] == 2 else s - plus_minus = lambda x: utf_8(u'{0} \u00b1 {1:.1e}'.format(x, self._get_margin(x))) - - if isinstance(self.expected, Iterable): - return ', '.join([plus_minus(x) for x in self.expected]) - else: - return plus_minus(self.expected) + return ', '.join(repr(x) for x in self.expected) def __eq__(self, actual): from collections import Iterable - expected = self.expected - almost_eq = lambda a, x: abs(x - a) < self._get_margin(x) + if not isinstance(actual, Iterable): actual = [actual] + if len(actual) != len(self.expected): return False + return all(a == x for a, x in zip(actual, self.expected)) - if isinstance(actual, Iterable) and isinstance(expected, Iterable): - return all(almost_eq(a, x) for a, x in zip(actual, expected)) + @property + def expected(self): + from collections import Iterable + approx_non_iter = lambda x: ApproxNonIterable(x, self.rel, self.abs) + if isinstance(self._expected, Iterable): + return [approx_non_iter(x) for x in self._expected] else: - return almost_eq(actual, expected) + return [approx_non_iter(self._expected)] - def _get_margin(self, x): - margin = self.max_absolute_error or 1e-12 + @expected.setter + def expected(self, expected): + self._expected = expected + - if self.max_relative_error is None: - if self.max_absolute_error is not None: - return margin - return max(margin, x * (self.max_relative_error or 1e-6)) +class ApproxNonIterable(object): + """ + Perform approximate comparisons for single numbers only. + + This class contains most of the + """ + + def __init__(self, expected, rel=None, abs=None): + self.expected = expected + self.abs = abs + self.rel = rel + + def __repr__(self): + # Infinities aren't compared using tolerances, so don't show a + # tolerance. + if math.isinf(self.expected): + return str(self.expected) + + # If a sensible tolerance can't be calculated, self.tolerance will + # raise a ValueError. In this case, display '???'. + try: + vetted_tolerance = '{:.1e}'.format(self.tolerance) + except ValueError: + vetted_tolerance = '???' + + repr = u'{0} \u00b1 {1}'.format(self.expected, vetted_tolerance) + + # In python2, __repr__() must return a string (i.e. not a unicode + # object). In python3, __repr__() must return a unicode object + # (although now strings are unicode objects and bytes are what + # strings were). + if sys.version_info[0] == 2: + return repr.encode('utf-8') + else: + return repr + + def __eq__(self, actual): + # Short-circuit exact equality. + if actual == self.expected: + return True + + # Infinity shouldn't be approximately equal to anything but itself, but + # if there's a relative tolerance, it will be infinite and infinity + # will seem approximately equal to everything. The equal-to-itself + # case would have been short circuited above, so here we can just + # return false if the expected value is infinite. The abs() call is + # for compatibility with complex numbers. + if math.isinf(abs(self.expected)): + return False + + # Return true if the two numbers are within the tolerance. + return abs(self.expected - actual) <= self.tolerance + + @property + def tolerance(self): + set_default = lambda x, default: x if x is not None else default + + # Figure out what the absolute tolerance should be. ``self.abs`` is + # either None or a value specified by the user. + absolute_tolerance = set_default(self.abs, 1e-12) + + if absolute_tolerance < 0: + raise ValueError("absolute tolerance can't be negative: {}".format(absolute_tolerance)) + if math.isnan(absolute_tolerance): + raise ValueError("absolute tolerance can't be NaN.") + + # If the user specified an absolute tolerance but not a relative one, + # just return the absolute tolerance. + if self.rel is None: + if self.abs is not None: + return absolute_tolerance + + # Figure out what the absolute tolerance should be. ``self.rel`` is + # either None or a value specified by the user. This is done after + # we've made sure the user didn't ask for an absolute tolerance only, + # because we don't want to raise errors about the relative tolerance if + # it isn't even being used. + relative_tolerance = set_default(self.rel, 1e-6) * abs(self.expected) + + if relative_tolerance < 0: + raise ValueError("relative tolerance can't be negative: {}".format(absolute_tolerance)) + if math.isnan(relative_tolerance): + raise ValueError("relative tolerance can't be NaN.") + + # Return the larger of the relative and absolute tolerances. + return max(relative_tolerance, absolute_tolerance) + # diff --git a/testing/python/approx.py b/testing/python/approx.py index 2f37c9b9b..2d7eb6ea3 100644 --- a/testing/python/approx.py +++ b/testing/python/approx.py @@ -3,6 +3,12 @@ import pytest import doctest +from pytest import approx +from operator import eq, ne +from decimal import Decimal +from fractions import Fraction +inf, nan = float('inf'), float('nan') + class MyDocTestRunner(doctest.DocTestRunner): def __init__(self): @@ -15,19 +21,263 @@ class MyDocTestRunner(doctest.DocTestRunner): class TestApprox: - def test_approx_doctests(self): + def test_repr_string(self): + # Just make sure the Unicode handling doesn't raise any exceptions. + print(approx(1.0)) + print(approx([1.0, 2.0, 3.0])) + print(approx(inf)) + print(approx(1.0, rel=nan)) + print(approx(1.0, rel=inf)) + + def test_operator_overloading(self): + assert 1 == approx(1, rel=1e-6, abs=1e-12) + assert 10 != approx(1, rel=1e-6, abs=1e-12) + + def test_exactly_equal(self): + examples = [ + (2.0, 2.0), + (0.1e200, 0.1e200), + (1.123e-300, 1.123e-300), + (12345, 12345.0), + (0.0, -0.0), + (345678, 345678), + (Decimal(1.0001), Decimal(1.0001)), + ] + for a, x in examples: + assert a == approx(x) + + def test_opposite_sign(self): + examples = [ + (eq, 1e-100, -1e-100), + (ne, 1e100, -1e100), + ] + for op, a, x in examples: + assert op(a, approx(x)) + + def test_zero_tolerance(self): + within_1e10 = [ + (1.1e-100, 1e-100), + (-1.1e-100, -1e-100), + ] + for a, x in within_1e10: + assert x == approx(x, rel=0.0, abs=0.0) + assert a != approx(x, rel=0.0, abs=0.0) + assert a == approx(x, rel=0.0, abs=5e-101) + assert a != approx(x, rel=0.0, abs=5e-102) + assert a == approx(x, rel=5e-1, abs=0.0) + assert a != approx(x, rel=5e-2, abs=0.0) + + def test_negative_tolerance(self): + # Negative tolerances are not allowed. + illegal_kwargs = [ + dict(rel=-1e100), + dict(abs=-1e100), + dict(rel=1e100, abs=-1e100), + dict(rel=-1e100, abs=1e100), + dict(rel=-1e100, abs=-1e100), + ] + for kwargs in illegal_kwargs: + with pytest.raises(ValueError): + 1.1 == approx(1, **kwargs) + + def test_inf_tolerance(self): + # Everything should be equal if the tolerance is infinite. + large_diffs = [ + (1, 1000), + (1e-50, 1e50), + (-1.0, -1e300), + (0.0, 10), + ] + for a, x in large_diffs: + assert a != approx(x, rel=0.0, abs=0.0) + assert a == approx(x, rel=inf, abs=0.0) + assert a == approx(x, rel=0.0, abs=inf) + assert a == approx(x, rel=inf, abs=inf) + + def test_inf_tolerance_expecting_zero(self): + # If the relative tolerance is zero but the expected value is infinite, + # the actual tolerance is a NaN, which should be an error. + illegal_kwargs = [ + dict(rel=inf, abs=0.0), + dict(rel=inf, abs=inf), + ] + for kwargs in illegal_kwargs: + with pytest.raises(ValueError): + 1 == approx(0, **kwargs) + + def test_nan_tolerance(self): + illegal_kwargs = [ + dict(rel=nan), + dict(abs=nan), + dict(rel=nan, abs=nan), + ] + for kwargs in illegal_kwargs: + with pytest.raises(ValueError): + 1.1 == approx(1, **kwargs) + + def test_reasonable_defaults(self): + # Whatever the defaults are, they should work for numbers close to 1 + # than have a small amount of floating-point error. + assert 0.1 + 0.2 == approx(0.3) + + def test_default_tolerances(self): + # This tests the defaults as they are currently set. If you change the + # defaults, this test will fail but you should feel free to change it. + # None of the other tests (except the doctests) should be affected by + # the choice of defaults. + examples = [ + # Relative tolerance used. + (eq, 1e100 + 1e94, 1e100), + (ne, 1e100 + 2e94, 1e100), + (eq, 1e0 + 1e-6, 1e0), + (ne, 1e0 + 2e-6, 1e0), + # Absolute tolerance used. + (eq, 1e-100, + 1e-106), + (eq, 1e-100, + 2e-106), + (eq, 1e-100, 0), + ] + for op, a, x in examples: + assert op(a, approx(x)) + + def test_custom_tolerances(self): + assert 1e8 + 1e0 == approx(1e8, rel=5e-8, abs=5e0) + assert 1e8 + 1e0 == approx(1e8, rel=5e-9, abs=5e0) + assert 1e8 + 1e0 == approx(1e8, rel=5e-8, abs=5e-1) + assert 1e8 + 1e0 != approx(1e8, rel=5e-9, abs=5e-1) + + assert 1e0 + 1e-8 == approx(1e0, rel=5e-8, abs=5e-8) + assert 1e0 + 1e-8 == approx(1e0, rel=5e-9, abs=5e-8) + assert 1e0 + 1e-8 == approx(1e0, rel=5e-8, abs=5e-9) + assert 1e0 + 1e-8 != approx(1e0, rel=5e-9, abs=5e-9) + + assert 1e-8 + 1e-16 == approx(1e-8, rel=5e-8, abs=5e-16) + assert 1e-8 + 1e-16 == approx(1e-8, rel=5e-9, abs=5e-16) + assert 1e-8 + 1e-16 == approx(1e-8, rel=5e-8, abs=5e-17) + assert 1e-8 + 1e-16 != approx(1e-8, rel=5e-9, abs=5e-17) + + def test_relative_tolerance(self): + within_1e8_rel = [ + (1e8 + 1e0, 1e8), + (1e0 + 1e-8, 1e0), + (1e-8 + 1e-16, 1e-8), + ] + for a, x in within_1e8_rel: + assert a == approx(x, rel=5e-8, abs=0.0) + assert a != approx(x, rel=5e-9, abs=0.0) + + def test_absolute_tolerance(self): + within_1e8_abs = [ + (1e8 + 9e-9, 1e8), + (1e0 + 9e-9, 1e0), + (1e-8 + 9e-9, 1e-8), + ] + for a, x in within_1e8_abs: + assert a == approx(x, rel=0, abs=5e-8) + assert a != approx(x, rel=0, abs=5e-9) + + def test_expecting_zero(self): + examples = [ + (ne, 1e-6, 0.0), + (ne, -1e-6, 0.0), + (eq, 1e-12, 0.0), + (eq, -1e-12, 0.0), + (ne, 2e-12, 0.0), + (ne, -2e-12, 0.0), + (ne, inf, 0.0), + (ne, nan, 0.0), + ] + for op, a, x in examples: + assert op(a, approx(x, rel=0.0, abs=1e-12)) + assert op(a, approx(x, rel=1e-6, abs=1e-12)) + + def test_expecting_inf(self): + examples = [ + (eq, inf, inf), + (eq, -inf, -inf), + (ne, inf, -inf), + (ne, 0.0, inf), + (ne, nan, inf), + ] + for op, a, x in examples: + assert op(a, approx(x)) + + def test_expecting_nan(self): + examples = [ + (nan, nan), + (-nan, -nan), + (nan, -nan), + (0.0, nan), + (inf, nan), + ] + for a, x in examples: + # If there is a relative tolerance and the expected value is NaN, + # the actual tolerance is a NaN, which should be an error. + with pytest.raises(ValueError): + a != approx(x, rel=inf) + + # You can make comparisons against NaN by not specifying a relative + # tolerance, so only an absolute tolerance is calculated. + assert a != approx(x, abs=inf) + + def test_expecting_sequence(self): + within_1e8 = [ + (1e8 + 1e0, 1e8), + (1e0 + 1e-8, 1e0), + (1e-8 + 1e-16, 1e-8), + ] + actual, expected = zip(*within_1e8) + assert actual == approx(expected, rel=5e-8, abs=0.0) + + def test_expecting_sequence_wrong_len(self): + assert [1, 2] != approx([1]) + assert [1, 2] != approx([1,2,3]) + + def test_complex(self): + within_1e6 = [ + ( 1.000001 + 1.0j, 1.0 + 1.0j), + (1.0 + 1.000001j, 1.0 + 1.0j), + (-1.000001 + 1.0j, -1.0 + 1.0j), + (1.0 - 1.000001j, 1.0 - 1.0j), + ] + for a, x in within_1e6: + assert a == approx(x, rel=5e-6, abs=0) + assert a != approx(x, rel=5e-7, abs=0) + + def test_int(self): + within_1e6 = [ + (1000001, 1000000), + (-1000001, -1000000), + ] + for a, x in within_1e6: + assert a == approx(x, rel=5e-6, abs=0) + assert a != approx(x, rel=5e-7, abs=0) + + def test_decimal(self): + within_1e6 = [ + (Decimal('1.000001'), Decimal('1.0')), + (Decimal('-1.000001'), Decimal('-1.0')), + ] + for a, x in within_1e6: + assert a == approx(x, rel=Decimal(5e-6), abs=0) + assert a != approx(x, rel=Decimal(5e-7), abs=0) + + def test_fraction(self): + within_1e6 = [ + (1 + Fraction(1, 1000000), Fraction(1)), + (-1 - Fraction(-1, 1000000), Fraction(-1)), + ] + for a, x in within_1e6: + assert a == approx(x, rel=5e-6, abs=0) + assert a != approx(x, rel=5e-7, abs=0) + + def test_doctests(self): parser = doctest.DocTestParser() test = parser.get_doctest( - pytest.approx.__doc__, - {'approx': pytest.approx}, - pytest.approx.__name__, + approx.__doc__, + {'approx': approx}, + approx.__name__, None, None, ) runner = MyDocTestRunner() runner.run(test) - def test_repr_string(self): - # Just make sure the Unicode handling doesn't raise any exceptions. - print(pytest.approx(1.0)) - print(pytest.approx([1.0, 2.0, 3.0])) -