diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index 6cc123202..f7dcdeff9 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -44,13 +44,10 @@ class Source: else: self.lines = deindent(getsource(obj).lines) - def __eq__(self, other): - try: - return self.lines == other.lines - except AttributeError: - if isinstance(other, str): - return str(self) == other - return False + def __eq__(self, other: object) -> bool: + if not isinstance(other, Source): + return NotImplemented + return self.lines == other.lines # Ignore type because of https://github.com/python/mypy/issues/4266. __hash__ = None # type: ignore diff --git a/testing/code/test_code.py b/testing/code/test_code.py index 5cbd89990..25a3e9aeb 100644 --- a/testing/code/test_code.py +++ b/testing/code/test_code.py @@ -6,6 +6,7 @@ import pytest from _pytest._code import Code from _pytest._code import ExceptionInfo from _pytest._code import Frame +from _pytest._code import Source from _pytest._code.code import ExceptionChainRepr from _pytest._code.code import ReprFuncArgs @@ -67,7 +68,7 @@ def test_getstatement_empty_fullsource() -> None: f = Frame(func()) with mock.patch.object(f.code.__class__, "fullsource", None): - assert f.statement == "" + assert f.statement == Source("") def test_code_from_func() -> None: diff --git a/testing/code/test_source.py b/testing/code/test_source.py index 014034dec..8616b2f25 100644 --- a/testing/code/test_source.py +++ b/testing/code/test_source.py @@ -227,9 +227,9 @@ class TestSourceParsingAndCompiling: ''')""" ) s = source.getstatement(0) - assert s == str(source) + assert s == source s = source.getstatement(1) - assert s == str(source) + assert s == source def test_getstatementrange_within_constructs(self) -> None: source = Source( @@ -445,7 +445,7 @@ def test_getsource_fallback() -> None: expected = """def x(): pass""" src = getsource(x) - assert src == expected + assert str(src) == expected def test_idem_compile_and_getsource() -> None: @@ -454,7 +454,7 @@ def test_idem_compile_and_getsource() -> None: expected = "def x(): pass" co = _pytest._code.compile(expected) src = getsource(co) - assert src == expected + assert str(src) == expected def test_compile_ast() -> None: