Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/cactus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ mod vad;

pub use error::Error;
pub use hypr_language::Language;
pub use llm::{CompleteOptions, CompletionResult, Message, complete_stream};
pub use llm::{CompleteOptions, CompletionResult, CompletionStream, Message, complete_stream};
pub use model::{Model, ModelBuilder};
pub use stt::{
CloudConfig, StreamResult, TranscribeEvent, TranscribeOptions, Transcriber,
TranscriptionResult, constrain_to, transcribe_stream,
TranscriptionResult, TranscriptionSession, constrain_to, transcribe_stream,
};
pub use vad::{VadOptions, VadResult, VadSegment};

Expand Down
2 changes: 1 addition & 1 deletion crates/cactus/src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod stream;

pub use hypr_llm_types::Message;
pub use result::CompletionResult;
pub use stream::complete_stream;
pub use stream::{CompletionStream, complete_stream};

#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct CompleteOptions {
Expand Down
6 changes: 3 additions & 3 deletions crates/cactus/src/llm/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ pub struct CompletionResult {
#[serde(default)]
pub decode_tps: f64,
#[serde(default)]
pub prefill_tokens: u32,
pub prefill_tokens: u64,
#[serde(default)]
pub decode_tokens: u32,
pub decode_tokens: u64,
#[serde(default)]
pub total_tokens: u32,
pub total_tokens: u64,
}
64 changes: 57 additions & 7 deletions crates/cactus/src/llm/stream.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures_util::Stream;
use hypr_llm_types::{Response, StreamingParser};
use tokio::sync::mpsc::UnboundedSender;
use tokio_stream::wrappers::UnboundedReceiverStream;
Expand Down Expand Up @@ -78,15 +81,58 @@ fn run_stream_worker(
worker.run(&messages, &options);
}

/// A streaming LLM completion session.
///
/// Implements [`Stream`] yielding [`Response`] items. Cancelling the stream
/// (via [`CompletionStream::cancel`] or by dropping) stops the underlying
/// inference and joins the worker thread.
pub struct CompletionStream {
inner: UnboundedReceiverStream<Response>,
cancellation_token: CancellationToken,
handle: Option<std::thread::JoinHandle<()>>,
}

impl CompletionStream {
/// Returns a reference to the cancellation token for external use
/// (e.g. attaching a `drop_guard`).
pub fn cancellation_token(&self) -> &CancellationToken {
&self.cancellation_token
}

/// Signal the worker to stop generating tokens.
pub fn cancel(&self) {
self.cancellation_token.cancel();
}
}

impl Stream for CompletionStream {
type Item = Response;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}

impl Drop for CompletionStream {
fn drop(&mut self) {
self.cancellation_token.cancel();
if let Some(handle) = self.handle.take() {
// Detach: don't block the (possibly async) caller.
// Spawn a background thread to join so we still log panics.
std::thread::spawn(move || {
if let Err(panic) = handle.join() {
tracing::error!(?panic, "cactus_completion_worker_panicked");
}
});
}
}
}

pub fn complete_stream(
model: &Arc<Model>,
messages: Vec<Message>,
options: CompleteOptions,
) -> Result<(
impl futures_util::Stream<Item = Response> + 'static,
CancellationToken,
std::thread::JoinHandle<()>,
)> {
) -> Result<CompletionStream> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let cancellation_token = CancellationToken::new();

Expand All @@ -97,6 +143,10 @@ pub fn complete_stream(
run_stream_worker(model, messages, options, worker_cancellation_token, tx);
});

let stream = UnboundedReceiverStream::new(rx);
Ok((stream, cancellation_token, handle))
let inner = UnboundedReceiverStream::new(rx);
Ok(CompletionStream {
inner,
cancellation_token,
handle: Some(handle),
})
}
11 changes: 7 additions & 4 deletions crates/cactus/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ impl Model {
}

