Skip to content

Commit 9ee142f

Browse files
committed
implement progressive call results
1 parent 6bc7d35 commit 9ee142f

File tree

7 files changed

+289
-48
lines changed

7 files changed

+289
-48
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import asyncio
2+
3+
from xconn import run
4+
from xconn.types import Result, Invocation
5+
from xconn.async_client import connect_anonymous
6+
7+
8+
async def invocation_handler(invocation: Invocation) -> Result:
9+
file_size = 100
10+
for i in range(0, file_size + 1, 10):
11+
progress = i * 100 // file_size
12+
try:
13+
await invocation.send_progress([progress], {})
14+
except Exception as err:
15+
return Result(["wamp.error.canceled", str(err)])
16+
await asyncio.sleep(0.5)
17+
18+
return Result(["Download complete!"])
19+
20+
21+
async def main() -> None:
22+
test_procedure_progress_download = "io.xconn.progress.download"
23+
24+
# create and connect a callee client to server
25+
callee = await connect_anonymous("ws://localhost:8080/ws", "realm1")
26+
27+
await callee.register(test_procedure_progress_download, invocation_handler)
28+
print(f"Registered procedure '{test_procedure_progress_download}'")
29+
30+
31+
if __name__ == "__main__":
32+
run(main())
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from xconn import run
2+
from xconn.types import Result
3+
from xconn.async_client import connect_anonymous
4+
5+
6+
async def progress_handler(res: Result) -> None:
7+
progress = res.args[0]
8+
print(f"Download progress: {progress}%")
9+
10+
11+
async def main() -> None:
12+
test_procedure_progress_download = "io.xconn.progress.download"
13+
14+
# create and connect a callee client to server
15+
caller = await connect_anonymous("ws://localhost:8080/ws", "realm1")
16+
17+
await caller.call_progress(test_procedure_progress_download, progress_handler)
18+
19+
20+
if __name__ == "__main__":
21+
run(main())
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import sys
2+
import time
3+
import signal
4+
5+
from xconn.client import connect_anonymous
6+
from xconn.types import Result, Invocation
7+
8+
9+
def invocation_handler(invocation: Invocation) -> Result:
10+
file_size = 100
11+
for i in range(0, file_size + 1, 10):
12+
progress = i * 100 // file_size
13+
try:
14+
invocation.send_progress([progress], {})
15+
except Exception as err:
16+
return Result(["wamp.error.canceled", str(err)])
17+
time.sleep(0.5)
18+
19+
return Result(["Download complete!"])
20+
21+
22+
if __name__ == "__main__":
23+
test_procedure_progress_download = "io.xconn.progress.download"
24+
25+
# create and connect a callee client to server
26+
callee = connect_anonymous("ws://localhost:8080/ws", "realm1")
27+
28+
download_progress_registration = callee.register(test_procedure_progress_download, invocation_handler)
29+
print(f"Registered procedure '{test_procedure_progress_download}'")
30+
31+
def handle_sigint(signum, frame):
32+
print("SIGINT received. Cleaning up...")
33+
34+
# unregister procedure "io.xconn.progress.download"
35+
download_progress_registration.unregister()
36+
37+
# close connection to the server
38+
callee.leave()
39+
40+
sys.exit(0)
41+
42+
43+
# register signal handler
44+
signal.signal(signal.SIGINT, handle_sigint)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from xconn.types import Result
2+
from xconn.client import connect_anonymous
3+
4+
5+
def progress_handler(res: Result) -> None:
6+
progress = res.args[0]
7+
print(f"Download progress: {progress}%")
8+
9+
10+
if __name__ == "__main__":
11+
test_procedure_progress_download = "io.xconn.progress.download"
12+
13+
# create and connect a callee client to server
14+
caller = connect_anonymous("ws://localhost:8080/ws", "realm1")
15+
16+
caller.call_progress(test_procedure_progress_download, progress_handler)
17+
18+
caller.leave()

xconn/async_session.py

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
from dataclasses import dataclass
66
from asyncio import Future, get_event_loop
7-
from typing import Callable, Union, Awaitable, Any
7+
from typing import Callable, Awaitable, Any
88

99
from wampproto import messages, idgen, session
1010

@@ -70,17 +70,15 @@ def __init__(self, base_session: types.IAsyncBaseSession):
7070
# RPC data structures
7171
self._call_requests: dict[int, Future[types.Result]] = {}
7272
self._register_requests: dict[int, RegisterRequest] = {}
73-
self._registrations: dict[
74-
int,
75-
Union[Callable[[types.Invocation], types.Result], Callable[[types.Invocation], Awaitable[types.Result]]],
76-
] = {}
73+
self._registrations: dict[int, Callable[[types.Invocation], Awaitable[types.Result]]] = {}
7774
self._unregister_requests: dict[int, types.UnregisterRequest] = {}
7875

7976
# PubSub data structures
8077
self._publish_requests: dict[int, Future[None]] = {}
8178
self._subscribe_requests: dict[int, SubscribeRequest] = {}
8279
self._subscriptions: dict[int, Callable[[types.Event], Awaitable[None]]] = {}
8380
self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {}
81+
self._progress_handlers: dict[int, Callable[[types.Result], Awaitable[None]]] = {}
8482

8583
self._goodbye_request = Future()
8684

@@ -120,29 +118,68 @@ async def _process_incoming_message(self, msg: messages.Message):
120118
del self._registrations[request.registration_id]
121119
request.future.set_result(None)
122120
elif isinstance(msg, messages.Result):
123-
request = self._call_requests.pop(msg.request_id)
124-
request.set_result(types.Result(msg.args, msg.kwargs, msg.details))
121+
progress = msg.details.get("progress", False)
122+
if progress:
123+
progress_handler = self._progress_handlers.get(msg.request_id, None)
124+
if progress_handler is not None:
125+
try:
126+
await progress_handler(types.Result(msg.args, msg.kwargs, msg.details))
127+
except Exception as e:
128+
# TODO: implement call canceling
129+
print(e)
130+
else:
131+
request = self._call_requests.pop(msg.request_id, None)
132+
if request is not None:
133+
request.set_result(types.Result(msg.args, msg.kwargs, msg.details))
134+
self._progress_handlers.pop(msg.request_id, None)
125135
elif isinstance(msg, messages.Invocation):
126136
try:
127137
endpoint = self._registrations[msg.registration_id]
128-
result = await endpoint(types.Invocation(msg.args, msg.kwargs, msg.details))
129-
130-
if result is None:
131-
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
132-
elif isinstance(result, types.Result):
133-
data = self._session.send_message(
134-
messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details))
135-
)
136-
else:
137-
message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str(
138-
type(result)
139-
)
140-
msg_to_send = messages.Error(
141-
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
142-
)
143-
data = self._session.send_message(msg_to_send)
144-
145-
await self._base_session.send(data)
138+
invocation = types.Invocation(msg.args, msg.kwargs, msg.details)
139+
receive_progress = msg.details.get("receive_progress", False)
140+
if receive_progress:
141+
142+
async def _progress_func(args: list[Any] | None, kwargs: dict[str, Any] | None):
143+
yield_msg = messages.Yield(
144+
messages.YieldFields(msg.request_id, args, kwargs, {"progress": True})
145+
)
146+
data = self._session.send_message(yield_msg)
147+
await self._base_session.send(data)
148+
149+
invocation.send_progress = _progress_func
150+
151+
async def handle_endpoint_invocation():
152+
try:
153+
result = await endpoint(invocation)
154+
if result is None:
155+
data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
156+
elif isinstance(result, types.Result):
157+
data = self._session.send_message(
158+
messages.Yield(
159+
messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details)
160+
)
161+
)
162+
else:
163+
message = (
164+
"Endpoint returned invalid result type. Expected types.Result or None, got: "
165+
+ str(type(result))
166+
)
167+
msg_to_send = messages.Error(
168+
messages.ErrorFields(
169+
msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message]
170+
)
171+
)
172+
data = self._session.send_message(msg_to_send)
173+
except Exception as e:
174+
message = f"unexpected error calling endpoint {endpoint.__name__}, error is: {e}"
175+
msg_to_send = messages.Error(
176+
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
177+
)
178+
data = self._session.send_message(msg_to_send)
179+
await self._base_session.send(data)
180+
181+
current_loop = get_event_loop()
182+
current_loop.create_task(handle_endpoint_invocation())
146183
except ApplicationError as e:
147184
msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args))
148185
data = self._session.send_message(msg_to_send)
@@ -217,6 +254,15 @@ async def register(
217254

218255
return await f
219256

257+
async def _call(self, call_msg: messages.Call) -> types.Result:
258+
f = Future()
259+
self._call_requests[call_msg.request_id] = f
260+
261+
data = self._session.send_message(call_msg)
262+
await self._base_session.send(data)
263+
264+
return await f
265+
220266
async def call(
221267
self,
222268
procedure: str,
@@ -234,6 +280,23 @@ async def call(
234280

235281
return await f
236282

283+
async def call_progress(
284+
self,
285+
procedure: str,
286+
progress_handler: Callable[[types.Result], Awaitable[None]],
287+
args: list[Any] | None = None,
288+
kwargs: dict[str, Any] | None = None,
289+
options: dict[str, Any] | None = None,
290+
) -> types.Result:
291+
if options is None:
292+
options = {}
293+
294+
options["receive_progress"] = True
295+
call_msg = messages.Call(messages.CallFields(self._idgen.next(), procedure, args, kwargs, options))
296+
self._progress_handlers[call_msg.request_id] = progress_handler
297+
298+
return await self._call(call_msg)
299+
237300
async def subscribe(
238301
self, topic: str, event_handler: Callable[[types.Event], Awaitable[None]], options: dict | None = None
239302
) -> Subscription:

0 commit comments

Comments
 (0)