Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
language: python
python:
- "2.7"
- "3.3"
- "3.4"
- "3.5"
- "3.5-dev"
- "nightly"
- "pypy"
install:
- pip install tox
script: tox
script:
- tox -e travis
2 changes: 2 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
CHANGES
=======
1.1.0 (2016-04-26)
- Python 3 Support

1.0.0 (2015-10-06)
------------------
Expand Down
171 changes: 90 additions & 81 deletions jose.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,31 @@
import binascii
import datetime
import json
import logging
logger = logging.getLogger(__name__)

try:
from cjson import encode as json_encode, decode as json_decode
except ImportError: # pragma: nocover
logger.warn('cjson not found, falling back to stdlib json')
from json import loads as json_decode, dumps as json_encode

import six
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You only seem to be using six.b, so normally I would just pull out six.b and have it as a function locally.

import zlib
import datetime

from base64 import urlsafe_b64encode, urlsafe_b64decode
from collections import namedtuple
from copy import deepcopy
from time import time
from struct import pack
from time import time

from Crypto.Hash import HMAC, SHA256, SHA384, SHA512
from Crypto.Cipher import PKCS1_OAEP, AES
from Crypto.PublicKey import RSA
from Crypto.Random import get_random_bytes
from Crypto.Signature import PKCS1_v1_5 as PKCS1_v1_5_SIG

logger = logging.getLogger(__name__)

__all__ = ['encrypt', 'decrypt', 'sign', 'verify']


# XXX: The attribute order is IMPORTANT in the following namedtuple
# definitions. DO NOT change them, unless you really know what you're doing.

JWE = namedtuple('JWE',
'header '
'cek '
'iv '
'ciphertext '
'tag ')
__all__ = ['encrypt', 'decrypt', 'sign', 'verify']

JWS = namedtuple('JWS',
'header '
'payload '
'signature ')

JWT = namedtuple('JWT',
'header '
'claims ')
JWE = namedtuple('JWE', ['header', 'cek', 'iv', 'ciphertext', 'tag'])
JWS = namedtuple('JWS', ['header', 'payload', 'signature'])
JWT = namedtuple('JWT', ['header', 'claims'])


CLAIM_ISSUER = 'iss'
Expand All @@ -63,7 +45,8 @@
class Error(Exception):
""" The base error type raised by jose
"""
pass
def __init__(self, message):
self.message = message


class Expired(Error):
Expand All @@ -85,7 +68,7 @@ def serialize_compact(jwt):
:returns: A string, representing the compact serialization of a
:class:`~jose.JWE` or :class:`~jose.JWS`.
"""
return '.'.join(jwt)
return six.b('.').join(jwt)


def deserialize_compact(jwt):
Expand All @@ -95,7 +78,7 @@ def deserialize_compact(jwt):
:rtype: :class:`~jose.JWT`.
:raises: :class:`~jose.Error` if the JWT is malformed
"""
parts = jwt.split('.')
parts = jwt.split(six.b('.'))

# http://tools.ietf.org/html/
# draft-ietf-jose-json-web-encryption-23#section-9
Expand All @@ -109,8 +92,8 @@ def deserialize_compact(jwt):
return token_type(*parts)


def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP',
enc='A128CBC-HS256', rng=get_random_bytes, compression=None):
def encrypt(claims, jwk, adata=six.b(''), add_header=None, alg='RSA-OAEP',
enc='A128CBC-HS256', rng=get_random_bytes, compression=None):
""" Encrypts the given claims and produces a :class:`~jose.JWE`

:param claims: A `dict` representing the claims for this
Expand Down Expand Up @@ -139,8 +122,7 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP',
assert _TEMP_VER_KEY not in claims
claims[_TEMP_VER_KEY] = _TEMP_VER

header = dict((add_header or {}).items() + [
('enc', enc), ('alg', alg)])
header = dict(add_header or {}, enc=enc, alg=alg)

# promote the temp key to the header
assert _TEMP_VER_KEY not in header
Expand All @@ -162,24 +144,29 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP',
((cipher, _), key_size), ((hash_fn, _), hash_mod) = JWA[enc]
iv = rng(AES.block_size)
encryption_key = rng(hash_mod.digest_size)
encryption_key_index = hash_mod.digest_size // 2

ciphertext = cipher(plaintext, encryption_key[-hash_mod.digest_size/2:], iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata),
encryption_key[:-hash_mod.digest_size/2], hash_mod)
ciphertext = cipher(
plaintext, encryption_key[-encryption_key_index:], iv
)
hash = hash_fn(
_jwe_hash_str(ciphertext, iv, adata),
encryption_key[:-encryption_key_index], hash_mod
)

# cek encryption
(cipher, _), _ = JWA[alg]
encryption_key_ciphertext = cipher(encryption_key, jwk)

return JWE(*map(b64encode_url,
(json_encode(header),
encryption_key_ciphertext,
iv,
ciphertext,
auth_tag(hash))))
jwe_components = (
json_encode(header), encryption_key_ciphertext, iv, ciphertext,
auth_tag(hash)
)
return JWE(*map(b64encode_url, jwe_components))


def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None):
def decrypt(jwe, jwk, adata=six.b(''), validate_claims=True,
expiry_seconds=None):
""" Decrypts a deserialized :class:`~jose.JWE`

