22import asyncio
33import base64
44import logging
5- import uuid
65from pathlib import Path
7- from urllib .parse import urlparse
6+ from urllib .parse import urlparse , quote
87import aiohttp
98from aiortc import MediaStreamTrack
109from pydantic import BaseModel
1110
1211from .webrtc_manager import WebRTCManager , WebRTCConfiguration
13- from .messages import PromptMessage
12+ from .messages import PromptMessage , SessionIdMessage
13+ from .subscribe import (
14+ SubscribeClient ,
15+ SubscribeOptions ,
16+ encode_subscribe_token ,
17+ decode_subscribe_token ,
18+ )
1419from .types import ConnectionState , RealtimeConnectOptions
1520from ..types import FileInput
1621from ..errors import DecartSDKError , InvalidInputError , WebRTCError
@@ -51,26 +56,41 @@ async def _image_to_base64(
5156 image_bytes , _ = await file_input_to_bytes (image , http_session )
5257 return base64 .b64encode (image_bytes ).decode ("utf-8" )
5358
54- return image
55-
56- image_bytes , _ = await file_input_to_bytes (image , http_session )
57- return base64 .b64encode (image_bytes ).decode ("utf-8" )
59+ raise InvalidInputError (
60+ f"Invalid image input: string is not a data URI, URL, or valid file path"
61+ )
5862
5963
6064class RealtimeClient :
6165 def __init__ (
6266 self ,
6367 manager : WebRTCManager ,
64- session_id : str ,
6568 http_session : Optional [aiohttp .ClientSession ] = None ,
6669 is_avatar_live : bool = False ,
6770 ):
6871 self ._manager = manager
69- self .session_id = session_id
7072 self ._http_session = http_session
7173 self ._is_avatar_live = is_avatar_live
7274 self ._connection_callbacks : list [Callable [[ConnectionState ], None ]] = []
7375 self ._error_callbacks : list [Callable [[DecartSDKError ], None ]] = []
76+ self ._session_id : Optional [str ] = None
77+ self ._subscribe_token : Optional [str ] = None
78+ self ._buffering = True
79+ self ._buffer : list [tuple [str , object ]] = []
80+
81+ @property
82+ def session_id (self ) -> Optional [str ]:
83+ return self ._session_id
84+
85+ @property
86+ def subscribe_token (self ) -> Optional [str ]:
87+ return self ._subscribe_token
88+
89+ def _handle_session_id (self , msg : SessionIdMessage ) -> None :
90+ self ._session_id = msg .session_id
91+ self ._subscribe_token = encode_subscribe_token (
92+ msg .session_id , msg .server_ip , msg .server_port
93+ )
7494
7595 @classmethod
7696 async def connect (
@@ -81,20 +101,20 @@ async def connect(
81101 options : RealtimeConnectOptions ,
82102 integration : Optional [str ] = None ,
83103 ) -> "RealtimeClient" :
84- session_id = str (uuid .uuid4 ())
85104 ws_url = f"{ base_url } { options .model .url_path } "
86- ws_url += f"?api_key={ api_key } &model={ options .model .name } "
105+ ws_url += f"?api_key={ quote ( api_key ) } &model={ quote ( options .model .name ) } "
87106
88107 is_avatar_live = options .model .name == "avatar-live"
89108
90109 config = WebRTCConfiguration (
91110 webrtc_url = ws_url ,
92111 api_key = api_key ,
93- session_id = session_id ,
112+ session_id = "" ,
94113 fps = options .model .fps ,
95114 on_remote_stream = options .on_remote_stream ,
96115 on_connection_state_change = None ,
97116 on_error = None ,
117+ on_session_id = None ,
98118 initial_state = options .initial_state ,
99119 customize_offer = options .customize_offer ,
100120 integration = integration ,
@@ -107,13 +127,13 @@ async def connect(
107127 manager = WebRTCManager (config )
108128 client = cls (
109129 manager = manager ,
110- session_id = session_id ,
111130 http_session = http_session ,
112131 is_avatar_live = is_avatar_live ,
113132 )
114133
115134 config .on_connection_state_change = client ._emit_connection_change
116135 config .on_error = lambda error : client ._emit_error (WebRTCError (str (error ), cause = error ))
136+ config .on_session_id = client ._handle_session_id
117137
118138 try :
119139 # For avatar-live, convert and send avatar image before WebRTC connection
@@ -143,28 +163,97 @@ async def connect(
143163 if options .initial_state .prompt :
144164 await client .set_prompt (
145165 options .initial_state .prompt .text ,
146- enrich = options .initial_state .prompt .enrich ,
166+ enhance = options .initial_state .prompt .enhance ,
147167 )
148168 except Exception as e :
169+ await manager .cleanup ()
149170 await http_session .close ()
150171 raise WebRTCError (str (e ), cause = e )
151172
173+ client ._flush ()
152174 return client
153175
154- def _emit_connection_change (self , state : ConnectionState ) -> None :
155- for callback in self ._connection_callbacks :
176+ @classmethod
177+ async def subscribe (
178+ cls ,
179+ base_url : str ,
180+ api_key : str ,
181+ options : SubscribeOptions ,
182+ integration : Optional [str ] = None ,
183+ ) -> SubscribeClient :
184+ token_data = decode_subscribe_token (options .token )
185+ subscribe_url = (
186+ f"{ base_url } /subscribe/{ quote (token_data .sid )} "
187+ f"?IP={ quote (token_data .ip )} "
188+ f"&port={ quote (str (token_data .port ))} "
189+ f"&api_key={ quote (api_key )} "
190+ )
191+
192+ config = WebRTCConfiguration (
193+ webrtc_url = subscribe_url ,
194+ api_key = api_key ,
195+ session_id = token_data .sid ,
196+ fps = 0 ,
197+ on_remote_stream = options .on_remote_stream ,
198+ on_connection_state_change = None ,
199+ on_error = None ,
200+ integration = integration ,
201+ )
202+
203+ manager = WebRTCManager (config )
204+ sub_client = SubscribeClient (manager )
205+
206+ config .on_connection_state_change = sub_client ._emit_connection_change
207+ config .on_error = sub_client ._emit_error
208+
209+ try :
210+ await manager .connect (None )
211+ except Exception as e :
212+ await manager .cleanup ()
213+ raise WebRTCError (str (e ), cause = e )
214+
215+ sub_client ._flush ()
216+ return sub_client
217+
218+ def _flush (self ) -> None :
219+ # Defer to next tick so caller can register handlers before buffered events fire
220+ asyncio .get_running_loop ().call_soon (self ._do_flush )
221+
222+ def _do_flush (self ) -> None :
223+ self ._buffering = False
224+ for event , data in self ._buffer :
225+ if event == "connection_change" :
226+ self ._dispatch_connection_change (data ) # type: ignore[arg-type]
227+ elif event == "error" :
228+ self ._dispatch_error (data ) # type: ignore[arg-type]
229+ self ._buffer .clear ()
230+
231+ def _dispatch_connection_change (self , state : ConnectionState ) -> None :
232+ for callback in list (self ._connection_callbacks ):
156233 try :
157234 callback (state )
158235 except Exception as e :
159236 logger .exception (f"Error in connection_change callback: { e } " )
160237
161- def _emit_error (self , error : DecartSDKError ) -> None :
162- for callback in self ._error_callbacks :
238+ def _dispatch_error (self , error : DecartSDKError ) -> None :
239+ for callback in list ( self ._error_callbacks ) :
163240 try :
164241 callback (error )
165242 except Exception as e :
166243 logger .exception (f"Error in error callback: { e } " )
167244
245+ def _emit_connection_change (self , state : ConnectionState ) -> None :
246+ if self ._buffering :
247+ self ._buffer .append (("connection_change" , state ))
248+ else :
249+ self ._dispatch_connection_change (state )
250+
251+ def _emit_error (self , error : DecartSDKError ) -> None :
252+ if self ._buffering :
253+ self ._buffer .append (("error" , error ))
254+ else :
255+ self ._dispatch_error (error )
256+
168257 async def set (self , input : SetInput ) -> None :
169258 if input .prompt is None and input .image is None :
170259 raise InvalidInputError ("At least one of 'prompt' or 'image' must be provided" )
@@ -187,15 +276,29 @@ async def set(self, input: SetInput) -> None:
187276 },
188277 )
189278
190- async def set_prompt (self , prompt : str , enrich : bool = True ) -> None :
279+ async def set_prompt (
280+ self ,
281+ prompt : str ,
282+ enhance : bool = True ,
283+ enrich : Optional [bool ] = None ,
284+ ) -> None :
285+ if enrich is not None :
286+ import warnings
287+
288+ warnings .warn (
289+ "set_prompt(enrich=...) is deprecated, use set_prompt(enhance=...) instead" ,
290+ DeprecationWarning ,
291+ stacklevel = 2 ,
292+ )
293+ enhance = enrich
191294 if not prompt or not prompt .strip ():
192295 raise InvalidInputError ("Prompt cannot be empty" )
193296
194297 event , result = self ._manager .register_prompt_wait (prompt )
195298
196299 try :
197300 await self ._manager .send_message (
198- PromptMessage (type = "prompt" , prompt = prompt , enhance_prompt = enrich )
301+ PromptMessage (type = "prompt" , prompt = prompt , enhance_prompt = enhance )
199302 )
200303
201304 try :
@@ -208,17 +311,26 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
208311 finally :
209312 self ._manager .unregister_prompt_wait (prompt )
210313
211- async def set_image (self , image : FileInput ) -> None :
212- if not self ._is_avatar_live :
213- raise InvalidInputError ("set_image() is only available for avatar-live model" )
214-
215- if not self ._http_session :
216- raise InvalidInputError ("HTTP session not available" )
314+ async def set_image (
315+ self ,
316+ image : Optional [FileInput ],
317+ prompt : Optional [str ] = None ,
318+ enhance : bool = True ,
319+ timeout : float = UPDATE_TIMEOUT_S ,
320+ ) -> None :
321+ image_base64 : Optional [str ] = None
322+ if image is not None :
323+ if not self ._http_session :
324+ raise InvalidInputError ("HTTP session not available" )
325+ image_bytes , _ = await file_input_to_bytes (image , self ._http_session )
326+ image_base64 = base64 .b64encode (image_bytes ).decode ("utf-8" )
217327
218- image_bytes , _ = await file_input_to_bytes (image , self ._http_session )
219- image_base64 = base64 .b64encode (image_bytes ).decode ("utf-8" )
328+ opts : dict = {"timeout" : timeout }
329+ if prompt is not None :
330+ opts ["prompt" ] = prompt
331+ opts ["enhance" ] = enhance
220332
221- await self ._manager .set_image (image_base64 )
333+ await self ._manager .set_image (image_base64 , opts )
222334
223335 def is_connected (self ) -> bool :
224336 return self ._manager .is_connected ()
@@ -227,6 +339,8 @@ def get_connection_state(self) -> ConnectionState:
227339 return self ._manager .get_connection_state ()
228340
229341 async def disconnect (self ) -> None :
342+ self ._buffering = False
343+ self ._buffer .clear ()
230344 await self ._manager .cleanup ()
231345 if self ._http_session and not self ._http_session .closed :
232346 await self ._http_session .close ()
0 commit comments