diff --git a/.github/workflows/local_stt_e2e.yaml b/.github/workflows/local_stt_e2e.yaml new file mode 100644 index 0000000000..3ffc55eba7 --- /dev/null +++ b/.github/workflows/local_stt_e2e.yaml @@ -0,0 +1,48 @@ +on: + workflow_dispatch: + push: + branches: [main] + paths: + - crates/transcribe-cactus/** + - crates/cactus/** + - crates/cactus-sys/** + pull_request: + paths: + - crates/transcribe-cactus/** + - crates/cactus/** + - crates/cactus-sys/** + +jobs: + local-stt-e2e: + runs-on: depot-ubuntu-24.04-arm-8 + strategy: + matrix: + model: + - name: whisper-small + repo: openai/whisper-small + defaults: + run: + shell: bash + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + - uses: ./.github/actions/rust_install + with: + platform: linux + - run: | + sudo apt-get update + sudo apt-get install -y cmake build-essential libcurl4-openssl-dev libclang-dev + - run: | + pip3 install --break-system-packages huggingface-hub + pip3 install --break-system-packages -e vendor/cactus/python/ --no-deps + - uses: actions/cache@v4 + with: + path: vendor/cactus/weights/ + key: cactus-models-${{ matrix.model.name }}-arm-v1 + - run: cactus download ${{ matrix.model.repo }} + - run: cargo test -p transcribe-cactus -- --ignored --nocapture + env: + CACTUS_STT_MODEL: ${{ github.workspace }}/vendor/cactus/weights/${{ matrix.model.name }} + CACTUS_CLOUD_API_KEY: ${{ secrets.CACTUS_CLOUD_API_KEY }} + E2E_AUDIO_SECS: 20 diff --git a/Cargo.lock b/Cargo.lock index 4d557a748c..f5a42f463e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -15600,6 +15600,21 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" +[[package]] +name = "sequential-macro" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb5facc5f409a55d25bf271c853402a00e1187097d326757043f5dd711944d07" + +[[package]] +name = "sequential-test" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d9c0d773bc7e7733264f460e5dfa00b2510421ddd6284db0749eef8dfb79e9" +dependencies = [ + "sequential-macro", +] + [[package]] name = "serde" version = "1.0.228" @@ -20186,6 +20201,8 @@ dependencies = [ "futures-util", "language", "owhisper-interface", + "reqwest 0.13.2", + "sequential-test", "serde_json", "serde_qs", "tempfile", diff --git a/crates/cactus/src/model.rs b/crates/cactus/src/model.rs index ca8c04c7ad..4b523caf29 100644 --- a/crates/cactus/src/model.rs +++ b/crates/cactus/src/model.rs @@ -8,6 +8,7 @@ use crate::error::{Error, Result}; pub struct Model { handle: NonNull, inference_lock: Mutex<()>, + is_moonshine: bool, } unsafe impl Send for Model {} @@ -39,9 +40,16 @@ impl ModelBuilder { let handle = NonNull::new(raw).ok_or_else(|| Error::Init("cactus_init returned null".into()))?; + let is_moonshine = self + .model_path + .to_string_lossy() + .to_lowercase() + .contains("moonshine"); + Ok(Model { handle, inference_lock: Mutex::new(()), + is_moonshine, }) } } @@ -57,6 +65,10 @@ impl Model { Self::builder(model_path).build() } + pub fn is_moonshine(&self) -> bool { + self.is_moonshine + } + /// Cancel an in-progress inference. Safe to call concurrently — only sets an /// atomic flag on the C++ side. pub fn stop(&self) { diff --git a/crates/cactus/src/stt/batch.rs b/crates/cactus/src/stt/batch.rs index 282d1d0613..19b9de0038 100644 --- a/crates/cactus/src/stt/batch.rs +++ b/crates/cactus/src/stt/batch.rs @@ -20,7 +20,12 @@ impl Model { options: &TranscribeOptions, ) -> Result { let guard = self.lock_inference(); - let prompt_c = CString::new(build_whisper_prompt(options))?; + let prompt = if self.is_moonshine() { + String::new() + } else { + build_whisper_prompt(options) + }; + let prompt_c = CString::new(prompt)?; let options_c = CString::new(serde_json::to_string(options)?)?; let mut buf = vec![0u8; RESPONSE_BUF_SIZE]; diff --git a/crates/transcribe-cactus/Cargo.toml b/crates/transcribe-cactus/Cargo.toml index 8c670d2556..89d6b56e5b 100644 --- a/crates/transcribe-cactus/Cargo.toml +++ b/crates/transcribe-cactus/Cargo.toml @@ -24,5 +24,14 @@ tower = { workspace = true } tracing = { workspace = true } [dev-dependencies] +hypr-audio-utils = { workspace = true } +hypr-cactus = { workspace = true } hypr-data = { workspace = true } + +axum = { workspace = true, features = ["ws"] } +futures-util = { workspace = true } +reqwest = { workspace = true, features = ["json"] } +sequential-test = "0.2" +serde_json = { workspace = true } +tokio = { workspace = true } tokio-tungstenite = { workspace = true } diff --git a/crates/transcribe-cactus/src/config.rs b/crates/transcribe-cactus/src/config.rs index f1fe3e578f..9bab874233 100644 --- a/crates/transcribe-cactus/src/config.rs +++ b/crates/transcribe-cactus/src/config.rs @@ -1,13 +1,13 @@ #[derive(Clone, Debug)] pub struct CactusConfig { - pub cloud_handoff: bool, + pub cloud: hypr_cactus::CloudConfig, pub min_chunk_sec: f32, } impl Default for CactusConfig { fn default() -> Self { Self { - cloud_handoff: true, + cloud: hypr_cactus::CloudConfig::default(), min_chunk_sec: 2.5, } } diff --git a/crates/transcribe-cactus/src/service/batch.rs b/crates/transcribe-cactus/src/service/batch.rs index 6e3b31bc52..c38ce514ef 100644 --- a/crates/transcribe-cactus/src/service/batch.rs +++ b/crates/transcribe-cactus/src/service/batch.rs @@ -320,9 +320,14 @@ mod tests { println!("{}", alternative.transcript.trim()); println!("--- END (confidence={:.2}) ---\n", alternative.confidence); + let transcript = alternative.transcript.trim().to_lowercase(); + assert!(!transcript.is_empty(), "expected non-empty transcript"); assert!( - !alternative.transcript.trim().is_empty(), - "expected non-empty transcript" + transcript.contains("maybe") + || transcript.contains("this") + || transcript.contains("talking"), + "transcript looks like a hallucination (got: {:?})", + transcript ); assert!( alternative.confidence.is_finite(), diff --git a/crates/transcribe-cactus/src/service/streaming/message.rs b/crates/transcribe-cactus/src/service/streaming/message.rs index a545be110d..3e1f79a4b9 100644 --- a/crates/transcribe-cactus/src/service/streaming/message.rs +++ b/crates/transcribe-cactus/src/service/streaming/message.rs @@ -56,3 +56,124 @@ pub(super) fn process_incoming_message(msg: &Message, channels: u8) -> IncomingM _ => IncomingMessage::Audio(AudioExtract::Empty), } } + +#[cfg(test)] +mod tests { + use axum::extract::ws::Message; + use owhisper_interface::ControlMessage; + + use super::*; + + #[test] + fn control_message_finalize_parsed() { + let msg = Message::Text(r#"{"type":"Finalize"}"#.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Control(ControlMessage::Finalize) => {} + other => panic!( + "expected Finalize, got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn control_message_keep_alive_parsed() { + let msg = Message::Text(r#"{"type":"KeepAlive"}"#.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Control(ControlMessage::KeepAlive) => {} + other => panic!( + "expected KeepAlive, got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn control_message_close_stream_parsed() { + let msg = Message::Text(r#"{"type":"CloseStream"}"#.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Control(ControlMessage::CloseStream) => {} + other => panic!( + "expected CloseStream, got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn audio_chunk_parsed_over_control() { + let chunk = owhisper_interface::ListenInputChunk::End; + let json = serde_json::to_string(&chunk).unwrap(); + let msg = Message::Text(json.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Audio(AudioExtract::End) => {} + other => panic!( + "expected Audio(End), got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn close_frame_yields_end() { + let msg = Message::Close(None); + match process_incoming_message(&msg, 1) { + IncomingMessage::Audio(AudioExtract::End) => {} + other => panic!( + "expected Audio(End), got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn binary_single_channel_yields_mono() { + let samples: Vec = vec![1000, 2000, 3000]; + let data: Vec = samples.iter().flat_map(|s| s.to_le_bytes()).collect(); + let msg = Message::Binary(data.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Audio(AudioExtract::Mono(s)) => assert!(!s.is_empty()), + other => panic!( + "expected Audio(Mono), got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn binary_dual_channel_yields_dual() { + // 2 interleaved i16 samples (4 bytes per frame: ch0, ch1) + let samples: Vec = vec![1000, -1000, 2000, -2000]; + let data: Vec = samples.iter().flat_map(|s| s.to_le_bytes()).collect(); + let msg = Message::Binary(data.into()); + match process_incoming_message(&msg, 2) { + IncomingMessage::Audio(AudioExtract::Dual { ch0, ch1 }) => { + assert_eq!(ch0.len(), 2); + assert_eq!(ch1.len(), 2); + assert!(ch0[0] > 0.0); + assert!(ch1[0] < 0.0); + } + other => panic!( + "expected Audio(Dual), got {:?}", + std::mem::discriminant(&other) + ), + } + } + + #[test] + fn dual_audio_json_yields_dual() { + let chunk = owhisper_interface::ListenInputChunk::DualAudio { + mic: vec![0x00, 0x10], + speaker: vec![0x00, 0x20], + }; + let json = serde_json::to_string(&chunk).unwrap(); + let msg = Message::Text(json.into()); + match process_incoming_message(&msg, 1) { + IncomingMessage::Audio(AudioExtract::Dual { .. }) => {} + other => panic!( + "expected Audio(Dual), got {:?}", + std::mem::discriminant(&other) + ), + } + } +} diff --git a/crates/transcribe-cactus/src/service/streaming/mod.rs b/crates/transcribe-cactus/src/service/streaming/mod.rs index b429f4cca7..e6131389de 100644 --- a/crates/transcribe-cactus/src/service/streaming/mod.rs +++ b/crates/transcribe-cactus/src/service/streaming/mod.rs @@ -2,9 +2,6 @@ mod message; mod response; mod session; -#[cfg(test)] -mod tests; - use std::{ future::Future, path::PathBuf, diff --git a/crates/transcribe-cactus/src/service/streaming/response.rs b/crates/transcribe-cactus/src/service/streaming/response.rs index 0e7b99aee1..852432454c 100644 --- a/crates/transcribe-cactus/src/service/streaming/response.rs +++ b/crates/transcribe-cactus/src/service/streaming/response.rs @@ -163,3 +163,172 @@ pub(super) fn format_timestamp_now() -> String { year, month, day, hours, minutes, seconds, millis ) } + +#[cfg(test)] +mod tests { + use std::path::Path; + + use owhisper_interface::stream::StreamResponse; + + use super::*; + + #[test] + fn session_metadata_has_required_fields() { + let meta = build_session_metadata(Path::new("/some/path/whisper-large-v3")); + assert!(!meta.request_id.is_empty()); + assert!(!meta.model_uuid.is_empty()); + assert_eq!(meta.model_info.name, "whisper-large-v3"); + assert_eq!(meta.model_info.arch, "cactus"); + assert!(meta.extra.is_some()); + } + + #[test] + fn format_timestamp_produces_iso8601() { + let ts = format_timestamp_now(); + assert!(ts.ends_with('Z')); + assert!(ts.contains('T')); + assert_eq!(ts.len(), 24); + } + + #[test] + fn transcript_response_serializes_as_results() { + let meta = build_session_metadata(Path::new("/models/whisper-small")); + let resp = build_transcript_response( + "hello world", + 0.0, + 1.5, + 0.95, + Some("en"), + true, + true, + false, + &meta, + &[0, 1], + None, + ); + + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["type"], "Results"); + assert_eq!(v["is_final"], true); + assert_eq!(v["speech_final"], true); + assert_eq!(v["from_finalize"], false); + assert_eq!(v["start"], 0.0); + assert_eq!(v["duration"], 1.5); + assert_eq!(v["channel"]["alternatives"][0]["transcript"], "hello world"); + assert_eq!( + v["channel"]["alternatives"][0]["words"] + .as_array() + .unwrap() + .len(), + 2 + ); + assert_eq!(v["channel"]["alternatives"][0]["languages"][0], "en"); + assert!(!v["metadata"]["request_id"].as_str().unwrap().is_empty()); + assert_eq!(v["metadata"]["model_info"]["name"], "whisper-small"); + assert_eq!(v["metadata"]["model_info"]["arch"], "cactus"); + assert!( + v["metadata"]["extra"]["started_unix_millis"] + .as_u64() + .is_some() + ); + assert_eq!(v["channel_index"], serde_json::json!([0, 1])); + } + + #[test] + fn transcript_response_from_finalize_flag() { + let meta = build_session_metadata(Path::new("/models/test")); + let resp = build_transcript_response( + "test", + 1.0, + 0.5, + 0.9, + None, + true, + true, + true, + &meta, + &[0, 2], + None, + ); + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["from_finalize"], true); + assert_eq!(v["channel_index"], serde_json::json!([0, 2])); + } + + #[test] + fn transcript_response_channel_index() { + let meta = build_session_metadata(Path::new("/models/test")); + let resp = build_transcript_response( + "speaker text", + 0.0, + 1.0, + 0.8, + None, + true, + true, + false, + &meta, + &[1, 2], + None, + ); + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["channel_index"], serde_json::json!([1, 2])); + } + + #[test] + fn terminal_response_serializes_as_metadata() { + let resp = StreamResponse::TerminalResponse { + request_id: "test-id".to_string(), + created: "2026-01-01T00:00:00.000Z".to_string(), + duration: 10.5, + channels: 1, + }; + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["type"], "Metadata"); + assert_eq!(v["request_id"], "test-id"); + assert_eq!(v["duration"], 10.5); + assert_eq!(v["channels"], 1); + } + + #[test] + fn error_response_serializes() { + let resp = StreamResponse::ErrorResponse { + error_code: None, + error_message: "model failed".to_string(), + provider: "cactus".to_string(), + }; + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["type"], "Error"); + assert_eq!(v["error_message"], "model failed"); + assert_eq!(v["provider"], "cactus"); + } + + #[test] + fn speech_started_response_serializes() { + let resp = StreamResponse::SpeechStartedResponse { + channel: vec![0], + timestamp: 1.23, + }; + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["type"], "SpeechStarted"); + assert_eq!(v["timestamp"], 1.23); + } + + #[test] + fn utterance_end_response_serializes() { + let resp = StreamResponse::UtteranceEndResponse { + channel: vec![0], + last_word_end: 5.67, + }; + let v: serde_json::Value = + serde_json::from_str(&serde_json::to_string(&resp).unwrap()).unwrap(); + assert_eq!(v["type"], "UtteranceEnd"); + assert_eq!(v["last_word_end"], 5.67); + } +} diff --git a/crates/transcribe-cactus/src/service/streaming/session.rs b/crates/transcribe-cactus/src/service/streaming/session.rs index c9aa0212f5..f6bdc39f05 100644 --- a/crates/transcribe-cactus/src/service/streaming/session.rs +++ b/crates/transcribe-cactus/src/service/streaming/session.rs @@ -69,24 +69,18 @@ pub(super) async fn handle_websocket( let mut worker_handles = Vec::with_capacity(total_channels); + let model = match hypr_cactus::Model::builder(&model_path).build() { + Ok(m) => std::sync::Arc::new(m), + Err(e) => { + tracing::error!(error = %e, "failed to load model"); + return; + } + }; + for ch_idx in 0..total_channels { - let model = match hypr_cactus::Model::builder(&model_path).build() { - Ok(m) => std::sync::Arc::new(m), - Err(e) => { - tracing::error!(error = %e, "failed to load model for channel {ch_idx}"); - return; - } - }; - let cloud_config = if cactus_config.cloud_handoff { - hypr_cactus::CloudConfig::default() - } else { - hypr_cactus::CloudConfig { - threshold: Some(0.0), - ..Default::default() - } - }; + let cloud_config = cactus_config.cloud.clone(); let (audio_tx, event_rx, cancel_token, handle) = hypr_cactus::transcribe_stream( - model, + model.clone(), options.clone(), cloud_config, chunk_size_ms, diff --git a/crates/transcribe-cactus/src/service/streaming/tests.rs b/crates/transcribe-cactus/src/service/streaming/tests.rs deleted file mode 100644 index 0bdb741534..0000000000 --- a/crates/transcribe-cactus/src/service/streaming/tests.rs +++ /dev/null @@ -1,588 +0,0 @@ -use std::path::{Path, PathBuf}; -use std::time::Duration; - -use axum::extract::ws::Message; -use axum::{Router, error_handling::HandleError, http::StatusCode}; -use futures_util::{SinkExt, StreamExt}; -use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage}; - -use hypr_audio_utils::bytes_to_f32_samples; -use owhisper_interface::ControlMessage; -use owhisper_interface::stream::StreamResponse; - -use super::TranscribeService; -use super::message::{AudioExtract, IncomingMessage, process_incoming_message}; -use super::response::{build_session_metadata, build_transcript_response, format_timestamp_now}; -use super::session::SAMPLE_RATE; - -#[test] -fn control_message_finalize_parsed() { - let msg = Message::Text(r#"{"type":"Finalize"}"#.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Control(ControlMessage::Finalize) => {} - other => panic!( - "expected Finalize, got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn control_message_keep_alive_parsed() { - let msg = Message::Text(r#"{"type":"KeepAlive"}"#.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Control(ControlMessage::KeepAlive) => {} - other => panic!( - "expected KeepAlive, got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn control_message_close_stream_parsed() { - let msg = Message::Text(r#"{"type":"CloseStream"}"#.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Control(ControlMessage::CloseStream) => {} - other => panic!( - "expected CloseStream, got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn audio_chunk_parsed_over_control() { - let chunk = owhisper_interface::ListenInputChunk::End; - let json = serde_json::to_string(&chunk).unwrap(); - let msg = Message::Text(json.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Audio(AudioExtract::End) => {} - other => panic!( - "expected Audio(End), got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn close_frame_yields_end() { - let msg = Message::Close(None); - match process_incoming_message(&msg, 1) { - IncomingMessage::Audio(AudioExtract::End) => {} - other => panic!( - "expected Audio(End), got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn binary_single_channel_yields_mono() { - let samples: Vec = vec![1000, 2000, 3000]; - let data: Vec = samples.iter().flat_map(|s| s.to_le_bytes()).collect(); - let msg = Message::Binary(data.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Audio(AudioExtract::Mono(s)) => assert!(!s.is_empty()), - other => panic!( - "expected Audio(Mono), got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn binary_dual_channel_yields_dual() { - // 2 interleaved i16 samples (4 bytes per frame: ch0, ch1) - let samples: Vec = vec![1000, -1000, 2000, -2000]; - let data: Vec = samples.iter().flat_map(|s| s.to_le_bytes()).collect(); - let msg = Message::Binary(data.into()); - match process_incoming_message(&msg, 2) { - IncomingMessage::Audio(AudioExtract::Dual { ch0, ch1 }) => { - assert_eq!(ch0.len(), 2); - assert_eq!(ch1.len(), 2); - assert!(ch0[0] > 0.0); - assert!(ch1[0] < 0.0); - } - other => panic!( - "expected Audio(Dual), got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn dual_audio_json_yields_dual() { - let chunk = owhisper_interface::ListenInputChunk::DualAudio { - mic: vec![0x00, 0x10], - speaker: vec![0x00, 0x20], - }; - let json = serde_json::to_string(&chunk).unwrap(); - let msg = Message::Text(json.into()); - match process_incoming_message(&msg, 1) { - IncomingMessage::Audio(AudioExtract::Dual { .. }) => {} - other => panic!( - "expected Audio(Dual), got {:?}", - std::mem::discriminant(&other) - ), - } -} - -#[test] -fn transcript_response_serializes_as_results() { - let meta = build_session_metadata(Path::new("/models/whisper-small")); - let resp = build_transcript_response( - "hello world", - 0.0, - 1.5, - 0.95, - Some("en"), - true, - true, - false, - &meta, - &[0, 1], - None, - ); - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - - assert_eq!(v["type"], "Results"); - assert_eq!(v["is_final"], true); - assert_eq!(v["speech_final"], true); - assert_eq!(v["from_finalize"], false); - assert_eq!(v["start"], 0.0); - assert_eq!(v["duration"], 1.5); - assert_eq!(v["channel"]["alternatives"][0]["transcript"], "hello world"); - assert_eq!( - v["channel"]["alternatives"][0]["words"] - .as_array() - .unwrap() - .len(), - 2 - ); - assert_eq!(v["channel"]["alternatives"][0]["languages"][0], "en"); - assert!(!v["metadata"]["request_id"].as_str().unwrap().is_empty()); - assert_eq!(v["metadata"]["model_info"]["name"], "whisper-small"); - assert_eq!(v["metadata"]["model_info"]["arch"], "cactus"); - assert!( - v["metadata"]["extra"]["started_unix_millis"] - .as_u64() - .is_some() - ); - assert_eq!(v["channel_index"], serde_json::json!([0, 1])); -} - -#[test] -fn transcript_response_from_finalize_flag() { - let meta = build_session_metadata(Path::new("/models/test")); - let resp = build_transcript_response( - "test", - 1.0, - 0.5, - 0.9, - None, - true, - true, - true, - &meta, - &[0, 2], - None, - ); - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(v["from_finalize"], true); - assert_eq!(v["channel_index"], serde_json::json!([0, 2])); -} - -#[test] -fn transcript_response_channel_1_of_2() { - let meta = build_session_metadata(Path::new("/models/test")); - let resp = build_transcript_response( - "speaker text", - 0.0, - 1.0, - 0.8, - None, - true, - true, - false, - &meta, - &[1, 2], - None, - ); - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(v["channel_index"], serde_json::json!([1, 2])); -} - -#[test] -fn terminal_response_serializes_as_metadata() { - let resp = StreamResponse::TerminalResponse { - request_id: "test-id".to_string(), - created: "2026-01-01T00:00:00.000Z".to_string(), - duration: 10.5, - channels: 1, - }; - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - - assert_eq!(v["type"], "Metadata"); - assert_eq!(v["request_id"], "test-id"); - assert_eq!(v["duration"], 10.5); - assert_eq!(v["channels"], 1); -} - -#[test] -fn error_response_serializes() { - let resp = StreamResponse::ErrorResponse { - error_code: None, - error_message: "model failed".to_string(), - provider: "cactus".to_string(), - }; - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - - assert_eq!(v["type"], "Error"); - assert_eq!(v["error_message"], "model failed"); - assert_eq!(v["provider"], "cactus"); -} - -#[test] -fn speech_started_response_serializes() { - let resp = StreamResponse::SpeechStartedResponse { - channel: vec![0], - timestamp: 1.23, - }; - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - - assert_eq!(v["type"], "SpeechStarted"); - assert_eq!(v["timestamp"], 1.23); -} - -#[test] -fn utterance_end_response_serializes() { - let resp = StreamResponse::UtteranceEndResponse { - channel: vec![0], - last_word_end: 5.67, - }; - - let json = serde_json::to_string(&resp).unwrap(); - let v: serde_json::Value = serde_json::from_str(&json).unwrap(); - - assert_eq!(v["type"], "UtteranceEnd"); - assert_eq!(v["last_word_end"], 5.67); -} - -#[test] -fn session_metadata_has_required_fields() { - let meta = build_session_metadata(Path::new("/some/path/whisper-large-v3")); - assert!(!meta.request_id.is_empty()); - assert!(!meta.model_uuid.is_empty()); - assert_eq!(meta.model_info.name, "whisper-large-v3"); - assert_eq!(meta.model_info.arch, "cactus"); - assert!(meta.extra.is_some()); -} - -#[test] -fn format_timestamp_produces_iso8601() { - let ts = format_timestamp_now(); - assert!(ts.ends_with('Z')); - assert!(ts.contains('T')); - assert_eq!(ts.len(), 24); -} - -// cargo test -p transcribe-cactus e2e_streaming -- --ignored --nocapture -#[ignore = "requires local cactus model files"] -#[test] -fn e2e_streaming() { - let model_path = std::env::var("CACTUS_STT_MODEL") - .unwrap_or_else(|_| "/tmp/cactus-model/moonshine-base-cactus".to_string()); - let model_path = PathBuf::from(model_path); - assert!( - model_path.exists(), - "model not found: {}", - model_path.display() - ); - - let model = std::sync::Arc::new(hypr_cactus::Model::new(&model_path).unwrap()); - let options = hypr_cactus::TranscribeOptions::default(); - let chunk_size_ms = 300u32; - - let (audio_tx, mut event_stream, _cancel, _worker_handle) = hypr_cactus::transcribe_stream( - model, - options, - hypr_cactus::CloudConfig::default(), - chunk_size_ms, - SAMPLE_RATE, - ); - - let samples = bytes_to_f32_samples(hypr_data::english_1::AUDIO); - let chunk_size = 8_000; - let total_chunks = (samples.len() + chunk_size - 1) / chunk_size; - let audio_duration = samples.len() as f64 / 16_000.0; - println!( - "\n--- feeding ALL audio: {:.1}s ({} chunks of {:.1}s) ---", - audio_duration, - total_chunks, - chunk_size as f64 / 16_000.0, - ); - - let t0 = std::time::Instant::now(); - - let sender = std::thread::spawn(move || { - for (i, chunk) in samples.chunks(chunk_size).enumerate() { - audio_tx.blocking_send(chunk.to_vec()).expect("send failed"); - if i % 20 == 0 { - println!( - "[{:>6.1}s] sent chunk {}/{}", - t0.elapsed().as_secs_f64(), - i, - total_chunks - ); - } - } - println!( - "[{:>6.1}s] all {} chunks sent", - t0.elapsed().as_secs_f64(), - total_chunks - ); - }); - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("failed to create tokio runtime"); - - let t0 = std::time::Instant::now(); - let mut full_transcript = String::new(); - let mut event_count = 0u32; - while let Some(event) = rt.block_on(event_stream.next()) { - match event { - Ok(hypr_cactus::TranscribeEvent { result: r, .. }) => { - let confirmed = r.confirmed.trim(); - let pending = r.pending.trim(); - if !confirmed.is_empty() || !pending.is_empty() { - println!( - "[{:>6.1}s] confirmed={:?} pending={:?}", - t0.elapsed().as_secs_f64(), - confirmed, - pending, - ); - } - if !confirmed.is_empty() { - if !full_transcript.is_empty() { - full_transcript.push(' '); - } - full_transcript.push_str(confirmed); - } - event_count += 1; - } - Err(e) => panic!("streaming error: {e}"), - } - } - - sender.join().expect("sender thread panicked"); - let elapsed = t0.elapsed().as_secs_f64(); - println!( - "\n--- FULL TRANSCRIPT ({:.1}s audio, {:.1}s wall, {:.1}x realtime) ---", - audio_duration, - elapsed, - elapsed / audio_duration - ); - println!("{full_transcript}"); - println!("--- END ({event_count} events) ---\n"); - assert!(!full_transcript.is_empty(), "expected non-empty transcript"); -} - -// cargo test -p transcribe-cactus e2e_websocket -- --ignored --nocapture -#[ignore = "requires local cactus model files"] -#[test] -fn e2e_websocket_listen_with_real_model_inference() { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("failed to create tokio runtime"); - rt.block_on(async { - let model_path = std::env::var("CACTUS_STT_MODEL") - .unwrap_or_else(|_| "/tmp/cactus-model/moonshine-base-cactus".to_string()); - let model_path = PathBuf::from(model_path); - assert!( - model_path.exists(), - "model not found: {}", - model_path.display() - ); - - let service = HandleError::new( - TranscribeService::builder().model_path(model_path).build(), - |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, - ); - let app = Router::new().route_service("/v1/listen", service); - - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .unwrap(); - let addr = listener.local_addr().unwrap(); - - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); - tokio::spawn(async move { - let _ = axum::serve(listener, app) - .with_graceful_shutdown(async { - let _ = shutdown_rx.await; - }) - .await; - }); - - let ws_url = format!( - "ws://{}/v1/listen?channels=1&sample_rate=16000&chunk_size_ms=300", - addr - ); - let (ws_stream, _) = connect_async(&ws_url).await.expect("ws connect failed"); - let (mut ws_tx, mut ws_rx) = ws_stream.split(); - - let t0 = std::time::Instant::now(); - let chunk_bytes = 32_000; // 1s of 16kHz i16 PCM per chunk - let num_chunks = 5; - println!( - "\n--- ws: sending {} chunks of {}B ({:.1}s each) ---", - num_chunks, - chunk_bytes, - chunk_bytes as f64 / 32_000.0, - ); - - let (close_tx, close_rx) = tokio::sync::oneshot::channel::<()>(); - let close_tx = std::cell::Cell::new(Some(close_tx)); - - let writer = tokio::spawn(async move { - for (i, chunk) in hypr_data::english_1::AUDIO - .chunks(chunk_bytes) - .take(num_chunks) - .enumerate() - { - ws_tx - .send(WsMessage::Binary(chunk.to_vec().into())) - .await - .unwrap(); - println!("[{:>5.1}s] ws sent chunk {}", t0.elapsed().as_secs_f64(), i); - } - - let _ = close_rx.await; - println!( - "[{:>5.1}s] ws sending CloseStream", - t0.elapsed().as_secs_f64() - ); - let _ = ws_tx - .send(WsMessage::Text( - r#"{"type":"CloseStream"}"#.to_string().into(), - )) - .await; - }); - - let mut results_count = 0u32; - let mut saw_terminal = false; - let mut saw_error: Option = None; - let mut close_sent = false; - - while let Ok(Some(Ok(msg))) = - tokio::time::timeout(Duration::from_secs(60), ws_rx.next()).await - { - match msg { - WsMessage::Text(text) => { - let Ok(v) = serde_json::from_str::(&text) else { - continue; - }; - let msg_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("?"); - match msg_type { - "Results" => { - let transcript = v - .pointer("/channel/alternatives/0/transcript") - .and_then(|t| t.as_str()) - .unwrap_or(""); - let is_final = - v.get("is_final").and_then(|f| f.as_bool()).unwrap_or(false); - let speech_final = v - .get("speech_final") - .and_then(|f| f.as_bool()) - .unwrap_or(false); - println!( - "[{:>5.1}s] ws recv Results is_final={:<5} speech_final={:<5} {:?}", - t0.elapsed().as_secs_f64(), - is_final, - speech_final, - transcript, - ); - results_count += 1; - - if results_count >= 3 && !close_sent { - close_sent = true; - if let Some(tx) = close_tx.take() { - let _ = tx.send(()); - } - } - } - "Metadata" => { - println!( - "[{:>5.1}s] ws recv Metadata (terminal)", - t0.elapsed().as_secs_f64() - ); - saw_terminal = true; - break; - } - "SpeechStarted" => { - println!( - "[{:>5.1}s] ws recv SpeechStarted", - t0.elapsed().as_secs_f64() - ); - } - "UtteranceEnd" => { - println!( - "[{:>5.1}s] ws recv UtteranceEnd", - t0.elapsed().as_secs_f64() - ); - } - "Error" => { - saw_error = v - .get("error_message") - .and_then(|m| m.as_str()) - .map(str::to_owned); - println!( - "[{:>5.1}s] ws recv Error: {:?}", - t0.elapsed().as_secs_f64(), - saw_error - ); - break; - } - other => { - println!( - "[{:>5.1}s] ws recv {other}", - t0.elapsed().as_secs_f64() - ); - } - } - } - WsMessage::Close(_) => break, - _ => {} - } - } - - let _ = writer.await; - let _ = shutdown_tx.send(()); - println!( - "[{:>5.1}s] done ({} Results, terminal={})\n", - t0.elapsed().as_secs_f64(), - results_count, - saw_terminal, - ); - - assert!(saw_error.is_none(), "ws error: {:?}", saw_error); - assert!(results_count > 0, "expected at least one Results message"); - assert!(saw_terminal, "expected terminal Metadata message"); - }); -} diff --git a/crates/transcribe-cactus/tests/batch.rs b/crates/transcribe-cactus/tests/batch.rs new file mode 100644 index 0000000000..e0f5cd93a2 --- /dev/null +++ b/crates/transcribe-cactus/tests/batch.rs @@ -0,0 +1,107 @@ +mod common; + +use axum::error_handling::HandleError; +use axum::{Router, http::StatusCode}; + +fn audio_wav_bytes() -> Vec { + let max_secs = std::env::var("E2E_AUDIO_SECS") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(usize::MAX); + let full = hypr_data::english_1::AUDIO; + let pcm = &full[..full.len().min(max_secs.saturating_mul(16000 * 2))]; + let data_len = pcm.len() as u32; + let mut w = Vec::with_capacity(44 + pcm.len()); + w.extend_from_slice(b"RIFF"); + w.extend_from_slice(&(36 + data_len).to_le_bytes()); + w.extend_from_slice(b"WAVE"); + w.extend_from_slice(b"fmt "); + w.extend_from_slice(&16u32.to_le_bytes()); + w.extend_from_slice(&1u16.to_le_bytes()); + w.extend_from_slice(&1u16.to_le_bytes()); + w.extend_from_slice(&16000u32.to_le_bytes()); + w.extend_from_slice(&32000u32.to_le_bytes()); + w.extend_from_slice(&2u16.to_le_bytes()); + w.extend_from_slice(&16u16.to_le_bytes()); + w.extend_from_slice(b"data"); + w.extend_from_slice(&data_len.to_le_bytes()); + w.extend_from_slice(pcm); + w +} + +use transcribe_cactus::TranscribeService; + +use common::model_path; + +#[ignore = "requires local cactus model files"] +#[test] +fn e2e_batch() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + let app = Router::new().route_service( + "/v1/listen", + HandleError::new( + TranscribeService::builder() + .model_path(model_path()) + .build(), + |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, + ), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + tokio::spawn(async move { + axum::serve(listener, app) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + let wav_bytes = audio_wav_bytes(); + + let url = format!("http://{}/v1/listen?channels=1&sample_rate=16000&language=en", addr); + let client = reqwest::Client::new(); + let response = client + .post(&url) + .header("content-type", "audio/wav") + .body(wav_bytes) + .send() + .await + .expect("request failed"); + + assert_eq!(response.status(), 200); + let v: serde_json::Value = response.json().await.expect("response is not JSON"); + + let transcript = v + .pointer("/results/channels/0/alternatives/0/transcript") + .and_then(|t| t.as_str()) + .unwrap_or(""); + + let transcript_lower = transcript.trim().to_lowercase(); + assert!( + !transcript_lower.is_empty(), + "expected non-empty transcript" + ); + assert!( + transcript_lower.contains("maybe") + || transcript_lower.contains("this") + || transcript_lower.contains("talking"), + "transcript looks like a hallucination (got: {:?})", + transcript_lower + ); + assert!( + v["metadata"]["duration"].as_f64().unwrap_or_default() > 0.0, + "expected positive duration in metadata" + ); + assert_eq!(v["metadata"]["channels"], 1); + + let _ = shutdown_tx.send(()); + }); +} diff --git a/crates/transcribe-cactus/tests/common/mod.rs b/crates/transcribe-cactus/tests/common/mod.rs new file mode 100644 index 0000000000..608b2594f3 --- /dev/null +++ b/crates/transcribe-cactus/tests/common/mod.rs @@ -0,0 +1,9 @@ +use std::path::PathBuf; + +pub fn model_path() -> PathBuf { + let path = std::env::var("CACTUS_STT_MODEL") + .unwrap_or_else(|_| "/tmp/cactus-model/moonshine-base-cactus".to_string()); + let path = PathBuf::from(path); + assert!(path.exists(), "model not found: {}", path.display()); + path +} diff --git a/crates/transcribe-cactus/tests/live.rs b/crates/transcribe-cactus/tests/live.rs new file mode 100644 index 0000000000..bbe0da2117 --- /dev/null +++ b/crates/transcribe-cactus/tests/live.rs @@ -0,0 +1,336 @@ +mod common; + +use std::time::Duration; + +fn e2e_audio_secs(default: usize) -> usize { + std::env::var("E2E_AUDIO_SECS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(default) +} + +fn scale_close_after(audio_secs: usize, default_audio: usize, default_close: u32) -> u32 { + let ratio = audio_secs as f64 / default_audio as f64; + ((default_close as f64 * ratio).ceil() as u32).max(1) +} + +use axum::error_handling::HandleError; +use axum::{Router, http::StatusCode}; +use futures_util::{SinkExt, StreamExt}; +use sequential_test::sequential; +use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage}; + +use hypr_cactus::CloudConfig; +use transcribe_cactus::{CactusConfig, TranscribeService}; + +use common::model_path; + +async fn run_single_channel_opts( + cactus_config: CactusConfig, + audio_secs: usize, + close_after_results: u32, + timeout_secs: u64, +) { + let app = Router::new().route_service( + "/v1/listen", + HandleError::new( + TranscribeService::builder() + .model_path(model_path()) + .cactus_config(cactus_config) + .build(), + |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, + ), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + tokio::spawn(async move { + axum::serve(listener, app) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + let ws_url = format!( + "ws://{}/v1/listen?channels=1&sample_rate=16000&chunk_size_ms=300", + addr + ); + let (ws, _) = connect_async(&ws_url).await.expect("ws connect failed"); + let (mut tx, mut rx) = ws.split(); + + let (close_tx, close_rx) = tokio::sync::oneshot::channel::<()>(); + let close_tx = std::cell::Cell::new(Some(close_tx)); + + let t0 = std::time::Instant::now(); + + let writer = tokio::spawn(async move { + let audio = hypr_data::english_1::AUDIO; + for chunk in audio.chunks(32_000).cycle().take(audio_secs) { + tx.send(WsMessage::Binary(chunk.to_vec().into())) + .await + .unwrap(); + } + let _ = close_rx.await; + let _ = tx + .send(WsMessage::Text( + r#"{"type":"CloseStream"}"#.to_string().into(), + )) + .await; + }); + + let mut results = 0u32; + let mut saw_terminal = false; + let mut close_sent = false; + + while let Ok(Some(Ok(msg))) = + tokio::time::timeout(Duration::from_secs(timeout_secs), rx.next()).await + { + match msg { + WsMessage::Text(text) => { + let Ok(v) = serde_json::from_str::(&text) else { + continue; + }; + match v.get("type").and_then(|t| t.as_str()).unwrap_or("") { + "Results" => { + let transcript = v + .pointer("/channel/alternatives/0/transcript") + .and_then(|t| t.as_str()) + .unwrap_or(""); + let is_final = v.get("is_final").and_then(|f| f.as_bool()).unwrap_or(false); + let cloud_corrected = v + .pointer("/metadata/extra/cloud_corrected") + .and_then(|b| b.as_bool()) + .unwrap_or(false); + println!( + "[{:>5.1}s] is_final={} cloud={} {:?}", + t0.elapsed().as_secs_f64(), + is_final, + cloud_corrected, + transcript, + ); + if is_final { + results += 1; + } + if results >= close_after_results && !close_sent { + close_sent = true; + if let Some(tx) = close_tx.take() { + let _ = tx.send(()); + } + } + } + "Metadata" => { + println!("[{:>5.1}s] terminal", t0.elapsed().as_secs_f64()); + saw_terminal = true; + break; + } + "Error" => panic!("ws error: {:?}", v.get("error_message")), + _ => {} + } + } + WsMessage::Close(_) => break, + _ => {} + } + } + + let _ = writer.await; + let _ = shutdown_tx.send(()); + + assert!(results > 0, "expected Results messages"); + assert!(saw_terminal, "expected terminal Metadata message"); +} + +async fn run_single_channel(cactus_config: CactusConfig) { + let secs = e2e_audio_secs(100); + run_single_channel_opts(cactus_config, secs, scale_close_after(secs, 100, 3), 120).await; +} + +#[ignore = "requires local cactus model files"] +#[sequential] +#[test] +fn e2e_websocket_no_handoff() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(run_single_channel(CactusConfig { + cloud: CloudConfig { + threshold: Some(0.0), + ..Default::default() + }, + ..Default::default() + })); +} + +#[ignore = "requires local cactus model files"] +#[sequential] +#[test] +fn e2e_websocket_with_handoff() { + let api_key = std::env::var("CACTUS_CLOUD_API_KEY") + .expect("CACTUS_CLOUD_API_KEY must be set for this test"); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let secs = e2e_audio_secs(120); + rt.block_on(run_single_channel_opts( + CactusConfig { + // Well below model defaults (Whisper=0.4, Moonshine=0.35) to trigger aggressively + cloud: CloudConfig { + api_key: Some(api_key), + threshold: Some(0.05), + }, + ..Default::default() + }, + secs, + scale_close_after(secs, 120, 30), + 180, + )); +} + +#[ignore = "requires local cactus model files"] +#[sequential] +#[test] +fn e2e_websocket_dual_channel_no_handoff() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let dual_secs = e2e_audio_secs(100); + let dual_close = scale_close_after(dual_secs, 100, 6); + rt.block_on(async move { + let app = Router::new().route_service( + "/v1/listen", + HandleError::new( + TranscribeService::builder() + .model_path(model_path()) + .cactus_config(CactusConfig { + cloud: CloudConfig { + threshold: Some(0.0), + ..Default::default() + }, + ..Default::default() + }) + .build(), + |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, + ), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + tokio::spawn(async move { + axum::serve(listener, app) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + let ws_url = format!( + "ws://{}/v1/listen?channels=2&sample_rate=16000&chunk_size_ms=300", + addr + ); + let (ws, _) = connect_async(&ws_url).await.expect("ws connect failed"); + let (mut tx, mut rx) = ws.split(); + + let audio = hypr_data::english_1::AUDIO; + let num_frames = audio.len() / 2; + let mut interleaved = Vec::with_capacity(num_frames * 4); + for i in 0..num_frames { + interleaved.extend_from_slice(&audio[i * 2..i * 2 + 2]); + interleaved.extend_from_slice(&audio[i * 2..i * 2 + 2]); + } + + let (close_tx, close_rx) = tokio::sync::oneshot::channel::<()>(); + let close_tx = std::cell::Cell::new(Some(close_tx)); + + let t0 = std::time::Instant::now(); + + let writer = tokio::spawn(async move { + for chunk in interleaved.chunks(64_000).cycle().take(dual_secs) { + tx.send(WsMessage::Binary(chunk.to_vec().into())) + .await + .unwrap(); + } + let _ = close_rx.await; + let _ = tx + .send(WsMessage::Text( + r#"{"type":"CloseStream"}"#.to_string().into(), + )) + .await; + }); + + let mut results = 0u32; + let mut saw_terminal = false; + let mut close_sent = false; + let mut channels_seen = std::collections::HashSet::new(); + + while let Ok(Some(Ok(msg))) = + tokio::time::timeout(Duration::from_secs(120), rx.next()).await + { + match msg { + WsMessage::Text(text) => { + let Ok(v) = serde_json::from_str::(&text) else { + continue; + }; + match v.get("type").and_then(|t| t.as_str()).unwrap_or("") { + "Results" => { + let ch = v + .pointer("/channel_index/0") + .and_then(|c| c.as_i64()) + .unwrap_or(-1); + let transcript = v + .pointer("/channel/alternatives/0/transcript") + .and_then(|t| t.as_str()) + .unwrap_or(""); + let is_final = + v.get("is_final").and_then(|f| f.as_bool()).unwrap_or(false); + println!( + "[{:>5.1}s] ch={} is_final={} {:?}", + t0.elapsed().as_secs_f64(), + ch, + is_final, + transcript, + ); + channels_seen.insert(ch); + results += 1; + if results >= dual_close && !close_sent { + close_sent = true; + if let Some(tx) = close_tx.take() { + let _ = tx.send(()); + } + } + } + "Metadata" => { + println!("[{:>5.1}s] terminal", t0.elapsed().as_secs_f64()); + saw_terminal = true; + break; + } + "Error" => panic!("ws error: {:?}", v.get("error_message")), + _ => {} + } + } + WsMessage::Close(_) => break, + _ => {} + } + } + + let _ = writer.await; + let _ = shutdown_tx.send(()); + + assert!(results > 0, "expected Results messages"); + assert!(saw_terminal, "expected terminal Metadata message"); + assert!( + channels_seen.contains(&0) && channels_seen.contains(&1), + "expected results from both channels, got {channels_seen:?}", + ); + }); +}