Skip to content

Commit 2bfd372

Browse files
authored
feat(realtime): add unified set() method (#25)
- Port set() method to realtime client with proper timeout handling - Convert SetInput from dataclass to Pydantic BaseModel for validation - Remove unused SetAvatarImageMessage import - Update image URL parsing with proper data URI handling - Add comprehensive unit tests for set() method - Update webrtc_manager with timeout configuration
1 parent 6d0c675 commit 2bfd372

File tree

8 files changed

+452
-90
lines changed

8 files changed

+452
-90
lines changed

decart/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
try:
3030
from .realtime import (
3131
RealtimeClient,
32+
SetInput,
3233
RealtimeConnectOptions,
3334
ConnectionState,
3435
AvatarOptions,
@@ -38,6 +39,7 @@
3839
except ImportError:
3940
REALTIME_AVAILABLE = False
4041
RealtimeClient = None # type: ignore
42+
SetInput = None # type: ignore
4143
RealtimeConnectOptions = None # type: ignore
4244
ConnectionState = None # type: ignore
4345
AvatarOptions = None # type: ignore
@@ -76,6 +78,7 @@
7678
__all__.extend(
7779
[
7880
"RealtimeClient",
81+
"SetInput",
7982
"RealtimeConnectOptions",
8083
"ConnectionState",
8184
"AvatarOptions",

decart/realtime/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from .client import RealtimeClient
1+
from .client import RealtimeClient, SetInput
22
from .types import RealtimeConnectOptions, ConnectionState, AvatarOptions
33

44
__all__ = [
55
"RealtimeClient",
6+
"SetInput",
67
"RealtimeConnectOptions",
78
"ConnectionState",
89
"AvatarOptions",

decart/realtime/client.py

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,61 @@
1-
from typing import Callable, Optional
1+
from typing import Callable, Optional, Union
22
import asyncio
33
import base64
44
import logging
55
import uuid
6+
from pathlib import Path
7+
from urllib.parse import urlparse
68
import aiohttp
79
from aiortc import MediaStreamTrack
10+
from pydantic import BaseModel
811

912
from .webrtc_manager import WebRTCManager, WebRTCConfiguration
10-
from .messages import PromptMessage, SetAvatarImageMessage
13+
from .messages import PromptMessage
1114
from .types import ConnectionState, RealtimeConnectOptions
1215
from ..types import FileInput
1316
from ..errors import DecartSDKError, InvalidInputError, WebRTCError
1417
from ..process.request import file_input_to_bytes
1518

1619
logger = logging.getLogger(__name__)
1720

21+
PROMPT_TIMEOUT_S = 15.0
22+
UPDATE_TIMEOUT_S = 30.0
23+
24+
25+
class SetInput(BaseModel):
26+
prompt: Optional[str] = None
27+
enhance: bool = True
28+
image: Optional[Union[bytes, str]] = None
29+
30+
31+
async def _image_to_base64(
32+
image: Union[bytes, str],
33+
http_session: aiohttp.ClientSession,
34+
) -> str:
35+
if isinstance(image, bytes):
36+
return base64.b64encode(image).decode("utf-8")
37+
38+
if isinstance(image, str):
39+
parsed = urlparse(image)
40+
41+
if parsed.scheme == "data":
42+
return image.split(",", 1)[1]
43+
44+
if parsed.scheme in ("http", "https"):
45+
async with http_session.get(image) as resp:
46+
resp.raise_for_status()
47+
data = await resp.read()
48+
return base64.b64encode(data).decode("utf-8")
49+
50+
if Path(image).exists():
51+
image_bytes, _ = await file_input_to_bytes(image, http_session)
52+
return base64.b64encode(image_bytes).decode("utf-8")
53+
54+
return image
55+
56+
image_bytes, _ = await file_input_to_bytes(image, http_session)
57+
return base64.b64encode(image_bytes).decode("utf-8")
58+
1859

1960
class RealtimeClient:
2061
def __init__(
@@ -124,6 +165,28 @@ def _emit_error(self, error: DecartSDKError) -> None:
124165
except Exception as e:
125166
logger.exception(f"Error in error callback: {e}")
126167

168+
async def set(self, input: SetInput) -> None:
169+
if input.prompt is None and input.image is None:
170+
raise InvalidInputError("At least one of 'prompt' or 'image' must be provided")
171+
172+
if input.prompt is not None and not input.prompt.strip():
173+
raise InvalidInputError("Prompt cannot be empty")
174+
175+
image_base64: Optional[str] = None
176+
if input.image is not None:
177+
if not self._http_session:
178+
raise InvalidInputError("HTTP session not available")
179+
image_base64 = await _image_to_base64(input.image, self._http_session)
180+
181+
await self._manager.set_image(
182+
image_base64,
183+
{
184+
"prompt": input.prompt,
185+
"enhance": input.enhance,
186+
"timeout": UPDATE_TIMEOUT_S,
187+
},
188+
)
189+
127190
async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
128191
if not prompt or not prompt.strip():
129192
raise InvalidInputError("Prompt cannot be empty")
@@ -136,7 +199,7 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
136199
)
137200

138201
try:
139-
await asyncio.wait_for(event.wait(), timeout=15.0)
202+
await asyncio.wait_for(event.wait(), timeout=PROMPT_TIMEOUT_S)
140203
except asyncio.TimeoutError:
141204
raise DecartSDKError("Prompt acknowledgment timed out")
142205

@@ -146,43 +209,16 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
146209
self._manager.unregister_prompt_wait(prompt)
147210

148211
async def set_image(self, image: FileInput) -> None:
149-
"""Set or update the avatar image.
150-
151-
Only available for avatar-live model.
152-
153-
Args:
154-
image: The image to set. Can be bytes, Path, URL string, or file-like object.
155-
156-
Raises:
157-
InvalidInputError: If not using avatar-live model or image is invalid.
158-
DecartSDKError: If the server fails to acknowledge the image.
159-
"""
160212
if not self._is_avatar_live:
161213
raise InvalidInputError("set_image() is only available for avatar-live model")
162214

163215
if not self._http_session:
164216
raise InvalidInputError("HTTP session not available")
165217

166-
# Convert image to base64
167218
image_bytes, _ = await file_input_to_bytes(image, self._http_session)
168219
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
169220

170-
event, result = self._manager.register_image_set_wait()
171-
172-
try:
173-
await self._manager.send_message(
174-
SetAvatarImageMessage(type="set_image", image_data=image_base64)
175-
)
176-
177-
try:
178-
await asyncio.wait_for(event.wait(), timeout=15.0)
179-
except asyncio.TimeoutError:
180-
raise DecartSDKError("Image set acknowledgment timed out")
181-
182-
if not result["success"]:
183-
raise DecartSDKError(result.get("error") or "Failed to set avatar image")
184-
finally:
185-
self._manager.unregister_image_set_wait()
221+
await self._manager.set_image(image_base64)
186222

187223
def is_connected(self) -> bool:
188224
return self._manager.is_connected()

decart/realtime/messages.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ class SetAvatarImageMessage(BaseModel):
128128
"""Set avatar image message."""
129129

130130
type: Literal["set_image"]
131-
image_data: str # Base64-encoded image
131+
image_data: Optional[str] = None
132+
prompt: Optional[str] = None
133+
enhance_prompt: Optional[bool] = None
132134

133135

134136
# Outgoing message union (no discriminator needed - we know what we're sending)
@@ -161,4 +163,4 @@ def message_to_json(message: OutgoingMessage) -> str:
161163
Returns:
162164
JSON string
163165
"""
164-
return message.model_dump_json()
166+
return message.model_dump_json(exclude_none=True)

decart/realtime/webrtc_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def connect(
5959
self,
6060
url: str,
6161
local_track: Optional[MediaStreamTrack],
62-
timeout: float = 30,
62+
timeout: float,
6363
integration: Optional[str] = None,
6464
is_avatar_live: bool = False,
6565
avatar_image_base64: Optional[str] = None,
@@ -107,7 +107,7 @@ async def connect(
107107
self._on_error(e)
108108
raise WebRTCError(str(e), cause=e)
109109

110-
async def _send_avatar_image_and_wait(self, image_base64: str, timeout: float = 15.0) -> None:
110+
async def _send_avatar_image_and_wait(self, image_base64: str, timeout: float = 30.0) -> None:
111111
"""Send avatar image and wait for acknowledgment."""
112112
event, result = self.register_image_set_wait()
113113

decart/realtime/webrtc_manager.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ async def connect(
6060
initial_prompt: Optional[dict] = None,
6161
) -> bool:
6262
try:
63+
timeout = 60 * 5 # 5 minutes
6364
await self._connection.connect(
6465
url=self._config.webrtc_url,
6566
local_track=local_track,
67+
timeout=timeout,
6668
integration=self._config.integration,
6769
is_avatar_live=self._config.is_avatar_live,
6870
avatar_image_base64=avatar_image_base64,
@@ -83,6 +85,44 @@ def _create_connection(self) -> WebRTCConnection:
8385
customize_offer=self._config.customize_offer,
8486
)
8587

88+
async def set_image(
89+
self,
90+
image_base64: Optional[str],
91+
options: Optional[dict] = None,
92+
) -> None:
93+
from .messages import SetAvatarImageMessage
94+
95+
opts = options or {}
96+
timeout = opts.get("timeout", 30.0)
97+
98+
event, result = self._connection.register_image_set_wait()
99+
100+
try:
101+
message = SetAvatarImageMessage(
102+
type="set_image",
103+
image_data=image_base64,
104+
)
105+
if opts.get("prompt") is not None:
106+
message.prompt = opts["prompt"]
107+
if opts.get("enhance") is not None:
108+
message.enhance_prompt = opts["enhance"]
109+
110+
await self._connection.send(message)
111+
112+
try:
113+
await asyncio.wait_for(event.wait(), timeout=timeout)
114+
except asyncio.TimeoutError:
115+
from ..errors import DecartSDKError
116+
117+
raise DecartSDKError("Image send timed out")
118+
119+
if not result["success"]:
120+
from ..errors import DecartSDKError
121+
122+
raise DecartSDKError(result.get("error") or "Failed to set image")
123+
finally:
124+
self._connection.unregister_image_set_wait()
125+
86126
async def send_message(self, message: OutgoingMessage) -> None:
87127
await self._connection.send(message)
88128

0 commit comments

Comments
 (0)