Skip to content

Commit

Permalink
Replace jose with pyjwt (#819)
Browse files Browse the repository at this point in the history
* Replace jose with pyjwt

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
sevdog and pre-commit-ci[bot] committed Sep 12, 2023
1 parent 875b7bd commit 013d27d
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 26 deletions.
2 changes: 1 addition & 1 deletion requirements-base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
requests>=2.9.1
oauthlib>=1.0.3
requests-oauthlib>=0.6.1
PyJWT>=2.0.0
PyJWT>=2.7.0
cryptography>=1.4
defusedxml>=0.5.0rc1
python3-openid>=3.0.10
1 change: 0 additions & 1 deletion requirements-openidconnect.txt

This file was deleted.

4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@ def read_tests_requirements(filename):


requirements = read_requirements("requirements-base.txt")
requirements_openidconnect = read_requirements("requirements-openidconnect.txt")
requirements_saml = read_requirements("requirements-saml.txt")
requirements_azuread = read_requirements("requirements-azuread.txt")

tests_requirements = read_tests_requirements("requirements.txt")

requirements_all = requirements_openidconnect + requirements_saml + requirements_azuread
requirements_all = requirements_saml + requirements_azuread

tests_requirements = tests_requirements + requirements_all

Expand All @@ -73,7 +72,6 @@ def read_tests_requirements(filename):
install_requires=requirements,
python_requires=">=3.8",
extras_require={
"openidconnect": [requirements_openidconnect],
"saml": [requirements_saml],
"azuread": [requirements_azuread],
"all": [requirements_all],
Expand Down
31 changes: 27 additions & 4 deletions social_core/backends/auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Auth0 implementation based on:
https://auth0.com/docs/quickstart/webapp/django/01-login
"""
from jose import jwt
import jwt

from .oauth import BaseOAuth2

Expand Down Expand Up @@ -37,9 +37,32 @@ def get_user_details(self, response):
jwks = self.get_json(self.api_path(".well-known/jwks.json"))
issuer = self.api_path()
audience = self.setting("KEY") # CLIENT_ID
payload = jwt.decode(
id_token, jwks, algorithms=["RS256"], audience=audience, issuer=issuer
)
try:
# it could be a set of JWKs
keys = jwt.PyJWKSet.from_dict(jwks).keys
except jwt.PyJWKSetError:
# let any error raise from here
# try to get single JWK
keys = [jwt.PyJWK.from_dict(jwks, "RS256")]

signature_error = None
for key in keys:
try:
payload = jwt.decode(
id_token,
key.key,
algorithms=["RS256"],
audience=audience,
issuer=issuer,
)
except (jwt.InvalidSignatureError, jwt.InvalidAlgorithmError) as ex:
signature_error = ex
else:
break
else:
# raise last esception found during iteration
raise signature_error

fullname, first_name, last_name = self.get_user_names(payload["name"])
return {
"username": payload["nickname"],
Expand Down
50 changes: 40 additions & 10 deletions social_core/backends/open_id_connect.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import base64
import datetime
import json
from calendar import timegm

from jose import jwk, jwt
from jose.jwt import ExpiredSignatureError, JWTClaimsError, JWTError
from jose.utils import base64url_decode
import jwt
from jwt import (
ExpiredSignatureError,
InvalidAudienceError,
InvalidTokenError,
PyJWTError,
)
from jwt.utils import base64url_decode

from social_core.backends.oauth import BaseOAuth2
from social_core.exceptions import AuthTokenError
Expand Down Expand Up @@ -186,10 +192,12 @@ def find_valid_key(self, id_token):
if kid is None or kid == key.get("kid"):
if "alg" not in key:
key["alg"] = self.setting("JWT_ALGORITHMS", self.JWT_ALGORITHMS)[0]
rsakey = jwk.construct(key)
rsakey = jwt.PyJWK(key)
message, encoded_sig = id_token.rsplit(".", 1)
decoded_sig = base64url_decode(encoded_sig.encode("utf-8"))
if rsakey.verify(message.encode("utf-8"), decoded_sig):
if rsakey.Algorithm.verify(
message.encode("utf-8"), rsakey.key, decoded_sig
):
return key
return None

Expand All @@ -205,25 +213,32 @@ def validate_and_return_id_token(self, id_token, access_token):
if not key:
raise AuthTokenError(self, "Signature verification failed")

rsakey = jwk.construct(key)
rsakey = jwt.PyJWK(key)

try:
claims = jwt.decode(
id_token,
rsakey.to_pem().decode("utf-8"),
rsakey.key,
algorithms=self.setting("JWT_ALGORITHMS", self.JWT_ALGORITHMS),
audience=client_id,
issuer=self.id_token_issuer(),
access_token=access_token,
options=self.setting("JWT_DECODE_OPTIONS", self.JWT_DECODE_OPTIONS),
)
except ExpiredSignatureError:
raise AuthTokenError(self, "Signature has expired")
except JWTClaimsError as error:
except InvalidAudienceError:
# compatibility with jose error message
raise AuthTokenError(self, "Token error: Invalid audience")
except InvalidTokenError as error:
raise AuthTokenError(self, str(error))
except JWTError:
except PyJWTError:
raise AuthTokenError(self, "Invalid signature")

# pyjwt does not validate OIDC claims
# see https://github.com/jpadilla/pyjwt/pull/296
if claims.get("at_hash") != self.calc_at_hash(access_token, key["alg"]):
raise AuthTokenError(self, "Invalid access token")

self.validate_claims(claims)

return claims
Expand Down Expand Up @@ -253,3 +268,18 @@ def get_user_details(self, response):
"first_name": response.get("given_name"),
"last_name": response.get("family_name"),
}

@staticmethod
def calc_at_hash(access_token, algorithm):
"""
Calculates "at_hash" claim which is not done by pyjwt.
See https://pyjwt.readthedocs.io/en/stable/usage.html#oidc-login-flow
"""
alg_obj = jwt.get_algorithm_by_name(algorithm)
digest = alg_obj.compute_hash_digest(access_token.encode("utf-8"))
return (
base64.urlsafe_b64encode(digest[: (len(digest) // 2)])
.decode("utf-8")
.rstrip("=")
)
5 changes: 3 additions & 2 deletions social_core/tests/backends/test_auth0.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

import jwt
from httpretty import HTTPretty
from jose import jwt

from .oauth import OAuth2Test

Expand Down Expand Up @@ -45,8 +45,9 @@ class Auth0OAuth2Test(OAuth2Test):
"picture": "http://example.com/image.png",
"sub": "123456",
"iss": f"https://{DOMAIN}/",
"aud": "a-key",
},
JWK_KEY,
jwt.PyJWK(JWK_KEY).key,
algorithm="RS256",
),
}
Expand Down
13 changes: 8 additions & 5 deletions social_core/tests/backends/test_open_id_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from calendar import timegm
from urllib.parse import urlparse

import jwt
from httpretty import HTTPretty
from jose import jwt

from social_core.backends.open_id_connect import OpenIdConnectAuth

Expand Down Expand Up @@ -150,13 +150,16 @@ def prepare_access_token_body(
nonce,
issuer,
)
# calc at_hash
id_token["at_hash"] = OpenIdConnectAuth.calc_at_hash("foobar", "RS256")

body["id_token"] = jwt.encode(
claims=id_token,
key=dict(self.key, iat=timegm(issue_datetime.utctimetuple()), nonce=nonce),
id_token,
key=jwt.PyJWK(
dict(self.key, iat=timegm(issue_datetime.utctimetuple()), nonce=nonce)
).key,
algorithm="RS256",
access_token="foobar",
headers=dict(kid=kid),
headers=dict(kid=kid) if kid else None,
)

if tamper_message:
Expand Down

0 comments on commit 013d27d

Please sign in to comment.