python_api: let approx() take nonnumeric values (#7710)

Co-authored-by: Bruno Oliveira <nicoddemus@gmail.com>
This commit is contained in:
Jakob van Santen 2020-09-28 17:17:23 +02:00 committed by GitHub
parent f324b27d02
commit 91fa11bed0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 117 additions and 21 deletions

View File

@ -129,6 +129,7 @@ Ilya Konstantinov
Ionuț Turturică
Iwan Briquemont
Jaap Broekhuizen
Jakob van Santen
Jakub Mitoraj
Jan Balster
Janne Vanhala

View File

@ -0,0 +1,3 @@
Use strict equality comparison for nonnumeric types in ``approx`` instead of
raising ``TypeError``.
This was the undocumented behavior before 3.7, but is now officially a supported feature.

View File

@ -4,7 +4,7 @@ from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import Sized
from decimal import Decimal
from numbers import Number
from numbers import Complex
from types import TracebackType
from typing import Any
from typing import Callable
@ -146,7 +146,10 @@ class ApproxMapping(ApproxBase):
)
def __eq__(self, actual) -> bool:
if set(actual.keys()) != set(self.expected.keys()):
try:
if set(actual.keys()) != set(self.expected.keys()):
return False
except AttributeError:
return False
return ApproxBase.__eq__(self, actual)
@ -161,8 +164,6 @@ class ApproxMapping(ApproxBase):
if isinstance(value, type(self.expected)):
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}"
raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
elif not isinstance(value, Number):
raise _non_numeric_type_error(self.expected, at="key={!r}".format(key))
class ApproxSequencelike(ApproxBase):
@ -177,7 +178,10 @@ class ApproxSequencelike(ApproxBase):
)
def __eq__(self, actual) -> bool:
if len(actual) != len(self.expected):
try:
if len(actual) != len(self.expected):
return False
except TypeError:
return False
return ApproxBase.__eq__(self, actual)
@ -190,10 +194,6 @@ class ApproxSequencelike(ApproxBase):
if isinstance(x, type(self.expected)):
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}"
raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
elif not isinstance(x, Number):
raise _non_numeric_type_error(
self.expected, at="index {}".format(index)
)
class ApproxScalar(ApproxBase):
@ -211,16 +211,23 @@ class ApproxScalar(ApproxBase):
For example, ``1.0 ± 1e-6``, ``(3+4j) ± 5e-6 ±180°``.
"""
# Infinities aren't compared using tolerances, so don't show a
# tolerance. Need to call abs to handle complex numbers, e.g. (inf + 1j).
if math.isinf(abs(self.expected)):
# Don't show a tolerance for values that aren't compared using
# tolerances, i.e. non-numerics and infinities. Need to call abs to
# handle complex numbers, e.g. (inf + 1j).
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
abs(self.expected)
):
return str(self.expected)
# If a sensible tolerance can't be calculated, self.tolerance will
# raise a ValueError. In this case, display '???'.
try:
vetted_tolerance = "{:.1e}".format(self.tolerance)
if isinstance(self.expected, complex) and not math.isinf(self.tolerance):
if (
isinstance(self.expected, Complex)
and self.expected.imag
and not math.isinf(self.tolerance)
):
vetted_tolerance += " ∠ ±180°"
except ValueError:
vetted_tolerance = "???"
@ -239,6 +246,15 @@ class ApproxScalar(ApproxBase):
if actual == self.expected:
return True
# If either type is non-numeric, fall back to strict equality.
# NB: we need Complex, rather than just Number, to ensure that __abs__,
# __sub__, and __float__ are defined.
if not (
isinstance(self.expected, (Complex, Decimal))
and isinstance(actual, (Complex, Decimal))
):
return False
# Allow the user to control whether NaNs are considered equal to each
# other or not. The abs() calls are for compatibility with complex
# numbers.
@ -409,6 +425,18 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
>>> 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12)
True
You can also use ``approx`` to compare nonnumeric types, or dicts and
sequences containing nonnumeric types, in which case it falls back to
strict equality. This can be useful for comparing dicts and sequences that
can contain optional values::
>>> {"required": 1.0000005, "optional": None} == approx({"required": 1, "optional": None})
True
>>> [None, 1.0000005] == approx([None,1])
True
>>> ["foo", 1.0000005] == approx([None,1])
False
If you're thinking about using ``approx``, then you might want to know how
it compares to other good ways of comparing floating-point numbers. All of
these algorithms are based on relative and absolute tolerances and should
@ -466,6 +494,14 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
follows a fixed behavior. `More information...`__
__ https://docs.python.org/3/reference/datamodel.html#object.__ge__
.. versionchanged:: 3.7.1
``approx`` raises ``TypeError`` when it encounters a dict value or
sequence element of nonnumeric type.
.. versionchanged:: 6.1.0
``approx`` falls back to strict equality for nonnumeric types instead
of raising ``TypeError``.
"""
# Delegate the comparison to a class that knows how to deal with the type
@ -487,8 +523,6 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
if isinstance(expected, Decimal):
cls = ApproxDecimal # type: Type[ApproxBase]
elif isinstance(expected, Number):
cls = ApproxScalar
elif isinstance(expected, Mapping):
cls = ApproxMapping
elif _is_numpy_array(expected):
@ -501,7 +535,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
):
cls = ApproxSequencelike
else:
raise _non_numeric_type_error(expected, at=None)
cls = ApproxScalar
return cls(expected, rel, abs, nan_ok)

