diff --git a/changelog/3763.bugfix.rst b/changelog/3763.bugfix.rst new file mode 100644 index 000000000..589346d2a --- /dev/null +++ b/changelog/3763.bugfix.rst @@ -0,0 +1 @@ +Fix ``TypeError`` when the assertion message is ``bytes`` in python 3. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index bc18aa1fc..4f96b9e8c 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -425,20 +425,18 @@ def _format_assertmsg(obj): # contains a newline it gets escaped, however if an object has a # .__repr__() which contains newlines it does not get escaped. # However in either case we want to preserve the newline. - if isinstance(obj, six.text_type) or isinstance(obj, six.binary_type): - s = obj - is_repr = False - else: - s = py.io.saferepr(obj) - is_repr = True - if isinstance(s, six.text_type): - t = six.text_type - else: - t = six.binary_type - s = s.replace(t("\n"), t("\n~")).replace(t("%"), t("%%")) - if is_repr: - s = s.replace(t("\\n"), t("\n~")) - return s + replaces = [(u"\n", u"\n~"), (u"%", u"%%")] + if not isinstance(obj, six.string_types): + obj = py.io.saferepr(obj) + replaces.append((u"\\n", u"\n~")) + + if isinstance(obj, bytes): + replaces = [(r1.encode(), r2.encode()) for r1, r2 in replaces] + + for r1, r2 in replaces: + obj = obj.replace(r1, r2) + + return obj def _should_repr_global_name(obj): diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 274b1ac53..6cec7f003 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -246,6 +246,15 @@ class TestAssertionRewrite(object): ["*AssertionError: To be escaped: %", "*assert 1 == 2"] ) + @pytest.mark.skipif( + sys.version_info < (3,), reason="bytes is a string type in python 2" + ) + def test_assertion_messages_bytes(self, testdir): + testdir.makepyfile("def test_bytes_assertion():\n assert False, b'ohai!'\n") + result = testdir.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines(["*AssertionError: b'ohai!'", "*assert False"]) + def test_boolop(self): def f(): f = g = False