Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions examples/asyncio/rpc_progressive_call_results/callee.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import asyncio

from xconn import run
from xconn.types import Result, Invocation
from xconn.async_client import connect_anonymous


async def invocation_handler(invocation: Invocation) -> Result:
file_size = 100
for i in range(0, file_size + 1, 10):
progress = i * 100 // file_size
try:
await invocation.send_progress([progress], {})
except Exception as err:
return Result(["wamp.error.canceled", str(err)])
await asyncio.sleep(0.5)

return Result(["Download complete!"])


async def main() -> None:
test_procedure_progress_download = "io.xconn.progress.download"

# create and connect a callee client to server
callee = await connect_anonymous("ws://localhost:8080/ws", "realm1")

await callee.register(test_procedure_progress_download, invocation_handler)
print(f"Registered procedure '{test_procedure_progress_download}'")


if __name__ == "__main__":
run(main())
21 changes: 21 additions & 0 deletions examples/asyncio/rpc_progressive_call_results/caller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from xconn import run
from xconn.types import Result
from xconn.async_client import connect_anonymous


async def progress_handler(res: Result) -> None:
progress = res.args[0]
print(f"Download progress: {progress}%")


async def main() -> None:
test_procedure_progress_download = "io.xconn.progress.download"

# create and connect a callee client to server
caller = await connect_anonymous("ws://localhost:8080/ws", "realm1")

await caller.call_progress(test_procedure_progress_download, progress_handler)


if __name__ == "__main__":
run(main())
44 changes: 44 additions & 0 deletions examples/sync/rpc_progressive_call_results/callee.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import sys
import time
import signal

from xconn.client import connect_anonymous
from xconn.types import Result, Invocation


def invocation_handler(invocation: Invocation) -> Result:
file_size = 100
for i in range(0, file_size + 1, 10):
progress = i * 100 // file_size
try:
invocation.send_progress([progress], {})
except Exception as err:
return Result(["wamp.error.canceled", str(err)])
time.sleep(0.5)

return Result(["Download complete!"])


if __name__ == "__main__":
test_procedure_progress_download = "io.xconn.progress.download"

# create and connect a callee client to server
callee = connect_anonymous("ws://localhost:8080/ws", "realm1")

download_progress_registration = callee.register(test_procedure_progress_download, invocation_handler)
print(f"Registered procedure '{test_procedure_progress_download}'")

def handle_sigint(signum, frame):
print("SIGINT received. Cleaning up...")

# unregister procedure "io.xconn.progress.download"
download_progress_registration.unregister()

# close connection to the server
callee.leave()

sys.exit(0)


# register signal handler
signal.signal(signal.SIGINT, handle_sigint)
18 changes: 18 additions & 0 deletions examples/sync/rpc_progressive_call_results/caller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from xconn.types import Result
from xconn.client import connect_anonymous


def progress_handler(res: Result) -> None:
progress = res.args[0]
print(f"Download progress: {progress}%")


if __name__ == "__main__":
test_procedure_progress_download = "io.xconn.progress.download"

# create and connect a callee client to server
caller = connect_anonymous("ws://localhost:8080/ws", "realm1")

caller.call_progress(test_procedure_progress_download, progress_handler)

caller.leave()
113 changes: 88 additions & 25 deletions xconn/async_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
from dataclasses import dataclass
from asyncio import Future, get_event_loop
from typing import Callable, Union, Awaitable, Any
from typing import Callable, Awaitable, Any

from wampproto import messages, idgen, session

Expand Down Expand Up @@ -70,17 +70,15 @@ def __init__(self, base_session: types.IAsyncBaseSession):
# RPC data structures
self._call_requests: dict[int, Future[types.Result]] = {}
self._register_requests: dict[int, RegisterRequest] = {}
self._registrations: dict[
int,
Union[Callable[[types.Invocation], types.Result], Callable[[types.Invocation], Awaitable[types.Result]]],
] = {}
self._registrations: dict[int, Callable[[types.Invocation], Awaitable[types.Result]]] = {}
self._unregister_requests: dict[int, types.UnregisterRequest] = {}