View File

@ -1,4 +1,5 @@
import operator
import sys
from decimal import Decimal
from fractions import Fraction
from operator import eq
@ -329,6 +330,9 @@ class TestApprox:
assert (1, 2) != approx((1,))
assert (1, 2) != approx((1, 2, 3))
def test_tuple_vs_other(self):
assert 1 != approx((1,))
def test_dict(self):
actual = {"a": 1 + 1e-7, "b": 2 + 1e-8}
# Dictionaries became ordered in python3.6, so switch up the order here
@ -346,6 +350,13 @@ class TestApprox:
assert {"a": 1, "b": 2} != approx({"a": 1, "c": 2})
assert {"a": 1, "b": 2} != approx({"a": 1, "b": 2, "c": 3})
def test_dict_nonnumeric(self):
assert {"a": 1.0, "b": None} == pytest.approx({"a": 1.0, "b": None})
assert {"a": 1.0, "b": 1} != pytest.approx({"a": 1.0, "b": None})
def test_dict_vs_other(self):
assert 1 != approx({"a": 0})
def test_numpy_array(self):
np = pytest.importorskip("numpy")
@ -463,20 +474,67 @@ class TestApprox:
["*At index 0 diff: 3 != 4 ± {}".format(expected), "=* 1 failed in *="]
)
@pytest.mark.parametrize(
"x, name",
[
pytest.param([[1]], "data structures", id="nested-list"),
pytest.param({"key": {"key": 1}}, "dictionaries", id="nested-dict"),
],
)
def test_expected_value_type_error(self, x, name):
with pytest.raises(
TypeError,
match=r"pytest.approx\(\) does not support nested {}:".format(name),
):
approx(x)
@pytest.mark.parametrize(
"x",
[
pytest.param(None),
pytest.param("string"),
pytest.param(["string"], id="nested-str"),
pytest.param([[1]], id="nested-list"),
pytest.param({"key": "string"}, id="dict-with-string"),
pytest.param({"key": {"key": 1}}, id="nested-dict"),
],
)
def test_expected_value_type_error(self, x):
with pytest.raises(TypeError):
approx(x)
def test_nonnumeric_okay_if_equal(self, x):
assert x == approx(x)
@pytest.mark.parametrize(
"x",
[
pytest.param("string"),
pytest.param(["string"], id="nested-str"),
pytest.param({"key": "string"}, id="dict-with-string"),
],
)
def test_nonnumeric_false_if_unequal(self, x):
"""For nonnumeric types, x != pytest.approx(y) reduces to x != y"""
assert "ab" != approx("abc")
assert ["ab"] != approx(["abc"])
# in particular, both of these should return False
assert {"a": 1.0} != approx({"a": None})
assert {"a": None} != approx({"a": 1.0})
assert 1.0 != approx(None)
assert None != approx(1.0) # noqa: E711
assert 1.0 != approx([None])
assert None != approx([1.0]) # noqa: E711
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires ordered dicts")
def test_nonnumeric_dict_repr(self):
"""Dicts with non-numerics and infinites have no tolerances"""
x1 = {"foo": 1.0000005, "bar": None, "foobar": inf}
assert (
repr(approx(x1))
== "approx({'foo': 1.0000005 ± 1.0e-06, 'bar': None, 'foobar': inf})"
)
def test_nonnumeric_list_repr(self):
"""Lists with non-numerics and infinites have no tolerances"""
x1 = [1.0000005, None, inf]
assert repr(approx(x1)) == "approx([1.0000005 ± 1.0e-06, None, inf])"
@pytest.mark.parametrize(
"op",