Skip to content

Commit

Permalink
adding session manager to clean expired strays
Browse files Browse the repository at this point in the history
  • Loading branch information
scicco committed Oct 15, 2024
1 parent f51a40e commit 074fb78
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 1 deletion.
47 changes: 46 additions & 1 deletion core/cat/auth/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,40 @@ async def get_user_stray(self, user: AuthUserInfo, connection: Request | WebSock
@abstractmethod
def not_allowed(self, connection: Request | WebSocket):
pass


def refresh_user_expiration(self, connection: Request, user: AuthUserInfo) -> None:
"""
this method uses session manager to refresh user expiration time
"""

session_manager = connection.app.state.session_manager
#
# TODO: remove this once we have properly tested this
#
# this should be changed into a meaningful amount of minutes it's just for testing
# if user.name contains "test" get the last character to extract minutes
#
# user: test1 will expire after 1 minutes
# user: test2 will expire after 2 minutes
# user: test3 will expire after 3 minutes
# user: test4 will expire after 4 minutes
# and so on
#
if "test" in user.name:
minutes = int(user.name[-1])
else:
minutes = 60 * 24
session_manager.add_or_refresh(user, minutes)
pass

def evict_expired_users(self, connection: Request) -> None:
"""
this method uses session manager to evict expired users
"""

session_manager = connection.app.state.session_manager
session_manager.evict_expired_sessions()
pass

class HTTPAuth(ConnectionAuth):

Expand Down Expand Up @@ -111,6 +144,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: Request) -> Stray
# TODOV2: user_id should be the user.id
user_id=user.name, user_data=user, main_loop=event_loop
)

self.refresh_user_expiration(connection, user)
self.evict_expired_users(connection)

return strays[user.id]

def not_allowed(self, connection: Request):
Expand Down Expand Up @@ -144,6 +181,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str

# Set new ws connection
stray.reset_connection(connection)

self.refresh_user_expiration(connection, user)
self.evict_expired_users(connection)

log.info(
f"New websocket connection for user '{user.id}', the old one has been closed."
)
Expand All @@ -157,6 +198,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str
main_loop=asyncio.get_running_loop(),
)
strays[user.id] = stray

self.refresh_user_expiration(connection, user)
self.evict_expired_users(connection)

return stray

def not_allowed(self, connection: WebSocket):
Expand Down
76 changes: 76 additions & 0 deletions core/cat/auth/session_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from datetime import datetime, timedelta, timezone
from cat.utils import singleton
from cat.log import log

from cat.auth.permissions import AuthUserInfo


@singleton
class SessionManager:
"""
This class is responsible for strays session management
- adding new sessions
- expiring sessions
"""

def __init__(self, strays: any) -> None:
self.strays = strays
self.sessions = {}

def add_or_refresh(self, user: AuthUserInfo, minutes: int = 60 * 24) -> None:
"""
add new session setting expiration time to now plus x minutes
each session uses the expiration time as key and the value is a dict with user_id and stray
"""

current_time = datetime.now(timezone.utc) + timedelta(minutes=minutes)

# if user_id is already in the session we removed it first
for timestamp in self.sessions:
if self.sessions[timestamp]["user_id"] == user.id:
self.sessions.pop(timestamp)
break

# then we create a new session
self.sessions[current_time] = {"user_id": user.id, "name": user.name}
pass

def evict_expired_sessions(self) -> int:
"""
this method removes expired sessions
"""

log.info("active sessions:")
log.info("*" * 20)

time_limit = datetime.now(timezone.utc)

for timestamp in self.sessions:
if timestamp >= time_limit:
log.info(f"expiration date: {timestamp} -> {self.sessions[timestamp]}")
log.info("*" * 20)

# retrieve all expired sessions
keys_to_remove = [
timestamp for timestamp in self.sessions if timestamp < time_limit
]

# for each expired session key
for key in keys_to_remove:
# get the user_id
user_id = self.sessions[key]["user_id"]
if user_id in self.strays.keys():
log.info(f"deleting expired user: user_id: {user_id} name: {self.strays[user_id].user_id}")

# remove the users' stray
del self.strays[user_id]

# then remove the session using the key (timestamp)
self.sessions.pop(key)

expired_count = len(keys_to_remove)
if expired_count > 0:
log.info(f"{expired_count} sessions expired.")

return expired_count
3 changes: 3 additions & 0 deletions core/cat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from cat.auth.session_manager import SessionManager

from cat.log import log
from cat.env import get_env, fix_legacy_env_variables
Expand Down Expand Up @@ -47,6 +48,8 @@ async def lifespan(app: FastAPI):
# Dict of pseudo-sessions (key is the user_id)
app.state.strays = {}

app.state.session_manager = SessionManager(app.state.strays)

# set a reference to asyncio event loop
app.state.event_loop = asyncio.get_running_loop()

Expand Down

0 comments on commit 074fb78

Please sign in to comment.