From 7122fa5613de8ef767eeb8106a5fee6dea0d9285 Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Wed, 19 Sep 2018 09:44:26 -0700 Subject: [PATCH] Fix UnicodeDecodeError in assertion with mixed non-ascii bytes repr + text --- changelog/3999.bugfix.rst | 1 + src/_pytest/assertion/rewrite.py | 13 +++++++++---- testing/test_assertrewrite.py | 18 +++++++++++++++++- 3 files changed, 27 insertions(+), 5 deletions(-) create mode 100644 changelog/3999.bugfix.rst diff --git a/changelog/3999.bugfix.rst b/changelog/3999.bugfix.rst new file mode 100644 index 000000000..e072f729e --- /dev/null +++ b/changelog/3999.bugfix.rst @@ -0,0 +1 @@ +Fix ``UnicodeDecodeError`` in python2.x when a class returns a non-ascii binary ``__repr__`` in an assertion which also contains non-ascii text. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 3539bd55d..be8c6dc4d 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -8,6 +8,7 @@ import marshal import os import re import six +import string import struct import sys import types @@ -466,10 +467,14 @@ def _saferepr(obj): """ r = py.io.saferepr(obj) - if isinstance(r, six.text_type): - return r.replace(u"\n", u"\\n") - else: - return r.replace(b"\n", b"\\n") + # only occurs in python2.x, repr must return text in python3+ + if isinstance(r, bytes): + # Represent unprintable bytes as `\x##` + r = u"".join( + u"\\x{:x}".format(ord(c)) if c not in string.printable else c.decode() + for c in r + ) + return r.replace(u"\n", u"\\n") from _pytest.assertion.util import format_explanation as _format_explanation # noqa diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 394d30a05..a2cd8e81c 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function import glob @@ -57,7 +58,7 @@ def getmsg(f, extra_ns=None, must_pass=False): except AssertionError: if must_pass: pytest.fail("shouldn't have raised") - s = str(sys.exc_info()[1]) + s = six.text_type(sys.exc_info()[1]) if not s.startswith("assert"): return "AssertionError: " + s return s @@ -608,6 +609,21 @@ class TestAssertionRewrite(object): assert r"where 1 = \n{ \n~ \n}.a" in util._format_lines([getmsg(f)])[0] + def test_custom_repr_non_ascii(self): + def f(): + class A(object): + name = u"รค" + + def __repr__(self): + return self.name.encode("UTF-8") # only legal in python2 + + a = A() + assert not a.name + + msg = getmsg(f) + assert "UnicodeDecodeError" not in msg + assert "UnicodeEncodeError" not in msg + class TestRewriteOnImport(object): def test_pycache_is_a_file(self, testdir):