1
1
import asyncio
2
2
from abc import ABC , abstractmethod
3
+ from datetime import datetime , timedelta , timezone
3
4
from enum import auto
4
5
from functools import partial
5
6
from random import choice
10
11
import auth0 # type: ignore[import-untyped]
11
12
import jwt
12
13
from auth0 .authentication import Database , GetToken # type: ignore[import-untyped]
13
- from auth0 .management import Auth0 as _Auth0 # type: ignore[import-untyped]
14
+ from auth0 .management import Auth0 # type: ignore[import-untyped]
14
15
from loguru import logger
15
16
16
17
from battleship .server .config import Config
@@ -121,6 +122,9 @@ async def assign_role(self, user_id: str, role: UserRole) -> None:
121
122
122
123
123
124
class Auth0API :
125
+ TOKEN_REFRESH_LEEWAY = timedelta (seconds = 60 )
126
+ TOKEN_WATCH_INTERVAL = timedelta (seconds = 10 )
127
+
124
128
def __init__ (self , domain : str , client_id : str , client_secret : str , realm : str , audience : str ):
125
129
self .domain = domain
126
130
self .client_id = client_id
@@ -138,7 +142,11 @@ def __init__(self, domain: str, client_id: str, client_secret: str, realm: str,
138
142
self .client_id ,
139
143
self .client_secret ,
140
144
)
141
- self ._mgmt : _Auth0 | None = None
145
+
146
+ token , expires_at = self ._fetch_management_token (self .audience )
147
+ self .mgmt = Auth0 (self .domain , token )
148
+ self .mgmt_token_expires_at = expires_at
149
+ self ._mgmt_token_watcher_task = asyncio .create_task (self ._mgmt_token_watcher ())
142
150
143
151
@classmethod
144
152
def from_config (cls , config : Config ) -> "Auth0API" :
@@ -151,10 +159,13 @@ def from_config(cls, config: Config) -> "Auth0API":
151
159
)
152
160
153
161
@property
154
- def mgmt (self ) -> _Auth0 :
155
- if self ._mgmt is None :
156
- self ._mgmt = _Auth0 (self .domain , self ._fetch_management_token (self .audience ))
157
- return self ._mgmt
162
+ def mgmt_token_expires_at (self ) -> datetime :
163
+ return self ._mgmt_token_expires_at
164
+
165
+ @mgmt_token_expires_at .setter
166
+ def mgmt_token_expires_at (self , expires_at : datetime ) -> None :
167
+ logger .info ("Set new Auth0 management token. Expires at {0}." , expires_at )
168
+ self ._mgmt_token_expires_at = expires_at
158
169
159
170
async def add_roles (self , user_id : str , * roles : str ) -> JSONPayload :
160
171
func = partial (self .mgmt .users .add_roles , id = user_id , roles = roles )
@@ -190,9 +201,27 @@ async def refresh_token(self, refresh_token: str) -> JSONPayload:
190
201
data = await asyncio .to_thread (func )
191
202
return cast (JSONPayload , data )
192
203
193
- def _fetch_management_token (self , audience : str ) -> str :
204
+ def _fetch_management_token (self , audience : str ) -> tuple [ str , datetime ] :
194
205
data = self .gettoken .client_credentials (audience )
195
- return cast (str , data ["access_token" ])
206
+ token , expires_in = cast (str , data ["access_token" ]), cast (int , data ["expires_in" ])
207
+ expires_at = datetime .now (timezone .utc ) + timedelta (seconds = expires_in )
208
+ return token , expires_at
209
+
210
+ @logger .catch
211
+ async def _mgmt_token_watcher (self ) -> None :
212
+ watch_interval = self .TOKEN_WATCH_INTERVAL .total_seconds ()
213
+ logger .info ("Run Auth0 management token watcher every {0} seconds." , watch_interval )
214
+
215
+ while True :
216
+ await asyncio .sleep (watch_interval )
217
+
218
+ now = datetime .now (timezone .utc )
219
+
220
+ if now > (self ._mgmt_token_expires_at - self .TOKEN_REFRESH_LEEWAY ):
221
+ logger .info ("Auth0 management token expires soon. Update it now." )
222
+ token , expires_at = self ._fetch_management_token (self .audience )
223
+ self .mgmt = Auth0 (self .domain , token )
224
+ self .mgmt_token_expires_at = expires_at
196
225
197
226
198
227
def _make_random_nickname (postfix_length : int = 7 ) -> str :
0 commit comments