diff --git a/crates/cactus/src/lib.rs b/crates/cactus/src/lib.rs index 5697ab44da..c7e8b99392 100644 --- a/crates/cactus/src/lib.rs +++ b/crates/cactus/src/lib.rs @@ -7,11 +7,11 @@ mod vad; pub use error::Error; pub use hypr_language::Language; -pub use llm::{CompleteOptions, CompletionResult, Message, complete_stream}; +pub use llm::{CompleteOptions, CompletionResult, CompletionStream, Message, complete_stream}; pub use model::{Model, ModelBuilder}; pub use stt::{ CloudConfig, StreamResult, TranscribeEvent, TranscribeOptions, Transcriber, - TranscriptionResult, constrain_to, transcribe_stream, + TranscriptionResult, TranscriptionSession, constrain_to, transcribe_stream, }; pub use vad::{VadOptions, VadResult, VadSegment}; diff --git a/crates/cactus/src/llm/mod.rs b/crates/cactus/src/llm/mod.rs index b4826f4f3b..a304766c9f 100644 --- a/crates/cactus/src/llm/mod.rs +++ b/crates/cactus/src/llm/mod.rs @@ -4,7 +4,7 @@ mod stream; pub use hypr_llm_types::Message; pub use result::CompletionResult; -pub use stream::complete_stream; +pub use stream::{CompletionStream, complete_stream}; #[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] pub struct CompleteOptions { diff --git a/crates/cactus/src/llm/result.rs b/crates/cactus/src/llm/result.rs index a87c147c62..3803f28491 100644 --- a/crates/cactus/src/llm/result.rs +++ b/crates/cactus/src/llm/result.rs @@ -29,9 +29,9 @@ pub struct CompletionResult { #[serde(default)] pub decode_tps: f64, #[serde(default)] - pub prefill_tokens: u32, + pub prefill_tokens: u64, #[serde(default)] - pub decode_tokens: u32, + pub decode_tokens: u64, #[serde(default)] - pub total_tokens: u32, + pub total_tokens: u64, } diff --git a/crates/cactus/src/llm/stream.rs b/crates/cactus/src/llm/stream.rs index fb4881a0e3..268bd895bc 100644 --- a/crates/cactus/src/llm/stream.rs +++ b/crates/cactus/src/llm/stream.rs @@ -1,5 +1,8 @@ +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; +use futures_util::Stream; use hypr_llm_types::{Response, StreamingParser}; use tokio::sync::mpsc::UnboundedSender; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -78,15 +81,58 @@ fn run_stream_worker( worker.run(&messages, &options); } +/// A streaming LLM completion session. +/// +/// Implements [`Stream`] yielding [`Response`] items. Cancelling the stream +/// (via [`CompletionStream::cancel`] or by dropping) stops the underlying +/// inference and joins the worker thread. +pub struct CompletionStream { + inner: UnboundedReceiverStream, + cancellation_token: CancellationToken, + handle: Option>, +} + +impl CompletionStream { + /// Returns a reference to the cancellation token for external use + /// (e.g. attaching a `drop_guard`). + pub fn cancellation_token(&self) -> &CancellationToken { + &self.cancellation_token + } + + /// Signal the worker to stop generating tokens. + pub fn cancel(&self) { + self.cancellation_token.cancel(); + } +} + +impl Stream for CompletionStream { + type Item = Response; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } +} + +impl Drop for CompletionStream { + fn drop(&mut self) { + self.cancellation_token.cancel(); + if let Some(handle) = self.handle.take() { + // Detach: don't block the (possibly async) caller. + // Spawn a background thread to join so we still log panics. + std::thread::spawn(move || { + if let Err(panic) = handle.join() { + tracing::error!(?panic, "cactus_completion_worker_panicked"); + } + }); + } + } +} + pub fn complete_stream( model: &Arc, messages: Vec, options: CompleteOptions, -) -> Result<( - impl futures_util::Stream + 'static, - CancellationToken, - std::thread::JoinHandle<()>, -)> { +) -> Result { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let cancellation_token = CancellationToken::new(); @@ -97,6 +143,10 @@ pub fn complete_stream( run_stream_worker(model, messages, options, worker_cancellation_token, tx); }); - let stream = UnboundedReceiverStream::new(rx); - Ok((stream, cancellation_token, handle)) + let inner = UnboundedReceiverStream::new(rx); + Ok(CompletionStream { + inner, + cancellation_token, + handle: Some(handle), + }) } diff --git a/crates/cactus/src/model.rs b/crates/cactus/src/model.rs index 4b523caf29..d9a9caf89f 100644 --- a/crates/cactus/src/model.rs +++ b/crates/cactus/src/model.rs @@ -85,10 +85,13 @@ impl Model { } pub(crate) fn lock_inference(&self) -> InferenceGuard<'_> { - let guard = self - .inference_lock - .lock() - .unwrap_or_else(|e| e.into_inner()); + let guard = self.inference_lock.lock().unwrap_or_else(|e| { + tracing::warn!( + "inference mutex was poisoned (a previous FFI call likely panicked); \ + recovering, but model state may be inconsistent" + ); + e.into_inner() + }); InferenceGuard { handle: self.handle, _guard: guard, diff --git a/crates/cactus/src/stt/mod.rs b/crates/cactus/src/stt/mod.rs index e57217c1b5..bebed2ac4e 100644 --- a/crates/cactus/src/stt/mod.rs +++ b/crates/cactus/src/stt/mod.rs @@ -5,7 +5,7 @@ mod transcriber; mod whisper; pub use result::TranscriptionResult; -pub use stream::{TranscribeEvent, transcribe_stream}; +pub use stream::{TranscribeEvent, TranscriptionSession, transcribe_stream}; pub use transcriber::{CloudConfig, StreamResult, Transcriber}; use hypr_language::Language; diff --git a/crates/cactus/src/stt/stream.rs b/crates/cactus/src/stt/stream.rs index 3139f37912..a9e3231e57 100644 --- a/crates/cactus/src/stt/stream.rs +++ b/crates/cactus/src/stt/stream.rs @@ -1,5 +1,8 @@ +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; +use futures_util::Stream; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; @@ -15,18 +18,65 @@ pub struct TranscribeEvent { pub chunk_duration_secs: f64, } +/// A live transcription session. +/// +/// Feed audio via [`TranscriptionSession::audio_tx`] and consume +/// [`TranscribeEvent`]s by polling the session as a [`Stream`]. +/// Dropping the session cancels the worker and joins its thread. +pub struct TranscriptionSession { + audio_tx: tokio::sync::mpsc::Sender>, + inner: ReceiverStream>, + cancellation_token: CancellationToken, + handle: Option>, +} + +impl TranscriptionSession { + /// Channel for sending audio samples to the transcription worker. + pub fn audio_tx(&self) -> &tokio::sync::mpsc::Sender> { + &self.audio_tx + } + + /// Returns a reference to the cancellation token. + pub fn cancellation_token(&self) -> &CancellationToken { + &self.cancellation_token + } + + /// Signal the worker to stop transcribing. + pub fn cancel(&self) { + self.cancellation_token.cancel(); + } +} + +impl Stream for TranscriptionSession { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } +} + +impl Drop for TranscriptionSession { + fn drop(&mut self) { + self.cancellation_token.cancel(); + if let Some(handle) = self.handle.take() { + // Detach: don't block the (possibly async) caller. + // Spawn a background thread to join so we still log panics. + std::thread::spawn(move || { + if let Err(panic) = handle.join() { + tracing::error!(?panic, "cactus_transcribe_worker_panicked"); + } + }); + } + } +} + pub fn transcribe_stream( model: Arc, options: TranscribeOptions, cloud: CloudConfig, chunk_size_ms: u32, sample_rate: u32, -) -> ( - tokio::sync::mpsc::Sender>, - impl futures_util::Stream> + Unpin, - CancellationToken, - std::thread::JoinHandle<()>, -) { +) -> TranscriptionSession { let (audio_tx, audio_rx) = tokio::sync::mpsc::channel::>(64); let (event_tx, event_rx) = tokio::sync::mpsc::channel(64); let cancellation_token = CancellationToken::new(); @@ -45,8 +95,13 @@ pub fn transcribe_stream( ); }); - let stream = ReceiverStream::new(event_rx); - (audio_tx, stream, cancellation_token, handle) + let inner = ReceiverStream::new(event_rx); + TranscriptionSession { + audio_tx, + inner, + cancellation_token, + handle: Some(handle), + } } fn run_transcribe_worker( diff --git a/crates/cactus/src/stt/transcriber.rs b/crates/cactus/src/stt/transcriber.rs index d4b781dd1d..65e8d7efca 100644 --- a/crates/cactus/src/stt/transcriber.rs +++ b/crates/cactus/src/stt/transcriber.rs @@ -1,12 +1,32 @@ use std::ffi::CString; use std::ptr::NonNull; +use serde::Deserialize; + use crate::error::{Error, Result}; use crate::ffi_utils::{RESPONSE_BUF_SIZE, read_cstr_from_buf}; use crate::model::Model; use super::TranscribeOptions; +/// Deserialise a JSON number that may arrive as either an integer or a float +/// (the C++ side stores token counts as `double` and serialises via +/// `operator<<`, which *usually* omits the decimal point for whole values but +/// is not guaranteed to). +fn deserialize_number_as_u64<'de, D>(deserializer: D) -> std::result::Result +where + D: serde::Deserializer<'de>, +{ + let v = serde_json::Value::deserialize(deserializer)?; + match v { + serde_json::Value::Number(n) => n + .as_u64() + .or_else(|| n.as_f64().map(|f| f as u64)) + .ok_or_else(|| serde::de::Error::custom("expected non-negative number")), + _ => Err(serde::de::Error::custom("expected a number")), + } +} + /// Cloud handoff configuration for streaming STT. /// /// `api_key` enables real cloud transcription requests (sent via @@ -88,12 +108,12 @@ pub struct StreamResult { pub decode_tps: f64, #[serde(default)] pub ram_usage_mb: f64, - #[serde(default)] - pub prefill_tokens: f64, - #[serde(default)] - pub decode_tokens: f64, - #[serde(default)] - pub total_tokens: f64, + #[serde(default, deserialize_with = "deserialize_number_as_u64")] + pub prefill_tokens: u64, + #[serde(default, deserialize_with = "deserialize_number_as_u64")] + pub decode_tokens: u64, + #[serde(default, deserialize_with = "deserialize_number_as_u64")] + pub total_tokens: u64, } impl std::str::FromStr for StreamResult { diff --git a/crates/llm-cactus/src/service.rs b/crates/llm-cactus/src/service.rs index 75ce9315b6..06e4813474 100644 --- a/crates/llm-cactus/src/service.rs +++ b/crates/llm-cactus/src/service.rs @@ -64,9 +64,9 @@ impl Service> for CompleteService { let options = build_options(&request); if request.stream.unwrap_or(false) { - let (stream, cancellation_token, _worker_handle) = + let completion_stream = match hypr_cactus::complete_stream(&model, messages, options) { - Ok(tuple) => tuple, + Ok(s) => s, Err(e) => { return Ok( (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() @@ -74,11 +74,7 @@ impl Service> for CompleteService { } }; - Ok(build_streaming_response( - stream, - cancellation_token, - &request.model, - )) + Ok(build_streaming_response(completion_stream, &request.model)) } else { Ok(build_non_streaming_response(&model, messages, options, &request.model).await) } @@ -134,8 +130,7 @@ fn model_name(model: &Option) -> &str { } fn build_streaming_response( - stream: impl futures_util::Stream + Send + 'static, - cancellation_token: tokio_util::sync::CancellationToken, + completion_stream: hypr_cactus::CompletionStream, model: &Option, ) -> Response { let id = format!("chatcmpl-{}", uuid::Uuid::new_v4()); @@ -150,7 +145,7 @@ fn build_streaming_response( type SseResult = Result; - let data_events = stream.filter_map(move |item| { + let data_events = completion_stream.filter_map(move |item| { let id = id_for_events.clone(); let model_name = model_for_events.clone(); @@ -211,24 +206,7 @@ fn build_streaming_response( Ok::<_, std::convert::Infallible>(sse::Event::default().data("[DONE]")), )); - // drop_guard ensures inference is cancelled when the client disconnects and the - // stream is dropped (channel closure also triggers model.stop() inside the worker, - // so this is belt-and-suspenders). - let drop_guard = cancellation_token.drop_guard(); - - let event_stream = stream::unfold( - ( - Box::pin(data_events.chain(stop_event).chain(done_event)) - as Pin + Send>>, - Some(drop_guard), - ), - |(mut s, guard)| async move { - match s.next().await { - Some(item) => Some((item, (s, guard))), - None => None, - } - }, - ); + let event_stream = data_events.chain(stop_event).chain(done_event); sse::Sse::new(event_stream).into_response() } diff --git a/crates/transcribe-cactus/src/service/streaming/session.rs b/crates/transcribe-cactus/src/service/streaming/session.rs index f6bdc39f05..a54571e451 100644 --- a/crates/transcribe-cactus/src/service/streaming/session.rs +++ b/crates/transcribe-cactus/src/service/streaming/session.rs @@ -67,8 +67,6 @@ pub(super) async fn handle_websocket( let mut event_streams: futures_util::stream::SelectAll = futures_util::stream::SelectAll::new(); - let mut worker_handles = Vec::with_capacity(total_channels); - let model = match hypr_cactus::Model::builder(&model_path).build() { Ok(m) => std::sync::Arc::new(m), Err(e) => { @@ -79,17 +77,16 @@ pub(super) async fn handle_websocket( for ch_idx in 0..total_channels { let cloud_config = cactus_config.cloud.clone(); - let (audio_tx, event_rx, cancel_token, handle) = hypr_cactus::transcribe_stream( + let session = hypr_cactus::transcribe_stream( model.clone(), options.clone(), cloud_config, chunk_size_ms, SAMPLE_RATE, ); - audio_txs.push(audio_tx); - cancel_tokens.push(cancel_token); - worker_handles.push(handle); - event_streams.push(Box::pin(event_rx.map(move |e| (ch_idx, e)))); + audio_txs.push(session.audio_tx().clone()); + cancel_tokens.push(session.cancellation_token().clone()); + event_streams.push(Box::pin(session.map(move |e| (ch_idx, e)))); } let mut channel_states: Vec = (0..total_channels) @@ -122,15 +119,11 @@ pub(super) async fn handle_websocket( } } + // Dropping audio senders signals workers to finish, then dropping + // event_streams drops the TranscriptionSessions which cancel + join workers. drop(audio_txs); drop(event_streams); - for handle in worker_handles { - if let Err(panic) = handle.join() { - tracing::error!(?panic, "cactus_transcribe_worker_panicked"); - } - } - let total_audio_offset = channel_states.first().map_or(0.0, |s| s.audio_offset); send_ws_best_effort(