1
1
from datetime import timedelta
2
2
from urllib .request import urlopen
3
+ from urllib .parse import urljoin
3
4
import json
4
5
import jwt
5
6
@@ -27,6 +28,7 @@ def authenticate_credentials(self, key):
27
28
class JWTAuthentication (BaseAuthentication ):
28
29
REQUIRED_CLAIMS = ["exp" , "nbf" , "aud" , "iss" , "sub" ]
29
30
SUPPORTED_ALGORITHMS = ["RS256" , "RS384" , "RS512" ]
31
+ AUTH_SCHEME = "Bearer"
30
32
31
33
def authenticate (self , request ):
32
34
try :
@@ -37,8 +39,8 @@ def authenticate(self, request):
37
39
return self .get_user (validated_token ), validated_token
38
40
39
41
def get_public_key (self , kid ):
40
- response = urlopen (settings . JWK_ENDPOINT )
41
- jwks = json .loads (response .read ())
42
+ r = urlopen (self . get_jwk_endpoint () )
43
+ jwks = json .loads (r .read ())
42
44
for jwk in jwks .get ("keys" ):
43
45
if jwk ["kid" ] == kid :
44
46
return jwt .algorithms .RSAAlgorithm .from_jwk (json .dumps (jwk ))
@@ -53,24 +55,20 @@ def get_raw_token(self, request):
53
55
scheme , token = auth_header .split ()
54
56
except ValueError as e :
55
57
raise ValueError (f"Failed to parse Authorization header: { e } " )
56
- if scheme != settings . JWT_AUTH_SCHEME :
58
+ if scheme != self . AUTH_SCHEME :
57
59
raise ValueError (f"Invalid Authorization scheme '{ scheme } '" )
58
60
return token
59
61
60
62
def decode_token (self , raw_token ):
61
- header = jwt .get_unverified_header (raw_token )
62
- kid = header .get ("kid" )
63
- if not kid :
64
- raise AuthenticationFailed ("Token must include the 'kid' header" )
65
- public_key = self .get_public_key (kid )
63
+ kid = self .get_kid (raw_token )
66
64
try :
67
65
validated_token = jwt .decode (
68
66
jwt = raw_token ,
69
67
algorithms = self .SUPPORTED_ALGORITHMS ,
70
- key = public_key ,
68
+ key = self . get_public_key ( kid ) ,
71
69
options = {"require" : self .REQUIRED_CLAIMS },
72
70
audience = settings .JWT_AUDIENCE ,
73
- issuer = settings . JWT_ISSUER ,
71
+ issuer = self . get_openid_issuer () ,
74
72
)
75
73
return validated_token
76
74
except jwt .exceptions .PyJWTError as e :
@@ -82,3 +80,23 @@ def get_user(self, token):
82
80
return User .objects .get (username = username )
83
81
except User .DoesNotExist :
84
82
raise AuthenticationFailed (f"No user found for username '{ username } '" )
83
+
84
+ def get_openid_config (self ):
85
+ url = urljoin (settings .OIDC_ENDPOINT , ".well-known/openid-configuration" )
86
+ r = urlopen (url )
87
+ return json .loads (r .read ())
88
+
89
+ def get_jwk_endpoint (self ):
90
+ openid_config = self .get_openid_config ()
91
+ return openid_config ["jwks_uri" ]
92
+
93
+ def get_openid_issuer (self ):
94
+ openid_config = self .get_openid_config ()
95
+ return openid_config ["issuer" ]
96
+
97
+ def get_kid (self , token ):
98
+ header = jwt .get_unverified_header (token )
99
+ kid = header .get ("kid" )
100
+ if not kid :
101
+ raise AuthenticationFailed ("Token must include the 'kid' header" )
102
+ return kid
0 commit comments