Skip to content

Commit

Permalink
fix: Feature parity for a_decode_token and decode_token (#616)
Browse files Browse the repository at this point in the history
* Consistency for token decoding

* Mark as staticmethod

* Helper function to convert key

* Refactor key handling

* Add tests for validate=False

* Change test name

* Fix failing test

* Remove special case for str

* Some docstring

* docs: missing docstrings

---------

Co-authored-by: Richard Nemeth <ryshoooo@gmail.com>
  • Loading branch information
Krismix1 and ryshoooo authored Nov 17, 2024
1 parent 3b946c3 commit ac07820
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 22 deletions.
59 changes: 38 additions & 21 deletions src/keycloak/keycloak_openid.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class to handle authentication and token manipulation.
"""

import json
from typing import Optional
from typing import Optional, Union

from jwcrypto import jwk, jwt

Expand Down Expand Up @@ -581,6 +581,33 @@ def introspect(self, token, rpt=None, token_type_hint=None):
)
return raise_error_from_response(data_raw, KeycloakPostError)

@staticmethod
def _verify_token(token, key: Union[jwk.JWK, jwk.JWKSet, None], **kwargs):
"""Decode and optionally validate a token.
:param token: The token to verify
:type token: str
:param key: Which key should be used for validation.
If not provided, the validation is not performed and the token is implicitly valid.
:type key: Union[jwk.JWK, jwk.JWKSet, None]
:param kwargs: Additional keyword arguments for jwcrypto's JWT object
:type kwargs: dict
:returns: Decoded token
"""
# keep the function free of IO
# this way it can be used by `decode_token` and `a_decode_token`

if key is not None:
leeway = kwargs.pop("leeway", 60)
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.leeway = leeway
full_jwt.validate(key)
return jwt.json_decode(full_jwt.claims)
else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))

def decode_token(self, token, validate: bool = True, **kwargs):
"""Decode user token.
Expand All @@ -603,26 +630,19 @@ def decode_token(self, token, validate: bool = True, **kwargs):
:returns: Decoded token
:rtype: dict
"""
key = kwargs.pop("key", None)
if validate:
if "key" not in kwargs:
if key is None:
key = (
"-----BEGIN PUBLIC KEY-----\n"
+ self.public_key()
+ "\n-----END PUBLIC KEY-----"
)
key = jwk.JWK.from_pem(key.encode("utf-8"))
kwargs["key"] = key

key = kwargs.pop("key")
leeway = kwargs.pop("leeway", 60)
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.leeway = leeway
full_jwt.validate(key)
return jwt.json_decode(full_jwt.claims)
else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))
key = None

return self._verify_token(token, key, **kwargs)

def load_authorization_config(self, path):
"""Load Keycloak settings (authorization).
Expand Down Expand Up @@ -1273,22 +1293,19 @@ async def a_decode_token(self, token, validate: bool = True, **kwargs):
:returns: Decoded token
:rtype: dict
"""
key = kwargs.pop("key", None)
if validate:
if "key" not in kwargs:
if key is None:
key = (
"-----BEGIN PUBLIC KEY-----\n"
+ await self.a_public_key()
+ "\n-----END PUBLIC KEY-----"
)
key = jwk.JWK.from_pem(key.encode("utf-8"))
kwargs["key"] = key

full_jwt = jwt.JWT(jwt=token, **kwargs)
return jwt.json_decode(full_jwt.claims)
else:
full_jwt = jwt.JWT(jwt=token, **kwargs)
full_jwt.token.objects["valid"] = True
return json.loads(full_jwt.token.payload.decode("utf-8"))
key = None

return self._verify_token(token, key, **kwargs)

async def a_load_authorization_config(self, path):
"""Load Keycloak settings (authorization) asynchronously.
Expand Down
75 changes: 74 additions & 1 deletion tests/test_keycloak_openid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Tuple
from unittest import mock

import jwcrypto.jwk
import jwcrypto.jws
import pytest

from keycloak import KeycloakAdmin, KeycloakOpenID
Expand Down Expand Up @@ -317,6 +319,39 @@ def test_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token


def test_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token with an invalid token.
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
"""
oid, username, password = oid_with_credentials
token = oid.token(username=username, password=password)
access_token = token["access_token"]
decoded_access_token = oid.decode_token(token=access_token)

key = oid.public_key()
key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----"
key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8"))

invalid_access_token = access_token + "a"
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=True)

with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = oid.decode_token(
token=invalid_access_token, validate=True, key=key
)

decoded_invalid_access_token = oid.decode_token(token=invalid_access_token, validate=False)
assert decoded_access_token == decoded_invalid_access_token

decoded_invalid_access_token = oid.decode_token(
token=invalid_access_token, validate=False, key=key
)
assert decoded_access_token == decoded_invalid_access_token


def test_load_authorization_config(oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]):
"""Test load authorization config.
Expand Down Expand Up @@ -765,7 +800,7 @@ async def test_a_introspect(oid_with_credentials: Tuple[KeycloakOpenID, str, str

@pytest.mark.asyncio
async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token.
"""Test decode token asynchronously.
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
Expand All @@ -781,6 +816,44 @@ async def test_a_decode_token(oid_with_credentials: Tuple[KeycloakOpenID, str, s
assert decoded_refresh_token["typ"] == "Refresh", decoded_refresh_token


@pytest.mark.asyncio
async def test_a_decode_token_invalid_token(oid_with_credentials: Tuple[KeycloakOpenID, str, str]):
"""Test decode token asynchronously an invalid token.
:param oid_with_credentials: Keycloak OpenID client with pre-configured user credentials
:type oid_with_credentials: Tuple[KeycloakOpenID, str, str]
"""
oid, username, password = oid_with_credentials
token = await oid.a_token(username=username, password=password)
access_token = token["access_token"]
decoded_access_token = await oid.a_decode_token(token=access_token)

key = await oid.a_public_key()
key = "-----BEGIN PUBLIC KEY-----\n" + key + "\n-----END PUBLIC KEY-----"
key = jwcrypto.jwk.JWK.from_pem(key.encode("utf-8"))

invalid_access_token = access_token + "a"
with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=True
)

with pytest.raises(jwcrypto.jws.InvalidJWSSignature):
decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=True, key=key
)

decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=False
)
assert decoded_access_token == decoded_invalid_access_token

decoded_invalid_access_token = await oid.a_decode_token(
token=invalid_access_token, validate=False, key=key
)
assert decoded_access_token == decoded_invalid_access_token


@pytest.mark.asyncio
async def test_a_load_authorization_config(
oid_with_credentials_authz: Tuple[KeycloakOpenID, str, str]
Expand Down

0 comments on commit ac07820

Please sign in to comment.