Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/speech_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ async def main() -> None:
lmnt = AsyncLmnt()

# Construct the streaming connection with our desired voice
session = await lmnt.speech.sessions.create(voice="leah", return_extras=True)
session = await lmnt.speech.sessions.create(voice="leah")
write_task = asyncio.create_task(write_messages(session))
read_task = asyncio.create_task(read_messages(session))

Expand All @@ -31,10 +31,10 @@ async def write_messages(session: SpeechSession) -> None:
async def read_messages(session: SpeechSession) -> None:
with open("stream-output.mp3", "wb") as audio_file:
async for message in session:
audio_bytes = len(message.audio)
print(f" ** Received from LMNT -- {audio_bytes} bytes ** ")
print(f" ** Durations: {message.durations} ** ")
audio_file.write(message.audio)
if message.type == "audio":
audio_bytes = len(message.audio)
print(f" ** Received from LMNT -- {audio_bytes} bytes ** ")
audio_file.write(message.audio)


if __name__ == "__main__":
Expand Down
125 changes: 75 additions & 50 deletions src/lmnt/lib/websocket_streaming.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,56 @@
from __future__ import annotations

import json
from typing import Any, Dict, List, Final, Optional
from dataclasses import dataclass
from typing import Any, Dict, List, Final, Union, Literal, Optional

import websockets
from websockets.typing import Data
from pydantic import Field, BaseModel

URL_STREAMING: Final = "wss://api.lmnt.com/v1/ai/speech/stream"


@dataclass
class Duration:
class Duration(BaseModel):
"""Duration information for a segment of synthesized speech."""

text: str
start: float # Start time in seconds
duration: float # Duration in seconds
start: float = Field(description="Start time in seconds")
duration: float = Field(description="Duration in seconds")


@dataclass
class SpeechSessionResponse:
"""Response from a speech session connection."""

class AudioMessage(BaseModel):
"""Audio message containing synthesized speech data."""

type: Literal["audio"] = "audio"
audio: bytes


class ExtrasMessage(BaseModel):
"""Extras message containing metadata about the synthesis."""

type: Literal["extras"] = "extras"
durations: Optional[List[Duration]] = None
warning: Optional[str] = None
buffer_empty: Optional[bool] = None


class ErrorMessage(BaseModel):
"""Error message containing error information."""

type: Literal["error"] = "error"
error: str


class CompleteMessage(BaseModel):
"""Complete message for commands (reset/flush)."""

type: Literal["complete"] = "complete"
complete: Literal["reset", "flush"]
nonce: int


SpeechSessionResponse = Union[AudioMessage, ExtrasMessage, ErrorMessage, CompleteMessage]


class UnexpectedMessageError(Exception):
"""Exception raised when an unexpected message is received from the server."""

Expand Down Expand Up @@ -54,13 +78,15 @@ def __init__(
self.sample_rate = sample_rate
self.return_extras = return_extras
self.websocket: Optional[Any] = None
self.nonce: int = 0

async def connect(self) -> None:
"""Establish the websocket connection."""
self.websocket = await websockets.connect(URL_STREAMING)
init_msg = {
"X-API-Key": self.api_key,
"voice": self.voice,
"protocol_version": 2,
}
if self.format:
init_msg["format"] = self.format
Expand All @@ -77,13 +103,21 @@ async def append_text(self, text: str) -> None:
"""Append text to be synthesized."""
await self._send_message({"text": text})

async def flush(self) -> None:
async def flush(self) -> int:
"""Flush the current text buffer."""
await self._send_message({"flush": True})
self.nonce += 1
await self._send_message({"command": "flush", "nonce": self.nonce})
return self.nonce

async def reset(self) -> int:
"""Reset the current text buffer."""
self.nonce += 1
await self._send_message({"command": "reset", "nonce": self.nonce})
return self.nonce

async def finish(self) -> None:
"""Mark the session as finished."""
await self._send_message({"eof": True})
await self._send_message({"command": "eof"})

async def close(self) -> None:
"""Close the websocket connection."""
Expand All @@ -96,7 +130,7 @@ async def _send_message(self, message: Dict[str, Any]) -> None:
if self.websocket is not None:
await self.websocket.send(json.dumps(message))

def __aiter__(self) -> "SpeechSession":
def __aiter__(self) -> SpeechSession:
"""Return the async iterator."""
return self

Expand All @@ -105,43 +139,34 @@ async def __anext__(self) -> SpeechSessionResponse:
if not self.websocket:
raise StopAsyncIteration
try:
if self.return_extras:
extras_message = await self.websocket.recv()
if not isinstance(extras_message, str):
raise UnexpectedMessageError("Expected string for extras message")
extras = self._parse_and_check_error(extras_message)
audio_msg = await self.websocket.recv()
audio_data = self._process_audio_data(audio_msg)
durations = None
if extras.get("durations"):
durations = [Duration(**d) for d in extras["durations"]]
return SpeechSessionResponse(
audio=audio_data,
durations=durations,
warning=extras.get("warning"),
buffer_empty=extras.get("buffer_empty"),
)
message = await self.websocket.recv()

if isinstance(message, str):
return self._parse_text_message(message)
elif isinstance(message, bytes):
return AudioMessage(type="audio", audio=message)
else:
audio_msg = await self.websocket.recv()
audio_data = self._process_audio_data(audio_msg)
return SpeechSessionResponse(audio=audio_data)
raise UnexpectedMessageError(f"Unexpected message type: {type(message)}")

except websockets.exceptions.ConnectionClosed as err:
raise StopAsyncIteration from err

def _process_audio_data(self, audio_msg: Data) -> bytes:
"""Process the audio data. Handles binary audio data and JSON error messages."""
if isinstance(audio_msg, bytes):
return audio_msg
else:
self._parse_and_check_error(str(audio_msg))
raise UnexpectedMessageError(str(audio_msg))

def _parse_and_check_error(self, message: str) -> Dict[str, Any]:
"""JSON parse a message and check for errors."""
def _parse_text_message(self, text_data: str) -> SpeechSessionResponse:
"""Parse a text message from the server."""
try:
msg_json: Dict[str, Any] = json.loads(message)
message_json = json.loads(text_data)
except json.JSONDecodeError as err:
raise ValueError(f"Invalid JSON received from server: {message}") from err
if "error" in msg_json:
raise ValueError(msg_json["error"])
return msg_json
raise ValueError(f"Invalid JSON received from server: {text_data}") from err

if "error" in message_json:
return ErrorMessage.model_construct(**message_json)

if "complete" in message_json:
return CompleteMessage.model_construct(**message_json)

if self.return_extras and (
message_json.get("durations") or message_json.get("warning") or message_json.get("buffer_empty") is not None
):
return ExtrasMessage.model_construct(**message_json)

raise UnexpectedMessageError(f"Unexpected message: {text_data}")