diff --git a/social_core/backends/open_id_connect.py b/social_core/backends/open_id_connect.py index e1e0f561..edba10fe 100644 --- a/social_core/backends/open_id_connect.py +++ b/social_core/backends/open_id_connect.py @@ -4,8 +4,13 @@ from calendar import timegm import jwt +from jwt import ( + ExpiredSignatureError, + InvalidAudienceError, + InvalidTokenError, + PyJWTError, +) from jwt.utils import base64url_decode -from jwt import ExpiredSignatureError, InvalidTokenError, PyJWTError, InvalidAudienceError from social_core.backends.oauth import BaseOAuth2 from social_core.exceptions import AuthTokenError @@ -190,7 +195,9 @@ def find_valid_key(self, id_token): rsakey = jwt.PyJWK(key) message, encoded_sig = id_token.rsplit(".", 1) decoded_sig = base64url_decode(encoded_sig.encode("utf-8")) - if rsakey.Algorithm.verify(message.encode("utf-8"), rsakey.key, decoded_sig): + if rsakey.Algorithm.verify( + message.encode("utf-8"), rsakey.key, decoded_sig + ): return key return None @@ -229,7 +236,7 @@ def validate_and_return_id_token(self, id_token, access_token): # pyjwt does not validate OIDC claims # see https://github.com/jpadilla/pyjwt/pull/296 - if claims.get("at_hash") != self.calc_at_hash(access_token, key['alg']): + if claims.get("at_hash") != self.calc_at_hash(access_token, key["alg"]): raise AuthTokenError(self, "Invalid access token") self.validate_claims(claims) @@ -271,4 +278,8 @@ def calc_at_hash(access_token, algorithm): """ alg_obj = jwt.get_algorithm_by_name(algorithm) digest = alg_obj.compute_hash_digest(access_token.encode("utf-8")) - return base64.urlsafe_b64encode(digest[: (len(digest) // 2)]).decode("utf-8").rstrip("=") + return ( + base64.urlsafe_b64encode(digest[: (len(digest) // 2)]) + .decode("utf-8") + .rstrip("=") + ) diff --git a/social_core/tests/backends/test_auth0.py b/social_core/tests/backends/test_auth0.py index 5c358ee3..659a9ef1 100644 --- a/social_core/tests/backends/test_auth0.py +++ b/social_core/tests/backends/test_auth0.py @@ -45,7 +45,7 @@ class Auth0OAuth2Test(OAuth2Test): "picture": "http://example.com/image.png", "sub": "123456", "iss": f"https://{DOMAIN}/", - "aud": "a-key" + "aud": "a-key", }, jwt.PyJWK(JWK_KEY).key, algorithm="RS256", diff --git a/social_core/tests/backends/test_open_id_connect.py b/social_core/tests/backends/test_open_id_connect.py index ff2d2bf8..739c5efe 100644 --- a/social_core/tests/backends/test_open_id_connect.py +++ b/social_core/tests/backends/test_open_id_connect.py @@ -6,8 +6,8 @@ from calendar import timegm from urllib.parse import urlparse -from httpretty import HTTPretty import jwt +from httpretty import HTTPretty from social_core.backends.open_id_connect import OpenIdConnectAuth @@ -155,7 +155,9 @@ def prepare_access_token_body( body["id_token"] = jwt.encode( id_token, - key=jwt.PyJWK(dict(self.key, iat=timegm(issue_datetime.utctimetuple()), nonce=nonce)).key, + key=jwt.PyJWK( + dict(self.key, iat=timegm(issue_datetime.utctimetuple()), nonce=nonce) + ).key, algorithm="RS256", headers=dict(kid=kid) if kid else None, )