:param jwe: An instance of :class:`~jose.JWE`
Expand All @@ -199,7 +186,8 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None):
:raises: :class:`~jose.Error` if there is an error decrypting the JWE
"""
header, encryption_key_ciphertext, iv, ciphertext, tag = map(
b64decode_url, jwe)
b64decode_url, jwe
)
header = json_decode(header)

# decrypt cek
Expand All @@ -211,13 +199,19 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None):

version = header.get(_TEMP_VER_KEY)
if version:
plaintext = decipher(ciphertext, encryption_key[-mod.digest_size/2:], iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version),
encryption_key[:-mod.digest_size/2], mod=mod)
plaintext = decipher(
ciphertext, encryption_key[-mod.digest_size // 2:], iv
)
hash = hash_fn(
_jwe_hash_str(ciphertext, iv, adata, version),
encryption_key[:-mod.digest_size // 2], mod=mod
)
else:
plaintext = decipher(ciphertext, encryption_key[:-mod.digest_size], iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version),
encryption_key[-mod.digest_size:], mod=mod)
hash = hash_fn(
_jwe_hash_str(ciphertext, iv, adata, version),
encryption_key[-mod.digest_size:], mod=mod
)

if not const_compare(auth_tag(hash), tag):
raise Error('Mismatched authentication tags')
Expand Down Expand Up @@ -256,12 +250,12 @@ def sign(claims, jwk, add_header=None, alg='HS256'):
:rtype: :class:`~jose.JWS`
"""
(hash_fn, _), mod = JWA[alg]

header = dict((add_header or {}).items() + [('alg', alg)])
header = dict(add_header or {}, alg=alg)
header, payload = map(b64encode_url, map(json_encode, (header, claims)))

sig = b64encode_url(hash_fn(_jws_hash_str(header, payload), jwk['k'],
mod=mod))
sig = b64encode_url(
hash_fn(_jws_hash_str(header, payload), jwk['k'], mod=mod)
)

return JWS(header, payload, sig)

Expand Down Expand Up @@ -291,8 +285,9 @@ def verify(jws, jwk, alg, validate_claims=True, expiry_seconds=None):

(_, verify_fn), mod = JWA[header['alg']]

if not verify_fn(_jws_hash_str(jws.header, jws.payload),
jwk['k'], sig, mod=mod):
if not verify_fn(
_jws_hash_str(jws.header, jws.payload), jwk['k'], sig, mod=mod
):
raise Error('Mismatched signatures')

claims = json_decode(b64decode_url(jws.payload))
Expand All @@ -305,27 +300,31 @@ def b64decode_url(istr):
""" JWT Tokens may be truncated without the usual trailing padding '='
symbols. Compensate by padding to the nearest 4 bytes.
"""
istr = encode_safe(istr)
try:
return urlsafe_b64decode(istr + '=' * (4 - (len(istr) % 4)))
except TypeError as e:
return urlsafe_b64decode(istr + six.b('=') * (4 - (len(istr) % 4)))
except (TypeError, binascii.Error) as e:
raise Error('Unable to decode base64: %s' % (e))


def b64encode_url(istr):
""" JWT Tokens may be truncated without the usual trailing padding '='
symbols. Compensate by padding to the nearest 4 bytes.
"""
return urlsafe_b64encode(encode_safe(istr)).rstrip('=')
return urlsafe_b64encode(istr).rstrip(six.b('='))


def encode_safe(istr, encoding='utf8'):
try:
return istr.encode(encoding)
except UnicodeDecodeError:
# this will fail if istr is already encoded
pass
return istr
def json_encode(x):
"""
Dict -> Binary
"""
return json.dumps(x).encode()


def json_decode(x):
"""
Binary -> Dict
"""
return json.loads(x.decode())


def auth_tag(hmac):
Expand All @@ -336,11 +335,19 @@ def auth_tag(hmac):

def pad_pkcs7(s):
sz = AES.block_size - (len(s) % AES.block_size)
return s + (chr(sz) * sz)
return s + (six.int2byte(sz) * sz)


if six.PY2:
def _ord(x):
return ord(x)
else:
def _ord(x):
return x


def unpad_pkcs7(s):
return s[:-ord(s[-1])]
return s[:-_ord(s[-1])]


def encrypt_oaep(plaintext, jwk):
Expand Down Expand Up @@ -397,7 +404,7 @@ def const_compare(stra, strb):

res = 0
for a, b in zip(stra, strb):
res |= ord(a) ^ ord(b)
res |= _ord(a) ^ _ord(b)
return res == 0


Expand Down Expand Up @@ -525,34 +532,36 @@ def _validate(claims, validate_claims, expiry_seconds):
_check_not_before(now, not_before)


def _jwe_hash_str(ciphertext, iv, adata='', version=_TEMP_VER):
def _jwe_hash_str(ciphertext, iv, adata=six.b(''), version=_TEMP_VER):
# http://tools.ietf.org/html/
# draft-ietf-jose-json-web-algorithms-24#section-5.2.2.1
# Both tokens without version and with version 1 should be ignored in
# the future as they use incorrect hashing. The version parameter
# should also be removed.
if not version:
return '.'.join((adata, iv, ciphertext, str(len(adata))))
return six.b('.').join(
(adata, iv, ciphertext, six.b(str(len(adata))))
)
elif version == 1:
return '.'.join((adata, iv, ciphertext, pack("!Q", len(adata) * 8)))
return ''.join((adata, iv, ciphertext, pack("!Q", len(adata) * 8)))
return six.b('.').join(
(adata, iv, ciphertext, pack("!Q", len(adata) * 8))
)
return six.b('').join(
(adata, iv, ciphertext, pack("!Q", len(adata) * 8))
)


def _jws_hash_str(header, claims):
return '.'.join((header, claims))
return six.b('.').join((header, claims))


def cli_decrypt(jwt, key):
print decrypt(deserialize_compact(jwt), {'k':key},
validate_claims=False)
print(decrypt(deserialize_compact(jwt), {'k': key}, validate_claims=False))


def _cli():
import inspect
import sys

from argparse import ArgumentParser
from copy import copy

parser = ArgumentParser()
subparsers = parser.add_subparsers(dest='subparser_name')
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pycrypto >= 2.6
six
Loading