From 571a2de0c35fed4e876a819bf99d564ca6966896 Mon Sep 17 00:00:00 2001 From: Mahad Date: Wed, 6 Aug 2025 17:50:01 +0500 Subject: [PATCH] implement progressive call results --- .../rpc_progressive_call_results/callee.py | 32 +++++ .../rpc_progressive_call_results/caller.py | 21 ++++ .../rpc_progressive_call_results/callee.py | 44 +++++++ .../rpc_progressive_call_results/caller.py | 18 +++ xconn/async_session.py | 113 ++++++++++++++---- xconn/session.py | 104 ++++++++++++---- xconn/types.py | 5 +- 7 files changed, 289 insertions(+), 48 deletions(-) create mode 100644 examples/asyncio/rpc_progressive_call_results/callee.py create mode 100644 examples/asyncio/rpc_progressive_call_results/caller.py create mode 100644 examples/sync/rpc_progressive_call_results/callee.py create mode 100644 examples/sync/rpc_progressive_call_results/caller.py diff --git a/examples/asyncio/rpc_progressive_call_results/callee.py b/examples/asyncio/rpc_progressive_call_results/callee.py new file mode 100644 index 0000000..9481e62 --- /dev/null +++ b/examples/asyncio/rpc_progressive_call_results/callee.py @@ -0,0 +1,32 @@ +import asyncio + +from xconn import run +from xconn.types import Result, Invocation +from xconn.async_client import connect_anonymous + + +async def invocation_handler(invocation: Invocation) -> Result: + file_size = 100 + for i in range(0, file_size + 1, 10): + progress = i * 100 // file_size + try: + await invocation.send_progress([progress], {}) + except Exception as err: + return Result(["wamp.error.canceled", str(err)]) + await asyncio.sleep(0.5) + + return Result(["Download complete!"]) + + +async def main() -> None: + test_procedure_progress_download = "io.xconn.progress.download" + + # create and connect a callee client to server + callee = await connect_anonymous("ws://localhost:8080/ws", "realm1") + + await callee.register(test_procedure_progress_download, invocation_handler) + print(f"Registered procedure '{test_procedure_progress_download}'") + + +if __name__ == "__main__": + run(main()) diff --git a/examples/asyncio/rpc_progressive_call_results/caller.py b/examples/asyncio/rpc_progressive_call_results/caller.py new file mode 100644 index 0000000..9808cbe --- /dev/null +++ b/examples/asyncio/rpc_progressive_call_results/caller.py @@ -0,0 +1,21 @@ +from xconn import run +from xconn.types import Result +from xconn.async_client import connect_anonymous + + +async def progress_handler(res: Result) -> None: + progress = res.args[0] + print(f"Download progress: {progress}%") + + +async def main() -> None: + test_procedure_progress_download = "io.xconn.progress.download" + + # create and connect a callee client to server + caller = await connect_anonymous("ws://localhost:8080/ws", "realm1") + + await caller.call_progress(test_procedure_progress_download, progress_handler) + + +if __name__ == "__main__": + run(main()) diff --git a/examples/sync/rpc_progressive_call_results/callee.py b/examples/sync/rpc_progressive_call_results/callee.py new file mode 100644 index 0000000..d89cf6a --- /dev/null +++ b/examples/sync/rpc_progressive_call_results/callee.py @@ -0,0 +1,44 @@ +import sys +import time +import signal + +from xconn.client import connect_anonymous +from xconn.types import Result, Invocation + + +def invocation_handler(invocation: Invocation) -> Result: + file_size = 100 + for i in range(0, file_size + 1, 10): + progress = i * 100 // file_size + try: + invocation.send_progress([progress], {}) + except Exception as err: + return Result(["wamp.error.canceled", str(err)]) + time.sleep(0.5) + + return Result(["Download complete!"]) + + +if __name__ == "__main__": + test_procedure_progress_download = "io.xconn.progress.download" + + # create and connect a callee client to server + callee = connect_anonymous("ws://localhost:8080/ws", "realm1") + + download_progress_registration = callee.register(test_procedure_progress_download, invocation_handler) + print(f"Registered procedure '{test_procedure_progress_download}'") + + def handle_sigint(signum, frame): + print("SIGINT received. Cleaning up...") + + # unregister procedure "io.xconn.progress.download" + download_progress_registration.unregister() + + # close connection to the server + callee.leave() + + sys.exit(0) + + +# register signal handler +signal.signal(signal.SIGINT, handle_sigint) diff --git a/examples/sync/rpc_progressive_call_results/caller.py b/examples/sync/rpc_progressive_call_results/caller.py new file mode 100644 index 0000000..3490923 --- /dev/null +++ b/examples/sync/rpc_progressive_call_results/caller.py @@ -0,0 +1,18 @@ +from xconn.types import Result +from xconn.client import connect_anonymous + + +def progress_handler(res: Result) -> None: + progress = res.args[0] + print(f"Download progress: {progress}%") + + +if __name__ == "__main__": + test_procedure_progress_download = "io.xconn.progress.download" + + # create and connect a callee client to server + caller = connect_anonymous("ws://localhost:8080/ws", "realm1") + + caller.call_progress(test_procedure_progress_download, progress_handler) + + caller.leave() diff --git a/xconn/async_session.py b/xconn/async_session.py index 38c793e..27cc468 100644 --- a/xconn/async_session.py +++ b/xconn/async_session.py @@ -4,7 +4,7 @@ import inspect from dataclasses import dataclass from asyncio import Future, get_event_loop -from typing import Callable, Union, Awaitable, Any +from typing import Callable, Awaitable, Any from wampproto import messages, idgen, session @@ -70,10 +70,7 @@ def __init__(self, base_session: types.IAsyncBaseSession): # RPC data structures self._call_requests: dict[int, Future[types.Result]] = {} self._register_requests: dict[int, RegisterRequest] = {} - self._registrations: dict[ - int, - Union[Callable[[types.Invocation], types.Result], Callable[[types.Invocation], Awaitable[types.Result]]], - ] = {} + self._registrations: dict[int, Callable[[types.Invocation], Awaitable[types.Result]]] = {} self._unregister_requests: dict[int, types.UnregisterRequest] = {} # PubSub data structures @@ -81,6 +78,7 @@ def __init__(self, base_session: types.IAsyncBaseSession): self._subscribe_requests: dict[int, SubscribeRequest] = {} self._subscriptions: dict[int, Callable[[types.Event], Awaitable[None]]] = {} self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {} + self._progress_handlers: dict[int, Callable[[types.Result], Awaitable[None]]] = {} self._goodbye_request = Future() @@ -120,29 +118,68 @@ async def _process_incoming_message(self, msg: messages.Message): del self._registrations[request.registration_id] request.future.set_result(None) elif isinstance(msg, messages.Result): - request = self._call_requests.pop(msg.request_id) - request.set_result(types.Result(msg.args, msg.kwargs, msg.details)) + progress = msg.details.get("progress", False) + if progress: + progress_handler = self._progress_handlers.get(msg.request_id, None) + if progress_handler is not None: + try: + await progress_handler(types.Result(msg.args, msg.kwargs, msg.details)) + except Exception as e: + # TODO: implement call canceling + print(e) + else: + request = self._call_requests.pop(msg.request_id, None) + if request is not None: + request.set_result(types.Result(msg.args, msg.kwargs, msg.details)) + self._progress_handlers.pop(msg.request_id, None) elif isinstance(msg, messages.Invocation): try: endpoint = self._registrations[msg.registration_id] - result = await endpoint(types.Invocation(msg.args, msg.kwargs, msg.details)) - - if result is None: - data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id))) - elif isinstance(result, types.Result): - data = self._session.send_message( - messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details)) - ) - else: - message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str( - type(result) - ) - msg_to_send = messages.Error( - messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message]) - ) - data = self._session.send_message(msg_to_send) - - await self._base_session.send(data) + invocation = types.Invocation(msg.args, msg.kwargs, msg.details) + receive_progress = msg.details.get("receive_progress", False) + if receive_progress: + + async def _progress_func(args: list[Any] | None, kwargs: dict[str, Any] | None): + yield_msg = messages.Yield( + messages.YieldFields(msg.request_id, args, kwargs, {"progress": True}) + ) + data = self._session.send_message(yield_msg) + await self._base_session.send(data) + + invocation.send_progress = _progress_func + + async def handle_endpoint_invocation(): + try: + result = await endpoint(invocation) + if result is None: + data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id))) + elif isinstance(result, types.Result): + data = self._session.send_message( + messages.Yield( + messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details) + ) + ) + else: + message = ( + "Endpoint returned invalid result type. Expected types.Result or None, got: " + + str(type(result)) + ) + msg_to_send = messages.Error( + messages.ErrorFields( + msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message] + ) + ) + data = self._session.send_message(msg_to_send) + except Exception as e: + message = f"unexpected error calling endpoint {endpoint.__name__}, error is: {e}" + msg_to_send = messages.Error( + messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message]) + ) + data = self._session.send_message(msg_to_send) + await self._base_session.send(data) + + current_loop = get_event_loop() + current_loop.create_task(handle_endpoint_invocation()) except ApplicationError as e: msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args)) data = self._session.send_message(msg_to_send) @@ -217,6 +254,15 @@ async def register( return await f + async def _call(self, call_msg: messages.Call) -> types.Result: + f = Future() + self._call_requests[call_msg.request_id] = f + + data = self._session.send_message(call_msg) + await self._base_session.send(data) + + return await f + async def call( self, procedure: str, @@ -234,6 +280,23 @@ async def call( return await f + async def call_progress( + self, + procedure: str, + progress_handler: Callable[[types.Result], Awaitable[None]], + args: list[Any] | None = None, + kwargs: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, + ) -> types.Result: + if options is None: + options = {} + + options["receive_progress"] = True + call_msg = messages.Call(messages.CallFields(self._idgen.next(), procedure, args, kwargs, options)) + self._progress_handlers[call_msg.request_id] = progress_handler + + return await self._call(call_msg) + async def subscribe( self, topic: str, event_handler: Callable[[types.Event], Awaitable[None]], options: dict | None = None ) -> Subscription: diff --git a/xconn/session.py b/xconn/session.py index 17f358c..dd12bb3 100644 --- a/xconn/session.py +++ b/xconn/session.py @@ -1,5 +1,6 @@ from __future__ import annotations +import threading from concurrent.futures import Future from threading import Thread from typing import Callable, Any @@ -77,6 +78,7 @@ def __init__(self, base_session: types.BaseSession): self._subscribe_requests: dict[int, SubscribeRequest] = {} self._subscriptions: dict[int, Callable[[types.Event], None]] = {} self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {} + self._progress_handlers: dict[int, Callable[[types.Result], None]] = {} self._goodbye_request = Future() @@ -115,29 +117,67 @@ def _process_incoming_message(self, msg: messages.Message): del self._registrations[request.registration_id] request.future.set_result(None) elif isinstance(msg, messages.Result): - request = self._call_requests.pop(msg.request_id) - request.set_result(types.Result(msg.args, msg.kwargs, msg.details)) + progress = msg.details.get("progress", False) + if progress: + progress_handler = self._progress_handlers.get(msg.request_id, None) + if progress_handler is not None: + try: + progress_handler(types.Result(msg.args, msg.kwargs, msg.details)) + except Exception as e: + # TODO: implement call canceling + print(e) + else: + request = self._call_requests.pop(msg.request_id, None) + if request is not None: + request.set_result(types.Result(msg.args, msg.kwargs, msg.details)) + self._progress_handlers.pop(msg.request_id, None) elif isinstance(msg, messages.Invocation): try: endpoint = self._registrations[msg.registration_id] - result = endpoint(types.Invocation(msg.args, msg.kwargs, msg.details)) - - if result is None: - data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id))) - elif isinstance(result, types.Result): - data = self._session.send_message( - messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details)) - ) - else: - message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str( - type(result) - ) - msg_to_send = messages.Error( - messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message]) - ) - data = self._session.send_message(msg_to_send) - - self._base_session.send(data) + invocation = types.Invocation(msg.args, msg.kwargs, msg.details) + receive_progress = msg.details.get("receive_progress", False) + if receive_progress: + + def _progress_func(args: list[Any] | None, kwargs: dict[str, Any] | None): + yield_msg = messages.Yield( + messages.YieldFields(msg.request_id, args, kwargs, {"progress": True}) + ) + data = self._session.send_message(yield_msg) + self._base_session.send(data) + + invocation.send_progress = _progress_func + + def handle_endpoint_invocation(): + try: + result = endpoint(invocation) + if result is None: + data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id))) + elif isinstance(result, types.Result): + data = self._session.send_message( + messages.Yield( + messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details) + ) + ) + else: + message = ( + "Endpoint returned invalid result type. Expected types.Result or None, got: " + + str(type(result)) + ) + msg_to_send = messages.Error( + messages.ErrorFields( + msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message] + ) + ) + data = self._session.send_message(msg_to_send) + except Exception as e: + message = f"unexpected error calling endpoint {endpoint.__name__}, error is: {e}" + msg_to_send = messages.Error( + messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message]) + ) + data = self._session.send_message(msg_to_send) + self._base_session.send(data) + + threading.Thread(target=handle_endpoint_invocation).start() except ApplicationError as e: msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args)) data = self._session.send_message(msg_to_send) @@ -228,14 +268,34 @@ def call( else: call = messages.Call(messages.CallFields(self._idgen.next(), procedure, args, kwargs, options=options)) - data = self._session.send_message(call) + return self._call(call) + def _call(self, call_msg: messages.Call) -> types.Result: f = Future() - self._call_requests[call.request_id] = f + self._call_requests[call_msg.request_id] = f + + data = self._session.send_message(call_msg) self._base_session.send(data) return f.result() + def call_progress( + self, + procedure: str, + progress_handler: Callable[[types.Result], None], + args: list[Any] | None = None, + kwargs: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, + ) -> types.Result: + if options is None: + options = {} + + options["receive_progress"] = True + call_msg = messages.Call(messages.CallFields(self._idgen.next(), procedure, args, kwargs, options)) + self._progress_handlers[call_msg.request_id] = progress_handler + + return self._call(call_msg) + def register( self, procedure: str, diff --git a/xconn/types.py b/xconn/types.py index d2c382b..b5556b9 100644 --- a/xconn/types.py +++ b/xconn/types.py @@ -7,7 +7,7 @@ from collections import deque from dataclasses import dataclass from enum import Enum -from typing import Callable, Awaitable +from typing import Callable, Awaitable, Any, Union from aiohttp import web from wampproto import messages, joiner, serializers @@ -37,6 +37,9 @@ class Invocation: args: list | None kwargs: dict | None details: dict | None + send_progress: Union[ + Callable[[list[Any], dict[str, Any]], None], Callable[[list[Any], dict[str, Any]], Awaitable[None]], None + ] = None @dataclass