From ccd30fde861df2c451a47b17877192241efa554b Mon Sep 17 00:00:00 2001 From: kaikato Date: Tue, 5 Aug 2025 19:49:53 +0000 Subject: [PATCH] feat: support resets in websocket speech sessions --- examples/speech_session.py | 10 +-- src/lmnt/lib/websocket_streaming.py | 125 +++++++++++++++++----------- 2 files changed, 80 insertions(+), 55 deletions(-) diff --git a/examples/speech_session.py b/examples/speech_session.py index cc7ed97..5f55874 100644 --- a/examples/speech_session.py +++ b/examples/speech_session.py @@ -9,7 +9,7 @@ async def main() -> None: lmnt = AsyncLmnt() # Construct the streaming connection with our desired voice - session = await lmnt.speech.sessions.create(voice="leah", return_extras=True) + session = await lmnt.speech.sessions.create(voice="leah") write_task = asyncio.create_task(write_messages(session)) read_task = asyncio.create_task(read_messages(session)) @@ -31,10 +31,10 @@ async def write_messages(session: SpeechSession) -> None: async def read_messages(session: SpeechSession) -> None: with open("stream-output.mp3", "wb") as audio_file: async for message in session: - audio_bytes = len(message.audio) - print(f" ** Received from LMNT -- {audio_bytes} bytes ** ") - print(f" ** Durations: {message.durations} ** ") - audio_file.write(message.audio) + if message.type == "audio": + audio_bytes = len(message.audio) + print(f" ** Received from LMNT -- {audio_bytes} bytes ** ") + audio_file.write(message.audio) if __name__ == "__main__": diff --git a/src/lmnt/lib/websocket_streaming.py b/src/lmnt/lib/websocket_streaming.py index 673b4a5..28963f0 100644 --- a/src/lmnt/lib/websocket_streaming.py +++ b/src/lmnt/lib/websocket_streaming.py @@ -1,32 +1,56 @@ +from __future__ import annotations + import json -from typing import Any, Dict, List, Final, Optional -from dataclasses import dataclass +from typing import Any, Dict, List, Final, Union, Literal, Optional import websockets -from websockets.typing import Data +from pydantic import Field, BaseModel URL_STREAMING: Final = "wss://api.lmnt.com/v1/ai/speech/stream" -@dataclass -class Duration: +class Duration(BaseModel): """Duration information for a segment of synthesized speech.""" - + text: str - start: float # Start time in seconds - duration: float # Duration in seconds + start: float = Field(description="Start time in seconds") + duration: float = Field(description="Duration in seconds") -@dataclass -class SpeechSessionResponse: - """Response from a speech session connection.""" - +class AudioMessage(BaseModel): + """Audio message containing synthesized speech data.""" + + type: Literal["audio"] = "audio" audio: bytes + + +class ExtrasMessage(BaseModel): + """Extras message containing metadata about the synthesis.""" + + type: Literal["extras"] = "extras" durations: Optional[List[Duration]] = None warning: Optional[str] = None buffer_empty: Optional[bool] = None +class ErrorMessage(BaseModel): + """Error message containing error information.""" + + type: Literal["error"] = "error" + error: str + + +class CompleteMessage(BaseModel): + """Complete message for commands (reset/flush).""" + + type: Literal["complete"] = "complete" + complete: Literal["reset", "flush"] + nonce: int + + +SpeechSessionResponse = Union[AudioMessage, ExtrasMessage, ErrorMessage, CompleteMessage] + + class UnexpectedMessageError(Exception): """Exception raised when an unexpected message is received from the server.""" @@ -54,6 +78,7 @@ def __init__( self.sample_rate = sample_rate self.return_extras = return_extras self.websocket: Optional[Any] = None + self.nonce: int = 0 async def connect(self) -> None: """Establish the websocket connection.""" @@ -61,6 +86,7 @@ async def connect(self) -> None: init_msg = { "X-API-Key": self.api_key, "voice": self.voice, + "protocol_version": 2, } if self.format: init_msg["format"] = self.format @@ -77,13 +103,21 @@ async def append_text(self, text: str) -> None: """Append text to be synthesized.""" await self._send_message({"text": text}) - async def flush(self) -> None: + async def flush(self) -> int: """Flush the current text buffer.""" - await self._send_message({"flush": True}) + self.nonce += 1 + await self._send_message({"command": "flush", "nonce": self.nonce}) + return self.nonce + + async def reset(self) -> int: + """Reset the current text buffer.""" + self.nonce += 1 + await self._send_message({"command": "reset", "nonce": self.nonce}) + return self.nonce async def finish(self) -> None: """Mark the session as finished.""" - await self._send_message({"eof": True}) + await self._send_message({"command": "eof"}) async def close(self) -> None: """Close the websocket connection.""" @@ -96,7 +130,7 @@ async def _send_message(self, message: Dict[str, Any]) -> None: if self.websocket is not None: await self.websocket.send(json.dumps(message)) - def __aiter__(self) -> "SpeechSession": + def __aiter__(self) -> SpeechSession: """Return the async iterator.""" return self @@ -105,43 +139,34 @@ async def __anext__(self) -> SpeechSessionResponse: if not self.websocket: raise StopAsyncIteration try: - if self.return_extras: - extras_message = await self.websocket.recv() - if not isinstance(extras_message, str): - raise UnexpectedMessageError("Expected string for extras message") - extras = self._parse_and_check_error(extras_message) - audio_msg = await self.websocket.recv() - audio_data = self._process_audio_data(audio_msg) - durations = None - if extras.get("durations"): - durations = [Duration(**d) for d in extras["durations"]] - return SpeechSessionResponse( - audio=audio_data, - durations=durations, - warning=extras.get("warning"), - buffer_empty=extras.get("buffer_empty"), - ) + message = await self.websocket.recv() + + if isinstance(message, str): + return self._parse_text_message(message) + elif isinstance(message, bytes): + return AudioMessage(type="audio", audio=message) else: - audio_msg = await self.websocket.recv() - audio_data = self._process_audio_data(audio_msg) - return SpeechSessionResponse(audio=audio_data) + raise UnexpectedMessageError(f"Unexpected message type: {type(message)}") + except websockets.exceptions.ConnectionClosed as err: raise StopAsyncIteration from err - def _process_audio_data(self, audio_msg: Data) -> bytes: - """Process the audio data. Handles binary audio data and JSON error messages.""" - if isinstance(audio_msg, bytes): - return audio_msg - else: - self._parse_and_check_error(str(audio_msg)) - raise UnexpectedMessageError(str(audio_msg)) - - def _parse_and_check_error(self, message: str) -> Dict[str, Any]: - """JSON parse a message and check for errors.""" + def _parse_text_message(self, text_data: str) -> SpeechSessionResponse: + """Parse a text message from the server.""" try: - msg_json: Dict[str, Any] = json.loads(message) + message_json = json.loads(text_data) except json.JSONDecodeError as err: - raise ValueError(f"Invalid JSON received from server: {message}") from err - if "error" in msg_json: - raise ValueError(msg_json["error"]) - return msg_json + raise ValueError(f"Invalid JSON received from server: {text_data}") from err + + if "error" in message_json: + return ErrorMessage.model_construct(**message_json) + + if "complete" in message_json: + return CompleteMessage.model_construct(**message_json) + + if self.return_extras and ( + message_json.get("durations") or message_json.get("warning") or message_json.get("buffer_empty") is not None + ): + return ExtrasMessage.model_construct(**message_json) + + raise UnexpectedMessageError(f"Unexpected message: {text_data}")