# PubSub data structures
self._publish_requests: dict[int, Future[None]] = {}
self._subscribe_requests: dict[int, SubscribeRequest] = {}
self._subscriptions: dict[int, Callable[[types.Event], Awaitable[None]]] = {}
self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {}
self._progress_handlers: dict[int, Callable[[types.Result], Awaitable[None]]] = {}

self._goodbye_request = Future()

Expand Down Expand Up @@ -120,29 +118,68 @@ async def _process_incoming_message(self, msg: messages.Message):
del self._registrations[request.registration_id]
request.future.set_result(None)
elif isinstance(msg, messages.Result):
request = self._call_requests.pop(msg.request_id)
request.set_result(types.Result(msg.args, msg.kwargs, msg.details))
progress = msg.details.get("progress", False)
if progress:
progress_handler = self._progress_handlers.get(msg.request_id, None)
if progress_handler is not None:
try:
await progress_handler(types.Result(msg.args, msg.kwargs, msg.details))
except Exception as e:
# TODO: implement call canceling
print(e)
else:
request = self._call_requests.pop(msg.request_id, None)
if request is not None:
request.set_result(types.Result(msg.args, msg.kwargs, msg.details))
self._progress_handlers.pop(msg.request_id, None)
elif isinstance(msg, messages.Invocation):
try:
endpoint = self._registrations[msg.registration_id]
result = await endpoint(types.Invocation(msg.args, msg.kwargs, msg.details))

if result is None:
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
elif isinstance(result, types.Result):
data = self._session.send_message(
messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details))
)
else:
message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str(
type(result)
)
msg_to_send = messages.Error(
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
)
data = self._session.send_message(msg_to_send)

await self._base_session.send(data)
invocation = types.Invocation(msg.args, msg.kwargs, msg.details)
receive_progress = msg.details.get("receive_progress", False)
if receive_progress:

async def _progress_func(args: list[Any] | None, kwargs: dict[str, Any] | None):
yield_msg = messages.Yield(
messages.YieldFields(msg.request_id, args, kwargs, {"progress": True})
)
data = self._session.send_message(yield_msg)
await self._base_session.send(data)

invocation.send_progress = _progress_func

async def handle_endpoint_invocation():
try:
result = await endpoint(invocation)
if result is None:
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
elif isinstance(result, types.Result):
data = self._session.send_message(
messages.Yield(
messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details)
)
)
else:
message = (
"Endpoint returned invalid result type. Expected types.Result or None, got: "
+ str(type(result))
)
msg_to_send = messages.Error(
messages.ErrorFields(
msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message]
)
)
data = self._session.send_message(msg_to_send)
except Exception as e:
message = f"unexpected error calling endpoint {endpoint.__name__}, error is: {e}"
msg_to_send = messages.Error(
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
)
data = self._session.send_message(msg_to_send)
await self._base_session.send(data)

current_loop = get_event_loop()
current_loop.create_task(handle_endpoint_invocation())
except ApplicationError as e:
msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args))
data = self._session.send_message(msg_to_send)
Expand Down Expand Up @@ -217,6 +254,15 @@ async def register(

return await f

async def _call(self, call_msg: messages.Call) -> types.Result:
f = Future()
self._call_requests[call_msg.request_id] = f

data = self._session.send_message(call_msg)
await self._base_session.send(data)

return await f

async def call(
self,
procedure: str,
Expand All @@ -234,6 +280,23 @@ async def call(

return await f

async def call_progress(
self,
procedure: str,
progress_handler: Callable[[types.Result], Awaitable[None]],
args: list[Any] | None = None,
kwargs: dict[str, Any] | None = None,
options: dict[str, Any] | None = None,
) -> types.Result:
if options is None:
options = {}

options["receive_progress"] = True
call_msg = messages.Call(messages.CallFields(self._idgen.next(), procedure, args, kwargs, options))
self._progress_handlers[call_msg.request_id] = progress_handler

return await self._call(call_msg)

async def subscribe(
self, topic: str, event_handler: Callable[[types.Event], Awaitable[None]], options: dict | None = None
) -> Subscription:
Expand Down
Loading