From 9b2aab99de321f7cdf9af2fa98a80a5e4d07751f Mon Sep 17 00:00:00 2001 From: Nick Murtagh Date: Tue, 26 Apr 2016 14:51:52 +0100 Subject: [PATCH 1/3] Python 3 Support --- .travis.yml | 9 ++- jose.py | 176 ++++++++++++++++++++++++++++------------------- requirements.txt | 1 + setup.py | 18 +++-- tests.py | 122 ++++++++++++++++---------------- tox.ini | 4 +- 6 files changed, 191 insertions(+), 139 deletions(-) diff --git a/.travis.yml b/.travis.yml index ede2df1..e9c52f8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,13 @@ language: python python: - "2.7" + - "3.3" + - "3.4" + - "3.5" + - "3.5-dev" + - "nightly" + - "pypy" install: - pip install tox -script: tox +script: + - tox -e travis diff --git a/jose.py b/jose.py index c93089a..7c88f35 100644 --- a/jose.py +++ b/jose.py @@ -1,20 +1,15 @@ +import binascii +import datetime import logging -logger = logging.getLogger(__name__) - -try: - from cjson import encode as json_encode, decode as json_decode -except ImportError: # pragma: nocover - logger.warn('cjson not found, falling back to stdlib json') - from json import loads as json_decode, dumps as json_encode - +import six import zlib -import datetime from base64 import urlsafe_b64encode, urlsafe_b64decode from collections import namedtuple from copy import deepcopy -from time import time +from json import loads as json_decode, dumps as json_encode from struct import pack +from time import time from Crypto.Hash import HMAC, SHA256, SHA384, SHA512 from Crypto.Cipher import PKCS1_OAEP, AES @@ -22,6 +17,8 @@ from Crypto.Random import get_random_bytes from Crypto.Signature import PKCS1_v1_5 as PKCS1_v1_5_SIG +logger = logging.getLogger(__name__) + __all__ = ['encrypt', 'decrypt', 'sign', 'verify'] @@ -63,7 +60,8 @@ class Error(Exception): """ The base error type raised by jose """ - pass + def __init__(self, message): + self.message = message class Expired(Error): @@ -85,7 +83,7 @@ def serialize_compact(jwt): :returns: A string, representing the compact serialization of a :class:`~jose.JWE` or :class:`~jose.JWS`. """ - return '.'.join(jwt) + return six.b('.').join(jwt) def deserialize_compact(jwt): @@ -95,7 +93,7 @@ def deserialize_compact(jwt): :rtype: :class:`~jose.JWT`. :raises: :class:`~jose.Error` if the JWT is malformed """ - parts = jwt.split('.') + parts = jwt.split(six.b('.')) # http://tools.ietf.org/html/ # draft-ietf-jose-json-web-encryption-23#section-9 @@ -109,8 +107,8 @@ def deserialize_compact(jwt): return token_type(*parts) -def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', - enc='A128CBC-HS256', rng=get_random_bytes, compression=None): +def encrypt(claims, jwk, adata=six.b(''), add_header=None, alg='RSA-OAEP', + enc='A128CBC-HS256', rng=get_random_bytes, compression=None): """ Encrypts the given claims and produces a :class:`~jose.JWE` :param claims: A `dict` representing the claims for this @@ -139,14 +137,15 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', assert _TEMP_VER_KEY not in claims claims[_TEMP_VER_KEY] = _TEMP_VER - header = dict((add_header or {}).items() + [ - ('enc', enc), ('alg', alg)]) + header = dict( + list((add_header or {}).items()) + [('enc', enc), ('alg', alg)] + ) # promote the temp key to the header assert _TEMP_VER_KEY not in header header[_TEMP_VER_KEY] = claims[_TEMP_VER_KEY] - plaintext = json_encode(claims) + plaintext = six.b(json_encode(claims)) # compress (if required) if compression is not None: @@ -162,24 +161,29 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', ((cipher, _), key_size), ((hash_fn, _), hash_mod) = JWA[enc] iv = rng(AES.block_size) encryption_key = rng(hash_mod.digest_size) + encryption_key_index = hash_mod.digest_size // 2 - ciphertext = cipher(plaintext, encryption_key[-hash_mod.digest_size/2:], iv) - hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata), - encryption_key[:-hash_mod.digest_size/2], hash_mod) + ciphertext = cipher( + plaintext, encryption_key[-encryption_key_index:], iv + ) + hash = hash_fn( + _jwe_hash_str(ciphertext, iv, adata), + encryption_key[:-encryption_key_index], hash_mod + ) # cek encryption (cipher, _), _ = JWA[alg] encryption_key_ciphertext = cipher(encryption_key, jwk) - return JWE(*map(b64encode_url, - (json_encode(header), - encryption_key_ciphertext, - iv, - ciphertext, - auth_tag(hash)))) + jwe_components = ( + json_encode(header), encryption_key_ciphertext, iv, ciphertext, + auth_tag(hash) + ) + return JWE(*map(b64encode_url, jwe_components)) -def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): +def decrypt(jwe, jwk, adata=six.b(''), validate_claims=True, + expiry_seconds=None): """ Decrypts a deserialized :class:`~jose.JWE` :param jwe: An instance of :class:`~jose.JWE` @@ -199,8 +203,9 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): :raises: :class:`~jose.Error` if there is an error decrypting the JWE """ header, encryption_key_ciphertext, iv, ciphertext, tag = map( - b64decode_url, jwe) - header = json_decode(header) + b64decode_url, jwe + ) + header = json_decode(header.decode()) # decrypt cek (_, decipher), _ = JWA[header['alg']] @@ -211,9 +216,13 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): version = header.get(_TEMP_VER_KEY) if version: - plaintext = decipher(ciphertext, encryption_key[-mod.digest_size/2:], iv) - hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version), - encryption_key[:-mod.digest_size/2], mod=mod) + plaintext = decipher( + ciphertext, encryption_key[-mod.digest_size // 2:], iv + ) + hash = hash_fn( + _jwe_hash_str(ciphertext, iv, adata, version), + encryption_key[:-mod.digest_size // 2], mod=mod + ) else: plaintext = decipher(ciphertext, encryption_key[:-mod.digest_size], iv) hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version), @@ -231,7 +240,7 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): plaintext = decompress(plaintext) - claims = json_decode(plaintext) + claims = json_decode(plaintext.decode()) try: del claims[_TEMP_VER_KEY] except KeyError: @@ -257,11 +266,12 @@ def sign(claims, jwk, add_header=None, alg='HS256'): """ (hash_fn, _), mod = JWA[alg] - header = dict((add_header or {}).items() + [('alg', alg)]) + header = dict(list((add_header or {}).items()) + [('alg', alg)]) header, payload = map(b64encode_url, map(json_encode, (header, claims))) - sig = b64encode_url(hash_fn(_jws_hash_str(header, payload), jwk['k'], - mod=mod)) + sig = b64encode_url( + hash_fn(_jws_hash_str(header, payload), jwk['k'], mod=mod) + ) return JWS(header, payload, sig) @@ -285,17 +295,18 @@ def verify(jws, jwk, alg, validate_claims=True, expiry_seconds=None): :raises: :class:`~jose.Error` if there is an error decrypting the JWE """ header, payload, sig = map(b64decode_url, jws) - header = json_decode(header) + header = json_decode(header.decode()) if alg != header['alg']: raise Error('Invalid algorithm') (_, verify_fn), mod = JWA[header['alg']] - if not verify_fn(_jws_hash_str(jws.header, jws.payload), - jwk['k'], sig, mod=mod): + if not verify_fn( + _jws_hash_str(jws.header, jws.payload), jwk['k'], sig, mod=mod + ): raise Error('Mismatched signatures') - claims = json_decode(b64decode_url(jws.payload)) + claims = json_decode(b64decode_url(jws.payload).decode()) _validate(claims, validate_claims, expiry_seconds) return JWT(header, claims) @@ -305,10 +316,9 @@ def b64decode_url(istr): """ JWT Tokens may be truncated without the usual trailing padding '=' symbols. Compensate by padding to the nearest 4 bytes. """ - istr = encode_safe(istr) try: - return urlsafe_b64decode(istr + '=' * (4 - (len(istr) % 4))) - except TypeError as e: + return urlsafe_b64decode(istr + six.b('=') * (4 - (len(istr) % 4))) + except (TypeError, binascii.Error) as e: raise Error('Unable to decode base64: %s' % (e)) @@ -316,16 +326,22 @@ def b64encode_url(istr): """ JWT Tokens may be truncated without the usual trailing padding '=' symbols. Compensate by padding to the nearest 4 bytes. """ - return urlsafe_b64encode(encode_safe(istr)).rstrip('=') + return urlsafe_b64encode(encode_safe(istr)).rstrip(six.b('=')) -def encode_safe(istr, encoding='utf8'): - try: - return istr.encode(encoding) - except UnicodeDecodeError: - # this will fail if istr is already encoded - pass - return istr +if six.PY3: + def encode_safe(istr, encoding='utf8'): + if not isinstance(istr, bytes): + return bytes(istr, encoding=encoding) + return istr +else: + def encode_safe(istr, encoding='utf8'): + try: + return istr.encode(encoding) + except UnicodeDecodeError: + # this will fail if istr is already encoded + pass + return istr def auth_tag(hmac): @@ -336,11 +352,15 @@ def auth_tag(hmac): def pad_pkcs7(s): sz = AES.block_size - (len(s) % AES.block_size) - return s + (chr(sz) * sz) + return s + (six.int2byte(sz) * sz) -def unpad_pkcs7(s): - return s[:-ord(s[-1])] +if six.PY3: + def unpad_pkcs7(s): + return s[:-s[-1]] +else: + def unpad_pkcs7(s): + return s[:-ord(s[-1])] def encrypt_oaep(plaintext, jwk): @@ -391,14 +411,24 @@ def decrypt_aescbc(ciphertext, key, iv): return unpad_pkcs7(AES.new(key, AES.MODE_CBC, iv).decrypt(ciphertext)) -def const_compare(stra, strb): - if len(stra) != len(strb): - return False +if six.PY3: + def const_compare(stra, strb): + if len(stra) != len(strb): + return False + + res = 0 + for a, b in zip(stra, strb): + res |= a ^ b + return res == 0 +else: + def const_compare(stra, strb): + if len(stra) != len(strb): + return False - res = 0 - for a, b in zip(stra, strb): - res |= ord(a) ^ ord(b) - return res == 0 + res = 0 + for a, b in zip(stra, strb): + res |= ord(a) ^ ord(b) + return res == 0 class _JWA(object): @@ -525,34 +555,36 @@ def _validate(claims, validate_claims, expiry_seconds): _check_not_before(now, not_before) -def _jwe_hash_str(ciphertext, iv, adata='', version=_TEMP_VER): +def _jwe_hash_str(ciphertext, iv, adata=six.b(''), version=_TEMP_VER): # http://tools.ietf.org/html/ # draft-ietf-jose-json-web-algorithms-24#section-5.2.2.1 # Both tokens without version and with version 1 should be ignored in # the future as they use incorrect hashing. The version parameter # should also be removed. if not version: - return '.'.join((adata, iv, ciphertext, str(len(adata)))) + return six.b('.').join( + (adata, iv, ciphertext, six.b(str(len(adata)))) + ) elif version == 1: - return '.'.join((adata, iv, ciphertext, pack("!Q", len(adata) * 8))) - return ''.join((adata, iv, ciphertext, pack("!Q", len(adata) * 8))) + return six.b('.').join( + (adata, iv, ciphertext, pack("!Q", len(adata) * 8)) + ) + return six.b('').join( + (adata, iv, ciphertext, pack("!Q", len(adata) * 8)) + ) def _jws_hash_str(header, claims): - return '.'.join((header, claims)) + return six.b('.').join((header, claims)) def cli_decrypt(jwt, key): - print decrypt(deserialize_compact(jwt), {'k':key}, - validate_claims=False) + print(decrypt(deserialize_compact(jwt), {'k': key}, validate_claims=False)) def _cli(): import inspect - import sys - from argparse import ArgumentParser - from copy import copy parser = ArgumentParser() subparsers = parser.add_subparsers(dest='subparser_name') diff --git a/requirements.txt b/requirements.txt index 9fadf90..fa08d5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ pycrypto >= 2.6 +six diff --git a/setup.py b/setup.py index 1cf1b23..48888e9 100644 --- a/setup.py +++ b/setup.py @@ -5,8 +5,14 @@ from setuptools.command.bdist_rpm import bdist_rpm as _bdist_rpm here = os.path.abspath(os.path.dirname(__file__)) -REQUIRES = filter(lambda s: len(s) > 0, - open(os.path.join(here, 'requirements.txt')).read().split('\n')) + +REQUIRES = list( + filter( + lambda s: len(s) > 0, + open(os.path.join(here, 'requirements.txt')).read().split('\n') + ) +) + pkg_name = 'jose' pyver = ''.join(('python', '.'.join(map(str, sys.version_info[:2])))) @@ -42,7 +48,8 @@ def finalize_package_data(self): if sys.argv[-1] == 'bdist_rpm': pkg_name = '-'.join((pyver.replace('.', ''), pkg_name)) - setup(name=pkg_name, + setup( + name=pkg_name, version='1.0.0', author='Demian Brecht', author_email='dbrecht@demonware.net', @@ -58,9 +65,10 @@ def finalize_package_data(self): 'Operating System :: OS Independent', 'Programming Language :: Python :: 2 :: Only', 'Topic :: Security', - 'Topic :: Software Development :: Libraries',], + 'Topic :: Software Development :: Libraries' + ], cmdclass={'bdist_rpm': bdist_rpm}, - entry_points = { + entry_points={ 'console_scripts': ( 'jose = jose:_cli', ) diff --git a/tests.py b/tests.py index 04a4f72..81e086e 100644 --- a/tests.py +++ b/tests.py @@ -1,3 +1,4 @@ +import six import json import unittest @@ -24,12 +25,14 @@ claims = {'john': 'cleese'} -def legacy_encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', - enc='A128CBC-HS256', rng=get_random_bytes, compression=None, version=None): +def legacy_encrypt(claims, jwk, adata=six.b(''), add_header=None, + alg='RSA-OAEP', enc='A128CBC-HS256', rng=get_random_bytes, + compression=None, version=None): # see https://github.com/Demonware/jose/pull/3/files - header = dict((add_header or {}).items() + [ - ('enc', enc), ('alg', alg)]) + header = dict( + list((add_header or {}).items()) + [('enc', enc), ('alg', alg)] + ) if version == 1: claims = deepcopy(claims) @@ -40,7 +43,7 @@ def legacy_encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', assert jose._TEMP_VER_KEY not in header header[jose._TEMP_VER_KEY] = version - plaintext = jose.json_encode(claims) + plaintext = six.b(jose.json_encode(claims)) # compress (if required) if compression is not None: @@ -57,15 +60,17 @@ def legacy_encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', iv = rng(AES.block_size) if version == 1: encryption_key = rng(hash_mod.digest_size) - cipher_key = encryption_key[-hash_mod.digest_size/2:] - mac_key = encryption_key[:-hash_mod.digest_size/2] + cipher_key = encryption_key[-hash_mod.digest_size // 2:] + mac_key = encryption_key[:-hash_mod.digest_size // 2] else: encryption_key = rng((key_size // 8) + hash_mod.digest_size) cipher_key = encryption_key[:-hash_mod.digest_size] mac_key = encryption_key[-hash_mod.digest_size:] ciphertext = cipher(plaintext, cipher_key, iv) - hash = hash_fn(jose._jwe_hash_str(ciphertext, iv, adata, version), mac_key, hash_mod) + hash = hash_fn( + jose._jwe_hash_str(ciphertext, iv, adata, version), mac_key, hash_mod + ) # cek encryption (cipher, _), _ = jose.JWA[alg] @@ -99,8 +104,6 @@ def test_jwe(self): self.assertEqual(e.message, 'Incorrect decryption.') def test_version1(self): - bad_key = {'k': RSA.generate(2048).exportKey('PEM')} - jwe = legacy_encrypt(claims, rsa_pub_key, version=1) token = jose.serialize_compact(jwe) @@ -113,7 +116,7 @@ def test_version1(self): class TestSerializeDeserialize(unittest.TestCase): def test_serialize(self): try: - jose.deserialize_compact('1.2.3.4') + jose.deserialize_compact(six.b('1.2.3.4')) self.fail() except jose.Error as e: self.assertEqual(e.message, 'Malformed JWT') @@ -131,7 +134,7 @@ def test_jwe(self): # make sure the body can't be loaded as json (should be encrypted) try: - json.loads(jose.b64decode_url(jwe.ciphertext)) + json.loads(jose.b64decode_url(jwe.ciphertext).decode()) self.fail() except ValueError: pass @@ -161,7 +164,7 @@ def test_jwe_add_header(self): self.assertEqual(jwt.header['foo'], add_header['foo']) def test_jwe_adata(self): - adata = '42' + adata = six.b('42') for (alg, jwk), enc in product(self.algs, self.encs): et = jose.serialize_compact(jose.encrypt(claims, rsa_pub_key, adata=adata)) @@ -180,15 +183,11 @@ def test_jwe_adata(self): def test_jwe_invalid_base64(self): try: - jose.decrypt('aaa', rsa_priv_key) - self.fail() # expecting error due to invalid base64 + jose.decrypt(six.b('aaa'), rsa_priv_key) except jose.Error as e: - pass - - self.assertEquals( - e.args[0], - 'Unable to decode base64: Incorrect padding' - ) + self.assertTrue(e.message.startswith('Unable to decode base64')) + else: + self.fail() # expecting error due to invalid base64 def test_jwe_no_error_with_exp_claim(self): claims = {jose.CLAIM_EXPIRATION_TIME: int(time()) + 5} @@ -201,16 +200,15 @@ def test_jwe_expired_error_with_exp_claim(self): try: jose.decrypt(jose.deserialize_compact(et), rsa_priv_key) - self.fail() # expecting expired token except jose.Expired as e: - pass - - self.assertEquals( - e.args[0], - 'Token expired at {}'.format( - jose._format_timestamp(claims[jose.CLAIM_EXPIRATION_TIME]) + self.assertEquals( + e.message, + 'Token expired at {}'.format( + jose._format_timestamp(claims[jose.CLAIM_EXPIRATION_TIME]) + ) ) - ) + else: + self.fail() # expecting expired token def test_jwe_no_error_with_iat_claim(self): claims = {jose.CLAIM_ISSUED_AT: int(time()) - 15} @@ -227,17 +225,16 @@ def test_jwe_expired_error_with_iat_claim(self): try: jose.decrypt(jose.deserialize_compact(et), rsa_priv_key, expiry_seconds=expiry_seconds) - self.fail() # expecting expired token except jose.Expired as e: - pass - - expiration_time = claims[jose.CLAIM_ISSUED_AT] + expiry_seconds - self.assertEquals( - e.args[0], - 'Token expired at {}'.format( - jose._format_timestamp(expiration_time) + expiration_time = claims[jose.CLAIM_ISSUED_AT] + expiry_seconds + self.assertEquals( + e.message, + 'Token expired at {}'.format( + jose._format_timestamp(expiration_time) + ) ) - ) + else: + self.fail() # expecting expired token def test_jwe_no_error_with_nbf_claim(self): claims = {jose.CLAIM_NOT_BEFORE: int(time()) - 5} @@ -250,16 +247,15 @@ def test_jwe_not_yet_valid_error_with_nbf_claim(self): try: jose.decrypt(jose.deserialize_compact(et), rsa_priv_key) - self.fail() # expecting not valid yet except jose.NotYetValid as e: - pass - - self.assertEquals( - e.args[0], - 'Token not valid until {}'.format( - jose._format_timestamp(claims[jose.CLAIM_NOT_BEFORE]) + self.assertEquals( + e.message, + 'Token not valid until {}'.format( + jose._format_timestamp(claims[jose.CLAIM_NOT_BEFORE]) + ) ) - ) + else: + self.fail() # expecting not valid yet def test_jwe_ignores_expired_token_if_validate_claims_is_false(self): claims = {jose.CLAIM_EXPIRATION_TIME: int(time()) - 5} @@ -276,15 +272,15 @@ def test_format_timestamp(self): def test_jwe_compression(self): local_claims = copy(claims) - for v in xrange(1000): + for v in range(1000): local_claims['dummy_' + str(v)] = '0' * 100 jwe = jose.serialize_compact(jose.encrypt(local_claims, rsa_pub_key)) - _, _, _, uncompressed_ciphertext, _ = jwe.split('.') + _, _, _, uncompressed_ciphertext, _ = jwe.split(six.b('.')) jwe = jose.serialize_compact(jose.encrypt(local_claims, rsa_pub_key, compression='DEF')) - _, _, _, compressed_ciphertext, _ = jwe.split('.') + _, _, _, compressed_ciphertext, _ = jwe.split(six.b('.')) self.assertTrue(len(compressed_ciphertext) < len(uncompressed_ciphertext)) @@ -320,7 +316,7 @@ class TestJWS(unittest.TestCase): def test_jws_sym(self): algs = ('HS256', 'HS384', 'HS512',) - jwk = {'k': 'password'} + jwk = {'k': six.b('password')} for alg in algs: st = jose.serialize_compact(jose.sign(claims, jwk, alg=alg)) @@ -339,21 +335,26 @@ def test_jws_asym(self): def test_jws_signature_mismatch_error(self): alg = 'HS256' - jwk = {'k': 'password'} + jwk = {'k': six.b('password')} jws = jose.sign(claims, jwk, alg=alg) try: - jose.verify(jose.JWS(jws.header, jws.payload, 'asd'), jwk, alg) + jose.verify( + jose.JWS(jws.header, jws.payload, six.b('asd')), + jwk, alg + ) except jose.Error as e: self.assertEqual(e.message, 'Mismatched signatures') def test_jws_invalid_algorithm_error(self): sign_alg = 'HS256' verify_alg = 'RS256' - jwk = {'k': 'password'} + jwk = {'k': six.b('password')} jws = jose.sign(claims, jwk, alg=sign_alg) try: - jose.verify(jose.JWS(jws.header, jws.payload, 'asd'), jwk, - verify_alg) + jose.verify( + jose.JWS(jws.header, jws.payload, six.b('asd')), + jwk, verify_alg + ) except jose.Error as e: self.assertEqual(e.message, 'Invalid algorithm') @@ -365,18 +366,21 @@ def test_b64encode_url_utf8(self): self.assertEqual(jose.b64decode_url(encoded), istr) def test_b64encode_url_ascii(self): - istr = 'eric idle' + istr = six.b('eric idle') encoded = jose.b64encode_url(istr) self.assertEqual(jose.b64decode_url(encoded), istr) def test_b64encode_url(self): - istr = '{"alg": "RSA-OAEP", "enc": "A128CBC-HS256"}' + istr = six.b('{"alg": "RSA-OAEP", "enc": "A128CBC-HS256"}') + + base64_encoded = b64encode(istr).decode() + base64_url_encoded = jose.b64encode_url(istr).decode() # sanity check - self.assertEqual(b64encode(istr)[-1], '=') + self.assertEqual(base64_encoded[-1], '=') # actual test - self.assertNotEqual(jose.b64encode_url(istr), '=') + self.assertNotEqual(base64_url_encoded[-1], '=') class TestJWA(unittest.TestCase): diff --git a/tox.ini b/tox.ini index 56a1d96..c159bf6 100644 --- a/tox.ini +++ b/tox.ini @@ -2,8 +2,8 @@ ignore = E128 [tox] -envlist=py27 +envlist=py27,py35,travis [testenv] deps=nose -commands=nosetests +commands=nosetests --no-path-adjustment --nocapture From b001c852bb3932b83dd2eee126139c4760c6e33a Mon Sep 17 00:00:00 2001 From: Nick Murtagh Date: Tue, 26 Apr 2016 15:55:52 +0100 Subject: [PATCH 2/3] Changes after review --- CHANGES | 2 ++ jose.py | 84 +++++++++++++++++++++++++------------------------------- setup.py | 2 +- tests.py | 10 ++----- 4 files changed, 43 insertions(+), 55 deletions(-) diff --git a/CHANGES b/CHANGES index 82bd801..a48b6cd 100644 --- a/CHANGES +++ b/CHANGES @@ -1,5 +1,7 @@ CHANGES ======= +1.1.0 (2016-04-26) +- Python 3 Support 1.0.0 (2015-10-06) ------------------ diff --git a/jose.py b/jose.py index 7c88f35..c78c6f7 100644 --- a/jose.py +++ b/jose.py @@ -1,5 +1,6 @@ import binascii import datetime +import json import logging import six import zlib @@ -7,7 +8,6 @@ from base64 import urlsafe_b64encode, urlsafe_b64decode from collections import namedtuple from copy import deepcopy -from json import loads as json_decode, dumps as json_encode from struct import pack from time import time @@ -137,15 +137,13 @@ def encrypt(claims, jwk, adata=six.b(''), add_header=None, alg='RSA-OAEP', assert _TEMP_VER_KEY not in claims claims[_TEMP_VER_KEY] = _TEMP_VER - header = dict( - list((add_header or {}).items()) + [('enc', enc), ('alg', alg)] - ) + header = dict(add_header or {}, enc=enc, alg=alg) # promote the temp key to the header assert _TEMP_VER_KEY not in header header[_TEMP_VER_KEY] = claims[_TEMP_VER_KEY] - plaintext = six.b(json_encode(claims)) + plaintext = json_encode(claims) # compress (if required) if compression is not None: @@ -205,7 +203,7 @@ def decrypt(jwe, jwk, adata=six.b(''), validate_claims=True, header, encryption_key_ciphertext, iv, ciphertext, tag = map( b64decode_url, jwe ) - header = json_decode(header.decode()) + header = json_decode(header) # decrypt cek (_, decipher), _ = JWA[header['alg']] @@ -240,7 +238,7 @@ def decrypt(jwe, jwk, adata=six.b(''), validate_claims=True, plaintext = decompress(plaintext) - claims = json_decode(plaintext.decode()) + claims = json_decode(plaintext) try: del claims[_TEMP_VER_KEY] except KeyError: @@ -265,8 +263,7 @@ def sign(claims, jwk, add_header=None, alg='HS256'): :rtype: :class:`~jose.JWS` """ (hash_fn, _), mod = JWA[alg] - - header = dict(list((add_header or {}).items()) + [('alg', alg)]) + header = dict(add_header or {}, alg=alg) header, payload = map(b64encode_url, map(json_encode, (header, claims))) sig = b64encode_url( @@ -295,7 +292,7 @@ def verify(jws, jwk, alg, validate_claims=True, expiry_seconds=None): :raises: :class:`~jose.Error` if there is an error decrypting the JWE """ header, payload, sig = map(b64decode_url, jws) - header = json_decode(header.decode()) + header = json_decode(header) if alg != header['alg']: raise Error('Invalid algorithm') @@ -306,7 +303,7 @@ def verify(jws, jwk, alg, validate_claims=True, expiry_seconds=None): ): raise Error('Mismatched signatures') - claims = json_decode(b64decode_url(jws.payload).decode()) + claims = json_decode(b64decode_url(jws.payload)) _validate(claims, validate_claims, expiry_seconds) return JWT(header, claims) @@ -326,22 +323,21 @@ def b64encode_url(istr): """ JWT Tokens may be truncated without the usual trailing padding '=' symbols. Compensate by padding to the nearest 4 bytes. """ - return urlsafe_b64encode(encode_safe(istr)).rstrip(six.b('=')) + return urlsafe_b64encode(istr).rstrip(six.b('=')) -if six.PY3: - def encode_safe(istr, encoding='utf8'): - if not isinstance(istr, bytes): - return bytes(istr, encoding=encoding) - return istr -else: - def encode_safe(istr, encoding='utf8'): - try: - return istr.encode(encoding) - except UnicodeDecodeError: - # this will fail if istr is already encoded - pass - return istr +def json_encode(x): + """ + Dict -> Binary + """ + return json.dumps(x).encode() + + +def json_decode(x): + """ + Binary -> Dict + """ + return json.loads(x.decode()) def auth_tag(hmac): @@ -355,12 +351,16 @@ def pad_pkcs7(s): return s + (six.int2byte(sz) * sz) -if six.PY3: - def unpad_pkcs7(s): - return s[:-s[-1]] +if six.PY2: + def _ord(x): + return ord(x) else: - def unpad_pkcs7(s): - return s[:-ord(s[-1])] + def _ord(x): + return x + + +def unpad_pkcs7(s): + return s[:-_ord(s[-1])] def encrypt_oaep(plaintext, jwk): @@ -411,24 +411,14 @@ def decrypt_aescbc(ciphertext, key, iv): return unpad_pkcs7(AES.new(key, AES.MODE_CBC, iv).decrypt(ciphertext)) -if six.PY3: - def const_compare(stra, strb): - if len(stra) != len(strb): - return False +def const_compare(stra, strb): + if len(stra) != len(strb): + return False - res = 0 - for a, b in zip(stra, strb): - res |= a ^ b - return res == 0 -else: - def const_compare(stra, strb): - if len(stra) != len(strb): - return False - - res = 0 - for a, b in zip(stra, strb): - res |= ord(a) ^ ord(b) - return res == 0 + res = 0 + for a, b in zip(stra, strb): + res |= _ord(a) ^ _ord(b) + return res == 0 class _JWA(object): diff --git a/setup.py b/setup.py index 48888e9..165ae1b 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ def finalize_package_data(self): setup( name=pkg_name, - version='1.0.0', + version='1.1.0', author='Demian Brecht', author_email='dbrecht@demonware.net', py_modules=['jose'], diff --git a/tests.py b/tests.py index 81e086e..5ba1c9b 100644 --- a/tests.py +++ b/tests.py @@ -22,17 +22,13 @@ 'k': rsa_key.publickey().exportKey('PEM'), } -claims = {'john': 'cleese'} +claims = {'john': u'cleese\u20ac'} def legacy_encrypt(claims, jwk, adata=six.b(''), add_header=None, alg='RSA-OAEP', enc='A128CBC-HS256', rng=get_random_bytes, compression=None, version=None): - # see https://github.com/Demonware/jose/pull/3/files - - header = dict( - list((add_header or {}).items()) + [('enc', enc), ('alg', alg)] - ) + header = dict(add_header or {}, enc=enc, alg=alg) if version == 1: claims = deepcopy(claims) @@ -43,7 +39,7 @@ def legacy_encrypt(claims, jwk, adata=six.b(''), add_header=None, assert jose._TEMP_VER_KEY not in header header[jose._TEMP_VER_KEY] = version - plaintext = six.b(jose.json_encode(claims)) + plaintext = jose.json_encode(claims) # compress (if required) if compression is not None: From 17c519bf910456acf07c4fc6daf04bccf42724e0 Mon Sep 17 00:00:00 2001 From: Nick Murtagh Date: Tue, 26 Apr 2016 19:25:32 +0100 Subject: [PATCH 3/3] flake8 --- jose.py | 27 ++++++--------------- setup.py | 5 ++-- tests.py | 74 ++++++++++++++++++++++++++++++-------------------------- tox.ini | 10 +++++--- 4 files changed, 57 insertions(+), 59 deletions(-) diff --git a/jose.py b/jose.py index c78c6f7..8ecd2a3 100644 --- a/jose.py +++ b/jose.py @@ -23,24 +23,9 @@ __all__ = ['encrypt', 'decrypt', 'sign', 'verify'] -# XXX: The attribute order is IMPORTANT in the following namedtuple -# definitions. DO NOT change them, unless you really know what you're doing. - -JWE = namedtuple('JWE', - 'header ' - 'cek ' - 'iv ' - 'ciphertext ' - 'tag ') - -JWS = namedtuple('JWS', - 'header ' - 'payload ' - 'signature ') - -JWT = namedtuple('JWT', - 'header ' - 'claims ') +JWE = namedtuple('JWE', ['header', 'cek', 'iv', 'ciphertext', 'tag']) +JWS = namedtuple('JWS', ['header', 'payload', 'signature']) +JWT = namedtuple('JWT', ['header', 'claims']) CLAIM_ISSUER = 'iss' @@ -223,8 +208,10 @@ def decrypt(jwe, jwk, adata=six.b(''), validate_claims=True, ) else: plaintext = decipher(ciphertext, encryption_key[:-mod.digest_size], iv) - hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version), - encryption_key[-mod.digest_size:], mod=mod) + hash = hash_fn( + _jwe_hash_str(ciphertext, iv, adata, version), + encryption_key[-mod.digest_size:], mod=mod + ) if not const_compare(auth_tag(hash), tag): raise Error('Mismatched authentication tags') diff --git a/setup.py b/setup.py index 165ae1b..e869155 100644 --- a/setup.py +++ b/setup.py @@ -39,8 +39,9 @@ def finalize_package_data(self): self.python = pyver if self.release is None: - self.release = '.'.join((os.environ.get('JOSE_RELEASE', '1'), - 'demonware')) + self.release = '.'.join( + (os.environ.get('JOSE_RELEASE', '1'), 'demonware') + ) _bdist_rpm.finalize_package_data(self) diff --git a/tests.py b/tests.py index 5ba1c9b..f124328 100644 --- a/tests.py +++ b/tests.py @@ -72,12 +72,12 @@ def legacy_encrypt(claims, jwk, adata=six.b(''), add_header=None, (cipher, _), _ = jose.JWA[alg] encryption_key_ciphertext = cipher(encryption_key, jwk) - return jose.JWE(*map(jose.b64encode_url, - (jose.json_encode(header), - encryption_key_ciphertext, - iv, - ciphertext, - jose.auth_tag(hash)))) + jwe_components = ( + jose.json_encode(header), encryption_key_ciphertext, iv, ciphertext, + jose.auth_tag(hash) + ) + + return jose.JWE(*map(jose.b64encode_url, jwe_components)) class TestLegacyDecrypt(unittest.TestCase): @@ -153,8 +153,9 @@ def test_jwe_add_header(self): add_header = {'foo': 'bar'} for (alg, jwk), enc in product(self.algs, self.encs): - et = jose.serialize_compact(jose.encrypt(claims, rsa_pub_key, - add_header=add_header)) + et = jose.serialize_compact( + jose.encrypt(claims, rsa_pub_key, add_header=add_header) + ) jwt = jose.decrypt(jose.deserialize_compact(et), rsa_priv_key) self.assertEqual(jwt.header['foo'], add_header['foo']) @@ -162,15 +163,19 @@ def test_jwe_add_header(self): def test_jwe_adata(self): adata = six.b('42') for (alg, jwk), enc in product(self.algs, self.encs): - et = jose.serialize_compact(jose.encrypt(claims, rsa_pub_key, - adata=adata)) - jwt = jose.decrypt(jose.deserialize_compact(et), rsa_priv_key, - adata=adata) + et = jose.serialize_compact( + jose.encrypt(claims, rsa_pub_key, adata=adata) + ) + jwt = jose.decrypt( + jose.deserialize_compact(et), rsa_priv_key, + adata=adata + ) # make sure signaures don't match when adata isn't passed in try: - hdr, dt = jose.decrypt(jose.deserialize_compact(et), - rsa_priv_key) + hdr, dt = jose.decrypt( + jose.deserialize_compact(et), rsa_priv_key + ) self.fail() except jose.Error as e: self.assertEqual(e.message, 'Mismatched authentication tags') @@ -210,8 +215,9 @@ def test_jwe_no_error_with_iat_claim(self): claims = {jose.CLAIM_ISSUED_AT: int(time()) - 15} et = jose.serialize_compact(jose.encrypt(claims, rsa_pub_key)) - jose.decrypt(jose.deserialize_compact(et), rsa_priv_key, - expiry_seconds=20) + jose.decrypt( + jose.deserialize_compact(et), rsa_priv_key, expiry_seconds=20 + ) def test_jwe_expired_error_with_iat_claim(self): expiry_seconds = 10 @@ -219,8 +225,10 @@ def test_jwe_expired_error_with_iat_claim(self): et = jose.serialize_compact(jose.encrypt(claims, rsa_pub_key)) try: - jose.decrypt(jose.deserialize_compact(et), rsa_priv_key, - expiry_seconds=expiry_seconds) + jose.decrypt( + jose.deserialize_compact(et), rsa_priv_key, + expiry_seconds=expiry_seconds + ) except jose.Expired as e: expiration_time = claims[jose.CLAIM_ISSUED_AT] + expiry_seconds self.assertEquals( @@ -256,8 +264,9 @@ def test_jwe_not_yet_valid_error_with_nbf_claim(self): def test_jwe_ignores_expired_token_if_validate_claims_is_false(self): claims = {jose.CLAIM_EXPIRATION_TIME: int(time()) - 5} et = jose.serialize_compact(jose.encrypt(claims, rsa_pub_key)) - jose.decrypt(jose.deserialize_compact(et), rsa_priv_key, - validate_claims=False) + jose.decrypt( + jose.deserialize_compact(et), rsa_priv_key, validate_claims=False + ) def test_format_timestamp(self): self.assertEquals( @@ -274,12 +283,14 @@ def test_jwe_compression(self): jwe = jose.serialize_compact(jose.encrypt(local_claims, rsa_pub_key)) _, _, _, uncompressed_ciphertext, _ = jwe.split(six.b('.')) - jwe = jose.serialize_compact(jose.encrypt(local_claims, rsa_pub_key, - compression='DEF')) + jwe = jose.serialize_compact( + jose.encrypt(local_claims, rsa_pub_key, compression='DEF') + ) _, _, _, compressed_ciphertext, _ = jwe.split(six.b('.')) - self.assertTrue(len(compressed_ciphertext) < - len(uncompressed_ciphertext)) + self.assertTrue( + len(compressed_ciphertext) < len(uncompressed_ciphertext) + ) jwt = jose.decrypt(jose.deserialize_compact(jwe), rsa_priv_key) self.assertEqual(jwt.claims, local_claims) @@ -324,8 +335,9 @@ def test_jws_asym(self): algs = ('RS256', 'RS384', 'RS512') for alg in algs: - st = jose.serialize_compact(jose.sign(claims, rsa_priv_key, - alg=alg)) + st = jose.serialize_compact( + jose.sign(claims, rsa_priv_key, alg=alg) + ) jwt = jose.verify(jose.deserialize_compact(st), rsa_pub_key, alg) self.assertEqual(jwt.claims, claims) @@ -387,10 +399,8 @@ def test_lookup(self): self.assertEqual(jose.JWA['HS256'], 'HS256') self.assertEqual(jose.JWA['RSA-OAEP'], 'RSA-OAEP') - self.assertEqual(jose.JWA['A128CBC-HS256'], - ('A128CBC', 'HS256')) - self.assertEqual(jose.JWA['A128CBC+HS256'], - ('A128CBC', 'HS256')) + self.assertEqual(jose.JWA['A128CBC-HS256'], ('A128CBC', 'HS256')) + self.assertEqual(jose.JWA['A128CBC+HS256'], ('A128CBC', 'HS256')) jose._JWA._impl = impl @@ -400,7 +410,3 @@ def test_invalid_error(self): self.fail() except jose.Error as e: self.assertTrue(e.message.startswith('Unsupported')) - - -if __name__ == '__main__': - unittest.main() diff --git a/tox.ini b/tox.ini index c159bf6..6042656 100644 --- a/tox.ini +++ b/tox.ini @@ -1,9 +1,13 @@ [flake8] -ignore = E128 +exclude=.tox,docs [tox] envlist=py27,py35,travis [testenv] -deps=nose -commands=nosetests --no-path-adjustment --nocapture +deps= + nose + flake8 +commands= + flake8 + nosetests --no-path-adjustment --nocapture