pub(crate) fn lock_inference(&self) -> InferenceGuard<'_> {
let guard = self
.inference_lock
.lock()
.unwrap_or_else(|e| e.into_inner());
let guard = self.inference_lock.lock().unwrap_or_else(|e| {
tracing::warn!(
"inference mutex was poisoned (a previous FFI call likely panicked); \
recovering, but model state may be inconsistent"
);
e.into_inner()
});
InferenceGuard {
handle: self.handle,
_guard: guard,
Expand Down
2 changes: 1 addition & 1 deletion crates/cactus/src/stt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod transcriber;
mod whisper;

pub use result::TranscriptionResult;
pub use stream::{TranscribeEvent, transcribe_stream};
pub use stream::{TranscribeEvent, TranscriptionSession, transcribe_stream};
pub use transcriber::{CloudConfig, StreamResult, Transcriber};

use hypr_language::Language;
Expand Down
71 changes: 63 additions & 8 deletions crates/cactus/src/stt/stream.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures_util::Stream;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;

Expand All @@ -15,18 +18,65 @@ pub struct TranscribeEvent {
pub chunk_duration_secs: f64,
}

/// A live transcription session.
///
/// Feed audio via [`TranscriptionSession::audio_tx`] and consume
/// [`TranscribeEvent`]s by polling the session as a [`Stream`].
/// Dropping the session cancels the worker and joins its thread.
pub struct TranscriptionSession {
audio_tx: tokio::sync::mpsc::Sender<Vec<f32>>,
inner: ReceiverStream<Result<TranscribeEvent>>,
cancellation_token: CancellationToken,
handle: Option<std::thread::JoinHandle<()>>,
}

impl TranscriptionSession {
/// Channel for sending audio samples to the transcription worker.
pub fn audio_tx(&self) -> &tokio::sync::mpsc::Sender<Vec<f32>> {
&self.audio_tx
}

/// Returns a reference to the cancellation token.
pub fn cancellation_token(&self) -> &CancellationToken {
&self.cancellation_token
}

/// Signal the worker to stop transcribing.
pub fn cancel(&self) {
self.cancellation_token.cancel();
}
}

impl Stream for TranscriptionSession {
type Item = Result<TranscribeEvent>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}

impl Drop for TranscriptionSession {
fn drop(&mut self) {
self.cancellation_token.cancel();
if let Some(handle) = self.handle.take() {
// Detach: don't block the (possibly async) caller.
// Spawn a background thread to join so we still log panics.
std::thread::spawn(move || {
if let Err(panic) = handle.join() {
tracing::error!(?panic, "cactus_transcribe_worker_panicked");
}
});
}
}
}

pub fn transcribe_stream(
model: Arc<Model>,
options: TranscribeOptions,
cloud: CloudConfig,
chunk_size_ms: u32,
sample_rate: u32,
) -> (
tokio::sync::mpsc::Sender<Vec<f32>>,
impl futures_util::Stream<Item = Result<TranscribeEvent>> + Unpin,
CancellationToken,
std::thread::JoinHandle<()>,
) {
) -> TranscriptionSession {
let (audio_tx, audio_rx) = tokio::sync::mpsc::channel::<Vec<f32>>(64);
let (event_tx, event_rx) = tokio::sync::mpsc::channel(64);
let cancellation_token = CancellationToken::new();
Expand All @@ -45,8 +95,13 @@ pub fn transcribe_stream(
);
});

let stream = ReceiverStream::new(event_rx);
(audio_tx, stream, cancellation_token, handle)
let inner = ReceiverStream::new(event_rx);
TranscriptionSession {
audio_tx,
inner,
cancellation_token,
handle: Some(handle),
}
}

fn run_transcribe_worker(
Expand Down
32 changes: 26 additions & 6 deletions crates/cactus/src/stt/transcriber.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
use std::ffi::CString;
use std::ptr::NonNull;

use serde::Deserialize;

use crate::error::{Error, Result};
use crate::ffi_utils::{RESPONSE_BUF_SIZE, read_cstr_from_buf};
use crate::model::Model;

use super::TranscribeOptions;

/// Deserialise a JSON number that may arrive as either an integer or a float
/// (the C++ side stores token counts as `double` and serialises via
/// `operator<<`, which *usually* omits the decimal point for whole values but
/// is not guaranteed to).
fn deserialize_number_as_u64<'de, D>(deserializer: D) -> std::result::Result<u64, D::Error>
where
D: serde::Deserializer<'de>,
{
let v = serde_json::Value::deserialize(deserializer)?;
match v {
serde_json::Value::Number(n) => n
.as_u64()
.or_else(|| n.as_f64().map(|f| f as u64))
.ok_or_else(|| serde::de::Error::custom("expected non-negative number")),
_ => Err(serde::de::Error::custom("expected a number")),
}
}

/// Cloud handoff configuration for streaming STT.
///
/// `api_key` enables real cloud transcription requests (sent via
Expand Down Expand Up @@ -88,12 +108,12 @@ pub struct StreamResult {
pub decode_tps: f64,
#[serde(default)]
pub ram_usage_mb: f64,
#[serde(default)]
pub prefill_tokens: f64,
#[serde(default)]
pub decode_tokens: f64,
#[serde(default)]
pub total_tokens: f64,
#[serde(default, deserialize_with = "deserialize_number_as_u64")]
pub prefill_tokens: u64,
#[serde(default, deserialize_with = "deserialize_number_as_u64")]
pub decode_tokens: u64,
#[serde(default, deserialize_with = "deserialize_number_as_u64")]
pub total_tokens: u64,
}

impl std::str::FromStr for StreamResult {
Expand Down
34 changes: 6 additions & 28 deletions crates/llm-cactus/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,17 @@ impl Service<Request<Body>> for CompleteService {
let options = build_options(&request);

if request.stream.unwrap_or(false) {
let (stream, cancellation_token, _worker_handle) =
let completion_stream =
match hypr_cactus::complete_stream(&model, messages, options) {
Ok(tuple) => tuple,
Ok(s) => s,
Err(e) => {
return Ok(
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
);
}
};

Ok(build_streaming_response(
stream,
cancellation_token,
&request.model,
))
Ok(build_streaming_response(completion_stream, &request.model))
} else {
Ok(build_non_streaming_response(&model, messages, options, &request.model).await)
}
Expand Down Expand Up @@ -134,8 +130,7 @@ fn model_name(model: &Option<String>) -> &str {
}

fn build_streaming_response(
stream: impl futures_util::Stream<Item = LlmResponse> + Send + 'static,
cancellation_token: tokio_util::sync::CancellationToken,
completion_stream: hypr_cactus::CompletionStream,
model: &Option<String>,
) -> Response {
let id = format!("chatcmpl-{}", uuid::Uuid::new_v4());
Expand All @@ -150,7 +145,7 @@ fn build_streaming_response(

type SseResult = Result<sse::Event, std::convert::Infallible>;

let data_events = stream.filter_map(move |item| {
let data_events = completion_stream.filter_map(move |item| {
let id = id_for_events.clone();
let model_name = model_for_events.clone();

Expand Down Expand Up @@ -211,24 +206,7 @@ fn build_streaming_response(
Ok::<_, std::convert::Infallible>(sse::Event::default().data("[DONE]")),
));

// drop_guard ensures inference is cancelled when the client disconnects and the
// stream is dropped (channel closure also triggers model.stop() inside the worker,
// so this is belt-and-suspenders).
let drop_guard = cancellation_token.drop_guard();

let event_stream = stream::unfold(
(
Box::pin(data_events.chain(stop_event).chain(done_event))
as Pin<Box<dyn futures_util::Stream<Item = SseResult> + Send>>,
Some(drop_guard),
),
|(mut s, guard)| async move {
match s.next().await {
Some(item) => Some((item, (s, guard))),
None => None,
}
},
);
let event_stream = data_events.chain(stop_event).chain(done_event);

sse::Sse::new(event_stream).into_response()
}
Expand Down
Loading
Loading