diff --git a/django/test/testcases.py b/django/test/testcases.py index 95ea4ae900b..6e16cfb2474 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -5,13 +5,13 @@ import errno import json import os import posixpath -import re import socket import sys import threading import unittest import warnings from collections import Counter +from contextlib import contextmanager from copy import copy from functools import wraps from unittest.util import safe_repr @@ -604,10 +604,16 @@ class SimpleTestCase(unittest.TestCase): msg_prefix + "Template '%s' was used unexpectedly in rendering" " the response" % template_name) + @contextmanager + def _assert_raises_message_cm(self, expected_exception, expected_message): + with self.assertRaises(expected_exception) as cm: + yield cm + self.assertIn(expected_message, str(cm.exception)) + def assertRaisesMessage(self, expected_exception, expected_message, *args, **kwargs): """ - Asserts that the message in a raised exception matches the passed - value. + Asserts that expected_message is found in the the message of a raised + exception. Args: expected_exception: Exception class expected to be raised. @@ -622,9 +628,17 @@ class SimpleTestCase(unittest.TestCase): 'The callable_obj kwarg is deprecated. Pass the callable ' 'as a positional argument instead.', RemovedInDjango20Warning ) - args = (callable_obj,) + args - return six.assertRaisesRegex(self, expected_exception, - re.escape(expected_message), *args, **kwargs) + elif len(args): + callable_obj = args[0] + args = args[1:] + + cm = self._assert_raises_message_cm(expected_exception, expected_message) + # Assertion used in context manager fashion. + if callable_obj is None: + return cm + # Assertion was passed a callable. + with cm: + callable_obj(*args, **kwargs) def assertFieldOutput(self, fieldclass, valid, invalid, field_args=None, field_kwargs=None, empty_value=''): diff --git a/tests/test_utils/tests.py b/tests/test_utils/tests.py index 682e0e08e69..a02bfb63976 100644 --- a/tests/test_utils/tests.py +++ b/tests/test_utils/tests.py @@ -749,6 +749,20 @@ class SkippingExtraTests(TestCase): class AssertRaisesMsgTest(SimpleTestCase): + def test_assert_raises_message(self): + msg = "'Expected message' not found in 'Unexpected message'" + # context manager form of assertRaisesMessage() + with self.assertRaisesMessage(AssertionError, msg): + with self.assertRaisesMessage(ValueError, "Expected message"): + raise ValueError("Unexpected message") + + # callable form + def func(): + raise ValueError("Unexpected message") + + with self.assertRaisesMessage(AssertionError, msg): + self.assertRaisesMessage(ValueError, "Expected message", func) + def test_special_re_chars(self): """assertRaisesMessage shouldn't interpret RE special chars.""" def func1():