From 4b2ac99cbc17784a4918e2e8433650d307433e3b Mon Sep 17 00:00:00 2001 From: Mahad Date: Fri, 5 Sep 2025 20:01:17 +0500 Subject: [PATCH] fix: allow multiple subscriptions to the same URI in a single session --- xconn/async_session.py | 30 ++++++++++++++++++++++++------ xconn/session.py | 28 ++++++++++++++++++++++------ 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/xconn/async_session.py b/xconn/async_session.py index 909dec9..c5d5ec8 100644 --- a/xconn/async_session.py +++ b/xconn/async_session.py @@ -45,14 +45,24 @@ class SubscribeRequest: class Subscription: - def __init__(self, subscription_id: int, session: AsyncSession): + def __init__( + self, subscription_id: int, session: AsyncSession, event_handler: Callable[[types.Event], Awaitable[None]] + ): self.subscription_id = subscription_id self._session = session + self._event_handler = event_handler async def unsubscribe(self) -> None: if not await self._session._base_session.transport.is_connected(): raise Exception("cannot unsubscribe topic: session not established") + subscriptions = self._session._subscriptions.get(self.subscription_id, None) + if subscriptions is not None: + subscriptions.pop(self, None) + if len(subscriptions) != 0: + self._session._subscriptions[self.subscription_id] = subscriptions + return None + unsubscribe = messages.Unsubscribe( messages.UnsubscribeFields(self._session._idgen.next(), self.subscription_id) ) @@ -79,7 +89,7 @@ def __init__(self, base_session: types.IAsyncBaseSession): # PubSub data structures self._publish_requests: dict[int, Future[None]] = {} self._subscribe_requests: dict[int, SubscribeRequest] = {} - self._subscriptions: dict[int, Callable[[types.Event], Awaitable[None]]] = {} + self._subscriptions: dict[int, dict[Subscription, Subscription]] = {} self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {} self._goodbye_request = Future() @@ -155,8 +165,14 @@ async def _process_incoming_message(self, msg: messages.Message): await self._base_session.send(data) elif isinstance(msg, messages.Subscribed): request = self._subscribe_requests.pop(msg.request_id) - self._subscriptions[msg.subscription_id] = request.endpoint - request.future.set_result(Subscription(msg.subscription_id, self)) + sub = Subscription(msg.subscription_id, self, request.endpoint) + subscriptions = self._subscriptions.get(msg.subscription_id, None) + if subscriptions is None: + self._subscriptions[msg.subscription_id] = {sub: sub} + else: + subscriptions[sub] = sub + + request.future.set_result(sub) elif isinstance(msg, messages.Unsubscribed): request = self._unsubscribe_requests.pop(msg.request_id) del self._subscriptions[request.subscription_id] @@ -165,9 +181,11 @@ async def _process_incoming_message(self, msg: messages.Message): request = self._publish_requests.pop(msg.request_id) request.set_result(None) elif isinstance(msg, messages.Event): - endpoint = self._subscriptions[msg.subscription_id] try: - await endpoint(types.Event(msg.args, msg.kwargs, msg.details)) + subscriptions = self._subscriptions[msg.subscription_id] + event = types.Event(msg.args, msg.kwargs, msg.details) + for subscription in subscriptions.keys(): + await subscription._event_handler(event) except Exception as e: print(e) elif isinstance(msg, messages.Error): diff --git a/xconn/session.py b/xconn/session.py index 375086e..af35e56 100644 --- a/xconn/session.py +++ b/xconn/session.py @@ -44,14 +44,22 @@ class SubscribeRequest: class Subscription: - def __init__(self, subscription_id: int, session: Session): + def __init__(self, subscription_id: int, session: Session, event_handler: Callable[[types.Event], None]): self.subscription_id = subscription_id self._session = session + self._event_handler = event_handler def unsubscribe(self) -> None: if not self._session._base_session.transport.is_connected(): raise Exception("cannot unsubscribe topic: session not established") + subscriptions = self._session._subscriptions.get(self.subscription_id, None) + if subscriptions is not None: + subscriptions.pop(self, None) + if len(subscriptions) != 0: + self._session._subscriptions[self.subscription_id] = subscriptions + return None + unsubscribe = messages.Unsubscribe( messages.UnsubscribeFields(self._session._idgen.next(), self.subscription_id) ) @@ -75,7 +83,7 @@ def __init__(self, base_session: types.BaseSession): # PubSub data structures self._publish_requests: dict[int, Future[None]] = {} self._subscribe_requests: dict[int, SubscribeRequest] = {} - self._subscriptions: dict[int, Callable[[types.Event], None]] = {} + self._subscriptions: dict[int, dict[Subscription, Subscription]] = {} self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {} self._goodbye_request = Future() @@ -150,8 +158,14 @@ def _process_incoming_message(self, msg: messages.Message): self._base_session.send(data) elif isinstance(msg, messages.Subscribed): request = self._subscribe_requests.pop(msg.request_id) - self._subscriptions[msg.subscription_id] = request.endpoint - request.future.set_result(Subscription(msg.subscription_id, self)) + sub = Subscription(msg.subscription_id, self, request.endpoint) + subscriptions = self._subscriptions.get(msg.subscription_id, None) + if subscriptions is None: + self._subscriptions[msg.subscription_id] = {sub: sub} + else: + subscriptions[sub] = sub + + request.future.set_result(sub) elif isinstance(msg, messages.Unsubscribed): request = self._unsubscribe_requests.pop(msg.request_id) del self._subscriptions[request.subscription_id] @@ -161,8 +175,10 @@ def _process_incoming_message(self, msg: messages.Message): request.set_result(None) elif isinstance(msg, messages.Event): try: - endpoint = self._subscriptions[msg.subscription_id] - endpoint(types.Event(msg.args, msg.kwargs, msg.details)) + subscriptions = self._subscriptions[msg.subscription_id] + event = types.Event(msg.args, msg.kwargs, msg.details) + for subscription in subscriptions.keys(): + subscription._event_handler(event) except Exception as e: print(e) elif isinstance(msg, messages.Error):