[py3] Made signing infrastructure pass tests with Python 3

This commit is contained in:
Claude Paroz 2012-08-09 20:08:47 +02:00
parent 7275576235
commit 92b2dec918
2 changed files with 16 additions and 13 deletions

View File

@ -32,6 +32,8 @@ start of the base64 JSON.
There are 65 url-safe characters: the 64 used by url-safe base64 and the ':'. There are 65 url-safe characters: the 64 used by url-safe base64 and the ':'.
These functions make use of all of them. These functions make use of all of them.
""" """
from __future__ import unicode_literals
import base64 import base64
import json import json
import time import time
@ -41,7 +43,7 @@ from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.utils import baseconv from django.utils import baseconv
from django.utils.crypto import constant_time_compare, salted_hmac from django.utils.crypto import constant_time_compare, salted_hmac
from django.utils.encoding import force_text, smart_bytes from django.utils.encoding import smart_bytes
from django.utils.importlib import import_module from django.utils.importlib import import_module
@ -60,12 +62,12 @@ class SignatureExpired(BadSignature):
def b64_encode(s): def b64_encode(s):
return base64.urlsafe_b64encode(s).strip('=') return base64.urlsafe_b64encode(smart_bytes(s)).decode('ascii').strip('=')
def b64_decode(s): def b64_decode(s):
pad = '=' * (-len(s) % 4) pad = '=' * (-len(s) % 4)
return base64.urlsafe_b64decode(s + pad) return base64.urlsafe_b64decode(smart_bytes(s + pad)).decode('ascii')
def base64_hmac(salt, value, key): def base64_hmac(salt, value, key):
@ -121,7 +123,7 @@ def dumps(obj, key=None, salt='django.core.signing', serializer=JSONSerializer,
if compress: if compress:
# Avoid zlib dependency unless compress is being used # Avoid zlib dependency unless compress is being used
compressed = zlib.compress(data) compressed = zlib.compress(smart_bytes(data))
if len(compressed) < (len(data) - 1): if len(compressed) < (len(data) - 1):
data = compressed data = compressed
is_compressed = True is_compressed = True
@ -135,8 +137,7 @@ def loads(s, key=None, salt='django.core.signing', serializer=JSONSerializer, ma
""" """
Reverse of dumps(), raises BadSignature if signature fails Reverse of dumps(), raises BadSignature if signature fails
""" """
base64d = smart_bytes( base64d = TimestampSigner(key, salt=salt).unsign(s, max_age=max_age)
TimestampSigner(key, salt=salt).unsign(s, max_age=max_age))
decompress = False decompress = False
if base64d[0] == '.': if base64d[0] == '.':
# It's compressed; uncompress it first # It's compressed; uncompress it first
@ -159,16 +160,14 @@ class Signer(object):
return base64_hmac(self.salt + 'signer', value, self.key) return base64_hmac(self.salt + 'signer', value, self.key)
def sign(self, value): def sign(self, value):
value = smart_bytes(value)
return '%s%s%s' % (value, self.sep, self.signature(value)) return '%s%s%s' % (value, self.sep, self.signature(value))
def unsign(self, signed_value): def unsign(self, signed_value):
signed_value = smart_bytes(signed_value)
if not self.sep in signed_value: if not self.sep in signed_value:
raise BadSignature('No "%s" found in value' % self.sep) raise BadSignature('No "%s" found in value' % self.sep)
value, sig = signed_value.rsplit(self.sep, 1) value, sig = signed_value.rsplit(self.sep, 1)
if constant_time_compare(sig, self.signature(value)): if constant_time_compare(sig, self.signature(value)):
return force_text(value) return value
raise BadSignature('Signature "%s" does not match' % sig) raise BadSignature('Signature "%s" does not match' % sig)
@ -178,7 +177,7 @@ class TimestampSigner(Signer):
return baseconv.base62.encode(int(time.time())) return baseconv.base62.encode(int(time.time()))
def sign(self, value): def sign(self, value):
value = smart_bytes('%s%s%s' % (value, self.sep, self.timestamp())) value = '%s%s%s' % (value, self.sep, self.timestamp())
return '%s%s%s' % (value, self.sep, self.signature(value)) return '%s%s%s' % (value, self.sep, self.signature(value))
def unsign(self, value, max_age=None): def unsign(self, value, max_age=None):

View File

@ -4,6 +4,7 @@ import time
from django.core import signing from django.core import signing
from django.test import TestCase from django.test import TestCase
from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
@ -69,15 +70,18 @@ class TestSigner(TestCase):
def test_dumps_loads(self): def test_dumps_loads(self):
"dumps and loads be reversible for any JSON serializable object" "dumps and loads be reversible for any JSON serializable object"
objects = ( objects = [
['a', 'list'], ['a', 'list'],
b'a string',
'a unicode string \u2019', 'a unicode string \u2019',
{'a': 'dictionary'}, {'a': 'dictionary'},
) ]
if not six.PY3:
objects.append(b'a byte string')
for o in objects: for o in objects:
self.assertNotEqual(o, signing.dumps(o)) self.assertNotEqual(o, signing.dumps(o))
self.assertEqual(o, signing.loads(signing.dumps(o))) self.assertEqual(o, signing.loads(signing.dumps(o)))
self.assertNotEqual(o, signing.dumps(o, compress=True))
self.assertEqual(o, signing.loads(signing.dumps(o, compress=True)))
def test_decode_detects_tampering(self): def test_decode_detects_tampering(self):
"loads should raise exception for tampered objects" "loads should raise exception for tampered objects"