Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions jwt/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import (
AbstractSet,
Tuple,
Callable,
)

from .exceptions import (
Expand Down Expand Up @@ -49,7 +50,8 @@ def _retrieve_alg(self, alg: str) -> AbstractSigningAlgorithm:
raise JWSDecodeError('Unsupported signing algorithm.')

def encode(self, message: bytes, key: AbstractJWKBase = None, alg='HS256',
optional_headers: dict = None) -> str:
optional_headers: dict = None,
dumps: Callable = json.dumps) -> str:
if alg not in self._supported_algs: # pragma: no cover
raise JWSEncodeError('unsupported algorithm: {}'.format(alg))
alg_impl = self._retrieve_alg(alg)
Expand All @@ -58,7 +60,7 @@ def encode(self, message: bytes, key: AbstractJWKBase = None, alg='HS256',
header['alg'] = alg

header_b64 = b64encode(
json.dumps(header, separators=(',', ':')).encode('ascii'))
dumps(header, separators=(',', ':')).encode('ascii'))
message_b64 = b64encode(message)
signing_message = header_b64 + '.' + message_b64

Expand All @@ -67,25 +69,27 @@ def encode(self, message: bytes, key: AbstractJWKBase = None, alg='HS256',

return signing_message + '.' + signature_b64

def _decode_segments(self, message: str) -> Tuple[dict, bytes, bytes, str]:
def _decode_segments(self, message: str,
loads: Callable) -> Tuple[dict, bytes, bytes, str]:
try:
signing_message, signature_b64 = message.rsplit('.', 1)
header_b64, message_b64 = signing_message.split('.')
except ValueError:
raise JWSDecodeError('malformed JWS payload')

header = json.loads(b64decode(header_b64).decode('ascii'))
header = loads(b64decode(header_b64).decode('ascii'))
message_bin = b64decode(message_b64)
signature = b64decode(signature_b64)
return header, message_bin, signature, signing_message

def decode(self, message: str, key: AbstractJWKBase = None,
do_verify=True, algorithms: AbstractSet[str]=None) -> bytes:
do_verify=True, algorithms: AbstractSet[str]=None,
loads: Callable = json.loads) -> bytes:
if algorithms is None:
algorithms = set(supported_signing_algorithms().keys())

header, message_bin, signature, signing_message = \
self._decode_segments(message)
self._decode_segments(message, loads=loads)

alg_value = header['alg']
if alg_value not in algorithms:
Expand Down
18 changes: 11 additions & 7 deletions jwt/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import json
from typing import AbstractSet
from typing import AbstractSet, Callable

from .exceptions import (
JWSEncodeError,
Expand All @@ -33,27 +33,31 @@ def __init__(self):
self._jws = JWS()

def encode(self, payload: dict, key: AbstractJWKBase = None, alg='HS256',
optional_headers: dict = None) -> str:
optional_headers: dict = None,
dumps: Callable = json.dumps) -> str:
try:
message = json.dumps(payload).encode('utf-8')
message = dumps(payload).encode('utf-8')
except ValueError as why:
raise JWTEncodeError('payload must be able to encode in JSON')

optional_headers = optional_headers and optional_headers.copy() or {}
optional_headers['typ'] = 'JWT'
try:
return self._jws.encode(message, key, alg, optional_headers)
return self._jws.encode(message, key, alg, optional_headers,
dumps=dumps)
except JWSEncodeError as why:
raise JWTEncodeError('failed to encode to JWT') from why

def decode(self, message: str, key: AbstractJWKBase = None,
do_verify=True, algorithms: AbstractSet[str]=None) -> dict:
do_verify=True, algorithms: AbstractSet[str]=None,
loads: Callable = json.loads) -> dict:
try:
message_bin = self._jws.decode(message, key, do_verify, algorithms)
message_bin = self._jws.decode(message, key, do_verify, algorithms,
loads=loads)
except JWSDecodeError as why:
raise JWTDecodeError('failed to decode JWT') from why
try:
payload = json.loads(message_bin.decode('utf-8'))
payload = loads(message_bin.decode('utf-8'))
return payload
except ValueError as why:
raise JWTDecodeError(
Expand Down