Fixed #7416: Modified test client to preserve session when a user logs in. Thanks to lakin.wecker@gmail.com for the report and Eric Holscher for the patch and test case.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@8372 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2008-08-15 12:08:29 +00:00
parent d8bfabe98d
commit 50b548d01b
4 changed files with 46 additions and 2 deletions

View File

@ -299,7 +299,10 @@ class Client:
# Create a fake request to store login details.
request = HttpRequest()
request.session = engine.SessionStore()
if self.session:
request.session = self.session
else:
request.session = engine.SessionStore()
login(request, user)
# Set the cookie to represent the session.

View File

@ -325,3 +325,33 @@ class zzUrlconfSubstitutionTests(TestCase):
"URLconf is reverted to original value after modification in a TestCase"
url = reverse('arg_view', args=['somename'])
self.assertEquals(url, '/test_client_regress/arg_view/somename/')
class SessionTests(TestCase):
fixtures = ['testdata.json']
def test_session(self):
"The session isn't lost if a user logs in"
# The session doesn't exist to start.
response = self.client.get('/test_client_regress/check_session/')
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, 'NO')
# This request sets a session variable.
response = self.client.get('/test_client_regress/set_session/')
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, 'set_session')
# Check that the session has been modified
response = self.client.get('/test_client_regress/check_session/')
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, 'YES')
# Log in
login = self.client.login(username='testclient',password='password')
self.failUnless(login, 'Could not log in')
# Session should still contain the modified value
response = self.client.get('/test_client_regress/check_session/')
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, 'YES')

View File

@ -6,5 +6,7 @@ urlpatterns = patterns('',
(r'^staff_only/$', views.staff_only_view),
(r'^get_view/$', views.get_view),
url(r'^arg_view/(?P<name>.+)/$', views.view_with_argument, name='arg_view'),
(r'^login_protected_redirect_view/$', views.login_protected_redirect_view)
(r'^login_protected_redirect_view/$', views.login_protected_redirect_view),
(r'^set_session/$', views.set_session_view),
(r'^check_session/$', views.check_session_view),
)

View File

@ -34,3 +34,12 @@ def login_protected_redirect_view(request):
"A view that redirects all requests to the GET view"
return HttpResponseRedirect('/test_client_regress/get_view/')
login_protected_redirect_view = login_required(login_protected_redirect_view)
def set_session_view(request):
"A view that sets a session variable"
request.session['session_var'] = 'YES'
return HttpResponse('set_session')
def check_session_view(request):
"A view that reads a session variable"
return HttpResponse(request.session.get('session_var', 'NO'))