From 69016320b8c19f79cffbaefa85e9bd5a75448169 Mon Sep 17 00:00:00 2001 From: scicco Date: Tue, 15 Oct 2024 19:20:35 +0200 Subject: [PATCH] adding session manager to clean expired strays --- core/cat/auth/connection.py | 39 ++++++++++++++++- core/cat/auth/session_manager.py | 74 ++++++++++++++++++++++++++++++++ core/cat/main.py | 3 ++ 3 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 core/cat/auth/session_manager.py diff --git a/core/cat/auth/connection.py b/core/cat/auth/connection.py index 702a15ae..8d21ada4 100644 --- a/core/cat/auth/connection.py +++ b/core/cat/auth/connection.py @@ -66,6 +66,31 @@ async def extract_credentials(self, connection: Request | WebSocket) -> Tuple[st async def get_user_stray(self, user: AuthUserInfo, connection: Request | WebSocket) -> StrayCat: pass + # this method uses session manager to refresh user expiration time + async def refresh_user_expiration(self, connection: Request, user: AuthUserInfo, strays: any) -> None: + session_manager = connection.app.state.session_manager + #this should be changed into a meaningful amount of minutes it's just for testing + # if user.name contains "test" get the last character to extract minutes + # + # user: test1 will expire after 1 minutes + # user: test2 will expire after 2 minutes + # user: test3 will expire after 3 minutes + # user: test4 will expire after 4 minutes + # and so on + # + if "test" in user.name: + minutes = int(user.name[-1]) + else: + minutes = 60 * 24 + session_manager.add(user.id, strays[user.id], minutes) + pass + + # this method uses session manager to evict expired users + async def evict_expired_users(self, connection: Request) -> None: + session_manager = connection.app.state.session_manager + await session_manager.evict_expired_sessions() + pass + @abstractmethod def not_allowed(self, connection: Request | WebSocket): pass @@ -111,6 +136,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: Request) -> Stray # TODOV2: user_id should be the user.id user_id=user.name, user_data=user, main_loop=event_loop ) + + await self.refresh_user_expiration(connection, user, strays) + await self.evict_expired_users(connection) + return strays[user.id] def not_allowed(self, connection: Request): @@ -144,6 +173,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str # Set new ws connection stray.reset_connection(connection) + + await self.refresh_user_expiration(connection, user, strays) + await self.evict_expired_users(connection) + log.info( f"New websocket connection for user '{user.id}', the old one has been closed." ) @@ -157,8 +190,12 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str main_loop=asyncio.get_running_loop(), ) strays[user.id] = stray + + await self.refresh_user_expiration(connection, user, strays) + await self.evict_expired_users(connection, user) + return stray - + def not_allowed(self, connection: WebSocket): raise WebSocketException(code=1004, reason="Invalid Credentials") diff --git a/core/cat/auth/session_manager.py b/core/cat/auth/session_manager.py new file mode 100644 index 00000000..d2423341 --- /dev/null +++ b/core/cat/auth/session_manager.py @@ -0,0 +1,74 @@ +from datetime import datetime, timedelta +from cat.utils import singleton +from cat.log import log + +@singleton +class SessionManager: + """ + This class is responsible for strays session management + + - adding new sessions + - getting sessions + - expiring sessions + """ + + def __init__(self, strays: any) -> None: + self.strays = strays + self.sessions = {} + self.user_ids = [] + + def add(self, user_id: str, stray: any, minutes: int = 60 * 24) -> None: + """ + add new session setting expiration time to now plus x minutes + each session uses the expiration time as key and the value is a dict with user_id and stray + """ + + current_time = datetime.now() + timedelta(minutes=minutes) + + # if user_id is already in the session we removed it first + if user_id in self.user_ids: + timestamp_to_remove = [ + timestamp for timestamp in self.sessions if self.sessions[timestamp]["user_id"] == user_id + ] + for timestamp in timestamp_to_remove: + self.sessions.pop(timestamp) + + # then we create a new session + self.sessions[current_time] = {"user_id": user_id, "stray": stray} + + # and add the user_id to the list of user_ids + self.user_ids.append(user_id) + pass + + async def evict_expired_sessions(self) -> int: + """ + this method removes expired sessions + """ + + print("active sessions") + print("*" * 20) + for timestamp in self.sessions: + if timestamp >= datetime.now(): + print(f"expiration date: {timestamp} -> {self.sessions[timestamp]}") + print("*" * 20) + + keys_to_remove = [ + timestamp for timestamp in self.sessions if timestamp < datetime.now() + ] + for key in keys_to_remove: + user_id = self.sessions[key]["user_id"] + if user_id in self.strays.keys(): + print(f"deleting expired user: {self.strays[user_id]}") + del self.strays[user_id] + if user_id in self.user_ids: + self.user_ids.remove(user_id) + self.sessions.pop(key) + + expired_count = len(keys_to_remove) + if expired_count > 0: + log.info( + f"{expired_count} sessions expired." + ) + + return expired_count + diff --git a/core/cat/main.py b/core/cat/main.py index e7845b56..1403ae59 100644 --- a/core/cat/main.py +++ b/core/cat/main.py @@ -8,6 +8,7 @@ from fastapi.responses import JSONResponse from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware +from cat.auth.session_manager import SessionManager from cat.log import log from cat.env import get_env, fix_legacy_env_variables @@ -47,6 +48,8 @@ async def lifespan(app: FastAPI): # Dict of pseudo-sessions (key is the user_id) app.state.strays = {} + app.state.session_manager = SessionManager(app.state.strays) + # set a reference to asyncio event loop app.state.event_loop = asyncio.get_running_loop()