Fixed #9002 -- Added a RequestFactory. This allows you to create request instances so you can unit test views as standalone functions. Thanks to Simon Willison for the suggestion and snippet on which this patch was originally based.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@14191 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2010-10-12 23:37:47 +00:00
parent 120aae2209
commit eec45e8b71
4 changed files with 244 additions and 132 deletions

View File

@ -2,6 +2,6 @@
Django Unit Test and Doctest framework. Django Unit Test and Doctest framework.
""" """
from django.test.client import Client from django.test.client import Client, RequestFactory
from django.test.testcases import TestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.testcases import TestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate from django.test.utils import Approximate

View File

@ -156,55 +156,29 @@ def encode_file(boundary, key, file):
file.read() file.read()
] ]
class Client(object):
class RequestFactory(object):
""" """
A class that can act as a client for testing purposes. Class that lets you create mock Request objects for use in testing.
It allows the user to compose GET and POST requests, and Usage:
obtain the response that the server gave to those requests.
The server Response objects are annotated with the details
of the contexts and templates that were rendered during the
process of serving the request.
Client objects are stateful - they will retain cookie (and rf = RequestFactory()
thus session) details for the lifetime of the Client instance. get_request = rf.get('/hello/')
post_request = rf.post('/submit/', {'foo': 'bar'})
This is not intended as a replacement for Twill/Selenium or Once you have a request object you can pass it to any view function,
the like - it is here to allow testing against the just as if that view had been hooked up using a URLconf.
contexts and templates produced by a view, rather than the
HTML rendered to the end-user.
""" """
def __init__(self, enforce_csrf_checks=False, **defaults): def __init__(self, **defaults):
self.handler = ClientHandler(enforce_csrf_checks)
self.defaults = defaults self.defaults = defaults
self.cookies = SimpleCookie() self.cookies = SimpleCookie()
self.exc_info = None
self.errors = StringIO() self.errors = StringIO()
def store_exc_info(self, **kwargs): def _base_environ(self, **request):
""" """
Stores exceptions when they are generated by a view. The base environment for a request.
"""
self.exc_info = sys.exc_info()
def _session(self):
"""
Obtains the current session variables.
"""
if 'django.contrib.sessions' in settings.INSTALLED_APPS:
engine = import_module(settings.SESSION_ENGINE)
cookie = self.cookies.get(settings.SESSION_COOKIE_NAME, None)
if cookie:
return engine.SessionStore(cookie.value)
return {}
session = property(_session)
def request(self, **request):
"""
The master request method. Composes the environment dictionary
and passes to the handler, returning the result of the handler.
Assumes defaults for the query environment, which can be overridden
using the arguments to the request.
""" """
environ = { environ = {
'HTTP_COOKIE': self.cookies.output(header='', sep='; '), 'HTTP_COOKIE': self.cookies.output(header='', sep='; '),
@ -225,6 +199,171 @@ class Client(object):
} }
environ.update(self.defaults) environ.update(self.defaults)
environ.update(request) environ.update(request)
return environ
def request(self, **request):
"Construct a generic request object."
return WSGIRequest(self._base_environ(**request))
def get(self, path, data={}, **extra):
"Construct a GET request"
parsed = urlparse(path)
r = {
'CONTENT_TYPE': 'text/html; charset=utf-8',
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': urlencode(data, doseq=True) or parsed[4],
'REQUEST_METHOD': 'GET',
'wsgi.input': FakePayload('')
}
r.update(extra)
return self.request(**r)
def post(self, path, data={}, content_type=MULTIPART_CONTENT,
**extra):
"Construct a POST request."
if content_type is MULTIPART_CONTENT:
post_data = encode_multipart(BOUNDARY, data)
else:
# Encode the content so that the byte representation is correct.
match = CONTENT_TYPE_RE.match(content_type)
if match:
charset = match.group(1)
else:
charset = settings.DEFAULT_CHARSET
post_data = smart_str(data, encoding=charset)
parsed = urlparse(path)
r = {
'CONTENT_LENGTH': len(post_data),
'CONTENT_TYPE': content_type,
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': parsed[4],
'REQUEST_METHOD': 'POST',
'wsgi.input': FakePayload(post_data),
}
r.update(extra)
return self.request(**r)
def head(self, path, data={}, **extra):
"Construct a HEAD request."
parsed = urlparse(path)
r = {
'CONTENT_TYPE': 'text/html; charset=utf-8',
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': urlencode(data, doseq=True) or parsed[4],
'REQUEST_METHOD': 'HEAD',
'wsgi.input': FakePayload('')
}
r.update(extra)
return self.request(**r)
def options(self, path, data={}, **extra):
"Constrict an OPTIONS request"
parsed = urlparse(path)
r = {
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': urlencode(data, doseq=True) or parsed[4],
'REQUEST_METHOD': 'OPTIONS',
'wsgi.input': FakePayload('')
}
r.update(extra)
return self.request(**r)
def put(self, path, data={}, content_type=MULTIPART_CONTENT,
**extra):
"Construct a PUT request."
if content_type is MULTIPART_CONTENT:
post_data = encode_multipart(BOUNDARY, data)
else:
post_data = data
# Make `data` into a querystring only if it's not already a string. If
# it is a string, we'll assume that the caller has already encoded it.
query_string = None
if not isinstance(data, basestring):
query_string = urlencode(data, doseq=True)
parsed = urlparse(path)
r = {
'CONTENT_LENGTH': len(post_data),
'CONTENT_TYPE': content_type,
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': query_string or parsed[4],
'REQUEST_METHOD': 'PUT',
'wsgi.input': FakePayload(post_data),
}
r.update(extra)
return self.request(**r)
def delete(self, path, data={}, **extra):
"Construct a DELETE request."
parsed = urlparse(path)
r = {
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': urlencode(data, doseq=True) or parsed[4],
'REQUEST_METHOD': 'DELETE',
'wsgi.input': FakePayload('')
}
r.update(extra)
return self.request(**r)
class Client(RequestFactory):
"""
A class that can act as a client for testing purposes.
It allows the user to compose GET and POST requests, and
obtain the response that the server gave to those requests.
The server Response objects are annotated with the details
of the contexts and templates that were rendered during the
process of serving the request.
Client objects are stateful - they will retain cookie (and
thus session) details for the lifetime of the Client instance.
This is not intended as a replacement for Twill/Selenium or
the like - it is here to allow testing against the
contexts and templates produced by a view, rather than the
HTML rendered to the end-user.
"""
def __init__(self, enforce_csrf_checks=False, **defaults):
super(Client, self).__init__(**defaults)
self.handler = ClientHandler(enforce_csrf_checks)
self.exc_info = None
def store_exc_info(self, **kwargs):
"""
Stores exceptions when they are generated by a view.
"""
self.exc_info = sys.exc_info()
def _session(self):
"""
Obtains the current session variables.
"""
if 'django.contrib.sessions' in settings.INSTALLED_APPS:
engine = import_module(settings.SESSION_ENGINE)
cookie = self.cookies.get(settings.SESSION_COOKIE_NAME, None)
if cookie:
return engine.SessionStore(cookie.value)
return {}
session = property(_session)
def request(self, **request):
"""
The master request method. Composes the environment dictionary
and passes to the handler, returning the result of the handler.
Assumes defaults for the query environment, which can be overridden
using the arguments to the request.
"""
environ = self._base_environ(**request)
# Curry a data dictionary into an instance of the template renderer # Curry a data dictionary into an instance of the template renderer
# callback function. # callback function.
@ -290,22 +429,11 @@ class Client(object):
signals.template_rendered.disconnect(dispatch_uid="template-render") signals.template_rendered.disconnect(dispatch_uid="template-render")
got_request_exception.disconnect(dispatch_uid="request-exception") got_request_exception.disconnect(dispatch_uid="request-exception")
def get(self, path, data={}, follow=False, **extra): def get(self, path, data={}, follow=False, **extra):
""" """
Requests a response from the server using GET. Requests a response from the server using GET.
""" """
parsed = urlparse(path) response = super(Client, self).get(path, data=data, **extra)
r = {
'CONTENT_TYPE': 'text/html; charset=utf-8',
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': urlencode(data, doseq=True) or parsed[4],
'REQUEST_METHOD': 'GET',
'wsgi.input': FakePayload('')
}
r.update(extra)
response = self.request(**r)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
@ -315,29 +443,7 @@ class Client(object):
""" """
Requests a response from the server using POST. Requests a response from the server using POST.
""" """
if content_type is MULTIPART_CONTENT: response = super(Client, self).post(path, data=data, content_type=content_type, **extra)
post_data = encode_multipart(BOUNDARY, data)
else:
# Encode the content so that the byte representation is correct.
match = CONTENT_TYPE_RE.match(content_type)
if match:
charset = match.group(1)
else:
charset = settings.DEFAULT_CHARSET
post_data = smart_str(data, encoding=charset)
parsed = urlparse(path)
r = {
'CONTENT_LENGTH': len(post_data),
'CONTENT_TYPE': content_type,
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': parsed[4],
'REQUEST_METHOD': 'POST',
'wsgi.input': FakePayload(post_data),
}
r.update(extra)
response = self.request(**r)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
@ -346,17 +452,7 @@ class Client(object):
""" """
Request a response from the server using HEAD. Request a response from the server using HEAD.
""" """
parsed = urlparse(path) response = super(Client, self).head(path, data=data, **extra)
r = {
'CONTENT_TYPE': 'text/html; charset=utf-8',
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': urlencode(data, doseq=True) or parsed[4],
'REQUEST_METHOD': 'HEAD',
'wsgi.input': FakePayload('')
}
r.update(extra)
response = self.request(**r)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
@ -365,16 +461,7 @@ class Client(object):
""" """
Request a response from the server using OPTIONS. Request a response from the server using OPTIONS.
""" """
parsed = urlparse(path) response = super(Client, self).options(path, data=data, **extra)
r = {
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': urlencode(data, doseq=True) or parsed[4],
'REQUEST_METHOD': 'OPTIONS',
'wsgi.input': FakePayload('')
}
r.update(extra)
response = self.request(**r)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
@ -384,29 +471,7 @@ class Client(object):
""" """
Send a resource to the server using PUT. Send a resource to the server using PUT.
""" """
if content_type is MULTIPART_CONTENT: response = super(Client, self).put(path, data=data, content_type=content_type, **extra)
post_data = encode_multipart(BOUNDARY, data)
else:
post_data = data
# Make `data` into a querystring only if it's not already a string. If
# it is a string, we'll assume that the caller has already encoded it.
query_string = None
if not isinstance(data, basestring):
query_string = urlencode(data, doseq=True)
parsed = urlparse(path)
r = {
'CONTENT_LENGTH': len(post_data),
'CONTENT_TYPE': content_type,
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': query_string or parsed[4],
'REQUEST_METHOD': 'PUT',
'wsgi.input': FakePayload(post_data),
}
r.update(extra)
response = self.request(**r)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
@ -415,23 +480,14 @@ class Client(object):
""" """
Send a DELETE request to the server. Send a DELETE request to the server.
""" """
parsed = urlparse(path) response = super(Client, self).delete(path, data=data, **extra)
r = {
'PATH_INFO': urllib.unquote(parsed[2]),
'QUERY_STRING': urlencode(data, doseq=True) or parsed[4],
'REQUEST_METHOD': 'DELETE',
'wsgi.input': FakePayload('')
}
r.update(extra)
response = self.request(**r)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
def login(self, **credentials): def login(self, **credentials):
""" """
Sets the Client to appear as if it has successfully logged into a site. Sets the Factory to appear as if it has successfully logged into a site.
Returns True if login is possible; False if the provided credentials Returns True if login is possible; False if the provided credentials
are incorrect, or the user is inactive, or if the sessions framework is are incorrect, or the user is inactive, or if the sessions framework is
@ -506,4 +562,3 @@ class Client(object):
if response.redirect_chain[-1] in response.redirect_chain[0:-1]: if response.redirect_chain[-1] in response.redirect_chain[0:-1]:
break break
return response return response

