11from typing import Callable , Optional
22import asyncio
3+ import base64
34import logging
45import uuid
6+ import aiohttp
57from aiortc import MediaStreamTrack
68
79from .webrtc_manager import WebRTCManager , WebRTCConfiguration
8- from .messages import PromptMessage
10+ from .messages import PromptMessage , SetAvatarImageMessage
911from .types import ConnectionState , RealtimeConnectOptions
12+ from ..types import FileInput
1013from ..errors import DecartSDKError , InvalidInputError , WebRTCError
14+ from ..process .request import file_input_to_bytes
1115
1216logger = logging .getLogger (__name__ )
1317
1418
1519class RealtimeClient :
16- def __init__ (self , manager : WebRTCManager , session_id : str ):
20+ def __init__ (
21+ self ,
22+ manager : WebRTCManager ,
23+ session_id : str ,
24+ http_session : Optional [aiohttp .ClientSession ] = None ,
25+ is_avatar_live : bool = False ,
26+ ):
1727 self ._manager = manager
1828 self .session_id = session_id
29+ self ._http_session = http_session
30+ self ._is_avatar_live = is_avatar_live
1931 self ._connection_callbacks : list [Callable [[ConnectionState ], None ]] = []
2032 self ._error_callbacks : list [Callable [[DecartSDKError ], None ]] = []
2133
@@ -24,14 +36,16 @@ async def connect(
2436 cls ,
2537 base_url : str ,
2638 api_key : str ,
27- local_track : MediaStreamTrack ,
39+ local_track : Optional [ MediaStreamTrack ] ,
2840 options : RealtimeConnectOptions ,
2941 integration : Optional [str ] = None ,
3042 ) -> "RealtimeClient" :
3143 session_id = str (uuid .uuid4 ())
3244 ws_url = f"{ base_url } { options .model .url_path } "
3345 ws_url += f"?api_key={ api_key } &model={ options .model .name } "
3446
47+ is_avatar_live = options .model .name == "avatar-live"
48+
3549 config = WebRTCConfiguration (
3650 webrtc_url = ws_url ,
3751 api_key = api_key ,
@@ -43,16 +57,33 @@ async def connect(
4357 initial_state = options .initial_state ,
4458 customize_offer = options .customize_offer ,
4559 integration = integration ,
60+ is_avatar_live = is_avatar_live ,
4661 )
4762
63+ # Create HTTP session for file conversions
64+ http_session = aiohttp .ClientSession ()
65+
4866 manager = WebRTCManager (config )
49- client = cls (manager = manager , session_id = session_id )
67+ client = cls (
68+ manager = manager ,
69+ session_id = session_id ,
70+ http_session = http_session ,
71+ is_avatar_live = is_avatar_live ,
72+ )
5073
5174 config .on_connection_state_change = client ._emit_connection_change
5275 config .on_error = lambda error : client ._emit_error (WebRTCError (str (error ), cause = error ))
5376
5477 try :
55- await manager .connect (local_track )
78+ # For avatar-live, convert and send avatar image before WebRTC connection
79+ avatar_image_base64 : Optional [str ] = None
80+ if is_avatar_live and options .avatar :
81+ image_bytes , _ = await file_input_to_bytes (
82+ options .avatar .avatar_image , http_session
83+ )
84+ avatar_image_base64 = base64 .b64encode (image_bytes ).decode ("utf-8" )
85+
86+ await manager .connect (local_track , avatar_image_base64 = avatar_image_base64 )
5687
5788 if options .initial_state :
5889 if options .initial_state .prompt :
@@ -61,6 +92,7 @@ async def connect(
6192 enrich = options .initial_state .prompt .enrich ,
6293 )
6394 except Exception as e :
95+ await http_session .close ()
6496 raise WebRTCError (str (e ), cause = e )
6597
6698 return client
@@ -100,6 +132,47 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
100132 finally :
101133 self ._manager .unregister_prompt_wait (prompt )
102134
135+ async def set_image (self , image : FileInput ) -> None :
136+ """Set or update the avatar image.
137+
138+ Only available for avatar-live model.
139+
140+ Args:
141+ image: The image to set. Can be bytes, Path, URL string, or file-like object.
142+
143+ Raises:
144+ InvalidInputError: If not using avatar-live model or image is invalid.
145+ DecartSDKError: If the server fails to acknowledge the image.
146+ """
147+ if not self ._is_avatar_live :
148+ raise InvalidInputError ("set_image() is only available for avatar-live model" )
149+
150+ if not self ._http_session :
151+ raise InvalidInputError ("HTTP session not available" )
152+
153+ # Convert image to base64
154+ image_bytes , _ = await file_input_to_bytes (image , self ._http_session )
155+ image_base64 = base64 .b64encode (image_bytes ).decode ("utf-8" )
156+
157+ event , result = self ._manager .register_image_set_wait ()
158+
159+ try :
160+ await self ._manager .send_message (
161+ SetAvatarImageMessage (type = "set_image" , image_data = image_base64 )
162+ )
163+
164+ try :
165+ await asyncio .wait_for (event .wait (), timeout = 15.0 )
166+ except asyncio .TimeoutError :
167+ raise DecartSDKError ("Image set acknowledgment timed out" )
168+
169+ if not result ["success" ]:
170+ raise DecartSDKError (
171+ result .get ("status" ) or "Failed to set avatar image"
172+ )
173+ finally :
174+ self ._manager .unregister_image_set_wait ()
175+
103176 def is_connected (self ) -> bool :
104177 return self ._manager .is_connected ()
105178
@@ -108,6 +181,8 @@ def get_connection_state(self) -> ConnectionState:
108181
109182 async def disconnect (self ) -> None :
110183 await self ._manager .cleanup ()
184+ if self ._http_session and not self ._http_session .closed :
185+ await self ._http_session .close ()
111186
112187 def on (self , event : str , callback : Callable ) -> None :
113188 if event == "connection_change" :
0 commit comments