From 013d27d291ae44063dcbfbbb3f1a96cc87251643 Mon Sep 17 00:00:00 2001 From: Devid <13779643+sevdog@users.noreply.github.com> Date: Tue, 12 Sep 2023 06:42:44 +0100 Subject: [PATCH] Replace jose with pyjwt (#819) * Replace jose with pyjwt * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- requirements-base.txt | 2 +- requirements-openidconnect.txt | 1 - setup.py | 4 +- social_core/backends/auth0.py | 31 ++++++++++-- social_core/backends/open_id_connect.py | 50 +++++++++++++++---- social_core/tests/backends/test_auth0.py | 5 +- .../tests/backends/test_open_id_connect.py | 13 +++-- 7 files changed, 80 insertions(+), 26 deletions(-) delete mode 100644 requirements-openidconnect.txt diff --git a/requirements-base.txt b/requirements-base.txt index e7846c758..9c0e31bfb 100644 --- a/requirements-base.txt +++ b/requirements-base.txt @@ -1,7 +1,7 @@ requests>=2.9.1 oauthlib>=1.0.3 requests-oauthlib>=0.6.1 -PyJWT>=2.0.0 +PyJWT>=2.7.0 cryptography>=1.4 defusedxml>=0.5.0rc1 python3-openid>=3.0.10 diff --git a/requirements-openidconnect.txt b/requirements-openidconnect.txt deleted file mode 100644 index 5843b4443..000000000 --- a/requirements-openidconnect.txt +++ /dev/null @@ -1 +0,0 @@ -python-jose>=3.0.0 diff --git a/setup.py b/setup.py index 9fc6cdf21..e448b94af 100644 --- a/setup.py +++ b/setup.py @@ -40,13 +40,12 @@ def read_tests_requirements(filename): requirements = read_requirements("requirements-base.txt") -requirements_openidconnect = read_requirements("requirements-openidconnect.txt") requirements_saml = read_requirements("requirements-saml.txt") requirements_azuread = read_requirements("requirements-azuread.txt") tests_requirements = read_tests_requirements("requirements.txt") -requirements_all = requirements_openidconnect + requirements_saml + requirements_azuread +requirements_all = requirements_saml + requirements_azuread tests_requirements = tests_requirements + requirements_all @@ -73,7 +72,6 @@ def read_tests_requirements(filename): install_requires=requirements, python_requires=">=3.8", extras_require={ - "openidconnect": [requirements_openidconnect], "saml": [requirements_saml], "azuread": [requirements_azuread], "all": [requirements_all], diff --git a/social_core/backends/auth0.py b/social_core/backends/auth0.py index 14eef639f..23ef453ba 100644 --- a/social_core/backends/auth0.py +++ b/social_core/backends/auth0.py @@ -2,7 +2,7 @@ Auth0 implementation based on: https://auth0.com/docs/quickstart/webapp/django/01-login """ -from jose import jwt +import jwt from .oauth import BaseOAuth2 @@ -37,9 +37,32 @@ def get_user_details(self, response): jwks = self.get_json(self.api_path(".well-known/jwks.json")) issuer = self.api_path() audience = self.setting("KEY") # CLIENT_ID - payload = jwt.decode( - id_token, jwks, algorithms=["RS256"], audience=audience, issuer=issuer - ) + try: + # it could be a set of JWKs + keys = jwt.PyJWKSet.from_dict(jwks).keys + except jwt.PyJWKSetError: + # let any error raise from here + # try to get single JWK + keys = [jwt.PyJWK.from_dict(jwks, "RS256")] + + signature_error = None + for key in keys: + try: + payload = jwt.decode( + id_token, + key.key, + algorithms=["RS256"], + audience=audience, + issuer=issuer, + ) + except (jwt.InvalidSignatureError, jwt.InvalidAlgorithmError) as ex: + signature_error = ex + else: + break + else: + # raise last esception found during iteration + raise signature_error + fullname, first_name, last_name = self.get_user_names(payload["name"]) return { "username": payload["nickname"], diff --git a/social_core/backends/open_id_connect.py b/social_core/backends/open_id_connect.py index 5e3733892..edba10fed 100644 --- a/social_core/backends/open_id_connect.py +++ b/social_core/backends/open_id_connect.py @@ -1,10 +1,16 @@ +import base64 import datetime import json from calendar import timegm -from jose import jwk, jwt -from jose.jwt import ExpiredSignatureError, JWTClaimsError, JWTError -from jose.utils import base64url_decode +import jwt +from jwt import ( + ExpiredSignatureError, + InvalidAudienceError, + InvalidTokenError, + PyJWTError, +) +from jwt.utils import base64url_decode from social_core.backends.oauth import BaseOAuth2 from social_core.exceptions import AuthTokenError @@ -186,10 +192,12 @@ def find_valid_key(self, id_token): if kid is None or kid == key.get("kid"): if "alg" not in key: key["alg"] = self.setting("JWT_ALGORITHMS", self.JWT_ALGORITHMS)[0] - rsakey = jwk.construct(key) + rsakey = jwt.PyJWK(key) message, encoded_sig = id_token.rsplit(".", 1) decoded_sig = base64url_decode(encoded_sig.encode("utf-8")) - if rsakey.verify(message.encode("utf-8"), decoded_sig): + if rsakey.Algorithm.verify( + message.encode("utf-8"), rsakey.key, decoded_sig + ): return key return None @@ -205,25 +213,32 @@ def validate_and_return_id_token(self, id_token, access_token): if not key: raise AuthTokenError(self, "Signature verification failed") - rsakey = jwk.construct(key) + rsakey = jwt.PyJWK(key) try: claims = jwt.decode( id_token, - rsakey.to_pem().decode("utf-8"), + rsakey.key, algorithms=self.setting("JWT_ALGORITHMS", self.JWT_ALGORITHMS), audience=client_id, issuer=self.id_token_issuer(), - access_token=access_token, options=self.setting("JWT_DECODE_OPTIONS", self.JWT_DECODE_OPTIONS), ) except ExpiredSignatureError: raise AuthTokenError(self, "Signature has expired") - except JWTClaimsError as error: + except InvalidAudienceError: + # compatibility with jose error message + raise AuthTokenError(self, "Token error: Invalid audience") + except InvalidTokenError as error: raise AuthTokenError(self, str(error)) - except JWTError: + except PyJWTError: raise AuthTokenError(self, "Invalid signature") + # pyjwt does not validate OIDC claims + # see https://github.com/jpadilla/pyjwt/pull/296 + if claims.get("at_hash") != self.calc_at_hash(access_token, key["alg"]): + raise AuthTokenError(self, "Invalid access token") + self.validate_claims(claims) return claims @@ -253,3 +268,18 @@ def get_user_details(self, response): "first_name": response.get("given_name"), "last_name": response.get("family_name"), } + + @staticmethod + def calc_at_hash(access_token, algorithm): + """ + Calculates "at_hash" claim which is not done by pyjwt. + + See https://pyjwt.readthedocs.io/en/stable/usage.html#oidc-login-flow + """ + alg_obj = jwt.get_algorithm_by_name(algorithm) + digest = alg_obj.compute_hash_digest(access_token.encode("utf-8")) + return ( + base64.urlsafe_b64encode(digest[: (len(digest) // 2)]) + .decode("utf-8") + .rstrip("=") + ) diff --git a/social_core/tests/backends/test_auth0.py b/social_core/tests/backends/test_auth0.py index 4c3725eee..659a9ef1d 100644 --- a/social_core/tests/backends/test_auth0.py +++ b/social_core/tests/backends/test_auth0.py @@ -1,7 +1,7 @@ import json +import jwt from httpretty import HTTPretty -from jose import jwt from .oauth import OAuth2Test @@ -45,8 +45,9 @@ class Auth0OAuth2Test(OAuth2Test): "picture": "http://example.com/image.png", "sub": "123456", "iss": f"https://{DOMAIN}/", + "aud": "a-key", }, - JWK_KEY, + jwt.PyJWK(JWK_KEY).key, algorithm="RS256", ), } diff --git a/social_core/tests/backends/test_open_id_connect.py b/social_core/tests/backends/test_open_id_connect.py index 7b375d8a1..739c5efed 100644 --- a/social_core/tests/backends/test_open_id_connect.py +++ b/social_core/tests/backends/test_open_id_connect.py @@ -6,8 +6,8 @@ from calendar import timegm from urllib.parse import urlparse +import jwt from httpretty import HTTPretty -from jose import jwt from social_core.backends.open_id_connect import OpenIdConnectAuth @@ -150,13 +150,16 @@ def prepare_access_token_body( nonce, issuer, ) + # calc at_hash + id_token["at_hash"] = OpenIdConnectAuth.calc_at_hash("foobar", "RS256") body["id_token"] = jwt.encode( - claims=id_token, - key=dict(self.key, iat=timegm(issue_datetime.utctimetuple()), nonce=nonce), + id_token, + key=jwt.PyJWK( + dict(self.key, iat=timegm(issue_datetime.utctimetuple()), nonce=nonce) + ).key, algorithm="RS256", - access_token="foobar", - headers=dict(kid=kid), + headers=dict(kid=kid) if kid else None, ) if tamper_message: