Skip to content

Commit 0fdfbe0

Browse files
committed
feat(realtime): port JS SDK realtime features and fix P1/P2 bugs
Port missing realtime client features from the JS SDK to the Python SDK, closing all identified gaps from the cross-SDK analysis. Also fixes all P1/P2 bugs found through external reviews and self-review. New features: - Subscribe mode (SubscribeClient, token encode/decode) - Server-assigned session_id and subscribe_token - Event buffer (deferred flush matching JS setTimeout(0)) - Generating/reconnecting connection states - Auto-reconnect with tenacity (generation counter, CancelledError handling) - set_image(None) to clear, prompt/enhance options on set_image - Prompt.enrich renamed to Prompt.enhance (backward compat in set_prompt) Bug fixes (P1): - connect() retry no longer leaks connection objects - _reconnect() resets _is_reconnecting via finally (CancelledError safe) - WS close transitions state to disconnected (triggers reconnect) - ICE restart preserves tracks, transceivers, and icecandidate handler Bug fixes (P2): - connect() classmethod cleans up manager on post-connect failure - Callback dispatch iterates list copy (safe self-removal) - cleanup() nulls _connection for cleaner post-disconnect errors - _image_to_base64 validates unrecognized strings (raises InvalidInputError) - Subscribe tokens use URL-safe base64
1 parent 6696299 commit 0fdfbe0

File tree

13 files changed

+654
-82
lines changed

13 files changed

+654
-82
lines changed

decart/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
from .realtime import (
3131
RealtimeClient,
3232
SetInput,
33+
SubscribeClient,
34+
SubscribeOptions,
35+
encode_subscribe_token,
36+
decode_subscribe_token,
3337
RealtimeConnectOptions,
3438
ConnectionState,
3539
AvatarOptions,
@@ -40,6 +44,10 @@
4044
REALTIME_AVAILABLE = False
4145
RealtimeClient = None # type: ignore
4246
SetInput = None # type: ignore
47+
SubscribeClient = None # type: ignore
48+
SubscribeOptions = None # type: ignore
49+
encode_subscribe_token = None # type: ignore
50+
decode_subscribe_token = None # type: ignore
4351
RealtimeConnectOptions = None # type: ignore
4452
ConnectionState = None # type: ignore
4553
AvatarOptions = None # type: ignore
@@ -79,6 +87,10 @@
7987
[
8088
"RealtimeClient",
8189
"SetInput",
90+
"SubscribeClient",
91+
"SubscribeOptions",
92+
"encode_subscribe_token",
93+
"decode_subscribe_token",
8294
"RealtimeConnectOptions",
8395
"ConnectionState",
8496
"AvatarOptions",

decart/realtime/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
from .client import RealtimeClient, SetInput
2+
from .subscribe import (
3+
SubscribeClient,
4+
SubscribeOptions,
5+
encode_subscribe_token,
6+
decode_subscribe_token,
7+
)
28
from .types import RealtimeConnectOptions, ConnectionState, AvatarOptions
39

410
__all__ = [
511
"RealtimeClient",
612
"SetInput",
13+
"SubscribeClient",
14+
"SubscribeOptions",
15+
"encode_subscribe_token",
16+
"decode_subscribe_token",
717
"RealtimeConnectOptions",
818
"ConnectionState",
919
"AvatarOptions",

decart/realtime/client.py

Lines changed: 143 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,20 @@
22
import asyncio
33
import base64
44
import logging
5-
import uuid
65
from pathlib import Path
7-
from urllib.parse import urlparse
6+
from urllib.parse import urlparse, quote
87
import aiohttp
98
from aiortc import MediaStreamTrack
109
from pydantic import BaseModel
1110

1211
from .webrtc_manager import WebRTCManager, WebRTCConfiguration
13-
from .messages import PromptMessage
12+
from .messages import PromptMessage, SessionIdMessage
13+
from .subscribe import (
14+
SubscribeClient,
15+
SubscribeOptions,
16+
encode_subscribe_token,
17+
decode_subscribe_token,
18+
)
1419
from .types import ConnectionState, RealtimeConnectOptions
1520
from ..types import FileInput
1621
from ..errors import DecartSDKError, InvalidInputError, WebRTCError
@@ -51,26 +56,41 @@ async def _image_to_base64(
5156
image_bytes, _ = await file_input_to_bytes(image, http_session)
5257
return base64.b64encode(image_bytes).decode("utf-8")
5358

54-
return image
55-
56-
image_bytes, _ = await file_input_to_bytes(image, http_session)
57-
return base64.b64encode(image_bytes).decode("utf-8")
59+
raise InvalidInputError(
60+
f"Invalid image input: string is not a data URI, URL, or valid file path"
61+
)
5862

5963

6064
class RealtimeClient:
6165
def __init__(
6266
self,
6367
manager: WebRTCManager,
64-
session_id: str,
6568
http_session: Optional[aiohttp.ClientSession] = None,
6669
is_avatar_live: bool = False,
6770
):
6871
self._manager = manager
69-
self.session_id = session_id
7072
self._http_session = http_session
7173
self._is_avatar_live = is_avatar_live
7274
self._connection_callbacks: list[Callable[[ConnectionState], None]] = []
7375
self._error_callbacks: list[Callable[[DecartSDKError], None]] = []
76+
self._session_id: Optional[str] = None
77+
self._subscribe_token: Optional[str] = None
78+
self._buffering = True
79+
self._buffer: list[tuple[str, object]] = []
80+
81+
@property
82+
def session_id(self) -> Optional[str]:
83+
return self._session_id
84+
85+
@property
86+
def subscribe_token(self) -> Optional[str]:
87+
return self._subscribe_token
88+
89+
def _handle_session_id(self, msg: SessionIdMessage) -> None:
90+
self._session_id = msg.session_id
91+
self._subscribe_token = encode_subscribe_token(
92+
msg.session_id, msg.server_ip, msg.server_port
93+
)
7494

7595
@classmethod
7696
async def connect(
@@ -81,20 +101,20 @@ async def connect(
81101
options: RealtimeConnectOptions,
82102
integration: Optional[str] = None,
83103
) -> "RealtimeClient":
84-
session_id = str(uuid.uuid4())
85104
ws_url = f"{base_url}{options.model.url_path}"
86-
ws_url += f"?api_key={api_key}&model={options.model.name}"
105+
ws_url += f"?api_key={quote(api_key)}&model={quote(options.model.name)}"
87106

88107
is_avatar_live = options.model.name == "avatar-live"
89108

90109
config = WebRTCConfiguration(
91110
webrtc_url=ws_url,
92111
api_key=api_key,
93-
session_id=session_id,
112+
session_id="",
94113
fps=options.model.fps,
95114
on_remote_stream=options.on_remote_stream,
96115
on_connection_state_change=None,
97116
on_error=None,
117+
on_session_id=None,
98118
initial_state=options.initial_state,
99119
customize_offer=options.customize_offer,
100120
integration=integration,
@@ -107,13 +127,13 @@ async def connect(
107127
manager = WebRTCManager(config)
108128
client = cls(
109129
manager=manager,
110-
session_id=session_id,
111130
http_session=http_session,
112131
is_avatar_live=is_avatar_live,
113132
)
114133

115134
config.on_connection_state_change = client._emit_connection_change
116135
config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error))
136+
config.on_session_id = client._handle_session_id
117137

118138
try:
119139
# For avatar-live, convert and send avatar image before WebRTC connection
@@ -143,28 +163,97 @@ async def connect(
143163
if options.initial_state.prompt:
144164
await client.set_prompt(
145165
options.initial_state.prompt.text,
146-
enrich=options.initial_state.prompt.enrich,
166+
enhance=options.initial_state.prompt.enhance,
147167
)
148168
except Exception as e:
169+
await manager.cleanup()
149170
await http_session.close()
150171
raise WebRTCError(str(e), cause=e)
151172

173+
client._flush()
152174
return client
153175

154-
def _emit_connection_change(self, state: ConnectionState) -> None:
155-
for callback in self._connection_callbacks:
176+
@classmethod
177+
async def subscribe(
178+
cls,
179+
base_url: str,
180+
api_key: str,
181+
options: SubscribeOptions,
182+
integration: Optional[str] = None,
183+
) -> SubscribeClient:
184+
token_data = decode_subscribe_token(options.token)
185+
subscribe_url = (
186+
f"{base_url}/subscribe/{quote(token_data.sid)}"
187+
f"?IP={quote(token_data.ip)}"
188+
f"&port={quote(str(token_data.port))}"
189+
f"&api_key={quote(api_key)}"
190+
)
191+
192+
config = WebRTCConfiguration(
193+
webrtc_url=subscribe_url,
194+
api_key=api_key,
195+
session_id=token_data.sid,
196+
fps=0,
197+
on_remote_stream=options.on_remote_stream,
198+
on_connection_state_change=None,
199+
on_error=None,
200+
integration=integration,
201+
)
202+
203+
manager = WebRTCManager(config)
204+
sub_client = SubscribeClient(manager)
205+
206+
config.on_connection_state_change = sub_client._emit_connection_change
207+
config.on_error = sub_client._emit_error
208+
209+
try:
210+
await manager.connect(None)
211+
except Exception as e:
212+
await manager.cleanup()
213+
raise WebRTCError(str(e), cause=e)
214+
215+
sub_client._flush()
216+
return sub_client
217+
218+
def _flush(self) -> None:
219+
# Defer to next tick so caller can register handlers before buffered events fire
220+
asyncio.get_running_loop().call_soon(self._do_flush)
221+
222+
def _do_flush(self) -> None:
223+
self._buffering = False
224+
for event, data in self._buffer:
225+
if event == "connection_change":
226+
self._dispatch_connection_change(data) # type: ignore[arg-type]
227+
elif event == "error":
228+
self._dispatch_error(data) # type: ignore[arg-type]
229+
self._buffer.clear()
230+
231+
def _dispatch_connection_change(self, state: ConnectionState) -> None:
232+
for callback in list(self._connection_callbacks):
156233
try:
157234
callback(state)
158235
except Exception as e:
159236
logger.exception(f"Error in connection_change callback: {e}")
160237

161-
def _emit_error(self, error: DecartSDKError) -> None:
162-
for callback in self._error_callbacks:
238+
def _dispatch_error(self, error: DecartSDKError) -> None:
239+
for callback in list(self._error_callbacks):
163240
try:
164241
callback(error)
165242
except Exception as e:
166243
logger.exception(f"Error in error callback: {e}")
167244

245+
def _emit_connection_change(self, state: ConnectionState) -> None:
246+
if self._buffering:
247+
self._buffer.append(("connection_change", state))
248+
else:
249+
self._dispatch_connection_change(state)
250+
251+
def _emit_error(self, error: DecartSDKError) -> None:
252+
if self._buffering:
253+
self._buffer.append(("error", error))
254+
else:
255+
self._dispatch_error(error)
256+
168257
async def set(self, input: SetInput) -> None:
169258
if input.prompt is None and input.image is None:
170259
raise InvalidInputError("At least one of 'prompt' or 'image' must be provided")
@@ -187,15 +276,29 @@ async def set(self, input: SetInput) -> None:
187276
},
188277
)
189278

190-
async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
279+
async def set_prompt(
280+
self,
281+
prompt: str,
282+
enhance: bool = True,
283+
enrich: Optional[bool] = None,
284+
) -> None:
285+
if enrich is not None:
286+
import warnings
287+
288+
warnings.warn(
289+
"set_prompt(enrich=...) is deprecated, use set_prompt(enhance=...) instead",
290+
DeprecationWarning,
291+
stacklevel=2,
292+
)
293+
enhance = enrich
191294
if not prompt or not prompt.strip():
192295
raise InvalidInputError("Prompt cannot be empty")
193296

194297
event, result = self._manager.register_prompt_wait(prompt)
195298

196299
try:
197300
await self._manager.send_message(
198-
PromptMessage(type="prompt", prompt=prompt, enhance_prompt=enrich)
301+
PromptMessage(type="prompt", prompt=prompt, enhance_prompt=enhance)
199302
)
200303

201304
try:
@@ -208,17 +311,26 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
208311
finally:
209312
self._manager.unregister_prompt_wait(prompt)
210313

211-
async def set_image(self, image: FileInput) -> None:
212-
if not self._is_avatar_live:
213-
raise InvalidInputError("set_image() is only available for avatar-live model")
214-
215-
if not self._http_session:
216-
raise InvalidInputError("HTTP session not available")
314+
async def set_image(
315+
self,
316+
image: Optional[FileInput],
317+
prompt: Optional[str] = None,
318+
enhance: bool = True,
319+
timeout: float = UPDATE_TIMEOUT_S,
320+
) -> None:
321+
image_base64: Optional[str] = None
322+
if image is not None:
323+
if not self._http_session:
324+
raise InvalidInputError("HTTP session not available")
325+
image_bytes, _ = await file_input_to_bytes(image, self._http_session)
326+
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
217327

218-
image_bytes, _ = await file_input_to_bytes(image, self._http_session)
219-
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
328+
opts: dict = {"timeout": timeout}
329+
if prompt is not None:
330+
opts["prompt"] = prompt
331+
opts["enhance"] = enhance
220332

221-
await self._manager.set_image(image_base64)
333+
await self._manager.set_image(image_base64, opts)
222334

223335
def is_connected(self) -> bool:
224336
return self._manager.is_connected()
@@ -227,6 +339,8 @@ def get_connection_state(self) -> ConnectionState:
227339
return self._manager.get_connection_state()
228340

229341
async def disconnect(self) -> None:
342+
self._buffering = False
343+
self._buffer.clear()
230344
await self._manager.cleanup()
231345
if self._http_session and not self._http_session.closed:
232346
await self._http_session.close()

decart/realtime/messages.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ class IceRestartMessage(BaseModel):
8787
turn_config: TurnConfig
8888

8989

90+
class GenerationStartedMessage(BaseModel):
91+
"""Server signals that generation has started."""
92+
93+
type: Literal["generation_started"]
94+
95+
9096
# Discriminated union for incoming messages
9197
IncomingMessage = Annotated[
9298
Union[
@@ -98,6 +104,7 @@ class IceRestartMessage(BaseModel):
98104
ErrorMessage,
99105
ReadyMessage,
100106
IceRestartMessage,
107+
GenerationStartedMessage,
101108
],
102109
Field(discriminator="type"),
103110
]

0 commit comments

Comments
 (0)