From c577732327f6c4da4980b6f6ecbb94bc6ac9c926 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 22 Feb 2026 01:44:56 +0000 Subject: [PATCH 1/5] share single model across channels instead of creating one per channel Co-Authored-By: yujonglee --- .../src/service/streaming/session.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/crates/transcribe-cactus/src/service/streaming/session.rs b/crates/transcribe-cactus/src/service/streaming/session.rs index c9aa0212f5..044ac5c3d7 100644 --- a/crates/transcribe-cactus/src/service/streaming/session.rs +++ b/crates/transcribe-cactus/src/service/streaming/session.rs @@ -69,14 +69,15 @@ pub(super) async fn handle_websocket( let mut worker_handles = Vec::with_capacity(total_channels); + let model = match hypr_cactus::Model::builder(&model_path).build() { + Ok(m) => std::sync::Arc::new(m), + Err(e) => { + tracing::error!(error = %e, "failed to load model"); + return; + } + }; + for ch_idx in 0..total_channels { - let model = match hypr_cactus::Model::builder(&model_path).build() { - Ok(m) => std::sync::Arc::new(m), - Err(e) => { - tracing::error!(error = %e, "failed to load model for channel {ch_idx}"); - return; - } - }; let cloud_config = if cactus_config.cloud_handoff { hypr_cactus::CloudConfig::default() } else { @@ -86,7 +87,7 @@ pub(super) async fn handle_websocket( } }; let (audio_tx, event_rx, cancel_token, handle) = hypr_cactus::transcribe_stream( - model, + model.clone(), options.clone(), cloud_config, chunk_size_ms, From 0ac0c9679180904b8cd369f644d3318bdcc94dbf Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 22 Feb 2026 01:50:39 +0000 Subject: [PATCH 2/5] add 2-channel e2e tests and local-stt-e2e workflow Co-Authored-By: yujonglee --- .github/workflows/local_stt_e2e.yaml | 47 +++ .../src/service/streaming/tests.rs | 344 ++++++++++++++++++ 2 files changed, 391 insertions(+) create mode 100644 .github/workflows/local_stt_e2e.yaml diff --git a/.github/workflows/local_stt_e2e.yaml b/.github/workflows/local_stt_e2e.yaml new file mode 100644 index 0000000000..f5f5cf2f42 --- /dev/null +++ b/.github/workflows/local_stt_e2e.yaml @@ -0,0 +1,47 @@ +on: + workflow_dispatch: + push: + branches: [main] + paths: + - crates/transcribe-cactus/** + - crates/cactus/** + - crates/cactus-sys/** + pull_request: + paths: + - crates/transcribe-cactus/** + - crates/cactus/** + - crates/cactus-sys/** + +jobs: + local-stt-e2e: + runs-on: depot-ubuntu-24.04-arm-8 + defaults: + run: + shell: bash + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + - uses: ./.github/actions/rust_install + with: + platform: linux + - run: | + sudo apt-get update + sudo apt-get install -y cmake build-essential libcurl4-openssl-dev libclang-dev + - run: | + pip3 install --break-system-packages huggingface-hub + pip3 install --break-system-packages -e vendor/cactus/python/ --no-deps + - uses: actions/cache@v4 + with: + path: vendor/cactus/weights/ + key: cactus-models-arm-v1 + - run: cactus download UsefulSensors/moonshine-base + - run: cargo test -p transcribe-cactus e2e_streaming -- --ignored --nocapture + env: + CACTUS_STT_MODEL: ${{ github.workspace }}/vendor/cactus/weights/moonshine-base + - run: cargo test -p transcribe-cactus e2e_streaming_dual_channel -- --ignored --nocapture + env: + CACTUS_STT_MODEL: ${{ github.workspace }}/vendor/cactus/weights/moonshine-base + - run: cargo test -p transcribe-cactus e2e_websocket_dual_channel -- --ignored --nocapture + env: + CACTUS_STT_MODEL: ${{ github.workspace }}/vendor/cactus/weights/moonshine-base diff --git a/crates/transcribe-cactus/src/service/streaming/tests.rs b/crates/transcribe-cactus/src/service/streaming/tests.rs index 0bdb741534..a309a71d55 100644 --- a/crates/transcribe-cactus/src/service/streaming/tests.rs +++ b/crates/transcribe-cactus/src/service/streaming/tests.rs @@ -402,6 +402,350 @@ fn e2e_streaming() { assert!(!full_transcript.is_empty(), "expected non-empty transcript"); } +// cargo test -p transcribe-cactus e2e_streaming_dual_channel -- --ignored --nocapture +#[ignore = "requires local cactus model files"] +#[test] +fn e2e_streaming_dual_channel() { + let model_path = std::env::var("CACTUS_STT_MODEL") + .unwrap_or_else(|_| "/tmp/cactus-model/moonshine-base-cactus".to_string()); + let model_path = PathBuf::from(model_path); + assert!( + model_path.exists(), + "model not found: {}", + model_path.display() + ); + + let model = std::sync::Arc::new(hypr_cactus::Model::new(&model_path).unwrap()); + let options = hypr_cactus::TranscribeOptions::default(); + let chunk_size_ms = 300u32; + + let (audio_tx_0, mut event_stream_0, _cancel_0, _handle_0) = hypr_cactus::transcribe_stream( + model.clone(), + options.clone(), + hypr_cactus::CloudConfig::default(), + chunk_size_ms, + SAMPLE_RATE, + ); + let (audio_tx_1, mut event_stream_1, _cancel_1, _handle_1) = hypr_cactus::transcribe_stream( + model, + options, + hypr_cactus::CloudConfig::default(), + chunk_size_ms, + SAMPLE_RATE, + ); + + let samples_ch0 = bytes_to_f32_samples(hypr_data::english_1::AUDIO); + let samples_ch1 = bytes_to_f32_samples(hypr_data::english_2::AUDIO); + let chunk_size = 8_000; + + println!( + "\n--- dual channel: ch0={:.1}s ch1={:.1}s ---", + samples_ch0.len() as f64 / 16_000.0, + samples_ch1.len() as f64 / 16_000.0, + ); + + let t0 = std::time::Instant::now(); + + let sender = std::thread::spawn(move || { + let max_chunks = samples_ch0 + .chunks(chunk_size) + .count() + .max(samples_ch1.chunks(chunk_size).count()); + let mut ch0_iter = samples_ch0.chunks(chunk_size); + let mut ch1_iter = samples_ch1.chunks(chunk_size); + + for i in 0..max_chunks { + if let Some(chunk) = ch0_iter.next() { + audio_tx_0.blocking_send(chunk.to_vec()).expect("ch0 send"); + } + if let Some(chunk) = ch1_iter.next() { + audio_tx_1.blocking_send(chunk.to_vec()).expect("ch1 send"); + } + if i % 20 == 0 { + println!( + "[{:>6.1}s] sent chunk {}", + t0.elapsed().as_secs_f64(), + i, + ); + } + } + println!( + "[{:>6.1}s] all chunks sent", + t0.elapsed().as_secs_f64(), + ); + }); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to create tokio runtime"); + + let t0 = std::time::Instant::now(); + let mut transcript_0 = String::new(); + let mut transcript_1 = String::new(); + let mut events_0 = 0u32; + let mut events_1 = 0u32; + + loop { + let event = rt.block_on(async { + tokio::select! { + e = event_stream_0.next() => e.map(|e| (0, e)), + e = event_stream_1.next() => e.map(|e| (1, e)), + } + }); + + let Some((ch, event)) = event else { break }; + match event { + Ok(hypr_cactus::TranscribeEvent { result: r, .. }) => { + let confirmed = r.confirmed.trim(); + let pending = r.pending.trim(); + if !confirmed.is_empty() || !pending.is_empty() { + println!( + "[{:>6.1}s] ch{ch} confirmed={:?} pending={:?}", + t0.elapsed().as_secs_f64(), + confirmed, + pending, + ); + } + if !confirmed.is_empty() { + let transcript = if ch == 0 { + &mut transcript_0 + } else { + &mut transcript_1 + }; + if !transcript.is_empty() { + transcript.push(' '); + } + transcript.push_str(confirmed); + } + if ch == 0 { + events_0 += 1; + } else { + events_1 += 1; + } + } + Err(e) => panic!("streaming error ch{ch}: {e}"), + } + } + + sender.join().expect("sender thread panicked"); + + println!("\n--- DUAL CHANNEL RESULTS ---"); + println!("ch0 ({events_0} events): {transcript_0}"); + println!("ch1 ({events_1} events): {transcript_1}"); + println!("---\n"); + + assert!(events_0 > 0, "expected events from channel 0"); + assert!(events_1 > 0, "expected events from channel 1"); +} + +// cargo test -p transcribe-cactus e2e_websocket_dual_channel -- --ignored --nocapture +#[ignore = "requires local cactus model files"] +#[test] +fn e2e_websocket_dual_channel() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to create tokio runtime"); + rt.block_on(async { + let model_path = std::env::var("CACTUS_STT_MODEL") + .unwrap_or_else(|_| "/tmp/cactus-model/moonshine-base-cactus".to_string()); + let model_path = PathBuf::from(model_path); + assert!( + model_path.exists(), + "model not found: {}", + model_path.display() + ); + + let service = HandleError::new( + TranscribeService::builder().model_path(model_path).build(), + |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, + ); + let app = Router::new().route_service("/v1/listen", service); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + tokio::spawn(async move { + let _ = axum::serve(listener, app) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await; + }); + + let ws_url = format!( + "ws://{}/v1/listen?channels=2&sample_rate=16000&chunk_size_ms=300", + addr + ); + let (ws_stream, _) = connect_async(&ws_url).await.expect("ws connect failed"); + let (mut ws_tx, mut ws_rx) = ws_stream.split(); + + // Build interleaved stereo PCM: ch0=english_1, ch1=english_2 + let audio_ch0 = hypr_data::english_1::AUDIO; + let audio_ch1 = hypr_data::english_2::AUDIO; + let frames_ch0 = audio_ch0.len() / 2; // i16 = 2 bytes + let frames_ch1 = audio_ch1.len() / 2; + let num_frames = frames_ch0.min(frames_ch1); + + let mut interleaved = Vec::with_capacity(num_frames * 4); + for i in 0..num_frames { + // ch0 sample (i16 LE) + interleaved.push(audio_ch0[i * 2]); + interleaved.push(audio_ch0[i * 2 + 1]); + // ch1 sample (i16 LE) + interleaved.push(audio_ch1[i * 2]); + interleaved.push(audio_ch1[i * 2 + 1]); + } + + let t0 = std::time::Instant::now(); + let chunk_bytes = 64_000; // 1s of stereo 16kHz i16 PCM (4 bytes per frame) + let num_chunks = 5; + println!( + "\n--- ws dual: sending {} chunks of {}B ({:.1}s each) ---", + num_chunks, + chunk_bytes, + chunk_bytes as f64 / 64_000.0, + ); + + let (close_tx, close_rx) = tokio::sync::oneshot::channel::<()>(); + let close_tx = std::cell::Cell::new(Some(close_tx)); + + let writer = tokio::spawn(async move { + for (i, chunk) in interleaved + .chunks(chunk_bytes) + .take(num_chunks) + .enumerate() + { + ws_tx + .send(WsMessage::Binary(chunk.to_vec().into())) + .await + .unwrap(); + println!("[{:>5.1}s] ws sent chunk {}", t0.elapsed().as_secs_f64(), i); + } + + let _ = close_rx.await; + println!( + "[{:>5.1}s] ws sending CloseStream", + t0.elapsed().as_secs_f64() + ); + let _ = ws_tx + .send(WsMessage::Text( + r#"{"type":"CloseStream"}"#.to_string().into(), + )) + .await; + }); + + let mut results_count = 0u32; + let mut saw_terminal = false; + let mut saw_error: Option = None; + let mut close_sent = false; + let mut channels_seen = std::collections::HashSet::new(); + + while let Ok(Some(Ok(msg))) = + tokio::time::timeout(Duration::from_secs(120), ws_rx.next()).await + { + match msg { + WsMessage::Text(text) => { + let Ok(v) = serde_json::from_str::(&text) else { + continue; + }; + let msg_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("?"); + match msg_type { + "Results" => { + let transcript = v + .pointer("/channel/alternatives/0/transcript") + .and_then(|t| t.as_str()) + .unwrap_or(""); + let ch_idx = v + .pointer("/channel_index/0") + .and_then(|c| c.as_i64()) + .unwrap_or(-1); + channels_seen.insert(ch_idx); + println!( + "[{:>5.1}s] ws recv Results ch={} {:?}", + t0.elapsed().as_secs_f64(), + ch_idx, + transcript, + ); + results_count += 1; + + if results_count >= 6 && !close_sent { + close_sent = true; + if let Some(tx) = close_tx.take() { + let _ = tx.send(()); + } + } + } + "Metadata" => { + println!( + "[{:>5.1}s] ws recv Metadata (terminal)", + t0.elapsed().as_secs_f64() + ); + saw_terminal = true; + break; + } + "SpeechStarted" => { + println!( + "[{:>5.1}s] ws recv SpeechStarted", + t0.elapsed().as_secs_f64() + ); + } + "UtteranceEnd" => { + println!( + "[{:>5.1}s] ws recv UtteranceEnd", + t0.elapsed().as_secs_f64() + ); + } + "Error" => { + saw_error = v + .get("error_message") + .and_then(|m| m.as_str()) + .map(str::to_owned); + println!( + "[{:>5.1}s] ws recv Error: {:?}", + t0.elapsed().as_secs_f64(), + saw_error + ); + break; + } + other => { + println!( + "[{:>5.1}s] ws recv {other}", + t0.elapsed().as_secs_f64() + ); + } + } + } + WsMessage::Close(_) => break, + _ => {} + } + } + + let _ = writer.await; + let _ = shutdown_tx.send(()); + println!( + "[{:>5.1}s] done ({} Results, terminal={}, channels_seen={:?})\n", + t0.elapsed().as_secs_f64(), + results_count, + saw_terminal, + channels_seen, + ); + + assert!(saw_error.is_none(), "ws error: {:?}", saw_error); + assert!(results_count > 0, "expected at least one Results message"); + assert!(saw_terminal, "expected terminal Metadata message"); + assert!( + channels_seen.contains(&0) && channels_seen.contains(&1), + "expected results from both channels, got {:?}", + channels_seen, + ); + }); +} + // cargo test -p transcribe-cactus e2e_websocket -- --ignored --nocapture #[ignore = "requires local cactus model files"] #[test] From 8327a868bff03627f8beaeb2442a9aebc526cc42 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 22 Feb 2026 01:57:30 +0000 Subject: [PATCH 3/5] fix formatting with dprint Co-Authored-By: yujonglee --- .../src/service/streaming/tests.rs | 26 ++++--------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/crates/transcribe-cactus/src/service/streaming/tests.rs b/crates/transcribe-cactus/src/service/streaming/tests.rs index a309a71d55..85c7c51021 100644 --- a/crates/transcribe-cactus/src/service/streaming/tests.rs +++ b/crates/transcribe-cactus/src/service/streaming/tests.rs @@ -462,17 +462,10 @@ fn e2e_streaming_dual_channel() { audio_tx_1.blocking_send(chunk.to_vec()).expect("ch1 send"); } if i % 20 == 0 { - println!( - "[{:>6.1}s] sent chunk {}", - t0.elapsed().as_secs_f64(), - i, - ); + println!("[{:>6.1}s] sent chunk {}", t0.elapsed().as_secs_f64(), i,); } } - println!( - "[{:>6.1}s] all chunks sent", - t0.elapsed().as_secs_f64(), - ); + println!("[{:>6.1}s] all chunks sent", t0.elapsed().as_secs_f64(),); }); let rt = tokio::runtime::Builder::new_current_thread() @@ -563,9 +556,7 @@ fn e2e_websocket_dual_channel() { ); let app = Router::new().route_service("/v1/listen", service); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .unwrap(); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); @@ -615,11 +606,7 @@ fn e2e_websocket_dual_channel() { let close_tx = std::cell::Cell::new(Some(close_tx)); let writer = tokio::spawn(async move { - for (i, chunk) in interleaved - .chunks(chunk_bytes) - .take(num_chunks) - .enumerate() - { + for (i, chunk) in interleaved.chunks(chunk_bytes).take(num_chunks).enumerate() { ws_tx .send(WsMessage::Binary(chunk.to_vec().into())) .await @@ -713,10 +700,7 @@ fn e2e_websocket_dual_channel() { break; } other => { - println!( - "[{:>5.1}s] ws recv {other}", - t0.elapsed().as_secs_f64() - ); + println!("[{:>5.1}s] ws recv {other}", t0.elapsed().as_secs_f64()); } } } From b5869dc056e3e592123eba4404ab99add3038a2d Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Sun, 22 Feb 2026 12:01:53 +0900 Subject: [PATCH 4/5] various testing improvements --- .github/workflows/local_stt_e2e.yaml | 20 +- Cargo.lock | 17 + crates/transcribe-cactus/Cargo.toml | 9 + crates/transcribe-cactus/src/config.rs | 4 +- .../src/service/streaming/message.rs | 121 +++ .../src/service/streaming/mod.rs | 3 - .../src/service/streaming/response.rs | 169 ++++ .../src/service/streaming/session.rs | 9 +- .../src/service/streaming/tests.rs | 916 ------------------ crates/transcribe-cactus/tests/batch.rs | 74 ++ crates/transcribe-cactus/tests/common/mod.rs | 9 + crates/transcribe-cactus/tests/live.rs | 320 ++++++ 12 files changed, 732 insertions(+), 939 deletions(-) delete mode 100644 crates/transcribe-cactus/src/service/streaming/tests.rs create mode 100644 crates/transcribe-cactus/tests/batch.rs create mode 100644 crates/transcribe-cactus/tests/common/mod.rs create mode 100644 crates/transcribe-cactus/tests/live.rs diff --git a/.github/workflows/local_stt_e2e.yaml b/.github/workflows/local_stt_e2e.yaml index f5f5cf2f42..ef8bde29a4 100644 --- a/.github/workflows/local_stt_e2e.yaml +++ b/.github/workflows/local_stt_e2e.yaml @@ -15,6 +15,11 @@ on: jobs: local-stt-e2e: runs-on: depot-ubuntu-24.04-arm-8 + strategy: + matrix: + model: + - name: whisper-small + repo: openai/whisper-small defaults: run: shell: bash @@ -34,14 +39,9 @@ jobs: - uses: actions/cache@v4 with: path: vendor/cactus/weights/ - key: cactus-models-arm-v1 - - run: cactus download UsefulSensors/moonshine-base - - run: cargo test -p transcribe-cactus e2e_streaming -- --ignored --nocapture + key: cactus-models-${{ matrix.model.name }}-arm-v1 + - run: cactus download ${{ matrix.model.repo }} + - run: cargo test -p transcribe-cactus -- --ignored --nocapture env: - CACTUS_STT_MODEL: ${{ github.workspace }}/vendor/cactus/weights/moonshine-base - - run: cargo test -p transcribe-cactus e2e_streaming_dual_channel -- --ignored --nocapture - env: - CACTUS_STT_MODEL: ${{ github.workspace }}/vendor/cactus/weights/moonshine-base - - run: cargo test -p transcribe-cactus e2e_websocket_dual_channel -- --ignored --nocapture - env: - CACTUS_STT_MODEL: ${{ github.workspace }}/vendor/cactus/weights/moonshine-base + CACTUS_STT_MODEL: ${{ github.workspace }}/vendor/cactus/weights/${{ matrix.model.name }} + CACTUS_CLOUD_API_KEY: ${{ secrets.CACTUS_CLOUD_API_KEY }} diff --git a/Cargo.lock b/Cargo.lock index 4d557a748c..f5a42f463e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -15600,6 +15600,21 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" +[[package]] +name = "sequential-macro" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb5facc5f409a55d25bf271c853402a00e1187097d326757043f5dd711944d07" + +[[package]] +name = "sequential-test" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d9c0d773bc7e7733264f460e5dfa00b2510421ddd6284db0749eef8dfb79e9" +dependencies = [ + "sequential-macro", +] + [[package]] name = "serde" version = "1.0.228" @@ -20186,6 +20201,8 @@ dependencies = [ "futures-util", "language", "owhisper-interface", + "reqwest 0.13.2", + "sequential-test", "serde_json", "serde_qs", "tempfile", diff --git a/crates/transcribe-cactus/Cargo.toml b/crates/transcribe-cactus/Cargo.toml index 8c670d2556..89d6b56e5b 100644 --- a/crates/transcribe-cactus/Cargo.toml +++ b/crates/transcribe-cactus/Cargo.toml @@ -24,5 +24,14 @@ tower = { workspace = true } tracing = { workspace = true } [dev-dependencies] +hypr-audio-utils = { workspace = true } +hypr-cactus = { workspace = true } hypr-data = { workspace = true } + +axum = { workspace = true, features = ["ws"] } +futures-util = { workspace = true } +reqwest = { workspace = true, features = ["json"] } +sequential-test = "0.2" +serde_json = { workspace = true } +tokio = { workspace = true } tokio-tungstenite = { workspace = true } diff --git a/crates/transcribe-cactus/src/config.rs b/crates/transcribe-cactus/src/config.rs index f1fe3e578f..9bab874233 100644 --- a/crates/transcribe-cactus/src/config.rs +++ b/crates/transcribe-cactus/src/config.rs @@ -1,13 +1,13 @@ #[derive(Clone, Debug)] pub struct CactusConfig { - pub cloud_handoff: bool, + pub cloud: hypr_cactus::CloudConfig, pub min_chunk_sec: f32, } impl Default for CactusConfig { fn default() -> Self { Self { - cloud_handoff: true, + cloud: hypr_cactus::CloudConfig::default(), min_chunk_sec: 2.5, } } diff --git a/crates/transcribe-cactus/src/service/streaming/message.rs b/crates/transcribe-cactus/src/service/streaming/message.rs index a545be110d..3e1f79a4b9 100644 --- a/crates/transcribe-cactus/src/service/streaming/message.rs +++ b/crates/transcribe-cactus/src/service/streaming/message.rs @@ -56,3 +56,124 @@ pub(super) fn process_incoming_message(msg: &Message, channels: u8) -> IncomingM _ => IncomingMessage::Audio(AudioExtract::Empty), } } + +#[cfg(test)] +mod tests { + use axum::extract::ws::Message; + use owhisper_interface::ControlMessage; + + use super::*; + + #[test] + fn control_message_finalize_parsed() { + let msg = Message::Text(r#"{"type":"Finalize"}"#.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Control(ControlMessage::Finalize) => {} + other => panic!( + "expected Finalize, got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn control_message_keep_alive_parsed() { + let msg = Message::Text(r#"{"type":"KeepAlive"}"#.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Control(ControlMessage::KeepAlive) => {} + other => panic!( + "expected KeepAlive, got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn control_message_close_stream_parsed() { + let msg = Message::Text(r#"{"type":"CloseStream"}"#.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Control(ControlMessage::CloseStream) => {} + other => panic!( + "expected CloseStream, got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn audio_chunk_parsed_over_control() { + let chunk = owhisper_interface::ListenInputChunk::End; + let json = serde_json::to_string(&chunk).unwrap(); + let msg = Message::Text(json.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Audio(AudioExtract::End) => {} + other => panic!( + "expected Audio(End), got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn close_frame_yields_end() { + let msg = Message::Close(None); + match process_incoming_message(&msg, 1) { + IncomingMessage::Audio(AudioExtract::End) => {} + other => panic!( + "expected Audio(End), got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn binary_single_channel_yields_mono() { + let samples: Vec = vec![1000, 2000, 3000]; + let data: Vec = samples.iter().flat_map(|s| s.to_le_bytes()).collect(); + let msg = Message::Binary(data.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Audio(AudioExtract::Mono(s)) => assert!(!s.is_empty()), + other => panic!( + "expected Audio(Mono), got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn binary_dual_channel_yields_dual() { + // 2 interleaved i16 samples (4 bytes per frame: ch0, ch1) + let samples: Vec = vec![1000, -1000, 2000, -2000]; + let data: Vec = samples.iter().flat_map(|s| s.to_le_bytes()).collect(); + let msg = Message::Binary(data.into()); + match process_incoming_message(&msg, 2) { + IncomingMessage::Audio(AudioExtract::Dual { ch0, ch1 }) => { + assert_eq!(ch0.len(), 2); + assert_eq!(ch1.len(), 2); + assert!(ch0[0] > 0.0); + assert!(ch1[0] < 0.0); + } + other => panic!( + "expected Audio(Dual), got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn dual_audio_json_yields_dual() { + let chunk = owhisper_interface::ListenInputChunk::DualAudio { + mic: vec![0x00, 0x10], + speaker: vec![0x00, 0x20], + }; + let json = serde_json::to_string(&chunk).unwrap(); + let msg = Message::Text(json.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Audio(AudioExtract::Dual { .. }) => {} + other => panic!( + "expected Audio(Dual), got {:?}", + std::mem::discriminant(&other) + ), + } + } +} diff --git a/crates/transcribe-cactus/src/service/streaming/mod.rs b/crates/transcribe-cactus/src/service/streaming/mod.rs index b429f4cca7..e6131389de 100644 --- a/crates/transcribe-cactus/src/service/streaming/mod.rs +++ b/crates/transcribe-cactus/src/service/streaming/mod.rs @@ -2,9 +2,6 @@ mod message; mod response; mod session; -#[cfg(test)] -mod tests; - use std::{ future::Future, path::PathBuf, diff --git a/crates/transcribe-cactus/src/service/streaming/response.rs b/crates/transcribe-cactus/src/service/streaming/response.rs index 0e7b99aee1..852432454c 100644 --- a/crates/transcribe-cactus/src/service/streaming/response.rs +++ b/crates/transcribe-cactus/src/service/streaming/response.rs @@ -163,3 +163,172 @@ pub(super) fn format_timestamp_now() -> String { year, month, day, hours, minutes, seconds, millis ) } + +#[cfg(test)] +mod tests { + use std::path::Path; + + use owhisper_interface::stream::StreamResponse; + + use super::*; + + #[test] + fn session_metadata_has_required_fields() { + let meta = build_session_metadata(Path::new("/some/path/whisper-large-v3")); + assert!(!meta.request_id.is_empty()); + assert!(!meta.model_uuid.is_empty()); + assert_eq!(meta.model_info.name, "whisper-large-v3"); + assert_eq!(meta.model_info.arch, "cactus"); + assert!(meta.extra.is_some()); + } + + #[test] + fn format_timestamp_produces_iso8601() { + let ts = format_timestamp_now(); + assert!(ts.ends_with('Z')); + assert!(ts.contains('T')); + assert_eq!(ts.len(), 24); + } + + #[test] + fn transcript_response_serializes_as_results() { + let meta = build_session_metadata(Path::new("/models/whisper-small")); + let resp = build_transcript_response( + "hello world", + 0.0, + 1.5, + 0.95, + Some("en"), + true, + true, + false, + &meta, + &[0, 1], + None, + ); + + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["type"], "Results"); + assert_eq!(v["is_final"], true); + assert_eq!(v["speech_final"], true); + assert_eq!(v["from_finalize"], false); + assert_eq!(v["start"], 0.0); + assert_eq!(v["duration"], 1.5); + assert_eq!(v["channel"]["alternatives"][0]["transcript"], "hello world"); + assert_eq!( + v["channel"]["alternatives"][0]["words"] + .as_array() + .unwrap() + .len(), + 2 + ); + assert_eq!(v["channel"]["alternatives"][0]["languages"][0], "en"); + assert!(!v["metadata"]["request_id"].as_str().unwrap().is_empty()); + assert_eq!(v["metadata"]["model_info"]["name"], "whisper-small"); + assert_eq!(v["metadata"]["model_info"]["arch"], "cactus"); + assert!( + v["metadata"]["extra"]["started_unix_millis"] + .as_u64() + .is_some() + ); + assert_eq!(v["channel_index"], serde_json::json!([0, 1])); + } + + #[test] + fn transcript_response_from_finalize_flag() { + let meta = build_session_metadata(Path::new("/models/test")); + let resp = build_transcript_response( + "test", + 1.0, + 0.5, + 0.9, + None, + true, + true, + true, + &meta, + &[0, 2], + None, + ); + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["from_finalize"], true); + assert_eq!(v["channel_index"], serde_json::json!([0, 2])); + } + + #[test] + fn transcript_response_channel_index() { + let meta = build_session_metadata(Path::new("/models/test")); + let resp = build_transcript_response( + "speaker text", + 0.0, + 1.0, + 0.8, + None, + true, + true, + false, + &meta, + &[1, 2], + None, + ); + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["channel_index"], serde_json::json!([1, 2])); + } + + #[test] + fn terminal_response_serializes_as_metadata() { + let resp = StreamResponse::TerminalResponse { + request_id: "test-id".to_string(), + created: "2026-01-01T00:00:00.000Z".to_string(), + duration: 10.5, + channels: 1, + }; + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["type"], "Metadata"); + assert_eq!(v["request_id"], "test-id"); + assert_eq!(v["duration"], 10.5); + assert_eq!(v["channels"], 1); + } + + #[test] + fn error_response_serializes() { + let resp = StreamResponse::ErrorResponse { + error_code: None, + error_message: "model failed".to_string(), + provider: "cactus".to_string(), + }; + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["type"], "Error"); + assert_eq!(v["error_message"], "model failed"); + assert_eq!(v["provider"], "cactus"); + } + + #[test] + fn speech_started_response_serializes() { + let resp = StreamResponse::SpeechStartedResponse { + channel: vec![0], + timestamp: 1.23, + }; + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["type"], "SpeechStarted"); + assert_eq!(v["timestamp"], 1.23); + } + + #[test] + fn utterance_end_response_serializes() { + let resp = StreamResponse::UtteranceEndResponse { + channel: vec![0], + last_word_end: 5.67, + }; + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["type"], "UtteranceEnd"); + assert_eq!(v["last_word_end"], 5.67); + } +} diff --git a/crates/transcribe-cactus/src/service/streaming/session.rs b/crates/transcribe-cactus/src/service/streaming/session.rs index 044ac5c3d7..f6bdc39f05 100644 --- a/crates/transcribe-cactus/src/service/streaming/session.rs +++ b/crates/transcribe-cactus/src/service/streaming/session.rs @@ -78,14 +78,7 @@ pub(super) async fn handle_websocket( }; for ch_idx in 0..total_channels { - let cloud_config = if cactus_config.cloud_handoff { - hypr_cactus::CloudConfig::default() - } else { - hypr_cactus::CloudConfig { - threshold: Some(0.0), - ..Default::default() - } - }; + let cloud_config = cactus_config.cloud.clone(); let (audio_tx, event_rx, cancel_token, handle) = hypr_cactus::transcribe_stream( model.clone(), options.clone(), diff --git a/crates/transcribe-cactus/src/service/streaming/tests.rs b/crates/transcribe-cactus/src/service/streaming/tests.rs deleted file mode 100644 index 85c7c51021..0000000000 --- a/crates/transcribe-cactus/src/service/streaming/tests.rs +++ /dev/null @@ -1,916 +0,0 @@ -use std::path::{Path, PathBuf}; -use std::time::Duration; - -use axum::extract::ws::Message; -use axum::{Router, error_handling::HandleError, http::StatusCode}; -use futures_util::{SinkExt, StreamExt}; -use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage}; - -use hypr_audio_utils::bytes_to_f32_samples; -use owhisper_interface::ControlMessage; -use owhisper_interface::stream::StreamResponse; - -use super::TranscribeService; -use super::message::{AudioExtract, IncomingMessage, process_incoming_message}; -use super::response::{build_session_metadata, build_transcript_response, format_timestamp_now}; -use super::session::SAMPLE_RATE; - -#[test] -fn control_message_finalize_parsed() { - let msg = Message::Text(r#"{"type":"Finalize"}"#.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Control(ControlMessage::Finalize) => {} - other => panic!( - "expected Finalize, got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn control_message_keep_alive_parsed() { - let msg = Message::Text(r#"{"type":"KeepAlive"}"#.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Control(ControlMessage::KeepAlive) => {} - other => panic!( - "expected KeepAlive, got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn control_message_close_stream_parsed() { - let msg = Message::Text(r#"{"type":"CloseStream"}"#.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Control(ControlMessage::CloseStream) => {} - other => panic!( - "expected CloseStream, got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn audio_chunk_parsed_over_control() { - let chunk = owhisper_interface::ListenInputChunk::End; - let json = serde_json::to_string(&chunk).unwrap(); - let msg = Message::Text(json.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Audio(AudioExtract::End) => {} - other => panic!( - "expected Audio(End), got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn close_frame_yields_end() { - let msg = Message::Close(None); - match process_incoming_message(&msg, 1) { - IncomingMessage::Audio(AudioExtract::End) => {} - other => panic!( - "expected Audio(End), got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn binary_single_channel_yields_mono() { - let samples: Vec = vec![1000, 2000, 3000]; - let data: Vec = samples.iter().flat_map(|s| s.to_le_bytes()).collect(); - let msg = Message::Binary(data.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Audio(AudioExtract::Mono(s)) => assert!(!s.is_empty()), - other => panic!( - "expected Audio(Mono), got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn binary_dual_channel_yields_dual() { - // 2 interleaved i16 samples (4 bytes per frame: ch0, ch1) - let samples: Vec = vec![1000, -1000, 2000, -2000]; - let data: Vec = samples.iter().flat_map(|s| s.to_le_bytes()).collect(); - let msg = Message::Binary(data.into()); - match process_incoming_message(&msg, 2) { - IncomingMessage::Audio(AudioExtract::Dual { ch0, ch1 }) => { - assert_eq!(ch0.len(), 2); - assert_eq!(ch1.len(), 2); - assert!(ch0[0] > 0.0); - assert!(ch1[0] < 0.0); - } - other => panic!( - "expected Audio(Dual), got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn dual_audio_json_yields_dual() { - let chunk = owhisper_interface::ListenInputChunk::DualAudio { - mic: vec![0x00, 0x10], - speaker: vec![0x00, 0x20], - }; - let json = serde_json::to_string(&chunk).unwrap(); - let msg = Message::Text(json.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Audio(AudioExtract::Dual { .. }) => {} - other => panic!( - "expected Audio(Dual), got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn transcript_response_serializes_as_results() { - let meta = build_session_metadata(Path::new("/models/whisper-small")); - let resp = build_transcript_response( - "hello world", - 0.0, - 1.5, - 0.95, - Some("en"), - true, - true, - false, - &meta, - &[0, 1], - None, - ); - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - - assert_eq!(v["type"], "Results"); - assert_eq!(v["is_final"], true); - assert_eq!(v["speech_final"], true); - assert_eq!(v["from_finalize"], false); - assert_eq!(v["start"], 0.0); - assert_eq!(v["duration"], 1.5); - assert_eq!(v["channel"]["alternatives"][0]["transcript"], "hello world"); - assert_eq!( - v["channel"]["alternatives"][0]["words"] - .as_array() - .unwrap() - .len(), - 2 - ); - assert_eq!(v["channel"]["alternatives"][0]["languages"][0], "en"); - assert!(!v["metadata"]["request_id"].as_str().unwrap().is_empty()); - assert_eq!(v["metadata"]["model_info"]["name"], "whisper-small"); - assert_eq!(v["metadata"]["model_info"]["arch"], "cactus"); - assert!( - v["metadata"]["extra"]["started_unix_millis"] - .as_u64() - .is_some() - ); - assert_eq!(v["channel_index"], serde_json::json!([0, 1])); -} - -#[test] -fn transcript_response_from_finalize_flag() { - let meta = build_session_metadata(Path::new("/models/test")); - let resp = build_transcript_response( - "test", - 1.0, - 0.5, - 0.9, - None, - true, - true, - true, - &meta, - &[0, 2], - None, - ); - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(v["from_finalize"], true); - assert_eq!(v["channel_index"], serde_json::json!([0, 2])); -} - -#[test] -fn transcript_response_channel_1_of_2() { - let meta = build_session_metadata(Path::new("/models/test")); - let resp = build_transcript_response( - "speaker text", - 0.0, - 1.0, - 0.8, - None, - true, - true, - false, - &meta, - &[1, 2], - None, - ); - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(v["channel_index"], serde_json::json!([1, 2])); -} - -#[test] -fn terminal_response_serializes_as_metadata() { - let resp = StreamResponse::TerminalResponse { - request_id: "test-id".to_string(), - created: "2026-01-01T00:00:00.000Z".to_string(), - duration: 10.5, - channels: 1, - }; - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - - assert_eq!(v["type"], "Metadata"); - assert_eq!(v["request_id"], "test-id"); - assert_eq!(v["duration"], 10.5); - assert_eq!(v["channels"], 1); -} - -#[test] -fn error_response_serializes() { - let resp = StreamResponse::ErrorResponse { - error_code: None, - error_message: "model failed".to_string(), - provider: "cactus".to_string(), - }; - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - - assert_eq!(v["type"], "Error"); - assert_eq!(v["error_message"], "model failed"); - assert_eq!(v["provider"], "cactus"); -} - -#[test] -fn speech_started_response_serializes() { - let resp = StreamResponse::SpeechStartedResponse { - channel: vec![0], - timestamp: 1.23, - }; - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - - assert_eq!(v["type"], "SpeechStarted"); - assert_eq!(v["timestamp"], 1.23); -} - -#[test] -fn utterance_end_response_serializes() { - let resp = StreamResponse::UtteranceEndResponse { - channel: vec![0], - last_word_end: 5.67, - }; - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - - assert_eq!(v["type"], "UtteranceEnd"); - assert_eq!(v["last_word_end"], 5.67); -} - -#[test] -fn session_metadata_has_required_fields() { - let meta = build_session_metadata(Path::new("/some/path/whisper-large-v3")); - assert!(!meta.request_id.is_empty()); - assert!(!meta.model_uuid.is_empty()); - assert_eq!(meta.model_info.name, "whisper-large-v3"); - assert_eq!(meta.model_info.arch, "cactus"); - assert!(meta.extra.is_some()); -} - -#[test] -fn format_timestamp_produces_iso8601() { - let ts = format_timestamp_now(); - assert!(ts.ends_with('Z')); - assert!(ts.contains('T')); - assert_eq!(ts.len(), 24); -} - -// cargo test -p transcribe-cactus e2e_streaming -- --ignored --nocapture -#[ignore = "requires local cactus model files"] -#[test] -fn e2e_streaming() { - let model_path = std::env::var("CACTUS_STT_MODEL") - .unwrap_or_else(|_| "/tmp/cactus-model/moonshine-base-cactus".to_string()); - let model_path = PathBuf::from(model_path); - assert!( - model_path.exists(), - "model not found: {}", - model_path.display() - ); - - let model = std::sync::Arc::new(hypr_cactus::Model::new(&model_path).unwrap()); - let options = hypr_cactus::TranscribeOptions::default(); - let chunk_size_ms = 300u32; - - let (audio_tx, mut event_stream, _cancel, _worker_handle) = hypr_cactus::transcribe_stream( - model, - options, - hypr_cactus::CloudConfig::default(), - chunk_size_ms, - SAMPLE_RATE, - ); - - let samples = bytes_to_f32_samples(hypr_data::english_1::AUDIO); - let chunk_size = 8_000; - let total_chunks = (samples.len() + chunk_size - 1) / chunk_size; - let audio_duration = samples.len() as f64 / 16_000.0; - println!( - "\n--- feeding ALL audio: {:.1}s ({} chunks of {:.1}s) ---", - audio_duration, - total_chunks, - chunk_size as f64 / 16_000.0, - ); - - let t0 = std::time::Instant::now(); - - let sender = std::thread::spawn(move || { - for (i, chunk) in samples.chunks(chunk_size).enumerate() { - audio_tx.blocking_send(chunk.to_vec()).expect("send failed"); - if i % 20 == 0 { - println!( - "[{:>6.1}s] sent chunk {}/{}", - t0.elapsed().as_secs_f64(), - i, - total_chunks - ); - } - } - println!( - "[{:>6.1}s] all {} chunks sent", - t0.elapsed().as_secs_f64(), - total_chunks - ); - }); - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("failed to create tokio runtime"); - - let t0 = std::time::Instant::now(); - let mut full_transcript = String::new(); - let mut event_count = 0u32; - while let Some(event) = rt.block_on(event_stream.next()) { - match event { - Ok(hypr_cactus::TranscribeEvent { result: r, .. }) => { - let confirmed = r.confirmed.trim(); - let pending = r.pending.trim(); - if !confirmed.is_empty() || !pending.is_empty() { - println!( - "[{:>6.1}s] confirmed={:?} pending={:?}", - t0.elapsed().as_secs_f64(), - confirmed, - pending, - ); - } - if !confirmed.is_empty() { - if !full_transcript.is_empty() { - full_transcript.push(' '); - } - full_transcript.push_str(confirmed); - } - event_count += 1; - } - Err(e) => panic!("streaming error: {e}"), - } - } - - sender.join().expect("sender thread panicked"); - let elapsed = t0.elapsed().as_secs_f64(); - println!( - "\n--- FULL TRANSCRIPT ({:.1}s audio, {:.1}s wall, {:.1}x realtime) ---", - audio_duration, - elapsed, - elapsed / audio_duration - ); - println!("{full_transcript}"); - println!("--- END ({event_count} events) ---\n"); - assert!(!full_transcript.is_empty(), "expected non-empty transcript"); -} - -// cargo test -p transcribe-cactus e2e_streaming_dual_channel -- --ignored --nocapture -#[ignore = "requires local cactus model files"] -#[test] -fn e2e_streaming_dual_channel() { - let model_path = std::env::var("CACTUS_STT_MODEL") - .unwrap_or_else(|_| "/tmp/cactus-model/moonshine-base-cactus".to_string()); - let model_path = PathBuf::from(model_path); - assert!( - model_path.exists(), - "model not found: {}", - model_path.display() - ); - - let model = std::sync::Arc::new(hypr_cactus::Model::new(&model_path).unwrap()); - let options = hypr_cactus::TranscribeOptions::default(); - let chunk_size_ms = 300u32; - - let (audio_tx_0, mut event_stream_0, _cancel_0, _handle_0) = hypr_cactus::transcribe_stream( - model.clone(), - options.clone(), - hypr_cactus::CloudConfig::default(), - chunk_size_ms, - SAMPLE_RATE, - ); - let (audio_tx_1, mut event_stream_1, _cancel_1, _handle_1) = hypr_cactus::transcribe_stream( - model, - options, - hypr_cactus::CloudConfig::default(), - chunk_size_ms, - SAMPLE_RATE, - ); - - let samples_ch0 = bytes_to_f32_samples(hypr_data::english_1::AUDIO); - let samples_ch1 = bytes_to_f32_samples(hypr_data::english_2::AUDIO); - let chunk_size = 8_000; - - println!( - "\n--- dual channel: ch0={:.1}s ch1={:.1}s ---", - samples_ch0.len() as f64 / 16_000.0, - samples_ch1.len() as f64 / 16_000.0, - ); - - let t0 = std::time::Instant::now(); - - let sender = std::thread::spawn(move || { - let max_chunks = samples_ch0 - .chunks(chunk_size) - .count() - .max(samples_ch1.chunks(chunk_size).count()); - let mut ch0_iter = samples_ch0.chunks(chunk_size); - let mut ch1_iter = samples_ch1.chunks(chunk_size); - - for i in 0..max_chunks { - if let Some(chunk) = ch0_iter.next() { - audio_tx_0.blocking_send(chunk.to_vec()).expect("ch0 send"); - } - if let Some(chunk) = ch1_iter.next() { - audio_tx_1.blocking_send(chunk.to_vec()).expect("ch1 send"); - } - if i % 20 == 0 { - println!("[{:>6.1}s] sent chunk {}", t0.elapsed().as_secs_f64(), i,); - } - } - println!("[{:>6.1}s] all chunks sent", t0.elapsed().as_secs_f64(),); - }); - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("failed to create tokio runtime"); - - let t0 = std::time::Instant::now(); - let mut transcript_0 = String::new(); - let mut transcript_1 = String::new(); - let mut events_0 = 0u32; - let mut events_1 = 0u32; - - loop { - let event = rt.block_on(async { - tokio::select! { - e = event_stream_0.next() => e.map(|e| (0, e)), - e = event_stream_1.next() => e.map(|e| (1, e)), - } - }); - - let Some((ch, event)) = event else { break }; - match event { - Ok(hypr_cactus::TranscribeEvent { result: r, .. }) => { - let confirmed = r.confirmed.trim(); - let pending = r.pending.trim(); - if !confirmed.is_empty() || !pending.is_empty() { - println!( - "[{:>6.1}s] ch{ch} confirmed={:?} pending={:?}", - t0.elapsed().as_secs_f64(), - confirmed, - pending, - ); - } - if !confirmed.is_empty() { - let transcript = if ch == 0 { - &mut transcript_0 - } else { - &mut transcript_1 - }; - if !transcript.is_empty() { - transcript.push(' '); - } - transcript.push_str(confirmed); - } - if ch == 0 { - events_0 += 1; - } else { - events_1 += 1; - } - } - Err(e) => panic!("streaming error ch{ch}: {e}"), - } - } - - sender.join().expect("sender thread panicked"); - - println!("\n--- DUAL CHANNEL RESULTS ---"); - println!("ch0 ({events_0} events): {transcript_0}"); - println!("ch1 ({events_1} events): {transcript_1}"); - println!("---\n"); - - assert!(events_0 > 0, "expected events from channel 0"); - assert!(events_1 > 0, "expected events from channel 1"); -} - -// cargo test -p transcribe-cactus e2e_websocket_dual_channel -- --ignored --nocapture -#[ignore = "requires local cactus model files"] -#[test] -fn e2e_websocket_dual_channel() { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("failed to create tokio runtime"); - rt.block_on(async { - let model_path = std::env::var("CACTUS_STT_MODEL") - .unwrap_or_else(|_| "/tmp/cactus-model/moonshine-base-cactus".to_string()); - let model_path = PathBuf::from(model_path); - assert!( - model_path.exists(), - "model not found: {}", - model_path.display() - ); - - let service = HandleError::new( - TranscribeService::builder().model_path(model_path).build(), - |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, - ); - let app = Router::new().route_service("/v1/listen", service); - - let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); - tokio::spawn(async move { - let _ = axum::serve(listener, app) - .with_graceful_shutdown(async { - let _ = shutdown_rx.await; - }) - .await; - }); - - let ws_url = format!( - "ws://{}/v1/listen?channels=2&sample_rate=16000&chunk_size_ms=300", - addr - ); - let (ws_stream, _) = connect_async(&ws_url).await.expect("ws connect failed"); - let (mut ws_tx, mut ws_rx) = ws_stream.split(); - - // Build interleaved stereo PCM: ch0=english_1, ch1=english_2 - let audio_ch0 = hypr_data::english_1::AUDIO; - let audio_ch1 = hypr_data::english_2::AUDIO; - let frames_ch0 = audio_ch0.len() / 2; // i16 = 2 bytes - let frames_ch1 = audio_ch1.len() / 2; - let num_frames = frames_ch0.min(frames_ch1); - - let mut interleaved = Vec::with_capacity(num_frames * 4); - for i in 0..num_frames { - // ch0 sample (i16 LE) - interleaved.push(audio_ch0[i * 2]); - interleaved.push(audio_ch0[i * 2 + 1]); - // ch1 sample (i16 LE) - interleaved.push(audio_ch1[i * 2]); - interleaved.push(audio_ch1[i * 2 + 1]); - } - - let t0 = std::time::Instant::now(); - let chunk_bytes = 64_000; // 1s of stereo 16kHz i16 PCM (4 bytes per frame) - let num_chunks = 5; - println!( - "\n--- ws dual: sending {} chunks of {}B ({:.1}s each) ---", - num_chunks, - chunk_bytes, - chunk_bytes as f64 / 64_000.0, - ); - - let (close_tx, close_rx) = tokio::sync::oneshot::channel::<()>(); - let close_tx = std::cell::Cell::new(Some(close_tx)); - - let writer = tokio::spawn(async move { - for (i, chunk) in interleaved.chunks(chunk_bytes).take(num_chunks).enumerate() { - ws_tx - .send(WsMessage::Binary(chunk.to_vec().into())) - .await - .unwrap(); - println!("[{:>5.1}s] ws sent chunk {}", t0.elapsed().as_secs_f64(), i); - } - - let _ = close_rx.await; - println!( - "[{:>5.1}s] ws sending CloseStream", - t0.elapsed().as_secs_f64() - ); - let _ = ws_tx - .send(WsMessage::Text( - r#"{"type":"CloseStream"}"#.to_string().into(), - )) - .await; - }); - - let mut results_count = 0u32; - let mut saw_terminal = false; - let mut saw_error: Option = None; - let mut close_sent = false; - let mut channels_seen = std::collections::HashSet::new(); - - while let Ok(Some(Ok(msg))) = - tokio::time::timeout(Duration::from_secs(120), ws_rx.next()).await - { - match msg { - WsMessage::Text(text) => { - let Ok(v) = serde_json::from_str::(&text) else { - continue; - }; - let msg_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("?"); - match msg_type { - "Results" => { - let transcript = v - .pointer("/channel/alternatives/0/transcript") - .and_then(|t| t.as_str()) - .unwrap_or(""); - let ch_idx = v - .pointer("/channel_index/0") - .and_then(|c| c.as_i64()) - .unwrap_or(-1); - channels_seen.insert(ch_idx); - println!( - "[{:>5.1}s] ws recv Results ch={} {:?}", - t0.elapsed().as_secs_f64(), - ch_idx, - transcript, - ); - results_count += 1; - - if results_count >= 6 && !close_sent { - close_sent = true; - if let Some(tx) = close_tx.take() { - let _ = tx.send(()); - } - } - } - "Metadata" => { - println!( - "[{:>5.1}s] ws recv Metadata (terminal)", - t0.elapsed().as_secs_f64() - ); - saw_terminal = true; - break; - } - "SpeechStarted" => { - println!( - "[{:>5.1}s] ws recv SpeechStarted", - t0.elapsed().as_secs_f64() - ); - } - "UtteranceEnd" => { - println!( - "[{:>5.1}s] ws recv UtteranceEnd", - t0.elapsed().as_secs_f64() - ); - } - "Error" => { - saw_error = v - .get("error_message") - .and_then(|m| m.as_str()) - .map(str::to_owned); - println!( - "[{:>5.1}s] ws recv Error: {:?}", - t0.elapsed().as_secs_f64(), - saw_error - ); - break; - } - other => { - println!("[{:>5.1}s] ws recv {other}", t0.elapsed().as_secs_f64()); - } - } - } - WsMessage::Close(_) => break, - _ => {} - } - } - - let _ = writer.await; - let _ = shutdown_tx.send(()); - println!( - "[{:>5.1}s] done ({} Results, terminal={}, channels_seen={:?})\n", - t0.elapsed().as_secs_f64(), - results_count, - saw_terminal, - channels_seen, - ); - - assert!(saw_error.is_none(), "ws error: {:?}", saw_error); - assert!(results_count > 0, "expected at least one Results message"); - assert!(saw_terminal, "expected terminal Metadata message"); - assert!( - channels_seen.contains(&0) && channels_seen.contains(&1), - "expected results from both channels, got {:?}", - channels_seen, - ); - }); -} - -// cargo test -p transcribe-cactus e2e_websocket -- --ignored --nocapture -#[ignore = "requires local cactus model files"] -#[test] -fn e2e_websocket_listen_with_real_model_inference() { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("failed to create tokio runtime"); - rt.block_on(async { - let model_path = std::env::var("CACTUS_STT_MODEL") - .unwrap_or_else(|_| "/tmp/cactus-model/moonshine-base-cactus".to_string()); - let model_path = PathBuf::from(model_path); - assert!( - model_path.exists(), - "model not found: {}", - model_path.display() - ); - - let service = HandleError::new( - TranscribeService::builder().model_path(model_path).build(), - |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, - ); - let app = Router::new().route_service("/v1/listen", service); - - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .unwrap(); - let addr = listener.local_addr().unwrap(); - - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); - tokio::spawn(async move { - let _ = axum::serve(listener, app) - .with_graceful_shutdown(async { - let _ = shutdown_rx.await; - }) - .await; - }); - - let ws_url = format!( - "ws://{}/v1/listen?channels=1&sample_rate=16000&chunk_size_ms=300", - addr - ); - let (ws_stream, _) = connect_async(&ws_url).await.expect("ws connect failed"); - let (mut ws_tx, mut ws_rx) = ws_stream.split(); - - let t0 = std::time::Instant::now(); - let chunk_bytes = 32_000; // 1s of 16kHz i16 PCM per chunk - let num_chunks = 5; - println!( - "\n--- ws: sending {} chunks of {}B ({:.1}s each) ---", - num_chunks, - chunk_bytes, - chunk_bytes as f64 / 32_000.0, - ); - - let (close_tx, close_rx) = tokio::sync::oneshot::channel::<()>(); - let close_tx = std::cell::Cell::new(Some(close_tx)); - - let writer = tokio::spawn(async move { - for (i, chunk) in hypr_data::english_1::AUDIO - .chunks(chunk_bytes) - .take(num_chunks) - .enumerate() - { - ws_tx - .send(WsMessage::Binary(chunk.to_vec().into())) - .await - .unwrap(); - println!("[{:>5.1}s] ws sent chunk {}", t0.elapsed().as_secs_f64(), i); - } - - let _ = close_rx.await; - println!( - "[{:>5.1}s] ws sending CloseStream", - t0.elapsed().as_secs_f64() - ); - let _ = ws_tx - .send(WsMessage::Text( - r#"{"type":"CloseStream"}"#.to_string().into(), - )) - .await; - }); - - let mut results_count = 0u32; - let mut saw_terminal = false; - let mut saw_error: Option = None; - let mut close_sent = false; - - while let Ok(Some(Ok(msg))) = - tokio::time::timeout(Duration::from_secs(60), ws_rx.next()).await - { - match msg { - WsMessage::Text(text) => { - let Ok(v) = serde_json::from_str::(&text) else { - continue; - }; - let msg_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("?"); - match msg_type { - "Results" => { - let transcript = v - .pointer("/channel/alternatives/0/transcript") - .and_then(|t| t.as_str()) - .unwrap_or(""); - let is_final = - v.get("is_final").and_then(|f| f.as_bool()).unwrap_or(false); - let speech_final = v - .get("speech_final") - .and_then(|f| f.as_bool()) - .unwrap_or(false); - println!( - "[{:>5.1}s] ws recv Results is_final={:<5} speech_final={:<5} {:?}", - t0.elapsed().as_secs_f64(), - is_final, - speech_final, - transcript, - ); - results_count += 1; - - if results_count >= 3 && !close_sent { - close_sent = true; - if let Some(tx) = close_tx.take() { - let _ = tx.send(()); - } - } - } - "Metadata" => { - println!( - "[{:>5.1}s] ws recv Metadata (terminal)", - t0.elapsed().as_secs_f64() - ); - saw_terminal = true; - break; - } - "SpeechStarted" => { - println!( - "[{:>5.1}s] ws recv SpeechStarted", - t0.elapsed().as_secs_f64() - ); - } - "UtteranceEnd" => { - println!( - "[{:>5.1}s] ws recv UtteranceEnd", - t0.elapsed().as_secs_f64() - ); - } - "Error" => { - saw_error = v - .get("error_message") - .and_then(|m| m.as_str()) - .map(str::to_owned); - println!( - "[{:>5.1}s] ws recv Error: {:?}", - t0.elapsed().as_secs_f64(), - saw_error - ); - break; - } - other => { - println!( - "[{:>5.1}s] ws recv {other}", - t0.elapsed().as_secs_f64() - ); - } - } - } - WsMessage::Close(_) => break, - _ => {} - } - } - - let _ = writer.await; - let _ = shutdown_tx.send(()); - println!( - "[{:>5.1}s] done ({} Results, terminal={})\n", - t0.elapsed().as_secs_f64(), - results_count, - saw_terminal, - ); - - assert!(saw_error.is_none(), "ws error: {:?}", saw_error); - assert!(results_count > 0, "expected at least one Results message"); - assert!(saw_terminal, "expected terminal Metadata message"); - }); -} diff --git a/crates/transcribe-cactus/tests/batch.rs b/crates/transcribe-cactus/tests/batch.rs new file mode 100644 index 0000000000..63bd3d5571 --- /dev/null +++ b/crates/transcribe-cactus/tests/batch.rs @@ -0,0 +1,74 @@ +mod common; + +use axum::error_handling::HandleError; +use axum::{Router, http::StatusCode}; + +use transcribe_cactus::TranscribeService; + +use common::model_path; + +#[ignore = "requires local cactus model files"] +#[test] +fn e2e_batch() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + let app = Router::new().route_service( + "/v1/listen", + HandleError::new( + TranscribeService::builder() + .model_path(model_path()) + .build(), + |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, + ), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + tokio::spawn(async move { + axum::serve(listener, app) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + let wav_bytes = std::fs::read(hypr_data::english_1::AUDIO_PATH) + .unwrap_or_else(|e| panic!("failed to read fixture wav: {e}")); + + let url = format!("http://{}/v1/listen?channels=1&sample_rate=16000", addr); + let client = reqwest::Client::new(); + let response = client + .post(&url) + .header("content-type", "audio/wav") + .body(wav_bytes) + .send() + .await + .expect("request failed"); + + assert_eq!(response.status(), 200); + let v: serde_json::Value = response.json().await.expect("response is not JSON"); + + let transcript = v + .pointer("/results/channels/0/alternatives/0/transcript") + .and_then(|t| t.as_str()) + .unwrap_or(""); + + assert!( + !transcript.trim().is_empty(), + "expected non-empty transcript" + ); + assert!( + v["metadata"]["duration"].as_f64().unwrap_or_default() > 0.0, + "expected positive duration in metadata" + ); + assert_eq!(v["metadata"]["channels"], 1); + + let _ = shutdown_tx.send(()); + }); +} diff --git a/crates/transcribe-cactus/tests/common/mod.rs b/crates/transcribe-cactus/tests/common/mod.rs new file mode 100644 index 0000000000..608b2594f3 --- /dev/null +++ b/crates/transcribe-cactus/tests/common/mod.rs @@ -0,0 +1,9 @@ +use std::path::PathBuf; + +pub fn model_path() -> PathBuf { + let path = std::env::var("CACTUS_STT_MODEL") + .unwrap_or_else(|_| "/tmp/cactus-model/moonshine-base-cactus".to_string()); + let path = PathBuf::from(path); + assert!(path.exists(), "model not found: {}", path.display()); + path +} diff --git a/crates/transcribe-cactus/tests/live.rs b/crates/transcribe-cactus/tests/live.rs new file mode 100644 index 0000000000..4d0c9fc0e5 --- /dev/null +++ b/crates/transcribe-cactus/tests/live.rs @@ -0,0 +1,320 @@ +mod common; + +use std::time::Duration; + +use axum::error_handling::HandleError; +use axum::{Router, http::StatusCode}; +use futures_util::{SinkExt, StreamExt}; +use sequential_test::sequential; +use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage}; + +use hypr_cactus::CloudConfig; +use transcribe_cactus::{CactusConfig, TranscribeService}; + +use common::model_path; + +async fn run_single_channel_opts( + cactus_config: CactusConfig, + audio_secs: usize, + close_after_results: u32, + timeout_secs: u64, +) { + let app = Router::new().route_service( + "/v1/listen", + HandleError::new( + TranscribeService::builder() + .model_path(model_path()) + .cactus_config(cactus_config) + .build(), + |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, + ), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + tokio::spawn(async move { + axum::serve(listener, app) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + let ws_url = format!( + "ws://{}/v1/listen?channels=1&sample_rate=16000&chunk_size_ms=300", + addr + ); + let (ws, _) = connect_async(&ws_url).await.expect("ws connect failed"); + let (mut tx, mut rx) = ws.split(); + + let (close_tx, close_rx) = tokio::sync::oneshot::channel::<()>(); + let close_tx = std::cell::Cell::new(Some(close_tx)); + + let t0 = std::time::Instant::now(); + + let writer = tokio::spawn(async move { + let audio = hypr_data::english_1::AUDIO; + for chunk in audio.chunks(32_000).cycle().take(audio_secs) { + tx.send(WsMessage::Binary(chunk.to_vec().into())) + .await + .unwrap(); + } + let _ = close_rx.await; + let _ = tx + .send(WsMessage::Text( + r#"{"type":"CloseStream"}"#.to_string().into(), + )) + .await; + }); + + let mut results = 0u32; + let mut saw_terminal = false; + let mut close_sent = false; + + while let Ok(Some(Ok(msg))) = + tokio::time::timeout(Duration::from_secs(timeout_secs), rx.next()).await + { + match msg { + WsMessage::Text(text) => { + let Ok(v) = serde_json::from_str::(&text) else { + continue; + }; + match v.get("type").and_then(|t| t.as_str()).unwrap_or("") { + "Results" => { + let transcript = v + .pointer("/channel/alternatives/0/transcript") + .and_then(|t| t.as_str()) + .unwrap_or(""); + let is_final = v.get("is_final").and_then(|f| f.as_bool()).unwrap_or(false); + let cloud_corrected = v + .pointer("/metadata/extra/cloud_corrected") + .and_then(|b| b.as_bool()) + .unwrap_or(false); + println!( + "[{:>5.1}s] is_final={} cloud={} {:?}", + t0.elapsed().as_secs_f64(), + is_final, + cloud_corrected, + transcript, + ); + if is_final { + results += 1; + } + if results >= close_after_results && !close_sent { + close_sent = true; + if let Some(tx) = close_tx.take() { + let _ = tx.send(()); + } + } + } + "Metadata" => { + println!("[{:>5.1}s] terminal", t0.elapsed().as_secs_f64()); + saw_terminal = true; + break; + } + "Error" => panic!("ws error: {:?}", v.get("error_message")), + _ => {} + } + } + WsMessage::Close(_) => break, + _ => {} + } + } + + let _ = writer.await; + let _ = shutdown_tx.send(()); + + assert!(results > 0, "expected Results messages"); + assert!(saw_terminal, "expected terminal Metadata message"); +} + +async fn run_single_channel(cactus_config: CactusConfig) { + run_single_channel_opts(cactus_config, 100, 3, 120).await; +} + +#[ignore = "requires local cactus model files"] +#[sequential] +#[test] +fn e2e_websocket_no_handoff() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(run_single_channel(CactusConfig { + cloud: CloudConfig { + threshold: Some(0.0), + ..Default::default() + }, + ..Default::default() + })); +} + +#[ignore = "requires local cactus model files"] +#[sequential] +#[test] +fn e2e_websocket_with_handoff() { + let api_key = std::env::var("CACTUS_CLOUD_API_KEY") + .expect("CACTUS_CLOUD_API_KEY must be set for this test"); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(run_single_channel_opts( + CactusConfig { + // Well below model defaults (Whisper=0.4, Moonshine=0.35) to trigger aggressively + cloud: CloudConfig { + api_key: Some(api_key), + threshold: Some(0.05), + }, + ..Default::default() + }, + 120, + 30, + 180, + )); +} + +#[ignore = "requires local cactus model files"] +#[sequential] +#[test] +fn e2e_websocket_dual_channel_no_handoff() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + let app = Router::new().route_service( + "/v1/listen", + HandleError::new( + TranscribeService::builder() + .model_path(model_path()) + .cactus_config(CactusConfig { + cloud: CloudConfig { + threshold: Some(0.0), + ..Default::default() + }, + ..Default::default() + }) + .build(), + |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, + ), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + tokio::spawn(async move { + axum::serve(listener, app) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + let ws_url = format!( + "ws://{}/v1/listen?channels=2&sample_rate=16000&chunk_size_ms=300", + addr + ); + let (ws, _) = connect_async(&ws_url).await.expect("ws connect failed"); + let (mut tx, mut rx) = ws.split(); + + let audio = hypr_data::english_1::AUDIO; + let num_frames = audio.len() / 2; + let mut interleaved = Vec::with_capacity(num_frames * 4); + for i in 0..num_frames { + interleaved.extend_from_slice(&audio[i * 2..i * 2 + 2]); + interleaved.extend_from_slice(&audio[i * 2..i * 2 + 2]); + } + + let (close_tx, close_rx) = tokio::sync::oneshot::channel::<()>(); + let close_tx = std::cell::Cell::new(Some(close_tx)); + + let t0 = std::time::Instant::now(); + + let writer = tokio::spawn(async move { + for chunk in interleaved.chunks(64_000).cycle().take(100) { + tx.send(WsMessage::Binary(chunk.to_vec().into())) + .await + .unwrap(); + } + let _ = close_rx.await; + let _ = tx + .send(WsMessage::Text( + r#"{"type":"CloseStream"}"#.to_string().into(), + )) + .await; + }); + + let mut results = 0u32; + let mut saw_terminal = false; + let mut close_sent = false; + let mut channels_seen = std::collections::HashSet::new(); + + while let Ok(Some(Ok(msg))) = + tokio::time::timeout(Duration::from_secs(120), rx.next()).await + { + match msg { + WsMessage::Text(text) => { + let Ok(v) = serde_json::from_str::(&text) else { + continue; + }; + match v.get("type").and_then(|t| t.as_str()).unwrap_or("") { + "Results" => { + let ch = v + .pointer("/channel_index/0") + .and_then(|c| c.as_i64()) + .unwrap_or(-1); + let transcript = v + .pointer("/channel/alternatives/0/transcript") + .and_then(|t| t.as_str()) + .unwrap_or(""); + let is_final = + v.get("is_final").and_then(|f| f.as_bool()).unwrap_or(false); + println!( + "[{:>5.1}s] ch={} is_final={} {:?}", + t0.elapsed().as_secs_f64(), + ch, + is_final, + transcript, + ); + channels_seen.insert(ch); + results += 1; + if results >= 6 && !close_sent { + close_sent = true; + if let Some(tx) = close_tx.take() { + let _ = tx.send(()); + } + } + } + "Metadata" => { + println!("[{:>5.1}s] terminal", t0.elapsed().as_secs_f64()); + saw_terminal = true; + break; + } + "Error" => panic!("ws error: {:?}", v.get("error_message")), + _ => {} + } + } + WsMessage::Close(_) => break, + _ => {} + } + } + + let _ = writer.await; + let _ = shutdown_tx.send(()); + + assert!(results > 0, "expected Results messages"); + assert!(saw_terminal, "expected terminal Metadata message"); + assert!( + channels_seen.contains(&0) && channels_seen.contains(&1), + "expected results from both channels, got {channels_seen:?}", + ); + }); +} From 683c5cda9e708951d5b0376bde6c833428b55646 Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Sun, 22 Feb 2026 12:34:26 +0900 Subject: [PATCH 5/5] timeout --- .github/workflows/local_stt_e2e.yaml | 1 + crates/cactus/src/model.rs | 12 ++++++ crates/cactus/src/stt/batch.rs | 7 +++- crates/transcribe-cactus/src/service/batch.rs | 9 +++- crates/transcribe-cactus/tests/batch.rs | 41 +++++++++++++++++-- crates/transcribe-cactus/tests/live.rs | 28 ++++++++++--- 6 files changed, 85 insertions(+), 13 deletions(-) diff --git a/.github/workflows/local_stt_e2e.yaml b/.github/workflows/local_stt_e2e.yaml index ef8bde29a4..3ffc55eba7 100644 --- a/.github/workflows/local_stt_e2e.yaml +++ b/.github/workflows/local_stt_e2e.yaml @@ -45,3 +45,4 @@ jobs: env: CACTUS_STT_MODEL: ${{ github.workspace }}/vendor/cactus/weights/${{ matrix.model.name }} CACTUS_CLOUD_API_KEY: ${{ secrets.CACTUS_CLOUD_API_KEY }} + E2E_AUDIO_SECS: 20 diff --git a/crates/cactus/src/model.rs b/crates/cactus/src/model.rs index ca8c04c7ad..4b523caf29 100644 --- a/crates/cactus/src/model.rs +++ b/crates/cactus/src/model.rs @@ -8,6 +8,7 @@ use crate::error::{Error, Result}; pub struct Model { handle: NonNull, inference_lock: Mutex<()>, + is_moonshine: bool, } unsafe impl Send for Model {} @@ -39,9 +40,16 @@ impl ModelBuilder { let handle = NonNull::new(raw).ok_or_else(|| Error::Init("cactus_init returned null".into()))?; + let is_moonshine = self + .model_path + .to_string_lossy() + .to_lowercase() + .contains("moonshine"); + Ok(Model { handle, inference_lock: Mutex::new(()), + is_moonshine, }) } } @@ -57,6 +65,10 @@ impl Model { Self::builder(model_path).build() } + pub fn is_moonshine(&self) -> bool { + self.is_moonshine + } + /// Cancel an in-progress inference. Safe to call concurrently — only sets an /// atomic flag on the C++ side. pub fn stop(&self) { diff --git a/crates/cactus/src/stt/batch.rs b/crates/cactus/src/stt/batch.rs index 282d1d0613..19b9de0038 100644 --- a/crates/cactus/src/stt/batch.rs +++ b/crates/cactus/src/stt/batch.rs @@ -20,7 +20,12 @@ impl Model { options: &TranscribeOptions, ) -> Result { let guard = self.lock_inference(); - let prompt_c = CString::new(build_whisper_prompt(options))?; + let prompt = if self.is_moonshine() { + String::new() + } else { + build_whisper_prompt(options) + }; + let prompt_c = CString::new(prompt)?; let options_c = CString::new(serde_json::to_string(options)?)?; let mut buf = vec![0u8; RESPONSE_BUF_SIZE]; diff --git a/crates/transcribe-cactus/src/service/batch.rs b/crates/transcribe-cactus/src/service/batch.rs index 6e3b31bc52..c38ce514ef 100644 --- a/crates/transcribe-cactus/src/service/batch.rs +++ b/crates/transcribe-cactus/src/service/batch.rs @@ -320,9 +320,14 @@ mod tests { println!("{}", alternative.transcript.trim()); println!("--- END (confidence={:.2}) ---\n", alternative.confidence); + let transcript = alternative.transcript.trim().to_lowercase(); + assert!(!transcript.is_empty(), "expected non-empty transcript"); assert!( - !alternative.transcript.trim().is_empty(), - "expected non-empty transcript" + transcript.contains("maybe") + || transcript.contains("this") + || transcript.contains("talking"), + "transcript looks like a hallucination (got: {:?})", + transcript ); assert!( alternative.confidence.is_finite(), diff --git a/crates/transcribe-cactus/tests/batch.rs b/crates/transcribe-cactus/tests/batch.rs index 63bd3d5571..e0f5cd93a2 100644 --- a/crates/transcribe-cactus/tests/batch.rs +++ b/crates/transcribe-cactus/tests/batch.rs @@ -3,6 +3,32 @@ mod common; use axum::error_handling::HandleError; use axum::{Router, http::StatusCode}; +fn audio_wav_bytes() -> Vec { + let max_secs = std::env::var("E2E_AUDIO_SECS") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(usize::MAX); + let full = hypr_data::english_1::AUDIO; + let pcm = &full[..full.len().min(max_secs.saturating_mul(16000 * 2))]; + let data_len = pcm.len() as u32; + let mut w = Vec::with_capacity(44 + pcm.len()); + w.extend_from_slice(b"RIFF"); + w.extend_from_slice(&(36 + data_len).to_le_bytes()); + w.extend_from_slice(b"WAVE"); + w.extend_from_slice(b"fmt "); + w.extend_from_slice(&16u32.to_le_bytes()); + w.extend_from_slice(&1u16.to_le_bytes()); + w.extend_from_slice(&1u16.to_le_bytes()); + w.extend_from_slice(&16000u32.to_le_bytes()); + w.extend_from_slice(&32000u32.to_le_bytes()); + w.extend_from_slice(&2u16.to_le_bytes()); + w.extend_from_slice(&16u16.to_le_bytes()); + w.extend_from_slice(b"data"); + w.extend_from_slice(&data_len.to_le_bytes()); + w.extend_from_slice(pcm); + w +} + use transcribe_cactus::TranscribeService; use common::model_path; @@ -38,10 +64,9 @@ fn e2e_batch() { .unwrap(); }); - let wav_bytes = std::fs::read(hypr_data::english_1::AUDIO_PATH) - .unwrap_or_else(|e| panic!("failed to read fixture wav: {e}")); + let wav_bytes = audio_wav_bytes(); - let url = format!("http://{}/v1/listen?channels=1&sample_rate=16000", addr); + let url = format!("http://{}/v1/listen?channels=1&sample_rate=16000&language=en", addr); let client = reqwest::Client::new(); let response = client .post(&url) @@ -59,10 +84,18 @@ fn e2e_batch() { .and_then(|t| t.as_str()) .unwrap_or(""); + let transcript_lower = transcript.trim().to_lowercase(); assert!( - !transcript.trim().is_empty(), + !transcript_lower.is_empty(), "expected non-empty transcript" ); + assert!( + transcript_lower.contains("maybe") + || transcript_lower.contains("this") + || transcript_lower.contains("talking"), + "transcript looks like a hallucination (got: {:?})", + transcript_lower + ); assert!( v["metadata"]["duration"].as_f64().unwrap_or_default() > 0.0, "expected positive duration in metadata" diff --git a/crates/transcribe-cactus/tests/live.rs b/crates/transcribe-cactus/tests/live.rs index 4d0c9fc0e5..bbe0da2117 100644 --- a/crates/transcribe-cactus/tests/live.rs +++ b/crates/transcribe-cactus/tests/live.rs @@ -2,6 +2,18 @@ mod common; use std::time::Duration; +fn e2e_audio_secs(default: usize) -> usize { + std::env::var("E2E_AUDIO_SECS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(default) +} + +fn scale_close_after(audio_secs: usize, default_audio: usize, default_close: u32) -> u32 { + let ratio = audio_secs as f64 / default_audio as f64; + ((default_close as f64 * ratio).ceil() as u32).max(1) +} + use axum::error_handling::HandleError; use axum::{Router, http::StatusCode}; use futures_util::{SinkExt, StreamExt}; @@ -131,7 +143,8 @@ async fn run_single_channel_opts( } async fn run_single_channel(cactus_config: CactusConfig) { - run_single_channel_opts(cactus_config, 100, 3, 120).await; + let secs = e2e_audio_secs(100); + run_single_channel_opts(cactus_config, secs, scale_close_after(secs, 100, 3), 120).await; } #[ignore = "requires local cactus model files"] @@ -164,6 +177,7 @@ fn e2e_websocket_with_handoff() { .build() .unwrap(); + let secs = e2e_audio_secs(120); rt.block_on(run_single_channel_opts( CactusConfig { // Well below model defaults (Whisper=0.4, Moonshine=0.35) to trigger aggressively @@ -173,8 +187,8 @@ fn e2e_websocket_with_handoff() { }, ..Default::default() }, - 120, - 30, + secs, + scale_close_after(secs, 120, 30), 180, )); } @@ -188,7 +202,9 @@ fn e2e_websocket_dual_channel_no_handoff() { .build() .unwrap(); - rt.block_on(async { + let dual_secs = e2e_audio_secs(100); + let dual_close = scale_close_after(dual_secs, 100, 6); + rt.block_on(async move { let app = Router::new().route_service( "/v1/listen", HandleError::new( @@ -239,7 +255,7 @@ fn e2e_websocket_dual_channel_no_handoff() { let t0 = std::time::Instant::now(); let writer = tokio::spawn(async move { - for chunk in interleaved.chunks(64_000).cycle().take(100) { + for chunk in interleaved.chunks(64_000).cycle().take(dual_secs) { tx.send(WsMessage::Binary(chunk.to_vec().into())) .await .unwrap(); @@ -286,7 +302,7 @@ fn e2e_websocket_dual_channel_no_handoff() { ); channels_seen.insert(ch); results += 1; - if results >= 6 && !close_sent { + if results >= dual_close && !close_sent { close_sent = true; if let Some(tx) = close_tx.take() { let _ = tx.send(());