From 897395afd584b54394ab0b747ff03c80da07871d Mon Sep 17 00:00:00 2001 From: Yuval Shimon Date: Sun, 12 Dec 2021 14:38:45 +0200 Subject: [PATCH] fix 9326 --- changelog/9326.bugfix.rst | 1 + src/_pytest/assertion/util.py | 23 ++++++++++++++++- ...test_compare_dataclasses_with_custom_eq.py | 17 +++++++++++++ testing/test_assertion.py | 25 ++++++++++++++++++- 4 files changed, 64 insertions(+), 2 deletions(-) create mode 100644 changelog/9326.bugfix.rst create mode 100644 testing/example_scripts/dataclasses/test_compare_dataclasses_with_custom_eq.py diff --git a/changelog/9326.bugfix.rst b/changelog/9326.bugfix.rst new file mode 100644 index 000000000..1aaa424d3 --- /dev/null +++ b/changelog/9326.bugfix.rst @@ -0,0 +1 @@ +Pytest will now avoid specialized assert formatting when it is detected that the default __eq__ is overridden diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index 19f1089c2..ed5b01d1d 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -135,6 +135,26 @@ def isiterable(obj: Any) -> bool: return False +def has_default_eq( + obj: object, +) -> bool: + """Check if an instance of an object contains the default eq + + First, we check if the object's __eq__ attribute has __code__, if so, we check the equally of the method code filename (__code__.co_filename) + to the default onces generated by the dataclass and attr module + for dataclasses the default co_filename is , for attrs class, the __eq__ should contain "attrs eq generated" + """ + # inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68 + if not hasattr(obj.__eq__, "__code__"): # the obj does not have a code attribute + return True + + code_filename = obj.__eq__.__code__.co_filename + if isattrs(obj): + return "attrs eq generated" in code_filename + if isdatacls(obj): + return code_filename == "" + + def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[str]]: """Return specialised explanations for some operators/operands.""" verbose = config.getoption("verbose") @@ -427,6 +447,8 @@ def _compare_eq_dict( def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]: + if not has_default_eq(left): + return [] if isdatacls(left): all_fields = left.__dataclass_fields__ fields_to_check = [field for field, info in all_fields.items() if info.compare] @@ -437,7 +459,6 @@ def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]: fields_to_check = left._fields else: assert False - indent = " " same = [] diff = [] diff --git a/testing/example_scripts/dataclasses/test_compare_dataclasses_with_custom_eq.py b/testing/example_scripts/dataclasses/test_compare_dataclasses_with_custom_eq.py new file mode 100644 index 000000000..e026fe3d1 --- /dev/null +++ b/testing/example_scripts/dataclasses/test_compare_dataclasses_with_custom_eq.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from dataclasses import field + + +def test_dataclasses() -> None: + @dataclass + class SimpleDataObject: + field_a: int = field() + field_b: str = field() + + def __eq__(self, __o: object) -> bool: + return super().__eq__(__o) + + left = SimpleDataObject(1, "b") + right = SimpleDataObject(1, "c") + + assert left == right diff --git a/testing/test_assertion.py b/testing/test_assertion.py index e8717590d..29320c54b 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -899,6 +899,16 @@ class TestAssert_reprcompare_dataclass: result = pytester.runpytest(p, "-vv") result.assert_outcomes(failed=0, passed=1) + @pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+") + def test_data_classes_with_custom_eq(self, pytester: Pytester) -> None: + p = pytester.copy_example( + "dataclasses/test_compare_dataclasses_with_custom_eq.py" + ) + # issue 9362 + result = pytester.runpytest(p, "-vv") + result.assert_outcomes(failed=1, passed=0) + result.stdout.no_re_match_line(".*Differing attributes.*") + class TestAssert_reprcompare_attrsclass: def test_attrs(self) -> None: @@ -982,7 +992,6 @@ class TestAssert_reprcompare_attrsclass: right = SimpleDataObject(1, "b") lines = callequal(left, right, verbose=2) - print(lines) assert lines is not None assert lines[2].startswith("Matching attributes:") assert "Omitting" not in lines[1] @@ -1007,6 +1016,20 @@ class TestAssert_reprcompare_attrsclass: lines = callequal(left, right) assert lines is None + def test_attrs_with_custom_eq(self) -> None: + @attr.define + class SimpleDataObject: + field_a = attr.ib() + + def __eq__(self, other): + return self.field_a == other.field_a + + left = SimpleDataObject(1) + right = SimpleDataObject(2) + # issue 9362 + lines = callequal(left, right, verbose=2) + assert lines is None + class TestAssert_reprcompare_namedtuple: def test_namedtuple(self) -> None: