From a127a22d13ce637b10244f2bf60e0df0b0313f57 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 1 Jul 2020 20:20:10 +0300 Subject: [PATCH] code/source: remove support for comparing Source with str Cross-type comparisons like this are a bad idea. This isn't used. --- src/_pytest/_code/source.py | 11 ++++------- testing/code/test_code.py | 3 ++- testing/code/test_source.py | 8 ++++---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index 6cc123202..f7dcdeff9 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -44,13 +44,10 @@ class Source: else: self.lines = deindent(getsource(obj).lines) - def __eq__(self, other): - try: - return self.lines == other.lines - except AttributeError: - if isinstance(other, str): - return str(self) == other - return False + def __eq__(self, other: object) -> bool: + if not isinstance(other, Source): + return NotImplemented + return self.lines == other.lines # Ignore type because of https://github.com/python/mypy/issues/4266. __hash__ = None # type: ignore diff --git a/testing/code/test_code.py b/testing/code/test_code.py index 5cbd89990..25a3e9aeb 100644 --- a/testing/code/test_code.py +++ b/testing/code/test_code.py @@ -6,6 +6,7 @@ import pytest from _pytest._code import Code from _pytest._code import ExceptionInfo from _pytest._code import Frame +from _pytest._code import Source from _pytest._code.code import ExceptionChainRepr from _pytest._code.code import ReprFuncArgs @@ -67,7 +68,7 @@ def test_getstatement_empty_fullsource() -> None: f = Frame(func()) with mock.patch.object(f.code.__class__, "fullsource", None): - assert f.statement == "" + assert f.statement == Source("") def test_code_from_func() -> None: diff --git a/testing/code/test_source.py b/testing/code/test_source.py index 014034dec..8616b2f25 100644 --- a/testing/code/test_source.py +++ b/testing/code/test_source.py @@ -227,9 +227,9 @@ class TestSourceParsingAndCompiling: ''')""" ) s = source.getstatement(0) - assert s == str(source) + assert s == source s = source.getstatement(1) - assert s == str(source) + assert s == source def test_getstatementrange_within_constructs(self) -> None: source = Source( @@ -445,7 +445,7 @@ def test_getsource_fallback() -> None: expected = """def x(): pass""" src = getsource(x) - assert src == expected + assert str(src) == expected def test_idem_compile_and_getsource() -> None: @@ -454,7 +454,7 @@ def test_idem_compile_and_getsource() -> None: expected = "def x(): pass" co = _pytest._code.compile(expected) src = getsource(co) - assert src == expected + assert str(src) == expected def test_compile_ast() -> None: