Skip to content

Commit b0baa47

Browse files
committed
added prompt ack event and updated setPrompt
1 parent 563d1d4 commit b0baa47

File tree

7 files changed

+198
-9
lines changed

7 files changed

+198
-9
lines changed

decart/realtime/client.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Callable, Optional
2+
import asyncio
23
import logging
34
import uuid
45
from aiortc import MediaStreamTrack
@@ -78,10 +79,30 @@ def _emit_error(self, error: DecartSDKError) -> None:
7879
except Exception as e:
7980
logger.exception(f"Error in error callback: {e}")
8081

81-
async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
82+
async def set_prompt(
83+
self, prompt: str, enrich: bool = True, max_timeout: float = 15.0
84+
) -> None:
8285
if not prompt or not prompt.strip():
8386
raise InvalidInputError("Prompt cannot be empty")
84-
await self._manager.send_message(PromptMessage(type="prompt", prompt=prompt))
87+
if max_timeout <= 0 or max_timeout > 60:
88+
raise InvalidInputError("max_timeout must be between 0 and 60 seconds")
89+
90+
event, result = self._manager.register_prompt_wait(prompt)
91+
92+
try:
93+
await self._manager.send_message(
94+
PromptMessage(type="prompt", prompt=prompt, enhance_prompt=enrich)
95+
)
96+
97+
try:
98+
await asyncio.wait_for(event.wait(), timeout=max_timeout)
99+
except asyncio.TimeoutError:
100+
raise DecartSDKError("Prompt acknowledgment timed out")
101+
102+
if not result["success"]:
103+
raise DecartSDKError(result["error"] or "Prompt failed")
104+
finally:
105+
self._manager.unregister_prompt_wait(prompt)
85106

86107
def is_connected(self) -> bool:
87108
return self._manager.is_connected()

decart/realtime/messages.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal, Union, Annotated
1+
from typing import Literal, Optional, Union, Annotated
22
from pydantic import BaseModel, Field, TypeAdapter
33

44
try:
@@ -42,9 +42,18 @@ class SessionIdMessage(BaseModel):
4242
server_ip: str
4343

4444

45+
class PromptAckMessage(BaseModel):
46+
"""Acknowledgment for prompt update from server."""
47+
48+
type: Literal["prompt_ack"]
49+
prompt: str
50+
success: bool
51+
error: Optional[str] = None
52+
53+
4554
# Discriminated union for incoming messages
4655
IncomingMessage = Annotated[
47-
Union[AnswerMessage, IceCandidateMessage, SessionIdMessage],
56+
Union[AnswerMessage, IceCandidateMessage, SessionIdMessage, PromptAckMessage],
4857
Field(discriminator="type"),
4958
]
5059

@@ -67,6 +76,7 @@ class PromptMessage(BaseModel):
6776

6877
type: Literal["prompt"]
6978
prompt: str
79+
enhance_prompt: bool = True
7080

7181

7282
# Outgoing message union (no discriminator needed - we know what we're sending)

decart/realtime/webrtc_connection.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
OfferMessage,
2222
IceCandidateMessage,
2323
IceCandidatePayload,
24+
PromptAckMessage,
2425
OutgoingMessage,
2526
)
2627
from .types import ConnectionState
@@ -47,6 +48,7 @@ def __init__(
4748
self._customize_offer = customize_offer
4849
self._ws_task: Optional[asyncio.Task] = None
4950
self._ice_candidates_queue: list[RTCIceCandidate] = []
51+
self._pending_prompts: dict[str, tuple[asyncio.Event, dict]] = {}
5052

5153
async def connect(
5254
self,
@@ -176,6 +178,8 @@ async def _handle_message(self, data: dict) -> None:
176178
await self._handle_ice_candidate(message.candidate)
177179
elif message.type == "session_id":
178180
logger.debug(f"Session ID: {message.session_id}")
181+
elif message.type == "prompt_ack":
182+
self._handle_prompt_ack(message)
179183

180184
async def _handle_answer(self, sdp: str) -> None:
181185
logger.debug("Received answer from server")
@@ -207,6 +211,23 @@ async def _handle_ice_candidate(self, candidate_data: IceCandidatePayload) -> No
207211
logger.debug("Queuing ICE candidate (no remote description yet)")
208212
self._ice_candidates_queue.append(candidate)
209213

214+
def _handle_prompt_ack(self, message: PromptAckMessage) -> None:
215+
logger.debug(f"Received prompt_ack for: {message.prompt}, success: {message.success}")
216+
if message.prompt in self._pending_prompts:
217+
event, result = self._pending_prompts[message.prompt]
218+
result["success"] = message.success
219+
result["error"] = message.error
220+
event.set()
221+
222+
def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]:
223+
event = asyncio.Event()
224+
result: dict = {"success": False, "error": None}
225+
self._pending_prompts[prompt] = (event, result)
226+
return event, result
227+
228+
def unregister_prompt_wait(self, prompt: str) -> None:
229+
self._pending_prompts.pop(prompt, None)
230+
210231
async def _send_message(self, message: OutgoingMessage) -> None:
211232
if not self._ws or self._ws.closed:
212233
raise RuntimeError("WebSocket not connected")

decart/realtime/webrtc_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from typing import Optional, Callable
34
from dataclasses import dataclass
@@ -84,3 +85,9 @@ def is_connected(self) -> bool:
8485

8586
def get_connection_state(self) -> ConnectionState:
8687
return self._connection.state
88+
89+
def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]:
90+
return self._connection.register_prompt_wait(prompt)
91+
92+
def unregister_prompt_wait(self, prompt: str) -> None:
93+
self._connection.unregister_prompt_wait(prompt)

examples/realtime_synthetic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,11 @@ def on_error(error):
126126
await asyncio.sleep(5)
127127

128128
print("\n🎨 Changing style to 'Cyberpunk city'...")
129-
await realtime_client.set_prompt("Cyberpunk city")
129+
try:
130+
await realtime_client.set_prompt("Cyberpunk city")
131+
print("✓ Prompt set successfully")
132+
except Exception as e:
133+
print(f"⚠️ Failed to set prompt: {e}")
130134

131135
await asyncio.sleep(5)
132136

tests/test_realtime_unit.py

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,25 @@ def test_realtime_models_available():
4444
@pytest.mark.asyncio
4545
async def test_realtime_client_creation_with_mock():
4646
"""Test client creation with mocked WebRTC"""
47+
import asyncio
48+
4749
client = DecartClient(api_key="test-key")
4850

4951
with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
5052
mock_manager = AsyncMock()
5153
mock_manager.connect = AsyncMock(return_value=True)
5254
mock_manager.is_connected = MagicMock(return_value=True)
5355
mock_manager.get_connection_state = MagicMock(return_value="connected")
56+
mock_manager.send_message = AsyncMock()
57+
58+
prompt_event = asyncio.Event()
59+
prompt_result = {"success": True, "error": None}
60+
prompt_event.set()
61+
62+
mock_manager.register_prompt_wait = MagicMock(
63+
return_value=(prompt_event, prompt_result)
64+
)
65+
mock_manager.unregister_prompt_wait = MagicMock()
5466
mock_manager_class.return_value = mock_manager
5567

5668
mock_track = MagicMock()
@@ -76,13 +88,24 @@ async def test_realtime_client_creation_with_mock():
7688

7789
@pytest.mark.asyncio
7890
async def test_realtime_set_prompt_with_mock():
79-
"""Test set_prompt with mocked WebRTC"""
91+
"""Test set_prompt with mocked WebRTC and prompt_ack"""
92+
import asyncio
93+
8094
client = DecartClient(api_key="test-key")
8195

8296
with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
8397
mock_manager = AsyncMock()
8498
mock_manager.connect = AsyncMock(return_value=True)
8599
mock_manager.send_message = AsyncMock()
100+
101+
prompt_event = asyncio.Event()
102+
prompt_result = {"success": True, "error": None}
103+
104+
def register_prompt_wait(prompt):
105+
return prompt_event, prompt_result
106+
107+
mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait)
108+
mock_manager.unregister_prompt_wait = MagicMock()
86109
mock_manager_class.return_value = mock_manager
87110

88111
mock_track = MagicMock()
@@ -99,12 +122,19 @@ async def test_realtime_set_prompt_with_mock():
99122
),
100123
)
101124

