Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding session manager #943

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion core/cat/auth/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,40 @@ async def get_user_stray(self, user: AuthUserInfo, connection: Request | WebSock
@abstractmethod
def not_allowed(self, connection: Request | WebSocket):
pass


def refresh_user_expiration(self, connection: Request, user: AuthUserInfo) -> None:
"""
this method uses session manager to refresh user expiration time
"""

session_manager = connection.app.state.session_manager
#
# TODO: remove this once we have properly tested this
#
# this should be changed into a meaningful amount of minutes it's just for testing
# if user.name contains "test" get the last character to extract minutes
#
# user: test1 will expire after 1 minutes
# user: test2 will expire after 2 minutes
# user: test3 will expire after 3 minutes
# user: test4 will expire after 4 minutes
# and so on
#
if "test" in user.name:
minutes = int(user.name[-1])
else:
minutes = 60 * 24
session_manager.add_or_refresh(user, minutes)
pass

def evict_expired_users(self, connection: Request) -> None:
"""
this method uses session manager to evict expired users
"""

session_manager = connection.app.state.session_manager
session_manager.evict_expired_sessions()
pass

class HTTPAuth(ConnectionAuth):

Expand Down Expand Up @@ -111,6 +144,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: Request) -> Stray
# TODOV2: user_id should be the user.id
user_id=user.name, user_data=user, main_loop=event_loop
)

self.refresh_user_expiration(connection, user)
self.evict_expired_users(connection)

return strays[user.id]

def not_allowed(self, connection: Request):
Expand Down Expand Up @@ -144,6 +181,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str

# Set new ws connection
stray.reset_connection(connection)

self.refresh_user_expiration(connection, user)
self.evict_expired_users(connection)

log.info(
f"New websocket connection for user '{user.id}', the old one has been closed."
)
Expand All @@ -157,6 +198,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str
main_loop=asyncio.get_running_loop(),
)
strays[user.id] = stray

self.refresh_user_expiration(connection, user)
self.evict_expired_users(connection)

return stray

def not_allowed(self, connection: WebSocket):
Expand Down
78 changes: 78 additions & 0 deletions core/cat/auth/session_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from datetime import datetime, timedelta, timezone
from cat.utils import singleton
from cat.log import log

from cat.auth.permissions import AuthUserInfo


@singleton
class SessionManager:
"""
This class is responsible for strays session management

- adding new sessions
- expiring sessions
"""

def __init__(self, strays: any) -> None:
self.strays = strays
self.sessions = {}

def add_or_refresh(self, user: AuthUserInfo, minutes: int = 60 * 24) -> None:
"""
add new session setting expiration time to now plus x minutes
each session uses the expiration time as key and the value is a dict with user_id and stray
"""

current_time = datetime.now(timezone.utc) + timedelta(minutes=minutes)

# if user_id is already in the session we removed it first
for timestamp in self.sessions:
if self.sessions[timestamp]["user_id"] == user.id:
self.sessions.pop(timestamp)
break

# then we create a new session
self.sessions[current_time] = {"user_id": user.id, "name": user.name}
pass

def evict_expired_sessions(self) -> int:
"""
this method removes expired sessions
"""

log.info("active sessions:")
log.info("*" * 20)

time_limit = datetime.now(timezone.utc)

for timestamp in self.sessions:
if timestamp >= time_limit:
log.info(f"expiration date: {timestamp} -> {self.sessions[timestamp]}")
log.info("*" * 20)

# retrieve all expired sessions
keys_to_remove = [
timestamp for timestamp in self.sessions if timestamp < time_limit
]

# for each expired session key
for key in keys_to_remove:
# get the user_id
user_id = self.sessions[key]["user_id"]
if user_id in self.strays.keys():
log.info(
f"deleting expired user: {self.strays[user_id].user_id} user_id: {user_id}"
)

# remove the users' stray
del self.strays[user_id]

# then remove the session using the key (timestamp)
self.sessions.pop(key)

expired_count = len(keys_to_remove)
if expired_count > 0:
log.info(f"{expired_count} sessions expired.")

return expired_count
3 changes: 3 additions & 0 deletions core/cat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from cat.auth.session_manager import SessionManager

from cat.log import log
from cat.env import get_env, fix_legacy_env_variables
Expand Down Expand Up @@ -47,6 +48,8 @@ async def lifespan(app: FastAPI):
# Dict of pseudo-sessions (key is the user_id)
app.state.strays = {}

app.state.session_manager = SessionManager(app.state.strays)

# set a reference to asyncio event loop
app.state.event_loop = asyncio.get_running_loop()

Expand Down