View File

@ -1014,6 +1014,51 @@ The following is a simple unit test using the test client::
# Check that the rendered context contains 5 customers. # Check that the rendered context contains 5 customers.
self.assertEqual(len(response.context['customers']), 5) self.assertEqual(len(response.context['customers']), 5)
The request factory
-------------------
.. Class:: RequestFactory
The :class:`~django.test.client.RequestFactory` is a simplified
version of the test client that provides a way to generate a request
instance that can be used as the first argument to any view. This
means you can test a view function the same way as you would test any
other function -- as a black box, with exactly known inputs, testing
for specific outputs.
The API for the :class:`~django.test.client.RequestFactory` is a slightly
restricted subset of the test client API:
* It only has access to the HTTP methods :meth:`~Client.get()`,
:meth:`~Client.post()`, :meth:`~Client.put()`,
:meth:`~Client.delete()`, :meth:`~Client.head()` and
:meth:`~Client.options()`.
* These methods accept all the same arguments *except* for
``follows``. Since this is just a factory for producing
requests, it's up to you to handle the response.
Example
~~~~~~~
The following is a simple unit test using the request factory::
from django.utils import unittest
from django.test.client import RequestFactory
class SimpleTest(unittest.TestCase):
def setUp(self):
# Every test needs a client.
self.factory = RequestFactory()
def test_details(self):
# Issue a GET request.
request = self.factory.get('/customer/details')
# Test my_view() as if it were deployed at /customer/details
response = my_view(request)
self.assertEquals(response.status_code, 200)
TestCase TestCase
-------- --------

View File

@ -20,9 +20,12 @@ testing against the contexts and templates produced by a view,
rather than the HTML rendered to the end-user. rather than the HTML rendered to the end-user.
""" """
from django.test import Client, TestCase
from django.conf import settings from django.conf import settings
from django.core import mail from django.core import mail
from django.test import Client, TestCase, RequestFactory
from views import get_view
class ClientTest(TestCase): class ClientTest(TestCase):
fixtures = ['testdata.json'] fixtures = ['testdata.json']
@ -469,3 +472,12 @@ class CustomTestClientTest(TestCase):
"""A test case can specify a custom class for self.client.""" """A test case can specify a custom class for self.client."""
self.assertEqual(hasattr(self.client, "i_am_customized"), True) self.assertEqual(hasattr(self.client, "i_am_customized"), True)
class RequestFactoryTest(TestCase):
def test_request_factory(self):
factory = RequestFactory()
request = factory.get('/somewhere/')
response = get_view(request)
self.assertEqual(response.status_code, 200)
self.assertContains(response, 'This is a test')