125+
async def set_event():
126+
await asyncio.sleep(0.01)
127+
prompt_event.set()
128+
129+
asyncio.create_task(set_event())
102130
await realtime_client.set_prompt("New prompt")
103131

104-
mock_manager.send_message.assert_called_once()
132+
mock_manager.send_message.assert_called()
105133
call_args = mock_manager.send_message.call_args[0][0]
106134
assert call_args.type == "prompt"
107135
assert call_args.prompt == "New prompt"
136+
assert call_args.enhance_prompt is True
137+
mock_manager.unregister_prompt_wait.assert_called_with("New prompt")
108138

109139

110140
@pytest.mark.asyncio
@@ -152,3 +182,99 @@ def on_error(error):
152182
realtime_client._emit_error(test_error)
153183
assert len(errors) == 1
154184
assert errors[0].message == "Test error"
185+
186+
187+
@pytest.mark.asyncio
188+
async def test_realtime_set_prompt_timeout():
189+
"""Test set_prompt raises on timeout"""
190+
import asyncio
191+
192+
client = DecartClient(api_key="test-key")
193+
194+
with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
195+
mock_manager = AsyncMock()
196+
mock_manager.connect = AsyncMock(return_value=True)
197+
mock_manager.send_message = AsyncMock()
198+
199+
prompt_event = asyncio.Event()
200+
prompt_result = {"success": False, "error": None}
201+
202+
def register_prompt_wait(prompt):
203+
return prompt_event, prompt_result
204+
205+
mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait)
206+
mock_manager.unregister_prompt_wait = MagicMock()
207+
mock_manager_class.return_value = mock_manager
208+
209+
mock_track = MagicMock()
210+
211+
from decart.realtime.types import RealtimeConnectOptions
212+
213+
realtime_client = await RealtimeClient.connect(
214+
base_url=client.base_url,
215+
api_key=client.api_key,
216+
local_track=mock_track,
217+
options=RealtimeConnectOptions(
218+
model=models.realtime("mirage"),
219+
on_remote_stream=lambda t: None,
220+
),
221+
)
222+
223+
from decart.errors import DecartSDKError
224+
225+
with pytest.raises(DecartSDKError) as exc_info:
226+
await realtime_client.set_prompt("New prompt", max_timeout=0.01)
227+
228+
assert "timed out" in str(exc_info.value)
229+
mock_manager.unregister_prompt_wait.assert_called_with("New prompt")
230+
231+
232+
@pytest.mark.asyncio
233+
async def test_realtime_set_prompt_server_error():
234+
"""Test set_prompt raises on server error"""
235+
import asyncio
236+
237+
client = DecartClient(api_key="test-key")
238+
239+
with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
240+
mock_manager = AsyncMock()
241+
mock_manager.connect = AsyncMock(return_value=True)
242+
mock_manager.send_message = AsyncMock()
243+
244+
prompt_event = asyncio.Event()
245+
prompt_result = {"success": False, "error": "Server rejected prompt"}
246+
247+
def register_prompt_wait(prompt):
248+
return prompt_event, prompt_result
249+
250+
mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait)
251+
mock_manager.unregister_prompt_wait = MagicMock()
252+
mock_manager_class.return_value = mock_manager
253+
254+
mock_track = MagicMock()
255+
256+
from decart.realtime.types import RealtimeConnectOptions
257+
258+
realtime_client = await RealtimeClient.connect(
259+
base_url=client.base_url,
260+
api_key=client.api_key,
261+
local_track=mock_track,
262+
options=RealtimeConnectOptions(
263+
model=models.realtime("mirage"),
264+
on_remote_stream=lambda t: None,
265+
),
266+
)
267+
268+
async def set_event():
269+
await asyncio.sleep(0.01)
270+
prompt_event.set()
271+
272+
asyncio.create_task(set_event())
273+
274+
from decart.errors import DecartSDKError
275+
276+
with pytest.raises(DecartSDKError) as exc_info:
277+
await realtime_client.set_prompt("New prompt")
278+
279+
assert "Server rejected prompt" in str(exc_info.value)
280+
mock_manager.unregister_prompt_wait.assert_called_with("New prompt")

uv.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)