python_api: let approx() take nonnumeric values (#7710)

Co-authored-by: Bruno Oliveira <nicoddemus@gmail.com>
This commit is contained in:
Jakob van Santen 2020-09-28 17:17:23 +02:00 committed by GitHub
parent f324b27d02
commit 91fa11bed0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 117 additions and 21 deletions

View File

@ -129,6 +129,7 @@ Ilya Konstantinov
Ionuț Turturică Ionuț Turturică
Iwan Briquemont Iwan Briquemont
Jaap Broekhuizen Jaap Broekhuizen
Jakob van Santen
Jakub Mitoraj Jakub Mitoraj
Jan Balster Jan Balster
Janne Vanhala Janne Vanhala

View File

@ -0,0 +1,3 @@
Use strict equality comparison for nonnumeric types in ``approx`` instead of
raising ``TypeError``.
This was the undocumented behavior before 3.7, but is now officially a supported feature.

View File

@ -4,7 +4,7 @@ from collections.abc import Iterable
from collections.abc import Mapping from collections.abc import Mapping
from collections.abc import Sized from collections.abc import Sized
from decimal import Decimal from decimal import Decimal
from numbers import Number from numbers import Complex
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
from typing import Callable from typing import Callable
@ -146,7 +146,10 @@ class ApproxMapping(ApproxBase):
) )
def __eq__(self, actual) -> bool: def __eq__(self, actual) -> bool:
if set(actual.keys()) != set(self.expected.keys()): try:
if set(actual.keys()) != set(self.expected.keys()):
return False
except AttributeError:
return False return False
return ApproxBase.__eq__(self, actual) return ApproxBase.__eq__(self, actual)
@ -161,8 +164,6 @@ class ApproxMapping(ApproxBase):
if isinstance(value, type(self.expected)): if isinstance(value, type(self.expected)):
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}" msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}"
raise TypeError(msg.format(key, value, pprint.pformat(self.expected))) raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
elif not isinstance(value, Number):
raise _non_numeric_type_error(self.expected, at="key={!r}".format(key))
class ApproxSequencelike(ApproxBase): class ApproxSequencelike(ApproxBase):
@ -177,7 +178,10 @@ class ApproxSequencelike(ApproxBase):
) )
def __eq__(self, actual) -> bool: def __eq__(self, actual) -> bool:
if len(actual) != len(self.expected): try:
if len(actual) != len(self.expected):
return False
except TypeError:
return False return False
return ApproxBase.__eq__(self, actual) return ApproxBase.__eq__(self, actual)
@ -190,10 +194,6 @@ class ApproxSequencelike(ApproxBase):
if isinstance(x, type(self.expected)): if isinstance(x, type(self.expected)):
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}" msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}"
raise TypeError(msg.format(x, index, pprint.pformat(self.expected))) raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
elif not isinstance(x, Number):
raise _non_numeric_type_error(
self.expected, at="index {}".format(index)
)
class ApproxScalar(ApproxBase): class ApproxScalar(ApproxBase):
@ -211,16 +211,23 @@ class ApproxScalar(ApproxBase):
For example, ``1.0 ± 1e-6``, ``(3+4j) ± 5e-6 ±180°``. For example, ``1.0 ± 1e-6``, ``(3+4j) ± 5e-6 ±180°``.
""" """
# Infinities aren't compared using tolerances, so don't show a # Don't show a tolerance for values that aren't compared using
# tolerance. Need to call abs to handle complex numbers, e.g. (inf + 1j). # tolerances, i.e. non-numerics and infinities. Need to call abs to
if math.isinf(abs(self.expected)): # handle complex numbers, e.g. (inf + 1j).
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
abs(self.expected)
):
return str(self.expected) return str(self.expected)
# If a sensible tolerance can't be calculated, self.tolerance will # If a sensible tolerance can't be calculated, self.tolerance will
# raise a ValueError. In this case, display '???'. # raise a ValueError. In this case, display '???'.
try: try:
vetted_tolerance = "{:.1e}".format(self.tolerance) vetted_tolerance = "{:.1e}".format(self.tolerance)
if isinstance(self.expected, complex) and not math.isinf(self.tolerance): if (
isinstance(self.expected, Complex)
and self.expected.imag
and not math.isinf(self.tolerance)
):
vetted_tolerance += " ∠ ±180°" vetted_tolerance += " ∠ ±180°"
except ValueError: except ValueError:
vetted_tolerance = "???" vetted_tolerance = "???"
@ -239,6 +246,15 @@ class ApproxScalar(ApproxBase):
if actual == self.expected: if actual == self.expected:
return True return True
# If either type is non-numeric, fall back to strict equality.
# NB: we need Complex, rather than just Number, to ensure that __abs__,
# __sub__, and __float__ are defined.
if not (
isinstance(self.expected, (Complex, Decimal))
and isinstance(actual, (Complex, Decimal))
):
return False
# Allow the user to control whether NaNs are considered equal to each # Allow the user to control whether NaNs are considered equal to each
# other or not. The abs() calls are for compatibility with complex # other or not. The abs() calls are for compatibility with complex
# numbers. # numbers.
@ -409,6 +425,18 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
>>> 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12) >>> 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12)
True True
You can also use ``approx`` to compare nonnumeric types, or dicts and
sequences containing nonnumeric types, in which case it falls back to
strict equality. This can be useful for comparing dicts and sequences that
can contain optional values::
>>> {"required": 1.0000005, "optional": None} == approx({"required": 1, "optional": None})
True
>>> [None, 1.0000005] == approx([None,1])
True
>>> ["foo", 1.0000005] == approx([None,1])
False
If you're thinking about using ``approx``, then you might want to know how If you're thinking about using ``approx``, then you might want to know how
it compares to other good ways of comparing floating-point numbers. All of it compares to other good ways of comparing floating-point numbers. All of
these algorithms are based on relative and absolute tolerances and should these algorithms are based on relative and absolute tolerances and should
@ -466,6 +494,14 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
follows a fixed behavior. `More information...`__ follows a fixed behavior. `More information...`__
__ https://docs.python.org/3/reference/datamodel.html#object.__ge__ __ https://docs.python.org/3/reference/datamodel.html#object.__ge__
.. versionchanged:: 3.7.1
``approx`` raises ``TypeError`` when it encounters a dict value or
sequence element of nonnumeric type.
.. versionchanged:: 6.1.0
``approx`` falls back to strict equality for nonnumeric types instead
of raising ``TypeError``.
""" """
# Delegate the comparison to a class that knows how to deal with the type # Delegate the comparison to a class that knows how to deal with the type
@ -487,8 +523,6 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
if isinstance(expected, Decimal): if isinstance(expected, Decimal):
cls = ApproxDecimal # type: Type[ApproxBase] cls = ApproxDecimal # type: Type[ApproxBase]
elif isinstance(expected, Number):
cls = ApproxScalar
elif isinstance(expected, Mapping): elif isinstance(expected, Mapping):
cls = ApproxMapping cls = ApproxMapping
elif _is_numpy_array(expected): elif _is_numpy_array(expected):
@ -501,7 +535,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
): ):
cls = ApproxSequencelike cls = ApproxSequencelike
else: else:
raise _non_numeric_type_error(expected, at=None) cls = ApproxScalar
return cls(expected, rel, abs, nan_ok) return cls(expected, rel, abs, nan_ok)

