This commit is contained in:
Yuval Shimon 2021-12-12 14:38:45 +02:00
parent c7be96dae4
commit 897395afd5
4 changed files with 64 additions and 2 deletions

View File

@ -0,0 +1 @@
Pytest will now avoid specialized assert formatting when it is detected that the default __eq__ is overridden

View File

@ -135,6 +135,26 @@ def isiterable(obj: Any) -> bool:
return False
def has_default_eq(
obj: object,
) -> bool:
"""Check if an instance of an object contains the default eq
First, we check if the object's __eq__ attribute has __code__, if so, we check the equally of the method code filename (__code__.co_filename)
to the default onces generated by the dataclass and attr module
for dataclasses the default co_filename is <string>, for attrs class, the __eq__ should contain "attrs eq generated"
"""
# inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68
if not hasattr(obj.__eq__, "__code__"): # the obj does not have a code attribute
return True
code_filename = obj.__eq__.__code__.co_filename
if isattrs(obj):
return "attrs eq generated" in code_filename
if isdatacls(obj):
return code_filename == "<string>"
def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[str]]:
"""Return specialised explanations for some operators/operands."""
verbose = config.getoption("verbose")
@ -427,6 +447,8 @@ def _compare_eq_dict(
def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]:
if not has_default_eq(left):
return []
if isdatacls(left):
all_fields = left.__dataclass_fields__
fields_to_check = [field for field, info in all_fields.items() if info.compare]
@ -437,7 +459,6 @@ def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]:
fields_to_check = left._fields
else:
assert False
indent = " "
same = []
diff = []

View File

@ -0,0 +1,17 @@
from dataclasses import dataclass
from dataclasses import field
def test_dataclasses() -> None:
@dataclass
class SimpleDataObject:
field_a: int = field()
field_b: str = field()
def __eq__(self, __o: object) -> bool:
return super().__eq__(__o)
left = SimpleDataObject(1, "b")
right = SimpleDataObject(1, "c")
assert left == right

View File

@ -899,6 +899,16 @@ class TestAssert_reprcompare_dataclass:
result = pytester.runpytest(p, "-vv")
result.assert_outcomes(failed=0, passed=1)
@pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+")
def test_data_classes_with_custom_eq(self, pytester: Pytester) -> None:
p = pytester.copy_example(
"dataclasses/test_compare_dataclasses_with_custom_eq.py"
)
# issue 9362
result = pytester.runpytest(p, "-vv")
result.assert_outcomes(failed=1, passed=0)
result.stdout.no_re_match_line(".*Differing attributes.*")
class TestAssert_reprcompare_attrsclass:
def test_attrs(self) -> None:
@ -982,7 +992,6 @@ class TestAssert_reprcompare_attrsclass:
right = SimpleDataObject(1, "b")
lines = callequal(left, right, verbose=2)
print(lines)
assert lines is not None
assert lines[2].startswith("Matching attributes:")
assert "Omitting" not in lines[1]
@ -1007,6 +1016,20 @@ class TestAssert_reprcompare_attrsclass:
lines = callequal(left, right)
assert lines is None
def test_attrs_with_custom_eq(self) -> None:
@attr.define
class SimpleDataObject:
field_a = attr.ib()
def __eq__(self, other):
return self.field_a == other.field_a
left = SimpleDataObject(1)
right = SimpleDataObject(2)
# issue 9362
lines = callequal(left, right, verbose=2)
assert lines is None
class TestAssert_reprcompare_namedtuple:
def test_namedtuple(self) -> None: