44import inspect
55from dataclasses import dataclass
66from asyncio import Future , get_event_loop
7- from typing import Callable , Union , Awaitable , Any
7+ from typing import Callable , Awaitable , Any
88
99from wampproto import messages , idgen , session
1010
@@ -70,17 +70,15 @@ def __init__(self, base_session: types.IAsyncBaseSession):
7070 # RPC data structures
7171 self ._call_requests : dict [int , Future [types .Result ]] = {}
7272 self ._register_requests : dict [int , RegisterRequest ] = {}
73- self ._registrations : dict [
74- int ,
75- Union [Callable [[types .Invocation ], types .Result ], Callable [[types .Invocation ], Awaitable [types .Result ]]],
76- ] = {}
73+ self ._registrations : dict [int , Callable [[types .Invocation ], Awaitable [types .Result ]]] = {}
7774 self ._unregister_requests : dict [int , types .UnregisterRequest ] = {}
7875
7976 # PubSub data structures
8077 self ._publish_requests : dict [int , Future [None ]] = {}
8178 self ._subscribe_requests : dict [int , SubscribeRequest ] = {}
8279 self ._subscriptions : dict [int , Callable [[types .Event ], Awaitable [None ]]] = {}
8380 self ._unsubscribe_requests : dict [int , types .UnsubscribeRequest ] = {}
81+ self ._progress_handlers : dict [int , Callable [[types .Result ], Awaitable [None ]]] = {}
8482
8583 self ._goodbye_request = Future ()
8684
@@ -120,29 +118,68 @@ async def _process_incoming_message(self, msg: messages.Message):
120118 del self ._registrations [request .registration_id ]
121119 request .future .set_result (None )
122120 elif isinstance (msg , messages .Result ):
123- request = self ._call_requests .pop (msg .request_id )
124- request .set_result (types .Result (msg .args , msg .kwargs , msg .details ))
121+ progress = msg .details .get ("progress" , False )
122+ if progress :
123+ progress_handler = self ._progress_handlers .get (msg .request_id , None )
124+ if progress_handler is not None :
125+ try :
126+ await progress_handler (types .Result (msg .args , msg .kwargs , msg .details ))
127+ except Exception as e :
128+ # TODO: implement call canceling
129+ print (e )
130+ else :
131+ request = self ._call_requests .pop (msg .request_id , None )
132+ if request is not None :
133+ request .set_result (types .Result (msg .args , msg .kwargs , msg .details ))
134+ self ._progress_handlers .pop (msg .request_id , None )
125135 elif isinstance (msg , messages .Invocation ):
126136 try :
127137 endpoint = self ._registrations [msg .registration_id ]
128- result = await endpoint (types .Invocation (msg .args , msg .kwargs , msg .details ))
129-
130- if result is None :
131- data = self ._session .send_message (messages .Yield (messages .YieldFields (msg .request_id )))
132- elif isinstance (result , types .Result ):
133- data = self ._session .send_message (
134- messages .Yield (messages .YieldFields (msg .request_id , result .args , result .kwargs , result .details ))
135- )
136- else :
137- message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str (
138- type (result )
139- )
140- msg_to_send = messages .Error (
141- messages .ErrorFields (msg .TYPE , msg .request_id , xconn_uris .ERROR_INTERNAL_ERROR , [message ])
142- )
143- data = self ._session .send_message (msg_to_send )
144-
145- await self ._base_session .send (data )
138+ invocation = types .Invocation (msg .args , msg .kwargs , msg .details )
139+ receive_progress = msg .details .get ("receive_progress" , False )
140+ if receive_progress :
141+
142+ async def _progress_func (args : list [Any ] | None , kwargs : dict [str , Any ] | None ):
143+ yield_msg = messages .Yield (
144+ messages .YieldFields (msg .request_id , args , kwargs , {"progress" : True })
145+ )
146+ data = self ._session .send_message (yield_msg )
147+ await self ._base_session .send (data )
148+
149+ invocation .send_progress = _progress_func
150+
151+ async def handle_endpoint_invocation ():
152+ try :
153+ result = await endpoint (invocation )
154+ if result is None :
155+ data = self ._session .send_message (messages .Yield (messages .YieldFields (msg .request_id )))
156+ elif isinstance (result , types .Result ):
157+ data = self ._session .send_message (
158+ messages .Yield (
159+ messages .YieldFields (msg .request_id , result .args , result .kwargs , result .details )
160+ )
161+ )
162+ else :
163+ message = (
164+ "Endpoint returned invalid result type. Expected types.Result or None, got: "
165+ + str (type (result ))
166+ )
167+ msg_to_send = messages .Error (
168+ messages .ErrorFields (
169+ msg .TYPE , msg .request_id , xconn_uris .ERROR_INTERNAL_ERROR , [message ]
170+ )
171+ )
172+ data = self ._session .send_message (msg_to_send )
173+ except Exception as e :
174+ message = f"unexpected error calling endpoint { endpoint .__name__ } , error is: { e } "
175+ msg_to_send = messages .Error (
176+ messages .ErrorFields (msg .TYPE , msg .request_id , xconn_uris .ERROR_INTERNAL_ERROR , [message ])
177+ )
178+ data = self ._session .send_message (msg_to_send )
179+ await self ._base_session .send (data )
180+
181+ current_loop = get_event_loop ()
182+ current_loop .create_task (handle_endpoint_invocation ())
146183 except ApplicationError as e :
147184 msg_to_send = messages .Error (messages .ErrorFields (msg .TYPE , msg .request_id , e .message , e .args ))
148185 data = self ._session .send_message (msg_to_send )
@@ -217,6 +254,15 @@ async def register(
217254
218255 return await f
219256
257+ async def _call (self , call_msg : messages .Call ) -> types .Result :
258+ f = Future ()
259+ self ._call_requests [call_msg .request_id ] = f
260+
261+ data = self ._session .send_message (call_msg )
262+ await self ._base_session .send (data )
263+
264+ return await f
265+
220266 async def call (
221267 self ,
222268 procedure : str ,
@@ -234,6 +280,23 @@ async def call(
234280
235281 return await f
236282
283+ async def call_progress (
284+ self ,
285+ procedure : str ,
286+ progress_handler : Callable [[types .Result ], Awaitable [None ]],
287+ args : list [Any ] | None = None ,
288+ kwargs : dict [str , Any ] | None = None ,
289+ options : dict [str , Any ] | None = None ,
290+ ) -> types .Result :
291+ if options is None :
292+ options = {}
293+
294+ options ["receive_progress" ] = True
295+ call_msg = messages .Call (messages .CallFields (self ._idgen .next (), procedure , args , kwargs , options ))
296+ self ._progress_handlers [call_msg .request_id ] = progress_handler
297+
298+ return await self ._call (call_msg )
299+
237300 async def subscribe (
238301 self , topic : str , event_handler : Callable [[types .Event ], Awaitable [None ]], options : dict | None = None
239302 ) -> Subscription :
0 commit comments