View File

@ -1,4 +1,5 @@
import operator import operator
import sys
from decimal import Decimal from decimal import Decimal
from fractions import Fraction from fractions import Fraction
from operator import eq from operator import eq
@ -329,6 +330,9 @@ class TestApprox:
assert (1, 2) != approx((1,)) assert (1, 2) != approx((1,))
assert (1, 2) != approx((1, 2, 3)) assert (1, 2) != approx((1, 2, 3))
def test_tuple_vs_other(self):
assert 1 != approx((1,))
def test_dict(self): def test_dict(self):
actual = {"a": 1 + 1e-7, "b": 2 + 1e-8} actual = {"a": 1 + 1e-7, "b": 2 + 1e-8}
# Dictionaries became ordered in python3.6, so switch up the order here # Dictionaries became ordered in python3.6, so switch up the order here
@ -346,6 +350,13 @@ class TestApprox:
assert {"a": 1, "b": 2} != approx({"a": 1, "c": 2}) assert {"a": 1, "b": 2} != approx({"a": 1, "c": 2})
assert {"a": 1, "b": 2} != approx({"a": 1, "b": 2, "c": 3}) assert {"a": 1, "b": 2} != approx({"a": 1, "b": 2, "c": 3})
def test_dict_nonnumeric(self):
assert {"a": 1.0, "b": None} == pytest.approx({"a": 1.0, "b": None})
assert {"a": 1.0, "b": 1} != pytest.approx({"a": 1.0, "b": None})
def test_dict_vs_other(self):
assert 1 != approx({"a": 0})
def test_numpy_array(self): def test_numpy_array(self):
np = pytest.importorskip("numpy") np = pytest.importorskip("numpy")
@ -463,20 +474,67 @@ class TestApprox:
["*At index 0 diff: 3 != 4 ± {}".format(expected), "=* 1 failed in *="] ["*At index 0 diff: 3 != 4 ± {}".format(expected), "=* 1 failed in *="]
) )
@pytest.mark.parametrize(
"x, name",
[
pytest.param([[1]], "data structures", id="nested-list"),
pytest.param({"key": {"key": 1}}, "dictionaries", id="nested-dict"),
],
)
def test_expected_value_type_error(self, x, name):
with pytest.raises(
TypeError,
match=r"pytest.approx\(\) does not support nested {}:".format(name),
):
approx(x)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x", "x",
[ [
pytest.param(None), pytest.param(None),
pytest.param("string"), pytest.param("string"),
pytest.param(["string"], id="nested-str"), pytest.param(["string"], id="nested-str"),
pytest.param([[1]], id="nested-list"),
pytest.param({"key": "string"}, id="dict-with-string"), pytest.param({"key": "string"}, id="dict-with-string"),
pytest.param({"key": {"key": 1}}, id="nested-dict"),
], ],
) )
def test_expected_value_type_error(self, x): def test_nonnumeric_okay_if_equal(self, x):
with pytest.raises(TypeError): assert x == approx(x)
approx(x)
@pytest.mark.parametrize(
"x",
[
pytest.param("string"),
pytest.param(["string"], id="nested-str"),
pytest.param({"key": "string"}, id="dict-with-string"),
],
)
def test_nonnumeric_false_if_unequal(self, x):
"""For nonnumeric types, x != pytest.approx(y) reduces to x != y"""
assert "ab" != approx("abc")
assert ["ab"] != approx(["abc"])
# in particular, both of these should return False
assert {"a": 1.0} != approx({"a": None})
assert {"a": None} != approx({"a": 1.0})
assert 1.0 != approx(None)
assert None != approx(1.0) # noqa: E711
assert 1.0 != approx([None])
assert None != approx([1.0]) # noqa: E711
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires ordered dicts")
def test_nonnumeric_dict_repr(self):
"""Dicts with non-numerics and infinites have no tolerances"""
x1 = {"foo": 1.0000005, "bar": None, "foobar": inf}
assert (
repr(approx(x1))
== "approx({'foo': 1.0000005 ± 1.0e-06, 'bar': None, 'foobar': inf})"
)
def test_nonnumeric_list_repr(self):
"""Lists with non-numerics and infinites have no tolerances"""
x1 = [1.0000005, None, inf]
assert repr(approx(x1)) == "approx([1.0000005 ± 1.0e-06, None, inf])"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"op", "op",