From 074fb78020428a7cdbb89230ccf49ca81558880a Mon Sep 17 00:00:00 2001 From: scicco Date: Tue, 15 Oct 2024 19:20:35 +0200 Subject: [PATCH] adding session manager to clean expired strays --- core/cat/auth/connection.py | 47 +++++++++++++++++++- core/cat/auth/session_manager.py | 76 ++++++++++++++++++++++++++++++++ core/cat/main.py | 3 ++ 3 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 core/cat/auth/session_manager.py diff --git a/core/cat/auth/connection.py b/core/cat/auth/connection.py index 702a15ae..807defc6 100644 --- a/core/cat/auth/connection.py +++ b/core/cat/auth/connection.py @@ -69,7 +69,40 @@ async def get_user_stray(self, user: AuthUserInfo, connection: Request | WebSock @abstractmethod def not_allowed(self, connection: Request | WebSocket): pass - + + def refresh_user_expiration(self, connection: Request, user: AuthUserInfo) -> None: + """ + this method uses session manager to refresh user expiration time + """ + + session_manager = connection.app.state.session_manager + # + # TODO: remove this once we have properly tested this + # + # this should be changed into a meaningful amount of minutes it's just for testing + # if user.name contains "test" get the last character to extract minutes + # + # user: test1 will expire after 1 minutes + # user: test2 will expire after 2 minutes + # user: test3 will expire after 3 minutes + # user: test4 will expire after 4 minutes + # and so on + # + if "test" in user.name: + minutes = int(user.name[-1]) + else: + minutes = 60 * 24 + session_manager.add_or_refresh(user, minutes) + pass + + def evict_expired_users(self, connection: Request) -> None: + """ + this method uses session manager to evict expired users + """ + + session_manager = connection.app.state.session_manager + session_manager.evict_expired_sessions() + pass class HTTPAuth(ConnectionAuth): @@ -111,6 +144,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: Request) -> Stray # TODOV2: user_id should be the user.id user_id=user.name, user_data=user, main_loop=event_loop ) + + self.refresh_user_expiration(connection, user) + self.evict_expired_users(connection) + return strays[user.id] def not_allowed(self, connection: Request): @@ -144,6 +181,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str # Set new ws connection stray.reset_connection(connection) + + self.refresh_user_expiration(connection, user) + self.evict_expired_users(connection) + log.info( f"New websocket connection for user '{user.id}', the old one has been closed." ) @@ -157,6 +198,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str main_loop=asyncio.get_running_loop(), ) strays[user.id] = stray + + self.refresh_user_expiration(connection, user) + self.evict_expired_users(connection) + return stray def not_allowed(self, connection: WebSocket): diff --git a/core/cat/auth/session_manager.py b/core/cat/auth/session_manager.py new file mode 100644 index 00000000..02ca8124 --- /dev/null +++ b/core/cat/auth/session_manager.py @@ -0,0 +1,76 @@ +from datetime import datetime, timedelta, timezone +from cat.utils import singleton +from cat.log import log + +from cat.auth.permissions import AuthUserInfo + + +@singleton +class SessionManager: + """ + This class is responsible for strays session management + + - adding new sessions + - expiring sessions + """ + + def __init__(self, strays: any) -> None: + self.strays = strays + self.sessions = {} + + def add_or_refresh(self, user: AuthUserInfo, minutes: int = 60 * 24) -> None: + """ + add new session setting expiration time to now plus x minutes + each session uses the expiration time as key and the value is a dict with user_id and stray + """ + + current_time = datetime.now(timezone.utc) + timedelta(minutes=minutes) + + # if user_id is already in the session we removed it first + for timestamp in self.sessions: + if self.sessions[timestamp]["user_id"] == user.id: + self.sessions.pop(timestamp) + break + + # then we create a new session + self.sessions[current_time] = {"user_id": user.id, "name": user.name} + pass + + def evict_expired_sessions(self) -> int: + """ + this method removes expired sessions + """ + + log.info("active sessions:") + log.info("*" * 20) + + time_limit = datetime.now(timezone.utc) + + for timestamp in self.sessions: + if timestamp >= time_limit: + log.info(f"expiration date: {timestamp} -> {self.sessions[timestamp]}") + log.info("*" * 20) + + # retrieve all expired sessions + keys_to_remove = [ + timestamp for timestamp in self.sessions if timestamp < time_limit + ] + + # for each expired session key + for key in keys_to_remove: + # get the user_id + user_id = self.sessions[key]["user_id"] + if user_id in self.strays.keys(): + log.info(f"deleting expired user: user_id: {user_id} name: {self.strays[user_id].user_id}") + + # remove the users' stray + del self.strays[user_id] + + # then remove the session using the key (timestamp) + self.sessions.pop(key) + + expired_count = len(keys_to_remove) + if expired_count > 0: + log.info(f"{expired_count} sessions expired.") + + return expired_count diff --git a/core/cat/main.py b/core/cat/main.py index e7845b56..1403ae59 100644 --- a/core/cat/main.py +++ b/core/cat/main.py @@ -8,6 +8,7 @@ from fastapi.responses import JSONResponse from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware +from cat.auth.session_manager import SessionManager from cat.log import log from cat.env import get_env, fix_legacy_env_variables @@ -47,6 +48,8 @@ async def lifespan(app: FastAPI): # Dict of pseudo-sessions (key is the user_id) app.state.strays = {} + app.state.session_manager = SessionManager(app.state.strays) + # set a reference to asyncio event loop app.state.event_loop = asyncio.get_running_loop()