Fix test failure introduced by 980ae2ab29.

This commit is contained in:
Baptiste Mispelon 2013-05-19 16:51:36 +02:00
parent b06f6c1618
commit 3cb1e9b93c
1 changed files with 9 additions and 4 deletions

View File

@ -58,15 +58,20 @@ class AuthViewsTestCase(TestCase):
form_errors = list(itertools.chain(*response.context['form'].errors.values())) form_errors = list(itertools.chain(*response.context['form'].errors.values()))
self.assertIn(force_text(error), form_errors) self.assertIn(force_text(error), form_errors)
def assertURLEqual(self, url, expected): def assertURLEqual(self, url, expected, parse_qs=False):
""" """
Given two URLs, make sure all their components (the ones given by Given two URLs, make sure all their components (the ones given by
urlparse) are equal, only comparing components that are present in both urlparse) are equal, only comparing components that are present in both
URLs. URLs.
If `parse_qs` is True, then the querystrings are parsed with QueryDict.
This is useful if you don't want the order of parameters to matter.
Otherwise, the query strings are compared as-is.
""" """
fields = ParseResult._fields fields = ParseResult._fields
for attr, x, y in zip(fields, urlparse(url), urlparse(expected)): for attr, x, y in zip(fields, urlparse(url), urlparse(expected)):
if parse_qs and attr == 'query':
x, y = QueryDict(x), QueryDict(y)
if x and y and x != y: if x and y and x != y:
self.fail("%r != %r (%s doesn't match)" % (url, expected, attr)) self.fail("%r != %r (%s doesn't match)" % (url, expected, attr))
@ -459,10 +464,10 @@ class LoginTest(AuthViewsTestCase):
@skipIfCustomUser @skipIfCustomUser
class LoginURLSettings(AuthViewsTestCase): class LoginURLSettings(AuthViewsTestCase):
"""Tests for settings.LOGIN_URL.""" """Tests for settings.LOGIN_URL."""
def assertLoginURLEquals(self, url): def assertLoginURLEquals(self, url, parse_qs=False):
response = self.client.get('/login_required/') response = self.client.get('/login_required/')
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertURLEqual(response.url, url) self.assertURLEqual(response.url, url, parse_qs=parse_qs)
@override_settings(LOGIN_URL='/login/') @override_settings(LOGIN_URL='/login/')
def test_standard_login_url(self): def test_standard_login_url(self):
@ -486,7 +491,7 @@ class LoginURLSettings(AuthViewsTestCase):
@override_settings(LOGIN_URL='/login/?pretty=1') @override_settings(LOGIN_URL='/login/?pretty=1')
def test_login_url_with_querystring(self): def test_login_url_with_querystring(self):
self.assertLoginURLEquals('/login/?pretty=1&next=/login_required/') self.assertLoginURLEquals('/login/?pretty=1&next=/login_required/', parse_qs=True)
@override_settings(LOGIN_URL='http://remote.example.com/login/?next=/default/') @override_settings(LOGIN_URL='http://remote.example.com/login/?next=/default/')
def test_remote_login_url_with_next_querystring(self): def test_remote_login_url_with_next_querystring(self):