Skip to content

Commit b8232d0

Browse files
committed
Make Auth0 management token refresh automatically
1 parent 86c8182 commit b8232d0

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

battleship/server/auth.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
from abc import ABC, abstractmethod
3+
from datetime import datetime, timedelta, timezone
34
from enum import auto
45
from functools import partial
56
from random import choice
@@ -10,7 +11,7 @@
1011
import auth0 # type: ignore[import-untyped]
1112
import jwt
1213
from auth0.authentication import Database, GetToken # type: ignore[import-untyped]
13-
from auth0.management import Auth0 as _Auth0 # type: ignore[import-untyped]
14+
from auth0.management import Auth0 # type: ignore[import-untyped]
1415
from loguru import logger
1516

1617
from battleship.server.config import Config
@@ -121,6 +122,9 @@ async def assign_role(self, user_id: str, role: UserRole) -> None:
121122

122123

123124
class Auth0API:
125+
TOKEN_REFRESH_LEEWAY = timedelta(seconds=60)
126+
TOKEN_WATCH_INTERVAL = timedelta(seconds=10)
127+
124128
def __init__(self, domain: str, client_id: str, client_secret: str, realm: str, audience: str):
125129
self.domain = domain
126130
self.client_id = client_id
@@ -138,7 +142,11 @@ def __init__(self, domain: str, client_id: str, client_secret: str, realm: str,
138142
self.client_id,
139143
self.client_secret,
140144
)
141-
self._mgmt: _Auth0 | None = None
145+
146+
token, expires_at = self._fetch_management_token(self.audience)
147+
self.mgmt = Auth0(self.domain, token)
148+
self.mgmt_token_expires_at = expires_at
149+
self._mgmt_token_watcher_task = asyncio.create_task(self._mgmt_token_watcher())
142150

143151
@classmethod
144152
def from_config(cls, config: Config) -> "Auth0API":
@@ -151,10 +159,13 @@ def from_config(cls, config: Config) -> "Auth0API":
151159
)
152160

153161
@property
154-
def mgmt(self) -> _Auth0:
155-
if self._mgmt is None:
156-
self._mgmt = _Auth0(self.domain, self._fetch_management_token(self.audience))
157-
return self._mgmt
162+
def mgmt_token_expires_at(self) -> datetime:
163+
return self._mgmt_token_expires_at
164+
165+
@mgmt_token_expires_at.setter
166+
def mgmt_token_expires_at(self, expires_at: datetime) -> None:
167+
logger.info("Set new Auth0 management token. Expires at {0}.", expires_at)
168+
self._mgmt_token_expires_at = expires_at
158169

159170
async def add_roles(self, user_id: str, *roles: str) -> JSONPayload:
160171
func = partial(self.mgmt.users.add_roles, id=user_id, roles=roles)
@@ -190,9 +201,27 @@ async def refresh_token(self, refresh_token: str) -> JSONPayload:
190201
data = await asyncio.to_thread(func)
191202
return cast(JSONPayload, data)
192203

193-
def _fetch_management_token(self, audience: str) -> str:
204+
def _fetch_management_token(self, audience: str) -> tuple[str, datetime]:
194205
data = self.gettoken.client_credentials(audience)
195-
return cast(str, data["access_token"])
206+
token, expires_in = cast(str, data["access_token"]), cast(int, data["expires_in"])
207+
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
208+
return token, expires_at
209+
210+
@logger.catch
211+
async def _mgmt_token_watcher(self) -> None:
212+
watch_interval = self.TOKEN_WATCH_INTERVAL.total_seconds()
213+
logger.info("Run Auth0 management token watcher every {0} seconds.", watch_interval)
214+
215+
while True:
216+
await asyncio.sleep(watch_interval)
217+
218+
now = datetime.now(timezone.utc)
219+
220+
if now > (self._mgmt_token_expires_at - self.TOKEN_REFRESH_LEEWAY):
221+
logger.info("Auth0 management token expires soon. Update it now.")
222+
token, expires_at = self._fetch_management_token(self.audience)
223+
self.mgmt = Auth0(self.domain, token)
224+
self.mgmt_token_expires_at = expires_at
196225

197226

198227
def _make_random_nickname(postfix_length: int = 7) -> str:

0 commit comments

Comments
 (0)