Skip to content

Commit

Permalink
adding session manager to clean expired strays
Browse files Browse the repository at this point in the history
  • Loading branch information
scicco committed Oct 15, 2024
1 parent f51a40e commit 6901632
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 1 deletion.
39 changes: 38 additions & 1 deletion core/cat/auth/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,31 @@ async def extract_credentials(self, connection: Request | WebSocket) -> Tuple[st
async def get_user_stray(self, user: AuthUserInfo, connection: Request | WebSocket) -> StrayCat:
pass

# this method uses session manager to refresh user expiration time
async def refresh_user_expiration(self, connection: Request, user: AuthUserInfo, strays: any) -> None:
session_manager = connection.app.state.session_manager
#this should be changed into a meaningful amount of minutes it's just for testing
# if user.name contains "test" get the last character to extract minutes
#
# user: test1 will expire after 1 minutes
# user: test2 will expire after 2 minutes
# user: test3 will expire after 3 minutes
# user: test4 will expire after 4 minutes
# and so on
#
if "test" in user.name:
minutes = int(user.name[-1])
else:
minutes = 60 * 24
session_manager.add(user.id, strays[user.id], minutes)
pass

# this method uses session manager to evict expired users
async def evict_expired_users(self, connection: Request) -> None:
session_manager = connection.app.state.session_manager
await session_manager.evict_expired_sessions()
pass

@abstractmethod
def not_allowed(self, connection: Request | WebSocket):
pass
Expand Down Expand Up @@ -111,6 +136,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: Request) -> Stray
# TODOV2: user_id should be the user.id
user_id=user.name, user_data=user, main_loop=event_loop
)

await self.refresh_user_expiration(connection, user, strays)
await self.evict_expired_users(connection)

return strays[user.id]

def not_allowed(self, connection: Request):
Expand Down Expand Up @@ -144,6 +173,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str

# Set new ws connection
stray.reset_connection(connection)

await self.refresh_user_expiration(connection, user, strays)
await self.evict_expired_users(connection)

log.info(
f"New websocket connection for user '{user.id}', the old one has been closed."
)
Expand All @@ -157,8 +190,12 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str
main_loop=asyncio.get_running_loop(),
)
strays[user.id] = stray

await self.refresh_user_expiration(connection, user, strays)
await self.evict_expired_users(connection, user)

return stray

def not_allowed(self, connection: WebSocket):
raise WebSocketException(code=1004, reason="Invalid Credentials")

Expand Down
74 changes: 74 additions & 0 deletions core/cat/auth/session_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from datetime import datetime, timedelta
from cat.utils import singleton
from cat.log import log

@singleton
class SessionManager:
"""
This class is responsible for strays session management
- adding new sessions
- getting sessions
- expiring sessions
"""

def __init__(self, strays: any) -> None:
self.strays = strays
self.sessions = {}
self.user_ids = []

def add(self, user_id: str, stray: any, minutes: int = 60 * 24) -> None:
"""
add new session setting expiration time to now plus x minutes
each session uses the expiration time as key and the value is a dict with user_id and stray
"""

current_time = datetime.now() + timedelta(minutes=minutes)

# if user_id is already in the session we removed it first
if user_id in self.user_ids:
timestamp_to_remove = [
timestamp for timestamp in self.sessions if self.sessions[timestamp]["user_id"] == user_id
]
for timestamp in timestamp_to_remove:
self.sessions.pop(timestamp)

# then we create a new session
self.sessions[current_time] = {"user_id": user_id, "stray": stray}

# and add the user_id to the list of user_ids
self.user_ids.append(user_id)
pass

async def evict_expired_sessions(self) -> int:
"""
this method removes expired sessions
"""

print("active sessions")
print("*" * 20)
for timestamp in self.sessions:
if timestamp >= datetime.now():
print(f"expiration date: {timestamp} -> {self.sessions[timestamp]}")
print("*" * 20)

keys_to_remove = [
timestamp for timestamp in self.sessions if timestamp < datetime.now()
]
for key in keys_to_remove:
user_id = self.sessions[key]["user_id"]
if user_id in self.strays.keys():
print(f"deleting expired user: {self.strays[user_id]}")
del self.strays[user_id]
if user_id in self.user_ids:
self.user_ids.remove(user_id)
self.sessions.pop(key)

expired_count = len(keys_to_remove)
if expired_count > 0:
log.info(
f"{expired_count} sessions expired."
)

return expired_count

3 changes: 3 additions & 0 deletions core/cat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from cat.auth.session_manager import SessionManager

from cat.log import log
from cat.env import get_env, fix_legacy_env_variables
Expand Down Expand Up @@ -47,6 +48,8 @@ async def lifespan(app: FastAPI):
# Dict of pseudo-sessions (key is the user_id)
app.state.strays = {}

app.state.session_manager = SessionManager(app.state.strays)

# set a reference to asyncio event loop
app.state.event_loop = asyncio.get_running_loop()

Expand Down

0 comments on commit 6901632

Please sign in to comment.