Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions xconn/async_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,24 @@ class SubscribeRequest:


class Subscription:
def __init__(self, subscription_id: int, session: AsyncSession):
def __init__(
self, subscription_id: int, session: AsyncSession, event_handler: Callable[[types.Event], Awaitable[None]]
):
self.subscription_id = subscription_id
self._session = session
self._event_handler = event_handler

async def unsubscribe(self) -> None:
if not await self._session._base_session.transport.is_connected():
raise Exception("cannot unsubscribe topic: session not established")

subscriptions = self._session._subscriptions.get(self.subscription_id, None)
if subscriptions is not None:
subscriptions.pop(self, None)
if len(subscriptions) != 0:
self._session._subscriptions[self.subscription_id] = subscriptions
return None

unsubscribe = messages.Unsubscribe(
messages.UnsubscribeFields(self._session._idgen.next(), self.subscription_id)
)
Expand All @@ -79,7 +89,7 @@ def __init__(self, base_session: types.IAsyncBaseSession):
# PubSub data structures
self._publish_requests: dict[int, Future[None]] = {}
self._subscribe_requests: dict[int, SubscribeRequest] = {}
self._subscriptions: dict[int, Callable[[types.Event], Awaitable[None]]] = {}
self._subscriptions: dict[int, dict[Subscription, Subscription]] = {}
self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {}

self._goodbye_request = Future()
Expand Down Expand Up @@ -155,8 +165,14 @@ async def _process_incoming_message(self, msg: messages.Message):
await self._base_session.send(data)
elif isinstance(msg, messages.Subscribed):
request = self._subscribe_requests.pop(msg.request_id)
self._subscriptions[msg.subscription_id] = request.endpoint
request.future.set_result(Subscription(msg.subscription_id, self))
sub = Subscription(msg.subscription_id, self, request.endpoint)
subscriptions = self._subscriptions.get(msg.subscription_id, None)
if subscriptions is None:
self._subscriptions[msg.subscription_id] = {sub: sub}
else:
subscriptions[sub] = sub

request.future.set_result(sub)
elif isinstance(msg, messages.Unsubscribed):
request = self._unsubscribe_requests.pop(msg.request_id)
del self._subscriptions[request.subscription_id]
Expand All @@ -165,9 +181,11 @@ async def _process_incoming_message(self, msg: messages.Message):
request = self._publish_requests.pop(msg.request_id)
request.set_result(None)
elif isinstance(msg, messages.Event):
endpoint = self._subscriptions[msg.subscription_id]
try:
await endpoint(types.Event(msg.args, msg.kwargs, msg.details))
subscriptions = self._subscriptions[msg.subscription_id]
event = types.Event(msg.args, msg.kwargs, msg.details)
for subscription in subscriptions.keys():
await subscription._event_handler(event)
except Exception as e:
print(e)
elif isinstance(msg, messages.Error):
Expand Down
28 changes: 22 additions & 6 deletions xconn/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,22 @@ class SubscribeRequest:


class Subscription:
def __init__(self, subscription_id: int, session: Session):
def __init__(self, subscription_id: int, session: Session, event_handler: Callable[[types.Event], None]):
self.subscription_id = subscription_id
self._session = session
self._event_handler = event_handler

def unsubscribe(self) -> None:
if not self._session._base_session.transport.is_connected():
raise Exception("cannot unsubscribe topic: session not established")

subscriptions = self._session._subscriptions.get(self.subscription_id, None)
if subscriptions is not None:
subscriptions.pop(self, None)
if len(subscriptions) != 0:
self._session._subscriptions[self.subscription_id] = subscriptions
return None

unsubscribe = messages.Unsubscribe(
messages.UnsubscribeFields(self._session._idgen.next(), self.subscription_id)
)
Expand All @@ -75,7 +83,7 @@ def __init__(self, base_session: types.BaseSession):
# PubSub data structures
self._publish_requests: dict[int, Future[None]] = {}
self._subscribe_requests: dict[int, SubscribeRequest] = {}
self._subscriptions: dict[int, Callable[[types.Event], None]] = {}
self._subscriptions: dict[int, dict[Subscription, Subscription]] = {}
self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {}

self._goodbye_request = Future()
Expand Down Expand Up @@ -150,8 +158,14 @@ def _process_incoming_message(self, msg: messages.Message):
self._base_session.send(data)
elif isinstance(msg, messages.Subscribed):
request = self._subscribe_requests.pop(msg.request_id)
self._subscriptions[msg.subscription_id] = request.endpoint
request.future.set_result(Subscription(msg.subscription_id, self))
sub = Subscription(msg.subscription_id, self, request.endpoint)
subscriptions = self._subscriptions.get(msg.subscription_id, None)
if subscriptions is None:
self._subscriptions[msg.subscription_id] = {sub: sub}
else:
subscriptions[sub] = sub

request.future.set_result(sub)
elif isinstance(msg, messages.Unsubscribed):
request = self._unsubscribe_requests.pop(msg.request_id)
del self._subscriptions[request.subscription_id]
Expand All @@ -161,8 +175,10 @@ def _process_incoming_message(self, msg: messages.Message):
request.set_result(None)
elif isinstance(msg, messages.Event):
try:
endpoint = self._subscriptions[msg.subscription_id]
endpoint(types.Event(msg.args, msg.kwargs, msg.details))
subscriptions = self._subscriptions[msg.subscription_id]
event = types.Event(msg.args, msg.kwargs, msg.details)
for subscription in subscriptions.keys():
subscription._event_handler(event)
except Exception as e:
print(e)
elif isinstance(msg, messages.Error):
Expand Down