Fix representation of tuples in approx

Closes #9917
This commit is contained in:
Zach OBrien 2022-06-14 05:54:32 -04:00 committed by GitHub
parent bb94e83b49
commit 96412d19ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 56 additions and 8 deletions

View File

@ -361,5 +361,6 @@ Yoav Caspi
Yuval Shimon
Zac Hatfield-Dodds
Zachary Kneupper
Zachary OBrien
Zoltán Máté
Zsolt Cserna

View File

@ -0,0 +1 @@
Fixed string representation for :func:`pytest.approx` when used to compare tuples.

View File

@ -133,9 +133,11 @@ class ApproxBase:
# raise if there are any non-numeric elements in the sequence.
def _recursive_list_map(f, x):
if isinstance(x, list):
return [_recursive_list_map(f, xi) for xi in x]
def _recursive_sequence_map(f, x):
"""Recursively map a function over a sequence of arbitary depth"""
if isinstance(x, (list, tuple)):
seq_type = type(x)
return seq_type(_recursive_sequence_map(f, xi) for xi in x)
else:
return f(x)
@ -144,7 +146,9 @@ class ApproxNumpy(ApproxBase):
"""Perform approximate comparisons where the expected value is numpy array."""
def __repr__(self) -> str:
list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
list_scalars = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist()
)
return f"approx({list_scalars!r})"
def _repr_compare(self, other_side: "ndarray") -> List[str]:
@ -164,7 +168,7 @@ class ApproxNumpy(ApproxBase):
return value
np_array_shape = self.expected.shape
approx_side_as_list = _recursive_list_map(
approx_side_as_seq = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist()
)
@ -179,7 +183,7 @@ class ApproxNumpy(ApproxBase):
max_rel_diff = -math.inf
different_ids = []
for index in itertools.product(*(range(i) for i in np_array_shape)):
approx_value = get_value_from_nested_list(approx_side_as_list, index)
approx_value = get_value_from_nested_list(approx_side_as_seq, index)
other_value = get_value_from_nested_list(other_side, index)
if approx_value != other_value:
abs_diff = abs(approx_value.expected - other_value)
@ -194,7 +198,7 @@ class ApproxNumpy(ApproxBase):
(
str(index),
str(get_value_from_nested_list(other_side, index)),
str(get_value_from_nested_list(approx_side_as_list, index)),
str(get_value_from_nested_list(approx_side_as_seq, index)),
)
for index in different_ids
]
@ -326,7 +330,7 @@ class ApproxSequenceLike(ApproxBase):
f"Lengths: {len(self.expected)} and {len(other_side)}",
]
approx_side_as_map = _recursive_list_map(self._approx_scalar, self.expected)
approx_side_as_map = _recursive_sequence_map(self._approx_scalar, self.expected)
number_of_elements = len(approx_side_as_map)
max_abs_diff = -math.inf

View File

@ -2,12 +2,14 @@ import operator
from contextlib import contextmanager
from decimal import Decimal
from fractions import Fraction
from math import sqrt
from operator import eq
from operator import ne
from typing import Optional
import pytest
from _pytest.pytester import Pytester
from _pytest.python_api import _recursive_sequence_map
from pytest import approx
inf, nan = float("inf"), float("nan")
@ -133,6 +135,18 @@ class TestApprox:
],
)
assert_approx_raises_regex(
(1, 2.2, 4),
(1, 3.2, 4),
[
r" comparison failed. Mismatched elements: 1 / 3:",
rf" Max absolute difference: {SOME_FLOAT}",
rf" Max relative difference: {SOME_FLOAT}",
r" Index \| Obtained\s+\| Expected ",
rf" 1 \| {SOME_FLOAT} \| {SOME_FLOAT} ± {SOME_FLOAT}",
],
)
# Specific test for comparison with 0.0 (relative diff will be 'inf')
assert_approx_raises_regex(
[0.0],
@ -878,3 +892,31 @@ class TestApprox:
"""pytest.approx() should raise an error on unordered sequences (#9692)."""
with pytest.raises(TypeError, match="only supports ordered sequences"):
assert {1, 2, 3} == approx({1, 2, 3})
class TestRecursiveSequenceMap:
def test_map_over_scalar(self):
assert _recursive_sequence_map(sqrt, 16) == 4
def test_map_over_empty_list(self):
assert _recursive_sequence_map(sqrt, []) == []
def test_map_over_list(self):
assert _recursive_sequence_map(sqrt, [4, 16, 25, 676]) == [2, 4, 5, 26]
def test_map_over_tuple(self):
assert _recursive_sequence_map(sqrt, (4, 16, 25, 676)) == (2, 4, 5, 26)
def test_map_over_nested_lists(self):
assert _recursive_sequence_map(sqrt, [4, [25, 64], [[49]]]) == [
2,
[5, 8],
[[7]],
]
def test_map_over_mixed_sequence(self):
assert _recursive_sequence_map(sqrt, [4, (25, 64), [(49)]]) == [
2,
(5, 8),
[(7)],
]