mirror of https://github.com/django/django.git
[py3] Made signing infrastructure pass tests with Python 3
This commit is contained in:
parent
7275576235
commit
92b2dec918
|
@ -32,6 +32,8 @@ start of the base64 JSON.
|
||||||
There are 65 url-safe characters: the 64 used by url-safe base64 and the ':'.
|
There are 65 url-safe characters: the 64 used by url-safe base64 and the ':'.
|
||||||
These functions make use of all of them.
|
These functions make use of all of them.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
@ -41,7 +43,7 @@ from django.conf import settings
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
from django.utils import baseconv
|
from django.utils import baseconv
|
||||||
from django.utils.crypto import constant_time_compare, salted_hmac
|
from django.utils.crypto import constant_time_compare, salted_hmac
|
||||||
from django.utils.encoding import force_text, smart_bytes
|
from django.utils.encoding import smart_bytes
|
||||||
from django.utils.importlib import import_module
|
from django.utils.importlib import import_module
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,12 +62,12 @@ class SignatureExpired(BadSignature):
|
||||||
|
|
||||||
|
|
||||||
def b64_encode(s):
|
def b64_encode(s):
|
||||||
return base64.urlsafe_b64encode(s).strip('=')
|
return base64.urlsafe_b64encode(smart_bytes(s)).decode('ascii').strip('=')
|
||||||
|
|
||||||
|
|
||||||
def b64_decode(s):
|
def b64_decode(s):
|
||||||
pad = '=' * (-len(s) % 4)
|
pad = '=' * (-len(s) % 4)
|
||||||
return base64.urlsafe_b64decode(s + pad)
|
return base64.urlsafe_b64decode(smart_bytes(s + pad)).decode('ascii')
|
||||||
|
|
||||||
|
|
||||||
def base64_hmac(salt, value, key):
|
def base64_hmac(salt, value, key):
|
||||||
|
@ -121,7 +123,7 @@ def dumps(obj, key=None, salt='django.core.signing', serializer=JSONSerializer,
|
||||||
|
|
||||||
if compress:
|
if compress:
|
||||||
# Avoid zlib dependency unless compress is being used
|
# Avoid zlib dependency unless compress is being used
|
||||||
compressed = zlib.compress(data)
|
compressed = zlib.compress(smart_bytes(data))
|
||||||
if len(compressed) < (len(data) - 1):
|
if len(compressed) < (len(data) - 1):
|
||||||
data = compressed
|
data = compressed
|
||||||
is_compressed = True
|
is_compressed = True
|
||||||
|
@ -135,8 +137,7 @@ def loads(s, key=None, salt='django.core.signing', serializer=JSONSerializer, ma
|
||||||
"""
|
"""
|
||||||
Reverse of dumps(), raises BadSignature if signature fails
|
Reverse of dumps(), raises BadSignature if signature fails
|
||||||
"""
|
"""
|
||||||
base64d = smart_bytes(
|
base64d = TimestampSigner(key, salt=salt).unsign(s, max_age=max_age)
|
||||||
TimestampSigner(key, salt=salt).unsign(s, max_age=max_age))
|
|
||||||
decompress = False
|
decompress = False
|
||||||
if base64d[0] == '.':
|
if base64d[0] == '.':
|
||||||
# It's compressed; uncompress it first
|
# It's compressed; uncompress it first
|
||||||
|
@ -159,16 +160,14 @@ class Signer(object):
|
||||||
return base64_hmac(self.salt + 'signer', value, self.key)
|
return base64_hmac(self.salt + 'signer', value, self.key)
|
||||||
|
|
||||||
def sign(self, value):
|
def sign(self, value):
|
||||||
value = smart_bytes(value)
|
|
||||||
return '%s%s%s' % (value, self.sep, self.signature(value))
|
return '%s%s%s' % (value, self.sep, self.signature(value))
|
||||||
|
|
||||||
def unsign(self, signed_value):
|
def unsign(self, signed_value):
|
||||||
signed_value = smart_bytes(signed_value)
|
|
||||||
if not self.sep in signed_value:
|
if not self.sep in signed_value:
|
||||||
raise BadSignature('No "%s" found in value' % self.sep)
|
raise BadSignature('No "%s" found in value' % self.sep)
|
||||||
value, sig = signed_value.rsplit(self.sep, 1)
|
value, sig = signed_value.rsplit(self.sep, 1)
|
||||||
if constant_time_compare(sig, self.signature(value)):
|
if constant_time_compare(sig, self.signature(value)):
|
||||||
return force_text(value)
|
return value
|
||||||
raise BadSignature('Signature "%s" does not match' % sig)
|
raise BadSignature('Signature "%s" does not match' % sig)
|
||||||
|
|
||||||
|
|
||||||
|
@ -178,7 +177,7 @@ class TimestampSigner(Signer):
|
||||||
return baseconv.base62.encode(int(time.time()))
|
return baseconv.base62.encode(int(time.time()))
|
||||||
|
|
||||||
def sign(self, value):
|
def sign(self, value):
|
||||||
value = smart_bytes('%s%s%s' % (value, self.sep, self.timestamp()))
|
value = '%s%s%s' % (value, self.sep, self.timestamp())
|
||||||
return '%s%s%s' % (value, self.sep, self.signature(value))
|
return '%s%s%s' % (value, self.sep, self.signature(value))
|
||||||
|
|
||||||
def unsign(self, value, max_age=None):
|
def unsign(self, value, max_age=None):
|
||||||
|
|
|
@ -4,6 +4,7 @@ import time
|
||||||
|
|
||||||
from django.core import signing
|
from django.core import signing
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
from django.utils import six
|
||||||
from django.utils.encoding import force_text
|
from django.utils.encoding import force_text
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,15 +70,18 @@ class TestSigner(TestCase):
|
||||||
|
|
||||||
def test_dumps_loads(self):
|
def test_dumps_loads(self):
|
||||||
"dumps and loads be reversible for any JSON serializable object"
|
"dumps and loads be reversible for any JSON serializable object"
|
||||||
objects = (
|
objects = [
|
||||||
['a', 'list'],
|
['a', 'list'],
|
||||||
b'a string',
|
|
||||||
'a unicode string \u2019',
|
'a unicode string \u2019',
|
||||||
{'a': 'dictionary'},
|
{'a': 'dictionary'},
|
||||||
)
|
]
|
||||||
|
if not six.PY3:
|
||||||
|
objects.append(b'a byte string')
|
||||||
for o in objects:
|
for o in objects:
|
||||||
self.assertNotEqual(o, signing.dumps(o))
|
self.assertNotEqual(o, signing.dumps(o))
|
||||||
self.assertEqual(o, signing.loads(signing.dumps(o)))
|
self.assertEqual(o, signing.loads(signing.dumps(o)))
|
||||||
|
self.assertNotEqual(o, signing.dumps(o, compress=True))
|
||||||
|
self.assertEqual(o, signing.loads(signing.dumps(o, compress=True)))
|
||||||
|
|
||||||
def test_decode_detects_tampering(self):
|
def test_decode_detects_tampering(self):
|
||||||
"loads should raise exception for tampered objects"
|
"loads should raise exception for tampered objects"
|
||||||
|
|
Loading…
Reference in New Issue