Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 3, 2023
1 parent 04e1104 commit c7d920b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
19 changes: 15 additions & 4 deletions social_core/backends/open_id_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
from calendar import timegm

import jwt
from jwt import (
ExpiredSignatureError,
InvalidAudienceError,
InvalidTokenError,
PyJWTError,
)
from jwt.utils import base64url_decode
from jwt import ExpiredSignatureError, InvalidTokenError, PyJWTError, InvalidAudienceError

from social_core.backends.oauth import BaseOAuth2
from social_core.exceptions import AuthTokenError
Expand Down Expand Up @@ -190,7 +195,9 @@ def find_valid_key(self, id_token):
rsakey = jwt.PyJWK(key)
message, encoded_sig = id_token.rsplit(".", 1)
decoded_sig = base64url_decode(encoded_sig.encode("utf-8"))
if rsakey.Algorithm.verify(message.encode("utf-8"), rsakey.key, decoded_sig):
if rsakey.Algorithm.verify(
message.encode("utf-8"), rsakey.key, decoded_sig
):
return key
return None

Expand Down Expand Up @@ -229,7 +236,7 @@ def validate_and_return_id_token(self, id_token, access_token):

# pyjwt does not validate OIDC claims
# see https://github.com/jpadilla/pyjwt/pull/296
if claims.get("at_hash") != self.calc_at_hash(access_token, key['alg']):
if claims.get("at_hash") != self.calc_at_hash(access_token, key["alg"]):
raise AuthTokenError(self, "Invalid access token")

self.validate_claims(claims)
Expand Down Expand Up @@ -271,4 +278,8 @@ def calc_at_hash(access_token, algorithm):
"""
alg_obj = jwt.get_algorithm_by_name(algorithm)
digest = alg_obj.compute_hash_digest(access_token.encode("utf-8"))
return base64.urlsafe_b64encode(digest[: (len(digest) // 2)]).decode("utf-8").rstrip("=")
return (
base64.urlsafe_b64encode(digest[: (len(digest) // 2)])
.decode("utf-8")
.rstrip("=")
)
2 changes: 1 addition & 1 deletion social_core/tests/backends/test_auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Auth0OAuth2Test(OAuth2Test):
"picture": "http://example.com/image.png",
"sub": "123456",
"iss": f"https://{DOMAIN}/",
"aud": "a-key"
"aud": "a-key",
},
jwt.PyJWK(JWK_KEY).key,
algorithm="RS256",
Expand Down
6 changes: 4 additions & 2 deletions social_core/tests/backends/test_open_id_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from calendar import timegm
from urllib.parse import urlparse

from httpretty import HTTPretty
import jwt
from httpretty import HTTPretty

from social_core.backends.open_id_connect import OpenIdConnectAuth

Expand Down Expand Up @@ -155,7 +155,9 @@ def prepare_access_token_body(

body["id_token"] = jwt.encode(
id_token,
key=jwt.PyJWK(dict(self.key, iat=timegm(issue_datetime.utctimetuple()), nonce=nonce)).key,
key=jwt.PyJWK(
dict(self.key, iat=timegm(issue_datetime.utctimetuple()), nonce=nonce)
).key,
algorithm="RS256",
headers=dict(kid=kid) if kid else None,
)
Expand Down

0 comments on commit c7d920b

Please sign in to comment.