Fixed #18852 -- Restored backwards compatibility

in django.core.signing. Specifically, kept the same return types
(str/unicode) under Python 2. Related to [92b2dec918].
This commit is contained in:
Aymeric Augustin 2012-08-25 13:02:52 +02:00
parent 62e1c5a441
commit 28ea4d4b07
2 changed files with 41 additions and 27 deletions

View File

@ -32,6 +32,7 @@ start of the base64 JSON.
There are 65 url-safe characters: the 64 used by url-safe base64 and the ':'. There are 65 url-safe characters: the 64 used by url-safe base64 and the ':'.
These functions make use of all of them. These functions make use of all of them.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
import base64 import base64
@ -43,7 +44,7 @@ from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.utils import baseconv from django.utils import baseconv
from django.utils.crypto import constant_time_compare, salted_hmac from django.utils.crypto import constant_time_compare, salted_hmac
from django.utils.encoding import smart_bytes from django.utils.encoding import force_bytes, force_str, force_text
from django.utils.importlib import import_module from django.utils.importlib import import_module
@ -62,12 +63,12 @@ class SignatureExpired(BadSignature):
def b64_encode(s): def b64_encode(s):
return base64.urlsafe_b64encode(smart_bytes(s)).decode('ascii').strip('=') return base64.urlsafe_b64encode(s).strip(b'=')
def b64_decode(s): def b64_decode(s):
pad = '=' * (-len(s) % 4) pad = b'=' * (-len(s) % 4)
return base64.urlsafe_b64decode(smart_bytes(s + pad)).decode('ascii') return base64.urlsafe_b64decode(s + pad)
def base64_hmac(salt, value, key): def base64_hmac(salt, value, key):
@ -116,20 +117,20 @@ def dumps(obj, key=None, salt='django.core.signing', serializer=JSONSerializer,
value or re-using a salt value across different parts of your value or re-using a salt value across different parts of your
application without good cause is a security risk. application without good cause is a security risk.
""" """
data = serializer().dumps(obj) data = force_bytes(serializer().dumps(obj))
# Flag for if it's been compressed or not # Flag for if it's been compressed or not
is_compressed = False is_compressed = False
if compress: if compress:
# Avoid zlib dependency unless compress is being used # Avoid zlib dependency unless compress is being used
compressed = zlib.compress(smart_bytes(data)) compressed = zlib.compress(data)
if len(compressed) < (len(data) - 1): if len(compressed) < (len(data) - 1):
data = compressed data = compressed
is_compressed = True is_compressed = True
base64d = b64_encode(data) base64d = b64_encode(data)
if is_compressed: if is_compressed:
base64d = '.' + base64d base64d = b'.' + base64d
return TimestampSigner(key, salt=salt).sign(base64d) return TimestampSigner(key, salt=salt).sign(base64d)
@ -137,37 +138,45 @@ def loads(s, key=None, salt='django.core.signing', serializer=JSONSerializer, ma
""" """
Reverse of dumps(), raises BadSignature if signature fails Reverse of dumps(), raises BadSignature if signature fails
""" """
base64d = TimestampSigner(key, salt=salt).unsign(s, max_age=max_age) # TimestampSigner.unsign always returns unicode but base64 and zlib
# compression operate on bytes.
base64d = force_bytes(TimestampSigner(key, salt=salt).unsign(s, max_age=max_age))
decompress = False decompress = False
if base64d[0] == '.': if base64d[0] == b'.':
# It's compressed; uncompress it first # It's compressed; uncompress it first
base64d = base64d[1:] base64d = base64d[1:]
decompress = True decompress = True
data = b64_decode(base64d) data = b64_decode(base64d)
if decompress: if decompress:
data = zlib.decompress(data) data = zlib.decompress(data)
return serializer().loads(data) return serializer().loads(force_str(data))
class Signer(object): class Signer(object):
def __init__(self, key=None, sep=':', salt=None): def __init__(self, key=None, sep=':', salt=None):
self.sep = sep # Use of native strings in all versions of Python
self.key = key or settings.SECRET_KEY self.sep = str(sep)
self.salt = salt or ('%s.%s' % self.key = str(key or settings.SECRET_KEY)
(self.__class__.__module__, self.__class__.__name__)) self.salt = str(salt or
'%s.%s' % (self.__class__.__module__, self.__class__.__name__))
def signature(self, value): def signature(self, value):
return base64_hmac(self.salt + 'signer', value, self.key) signature = base64_hmac(self.salt + 'signer', value, self.key)
# Convert the signature from bytes to str only on Python 3
return force_str(signature)
def sign(self, value): def sign(self, value):
return '%s%s%s' % (value, self.sep, self.signature(value)) value = force_str(value)
return str('%s%s%s') % (value, self.sep, self.signature(value))
def unsign(self, signed_value): def unsign(self, signed_value):
signed_value = force_str(signed_value)
if not self.sep in signed_value: if not self.sep in signed_value:
raise BadSignature('No "%s" found in value' % self.sep) raise BadSignature('No "%s" found in value' % self.sep)
value, sig = signed_value.rsplit(self.sep, 1) value, sig = signed_value.rsplit(self.sep, 1)
if constant_time_compare(sig, self.signature(value)): if constant_time_compare(sig, self.signature(value)):
return value return force_text(value)
raise BadSignature('Signature "%s" does not match' % sig) raise BadSignature('Signature "%s" does not match' % sig)
@ -177,8 +186,9 @@ class TimestampSigner(Signer):
return baseconv.base62.encode(int(time.time())) return baseconv.base62.encode(int(time.time()))
def sign(self, value): def sign(self, value):
value = '%s%s%s' % (value, self.sep, self.timestamp()) value = force_str(value)
return '%s%s%s' % (value, self.sep, self.signature(value)) value = str('%s%s%s') % (value, self.sep, self.timestamp())
return super(TimestampSigner, self).sign(value)
def unsign(self, value, max_age=None): def unsign(self, value, max_age=None):
result = super(TimestampSigner, self).unsign(value) result = super(TimestampSigner, self).unsign(value)

View File

@ -4,8 +4,8 @@ import time
from django.core import signing from django.core import signing
from django.test import TestCase from django.test import TestCase
from django.utils.encoding import force_str
from django.utils import six from django.utils import six
from django.utils.encoding import force_text
class TestSigner(TestCase): class TestSigner(TestCase):
@ -22,7 +22,7 @@ class TestSigner(TestCase):
self.assertEqual( self.assertEqual(
signer.signature(s), signer.signature(s),
signing.base64_hmac(signer.salt + 'signer', s, signing.base64_hmac(signer.salt + 'signer', s,
'predictable-secret') 'predictable-secret').decode()
) )
self.assertNotEqual(signer.signature(s), signer2.signature(s)) self.assertNotEqual(signer.signature(s), signer2.signature(s))
@ -32,7 +32,8 @@ class TestSigner(TestCase):
self.assertEqual( self.assertEqual(
signer.signature('hello'), signer.signature('hello'),
signing.base64_hmac('extra-salt' + 'signer', signing.base64_hmac('extra-salt' + 'signer',
'hello', 'predictable-secret')) 'hello', 'predictable-secret').decode()
)
self.assertNotEqual( self.assertNotEqual(
signing.Signer('predictable-secret', salt='one').signature('hello'), signing.Signer('predictable-secret', salt='one').signature('hello'),
signing.Signer('predictable-secret', salt='two').signature('hello')) signing.Signer('predictable-secret', salt='two').signature('hello'))
@ -40,17 +41,20 @@ class TestSigner(TestCase):
def test_sign_unsign(self): def test_sign_unsign(self):
"sign/unsign should be reversible" "sign/unsign should be reversible"
signer = signing.Signer('predictable-secret') signer = signing.Signer('predictable-secret')
examples = ( examples = [
'q;wjmbk;wkmb', 'q;wjmbk;wkmb',
'3098247529087', '3098247529087',
'3098247:529:087:', '3098247:529:087:',
'jkw osanteuh ,rcuh nthu aou oauh ,ud du', 'jkw osanteuh ,rcuh nthu aou oauh ,ud du',
'\u2019', '\u2019',
) ]
if not six.PY3:
examples.append(b'a byte string')
for example in examples: for example in examples:
self.assertNotEqual( signed = signer.sign(example)
force_text(example), force_text(signer.sign(example))) self.assertIsInstance(signed, str)
self.assertEqual(example, signer.unsign(signer.sign(example))) self.assertNotEqual(force_str(example), signed)
self.assertEqual(example, signer.unsign(signed))
def unsign_detects_tampering(self): def unsign_detects_tampering(self):
"unsign should raise an exception if the value has been tampered with" "unsign should raise an exception if the value has been tampered with"