Fixed #10571 -- Factored out the payload encoding code to make sure it is used for PUT requests. Thanks to kennu for the report, pterk for the patch, and wildfire for the review comments.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16651 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2011-08-23 00:52:45 +00:00
parent 0f767f9a99
commit d310f91ee7
2 changed files with 32 additions and 24 deletions

View File

@ -207,6 +207,18 @@ class RequestFactory(object):
"Construct a generic request object." "Construct a generic request object."
return WSGIRequest(self._base_environ(**request)) return WSGIRequest(self._base_environ(**request))
def _encode_data(self, data, content_type, ):
if content_type is MULTIPART_CONTENT:
return encode_multipart(BOUNDARY, data)
else:
# Encode the content so that the byte representation is correct.
match = CONTENT_TYPE_RE.match(content_type)
if match:
charset = match.group(1)
else:
charset = settings.DEFAULT_CHARSET
return smart_str(data, encoding=charset)
def _get_path(self, parsed): def _get_path(self, parsed):
# If there are parameters, add them # If there are parameters, add them
if parsed[3]: if parsed[3]:
@ -232,16 +244,7 @@ class RequestFactory(object):
**extra): **extra):
"Construct a POST request." "Construct a POST request."
if content_type is MULTIPART_CONTENT: post_data = self._encode_data(data, content_type)
post_data = encode_multipart(BOUNDARY, data)
else:
# Encode the content so that the byte representation is correct.
match = CONTENT_TYPE_RE.match(content_type)
if match:
charset = match.group(1)
else:
charset = settings.DEFAULT_CHARSET
post_data = smart_str(data, encoding=charset)
parsed = urlparse(path) parsed = urlparse(path)
r = { r = {
@ -286,25 +289,16 @@ class RequestFactory(object):
**extra): **extra):
"Construct a PUT request." "Construct a PUT request."
if content_type is MULTIPART_CONTENT: put_data = self._encode_data(data, content_type)
post_data = encode_multipart(BOUNDARY, data)
else:
post_data = data
# Make `data` into a querystring only if it's not already a string. If
# it is a string, we'll assume that the caller has already encoded it.
query_string = None
if not isinstance(data, basestring):
query_string = urlencode(data, doseq=True)
parsed = urlparse(path) parsed = urlparse(path)
r = { r = {
'CONTENT_LENGTH': len(post_data), 'CONTENT_LENGTH': len(put_data),
'CONTENT_TYPE': content_type, 'CONTENT_TYPE': content_type,
'PATH_INFO': self._get_path(parsed), 'PATH_INFO': self._get_path(parsed),
'QUERY_STRING': query_string or parsed[4], 'QUERY_STRING': parsed[4],
'REQUEST_METHOD': 'PUT', 'REQUEST_METHOD': 'PUT',
'wsgi.input': FakePayload(post_data), 'wsgi.input': FakePayload(put_data),
} }
r.update(extra) r.update(extra)
return self.request(**r) return self.request(**r)

View File

@ -770,7 +770,9 @@ class RequestMethodStringDataTests(TestCase):
class QueryStringTests(TestCase): class QueryStringTests(TestCase):
def test_get_like_requests(self): def test_get_like_requests(self):
for method_name in ('get','head','options','put','delete'): # See: https://code.djangoproject.com/ticket/10571.
# Removed 'put' and 'delete' here as they are 'GET-like requests'
for method_name in ('get','head','options'):
# A GET-like request can pass a query string as data # A GET-like request can pass a query string as data
method = getattr(self.client, method_name) method = getattr(self.client, method_name)
response = method("/test_client_regress/request_data/", data={'foo':'whiz'}) response = method("/test_client_regress/request_data/", data={'foo':'whiz'})
@ -827,6 +829,9 @@ class UnicodePayloadTests(TestCase):
response = self.client.post("/test_client_regress/parse_unicode_json/", json, response = self.client.post("/test_client_regress/parse_unicode_json/", json,
content_type="application/json") content_type="application/json")
self.assertEqual(response.content, json) self.assertEqual(response.content, json)
response = self.client.put("/test_client_regress/parse_unicode_json/", json,
content_type="application/json")
self.assertEqual(response.content, json)
def test_unicode_payload_utf8(self): def test_unicode_payload_utf8(self):
"A non-ASCII unicode data encoded as UTF-8 can be POSTed" "A non-ASCII unicode data encoded as UTF-8 can be POSTed"
@ -835,6 +840,9 @@ class UnicodePayloadTests(TestCase):
response = self.client.post("/test_client_regress/parse_unicode_json/", json, response = self.client.post("/test_client_regress/parse_unicode_json/", json,
content_type="application/json; charset=utf-8") content_type="application/json; charset=utf-8")
self.assertEqual(response.content, json.encode('utf-8')) self.assertEqual(response.content, json.encode('utf-8'))
response = self.client.put("/test_client_regress/parse_unicode_json/", json,
content_type="application/json; charset=utf-8")
self.assertEqual(response.content, json.encode('utf-8'))
def test_unicode_payload_utf16(self): def test_unicode_payload_utf16(self):
"A non-ASCII unicode data encoded as UTF-16 can be POSTed" "A non-ASCII unicode data encoded as UTF-16 can be POSTed"
@ -843,6 +851,9 @@ class UnicodePayloadTests(TestCase):
response = self.client.post("/test_client_regress/parse_unicode_json/", json, response = self.client.post("/test_client_regress/parse_unicode_json/", json,
content_type="application/json; charset=utf-16") content_type="application/json; charset=utf-16")
self.assertEqual(response.content, json.encode('utf-16')) self.assertEqual(response.content, json.encode('utf-16'))
response = self.client.put("/test_client_regress/parse_unicode_json/", json,
content_type="application/json; charset=utf-16")
self.assertEqual(response.content, json.encode('utf-16'))
def test_unicode_payload_non_utf(self): def test_unicode_payload_non_utf(self):
"A non-ASCII unicode data as a non-UTF based encoding can be POSTed" "A non-ASCII unicode data as a non-UTF based encoding can be POSTed"
@ -851,6 +862,9 @@ class UnicodePayloadTests(TestCase):
response = self.client.post("/test_client_regress/parse_unicode_json/", json, response = self.client.post("/test_client_regress/parse_unicode_json/", json,
content_type="application/json; charset=koi8-r") content_type="application/json; charset=koi8-r")
self.assertEqual(response.content, json.encode('koi8-r')) self.assertEqual(response.content, json.encode('koi8-r'))
response = self.client.put("/test_client_regress/parse_unicode_json/", json,
content_type="application/json; charset=koi8-r")
self.assertEqual(response.content, json.encode('koi8-r'))
class DummyFile(object): class DummyFile(object):
def __init__(self, filename): def __init__(self, filename):