diff --git a/backend/api/consumers.py b/backend/api/consumers.py index 8c9bdd86..58697fb0 100644 --- a/backend/api/consumers.py +++ b/backend/api/consumers.py @@ -14,6 +14,7 @@ from django.db.models import Q, Count from django.utils import timezone +from django.core.cache import cache from .models import Conversation, User, Relationship, Match, UserSettings from .util import generate_id, get_safe_profile, get_user_id_from_token @@ -49,6 +50,10 @@ async def connect(self): self.user = await sync_to_async(User.objects.get)(userID=userID) self.user_group_name = f"status_{self.user.userID}" + connection_count_key = f"status_user_connections_{self.user.userID}" + connection_count = cache.get(connection_count_key, 0) + cache.set(connection_count_key, connection_count + 1, timeout=None) + await self.channel_layer.group_add( self.user_group_name, self.channel_name @@ -59,16 +64,26 @@ async def connect(self): logger.info(f"[{self.__class__.__name__}] User {self.user.username} connected") async def disconnect(self, close_code): + if self.heartbeat_task and not self.heartbeat_task.done(): self.heartbeat_task.cancel() if self.user is not None: + connection_count_key = f"status_user_connections_{self.user.userID}" + connection_count = cache.get(connection_count_key, 0) - 1 + await self.channel_layer.group_discard( self.user_group_name, self.channel_name ) - await self.update_user_status(False, None) - await self.notify_friends_connection(self.user) - logger.info(f"[{self.__class__.__name__}] User {self.user.username} disconnected") + + if connection_count > 0: + cache.set(connection_count_key, connection_count, timeout=None) + logger.info(f"[{self.__class__.__name__}] User {self.user.username} disconnected, {connection_count} connections remaining") + else: + cache.delete(connection_count_key) + await self.update_user_status(False, None) + await self.notify_friends_connection(self.user) + logger.info(f"[{self.__class__.__name__}] User {self.user.username} disconnected") async def receive(self, text_data): try: @@ -195,24 +210,38 @@ async def connect(self): self.user = await sync_to_async(User.objects.get)(userID=userID) self.user_group_name = f"chat_{self.user.userID}" - await self.channel_layer.group_add( - self.user_group_name, - self.channel_name - ) + connection_count_key = f"chat_user_connections_{self.user.userID}" + connection_count = cache.get(connection_count_key, 0) + cache.set(connection_count_key, connection_count + 1, timeout=None) - logger.info(f"[{self.__class__.__name__}] User {self.user.username} connected") - await self.ensure_conversations_exist(self.user) - logger.info(f"[{self.__class__.__name__}] User {self.user.username} conversations ensured") + if connection_count <= 0: + logger.info(f"[{self.__class__.__name__}] User {self.user.username} connected") + await self.channel_layer.group_add( + self.user_group_name, + self.channel_name + ) + + await self.ensure_conversations_exist(self.user) + logger.info(f"[{self.__class__.__name__}] User {self.user.username} conversations ensured") + else: + logger.info(f"[{self.__class__.__name__}] User {self.user.username} reconnected") await self.accept() async def disconnect(self, close_code): if self.user: - await self.channel_layer.group_discard( - self.user_group_name, - self.channel_name - ) + connection_count_key = f"chat_user_connections_{self.user.userID}" + connection_count = cache.get(connection_count_key, 0) - 1 - logger.info(f"[{self.__class__.__name__}] User {self.user.username} disconnected") + if connection_count > 0: + cache.set(connection_count_key, connection_count, timeout=None) + logger.info(f"[{self.__class__.__name__}] User {self.user.username} disconnected, {connection_count} connections remaining") + else: + cache.delete(connection_count_key) + await self.channel_layer.group_discard( + self.user_group_name, + self.channel_name + ) + logger.info(f"[{self.__class__.__name__}] User {self.user.username} disconnected") async def receive(self, text_data): try: