From ff6e901a4e447a69051dac7f07c8522a74a99397 Mon Sep 17 00:00:00 2001 From: Sergey Nikolaev Date: Sat, 7 Feb 2026 17:28:31 +0000 Subject: [PATCH 1/5] fix: add local-model concurrency guard and expand model tests - Serialize causal model predictions to avoid concurrent Qwen crashes - Add FFI concurrency tests for Qwen and other local models - Auto-download required models for integration tests - Document test cache env vars in README --- embeddings/README.md | 9 +++ embeddings/src/model/ffi_test.rs | 114 +++++++++++++++++++++++++++++ embeddings/src/model/local.rs | 12 +++ embeddings/src/model/local_test.rs | 30 ++++++++ 4 files changed, 165 insertions(+) diff --git a/embeddings/README.md b/embeddings/README.md index b280fac1..46bdb7fa 100644 --- a/embeddings/README.md +++ b/embeddings/README.md @@ -14,3 +14,12 @@ cargo build --lib --release g++ -o test examples/test.cpp -Ltarget/release -lmanticore_knn_embeddings -I. -lpthread -ldl -std=c++17 ``` +## Testing + +Some integration tests download model files into a cache directory if they are missing. You can +override the cache location with environment variables: + +- `MANTICORE_TEST_CACHE`: preferred cache path for tests +- `MANTICORE_CACHE_PATH`: fallback cache path for tests + +If neither is set, tests use `./.cache/manticore` under the repo. diff --git a/embeddings/src/model/ffi_test.rs b/embeddings/src/model/ffi_test.rs index bd145cae..7113cec4 100644 --- a/embeddings/src/model/ffi_test.rs +++ b/embeddings/src/model/ffi_test.rs @@ -6,6 +6,10 @@ use std::ptr; #[cfg(test)] mod tests { use super::*; + use crate::model::local::build_model_info; + use std::collections::HashSet; + use std::path::PathBuf; + use std::sync::{Mutex, OnceLock}; // Helper function to create a C string from Rust string fn to_c_string(s: &str) -> CString { @@ -20,6 +24,97 @@ mod tests { } } + fn test_cache_root() -> String { + std::env::var("MANTICORE_TEST_CACHE") + .or_else(|_| std::env::var("MANTICORE_CACHE_PATH")) + .unwrap_or_else(|_| format!("{}/.cache/manticore", env!("CARGO_MANIFEST_DIR"))) + } + + fn ensure_model_cached(model_id: &str, cache_path: &PathBuf) { + static DOWNLOADED: OnceLock>> = OnceLock::new(); + let downloaded = DOWNLOADED.get_or_init(|| Mutex::new(HashSet::new())); + let mut set = downloaded.lock().expect("model cache lock poisoned"); + if set.contains(model_id) { + return; + } + std::fs::create_dir_all(cache_path).expect("failed to create model cache directory"); + build_model_info(cache_path.clone(), model_id, "main") + .expect("failed to download model into cache"); + set.insert(model_id.to_string()); + } + + fn run_concurrent_ffi_embeddings(model_id: &str) { + use std::sync::Arc; + use std::thread; + + let model_id = model_id.to_string(); + let cache_root = test_cache_root(); + let cache_path_buf = PathBuf::from(&cache_root); + ensure_model_cached(&model_id, &cache_path_buf); + + let model_name = to_c_string(&model_id); + let cache_path = to_c_string(&cache_root); + let api_key = to_c_string(""); + + let result = TextModelWrapper::load_model( + model_name.as_ptr(), + model_name.as_bytes().len(), + cache_path.as_ptr(), + cache_path.as_bytes().len(), + api_key.as_ptr(), + api_key.as_bytes().len(), + false, + ); + + if result.model.is_null() { + let error_message = if result.error.is_null() { + "unknown error".to_string() + } else { + unsafe { + CStr::from_ptr(result.error) + .to_str() + .unwrap_or("unknown error") + .to_string() + } + }; + TextModelWrapper::free_model_result(result); + panic!("failed to load model {}: {}", model_id, error_message); + } + + let model_ptr = result.model as usize; + let start = Arc::new(std::sync::Barrier::new(4)); + let handles: Vec<_> = (0..3) + .map(|i| { + let start = Arc::clone(&start); + let model_ptr = model_ptr; + let model_id = model_id.clone(); + thread::spawn(move || { + start.wait(); + let text = format!("Concurrent embedding test {} - {}", model_id, i); + let item = create_string_item(&text); + let items = [item]; + + // Safety: emulate FFI callers that share a model pointer across threads. + let wrapper = unsafe { + std::mem::transmute::<*mut std::ffi::c_void, TextModelWrapper>( + model_ptr as *mut std::ffi::c_void, + ) + }; + let vec_result = + TextModelWrapper::make_vect_embeddings(&wrapper, items.as_ptr(), 1); + TextModelWrapper::free_vec_result(vec_result); + }) + }) + .collect(); + + start.wait(); + for handle in handles { + handle.join().unwrap(); + } + + TextModelWrapper::free_model_result(result); + } + #[test] fn test_text_model_result_structure() { // Test that TextModelResult has the expected structure @@ -367,4 +462,23 @@ mod tests { assert_eq!(options2.api_key, Some("sk-test456".to_string())); assert_eq!(options2.use_gpu, None); } + + #[test] + fn test_concurrent_qwen_embeddings_via_ffi() { + run_concurrent_ffi_embeddings("Qwen/Qwen3-Embedding-0.6B"); + } + + #[test] + fn test_concurrent_other_models_via_ffi() { + let model_ids = [ + "sentence-transformers/all-MiniLM-L6-v2", + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "Locutusque/TinyMistral-248M-v2", + "h2oai/embeddinggemma-300m", + ]; + + for model_id in model_ids { + run_concurrent_ffi_embeddings(model_id); + } + } } diff --git a/embeddings/src/model/local.rs b/embeddings/src/model/local.rs index 7f4125f7..14123105 100644 --- a/embeddings/src/model/local.rs +++ b/embeddings/src/model/local.rs @@ -20,6 +20,7 @@ use serde_json::Value; use std::cell::RefCell; use std::error::Error; use std::path::PathBuf; +use std::sync::Mutex; use tokenizers::Tokenizer; /// Model architecture type - determines pooling strategy @@ -227,6 +228,7 @@ pub struct CausalEmbeddingModel { max_input_len: usize, hidden_size: usize, device: Device, + predict_lock: Mutex<()>, } impl CausalEmbeddingModel { @@ -312,6 +314,7 @@ impl CausalEmbeddingModel { max_input_len, hidden_size, device, + predict_lock: Mutex::new(()), }) } } @@ -375,6 +378,15 @@ impl LocalModel { impl TextModel for LocalModel { fn predict(&self, texts: &[&str]) -> Result>, Box> { + let _predict_guard = match self { + LocalModel::Causal(m) => Some( + m.predict_lock + .lock() + .map_err(|_| LibError::ModelLoadFailed)?, + ), + LocalModel::Bert(_) => None, + }; + let (device, max_input_len) = match self { LocalModel::Bert(m) => (m.device.clone(), m.max_input_len), LocalModel::Causal(m) => (m.device.clone(), m.max_input_len), diff --git a/embeddings/src/model/local_test.rs b/embeddings/src/model/local_test.rs index 558743c8..5d54c2e5 100644 --- a/embeddings/src/model/local_test.rs +++ b/embeddings/src/model/local_test.rs @@ -6,7 +6,9 @@ mod tests { use crate::model::TextModel; use crate::utils::{get_hidden_size, get_max_input_length}; use approx::assert_abs_diff_eq; + use std::collections::HashSet; use std::path::PathBuf; + use std::sync::{Mutex, OnceLock}; fn check_embedding_properties(embedding: &[f32], expected_len: usize) { assert_eq!(embedding.len(), expected_len); @@ -14,7 +16,26 @@ mod tests { assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-6); } + fn ensure_model_cached(model_id: &str, cache_path: &PathBuf) { + static DOWNLOADED: OnceLock>> = OnceLock::new(); + let downloaded = DOWNLOADED.get_or_init(|| Mutex::new(HashSet::new())); + let mut set = downloaded.lock().expect("model cache lock poisoned"); + if set.contains(model_id) { + return; + } + std::fs::create_dir_all(cache_path).expect("failed to create model cache directory"); + build_model_info(cache_path.clone(), model_id, "main") + .expect("failed to download model into cache"); + set.insert(model_id.to_string()); + } + fn test_cache_path() -> PathBuf { + if let Ok(path) = std::env::var("MANTICORE_TEST_CACHE") { + return PathBuf::from(path); + } + if let Ok(path) = std::env::var("MANTICORE_CACHE_PATH") { + return PathBuf::from(path); + } PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(".cache/manticore") } @@ -325,6 +346,7 @@ mod tests { fn test_all_minilm_l6_v2() { let model_id = "sentence-transformers/all-MiniLM-L6-v2"; let cache_path = test_cache_path(); + ensure_model_cached(model_id, &cache_path); let test_sentences = [ "This is a test sentence.", @@ -343,6 +365,7 @@ mod tests { fn test_embedding_consistency() { let model_id = "sentence-transformers/all-MiniLM-L6-v2"; let cache_path = test_cache_path(); + ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path, false).unwrap(); let sentence = &["This is a test sentence."]; @@ -358,6 +381,7 @@ mod tests { fn test_hidden_size() { let model_id = "sentence-transformers/all-MiniLM-L6-v2"; let cache_path = test_cache_path(); + ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path, false).unwrap(); assert_eq!(local_model.get_hidden_size(), 384); } @@ -366,6 +390,7 @@ mod tests { fn test_max_input_len() { let model_id = "sentence-transformers/all-MiniLM-L6-v2"; let cache_path = test_cache_path(); + ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path, false).unwrap(); assert_eq!(local_model.get_max_input_len(), 512); } @@ -375,6 +400,7 @@ mod tests { // Integration test for Qwen embedding models let model_id = "Qwen/Qwen3-Embedding-0.6B"; let cache_path = test_cache_path(); + ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path.clone(), false) .expect("Qwen model should load successfully"); @@ -395,6 +421,7 @@ mod tests { // Integration test for Llama-based embedding models. let model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"; let cache_path = test_cache_path(); + ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path.clone(), false).expect("Llama model should load"); @@ -410,6 +437,7 @@ mod tests { // Integration test for Mistral-based embedding models. let model_id = "Locutusque/TinyMistral-248M-v2"; let cache_path = test_cache_path(); + ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path.clone(), false) .expect("Mistral model should load"); @@ -424,6 +452,7 @@ mod tests { // Integration test for Gemma-based embedding models. let model_id = "h2oai/embeddinggemma-300m"; let cache_path = test_cache_path(); + ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path.clone(), false).expect("Gemma model should load"); @@ -438,6 +467,7 @@ mod tests { // Test batch processing with Qwen model let model_id = "Qwen/Qwen3-Embedding-0.6B"; let cache_path = test_cache_path(); + ensure_model_cached(model_id, &cache_path); let result = LocalModel::new(model_id, cache_path.clone(), false); From 92e38da7b591ea33af871326741dd84bead0e72f Mon Sep 17 00:00:00 2001 From: Don Hardman Date: Mon, 9 Feb 2026 17:14:56 +0700 Subject: [PATCH 2/5] fix(donwload): remove useless download fix for embedding model --- embeddings/README.md | 10 --- embeddings/src/model/ffi_test.rs | 114 ----------------------------- embeddings/src/model/local_test.rs | 30 -------- 3 files changed, 154 deletions(-) diff --git a/embeddings/README.md b/embeddings/README.md index 46bdb7fa..5a0a2c3f 100644 --- a/embeddings/README.md +++ b/embeddings/README.md @@ -13,13 +13,3 @@ cargo build --lib --release ```bash g++ -o test examples/test.cpp -Ltarget/release -lmanticore_knn_embeddings -I. -lpthread -ldl -std=c++17 ``` - -## Testing - -Some integration tests download model files into a cache directory if they are missing. You can -override the cache location with environment variables: - -- `MANTICORE_TEST_CACHE`: preferred cache path for tests -- `MANTICORE_CACHE_PATH`: fallback cache path for tests - -If neither is set, tests use `./.cache/manticore` under the repo. diff --git a/embeddings/src/model/ffi_test.rs b/embeddings/src/model/ffi_test.rs index 7113cec4..bd145cae 100644 --- a/embeddings/src/model/ffi_test.rs +++ b/embeddings/src/model/ffi_test.rs @@ -6,10 +6,6 @@ use std::ptr; #[cfg(test)] mod tests { use super::*; - use crate::model::local::build_model_info; - use std::collections::HashSet; - use std::path::PathBuf; - use std::sync::{Mutex, OnceLock}; // Helper function to create a C string from Rust string fn to_c_string(s: &str) -> CString { @@ -24,97 +20,6 @@ mod tests { } } - fn test_cache_root() -> String { - std::env::var("MANTICORE_TEST_CACHE") - .or_else(|_| std::env::var("MANTICORE_CACHE_PATH")) - .unwrap_or_else(|_| format!("{}/.cache/manticore", env!("CARGO_MANIFEST_DIR"))) - } - - fn ensure_model_cached(model_id: &str, cache_path: &PathBuf) { - static DOWNLOADED: OnceLock>> = OnceLock::new(); - let downloaded = DOWNLOADED.get_or_init(|| Mutex::new(HashSet::new())); - let mut set = downloaded.lock().expect("model cache lock poisoned"); - if set.contains(model_id) { - return; - } - std::fs::create_dir_all(cache_path).expect("failed to create model cache directory"); - build_model_info(cache_path.clone(), model_id, "main") - .expect("failed to download model into cache"); - set.insert(model_id.to_string()); - } - - fn run_concurrent_ffi_embeddings(model_id: &str) { - use std::sync::Arc; - use std::thread; - - let model_id = model_id.to_string(); - let cache_root = test_cache_root(); - let cache_path_buf = PathBuf::from(&cache_root); - ensure_model_cached(&model_id, &cache_path_buf); - - let model_name = to_c_string(&model_id); - let cache_path = to_c_string(&cache_root); - let api_key = to_c_string(""); - - let result = TextModelWrapper::load_model( - model_name.as_ptr(), - model_name.as_bytes().len(), - cache_path.as_ptr(), - cache_path.as_bytes().len(), - api_key.as_ptr(), - api_key.as_bytes().len(), - false, - ); - - if result.model.is_null() { - let error_message = if result.error.is_null() { - "unknown error".to_string() - } else { - unsafe { - CStr::from_ptr(result.error) - .to_str() - .unwrap_or("unknown error") - .to_string() - } - }; - TextModelWrapper::free_model_result(result); - panic!("failed to load model {}: {}", model_id, error_message); - } - - let model_ptr = result.model as usize; - let start = Arc::new(std::sync::Barrier::new(4)); - let handles: Vec<_> = (0..3) - .map(|i| { - let start = Arc::clone(&start); - let model_ptr = model_ptr; - let model_id = model_id.clone(); - thread::spawn(move || { - start.wait(); - let text = format!("Concurrent embedding test {} - {}", model_id, i); - let item = create_string_item(&text); - let items = [item]; - - // Safety: emulate FFI callers that share a model pointer across threads. - let wrapper = unsafe { - std::mem::transmute::<*mut std::ffi::c_void, TextModelWrapper>( - model_ptr as *mut std::ffi::c_void, - ) - }; - let vec_result = - TextModelWrapper::make_vect_embeddings(&wrapper, items.as_ptr(), 1); - TextModelWrapper::free_vec_result(vec_result); - }) - }) - .collect(); - - start.wait(); - for handle in handles { - handle.join().unwrap(); - } - - TextModelWrapper::free_model_result(result); - } - #[test] fn test_text_model_result_structure() { // Test that TextModelResult has the expected structure @@ -462,23 +367,4 @@ mod tests { assert_eq!(options2.api_key, Some("sk-test456".to_string())); assert_eq!(options2.use_gpu, None); } - - #[test] - fn test_concurrent_qwen_embeddings_via_ffi() { - run_concurrent_ffi_embeddings("Qwen/Qwen3-Embedding-0.6B"); - } - - #[test] - fn test_concurrent_other_models_via_ffi() { - let model_ids = [ - "sentence-transformers/all-MiniLM-L6-v2", - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "Locutusque/TinyMistral-248M-v2", - "h2oai/embeddinggemma-300m", - ]; - - for model_id in model_ids { - run_concurrent_ffi_embeddings(model_id); - } - } } diff --git a/embeddings/src/model/local_test.rs b/embeddings/src/model/local_test.rs index 5d54c2e5..558743c8 100644 --- a/embeddings/src/model/local_test.rs +++ b/embeddings/src/model/local_test.rs @@ -6,9 +6,7 @@ mod tests { use crate::model::TextModel; use crate::utils::{get_hidden_size, get_max_input_length}; use approx::assert_abs_diff_eq; - use std::collections::HashSet; use std::path::PathBuf; - use std::sync::{Mutex, OnceLock}; fn check_embedding_properties(embedding: &[f32], expected_len: usize) { assert_eq!(embedding.len(), expected_len); @@ -16,26 +14,7 @@ mod tests { assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-6); } - fn ensure_model_cached(model_id: &str, cache_path: &PathBuf) { - static DOWNLOADED: OnceLock>> = OnceLock::new(); - let downloaded = DOWNLOADED.get_or_init(|| Mutex::new(HashSet::new())); - let mut set = downloaded.lock().expect("model cache lock poisoned"); - if set.contains(model_id) { - return; - } - std::fs::create_dir_all(cache_path).expect("failed to create model cache directory"); - build_model_info(cache_path.clone(), model_id, "main") - .expect("failed to download model into cache"); - set.insert(model_id.to_string()); - } - fn test_cache_path() -> PathBuf { - if let Ok(path) = std::env::var("MANTICORE_TEST_CACHE") { - return PathBuf::from(path); - } - if let Ok(path) = std::env::var("MANTICORE_CACHE_PATH") { - return PathBuf::from(path); - } PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(".cache/manticore") } @@ -346,7 +325,6 @@ mod tests { fn test_all_minilm_l6_v2() { let model_id = "sentence-transformers/all-MiniLM-L6-v2"; let cache_path = test_cache_path(); - ensure_model_cached(model_id, &cache_path); let test_sentences = [ "This is a test sentence.", @@ -365,7 +343,6 @@ mod tests { fn test_embedding_consistency() { let model_id = "sentence-transformers/all-MiniLM-L6-v2"; let cache_path = test_cache_path(); - ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path, false).unwrap(); let sentence = &["This is a test sentence."]; @@ -381,7 +358,6 @@ mod tests { fn test_hidden_size() { let model_id = "sentence-transformers/all-MiniLM-L6-v2"; let cache_path = test_cache_path(); - ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path, false).unwrap(); assert_eq!(local_model.get_hidden_size(), 384); } @@ -390,7 +366,6 @@ mod tests { fn test_max_input_len() { let model_id = "sentence-transformers/all-MiniLM-L6-v2"; let cache_path = test_cache_path(); - ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path, false).unwrap(); assert_eq!(local_model.get_max_input_len(), 512); } @@ -400,7 +375,6 @@ mod tests { // Integration test for Qwen embedding models let model_id = "Qwen/Qwen3-Embedding-0.6B"; let cache_path = test_cache_path(); - ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path.clone(), false) .expect("Qwen model should load successfully"); @@ -421,7 +395,6 @@ mod tests { // Integration test for Llama-based embedding models. let model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"; let cache_path = test_cache_path(); - ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path.clone(), false).expect("Llama model should load"); @@ -437,7 +410,6 @@ mod tests { // Integration test for Mistral-based embedding models. let model_id = "Locutusque/TinyMistral-248M-v2"; let cache_path = test_cache_path(); - ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path.clone(), false) .expect("Mistral model should load"); @@ -452,7 +424,6 @@ mod tests { // Integration test for Gemma-based embedding models. let model_id = "h2oai/embeddinggemma-300m"; let cache_path = test_cache_path(); - ensure_model_cached(model_id, &cache_path); let local_model = LocalModel::new(model_id, cache_path.clone(), false).expect("Gemma model should load"); @@ -467,7 +438,6 @@ mod tests { // Test batch processing with Qwen model let model_id = "Qwen/Qwen3-Embedding-0.6B"; let cache_path = test_cache_path(); - ensure_model_cached(model_id, &cache_path); let result = LocalModel::new(model_id, cache_path.clone(), false); From 4886b5a2260e16756f8d9cf80947dac48b4a960d Mon Sep 17 00:00:00 2001 From: Don Hardman Date: Mon, 9 Feb 2026 17:45:41 +0700 Subject: [PATCH 3/5] fix(tests): add concurrency tests --- embeddings/src/model/ffi_test.rs | 74 +++++++++++++++++++++++++++++ embeddings/src/model/local.rs | 75 +++++++++++++++++++++++++++++- embeddings/src/model/local_test.rs | 31 +++++++++++- 3 files changed, 178 insertions(+), 2 deletions(-) diff --git a/embeddings/src/model/ffi_test.rs b/embeddings/src/model/ffi_test.rs index bd145cae..4df125d7 100644 --- a/embeddings/src/model/ffi_test.rs +++ b/embeddings/src/model/ffi_test.rs @@ -20,6 +20,75 @@ mod tests { } } + fn run_concurrent_ffi_embeddings(model_id: &str) { + use std::sync::{Arc, Barrier}; + use std::thread; + + let model_name = to_c_string(model_id); + let cache_path = to_c_string(""); + let api_key = to_c_string(""); + + let result = TextModelWrapper::load_model( + model_name.as_ptr(), + model_name.as_bytes().len(), + cache_path.as_ptr(), + cache_path.as_bytes().len(), + api_key.as_ptr(), + api_key.as_bytes().len(), + false, + ); + + if result.model.is_null() { + let error_message = if result.error.is_null() { + "unknown error".to_string() + } else { + unsafe { + CStr::from_ptr(result.error) + .to_str() + .unwrap_or("unknown error") + .to_string() + } + }; + TextModelWrapper::free_model_result(result); + panic!("failed to load model {}: {}", model_id, error_message); + } + + let model_ptr = result.model as usize; + let start = Arc::new(Barrier::new(4)); + let handles: Vec<_> = (0..3) + .map(|i| { + let start = Arc::clone(&start); + let model_ptr = model_ptr; + let model_id = model_id.to_string(); + thread::spawn(move || { + start.wait(); + let text = format!("Concurrent embedding test {} - {}", model_id, i); + let item = create_string_item(&text); + let items = [item]; + + // Safety: emulate FFI callers that share a model pointer across threads. + let wrapper = unsafe { + std::mem::transmute::<*mut std::ffi::c_void, TextModelWrapper>( + model_ptr as *mut std::ffi::c_void, + ) + }; + let vec_result = + TextModelWrapper::make_vect_embeddings(&wrapper, items.as_ptr(), 1); + assert!(vec_result.error.is_null()); + assert_eq!(vec_result.len, 1); + TextModelWrapper::free_vec_result(vec_result); + }) + }) + .collect(); + + start.wait(); + for handle in handles { + handle.join().unwrap(); + } + + TextModelWrapper::free_model_result(result); + } + #[test] fn test_text_model_result_structure() { // Test that TextModelResult has the expected structure @@ -367,4 +436,9 @@ mod tests { assert_eq!(options2.api_key, Some("sk-test456".to_string())); assert_eq!(options2.use_gpu, None); } + + #[test] + fn test_concurrent_qwen_embeddings_via_ffi() { + run_concurrent_ffi_embeddings("Qwen/Qwen3-Embedding-0.6B"); + } } diff --git a/embeddings/src/model/local.rs b/embeddings/src/model/local.rs index 14123105..b61f75a5 100644 --- a/embeddings/src/model/local.rs +++ b/embeddings/src/model/local.rs @@ -18,9 +18,10 @@ use candle_transformers::models::qwen3::{Config as QwenConfig, Model as QwenMode use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use serde_json::Value; use std::cell::RefCell; +use std::collections::HashMap; use std::error::Error; use std::path::PathBuf; -use std::sync::Mutex; +use std::sync::{Arc, Mutex, OnceLock}; use tokenizers::Tokenizer; /// Model architecture type - determines pooling strategy @@ -63,12 +64,84 @@ pub struct LocalModelInfo { pub weights_paths: Vec, } +fn model_download_lock(model_id: &str) -> Arc> { + static LOCKS: OnceLock>>>> = OnceLock::new(); + let locks = LOCKS.get_or_init(|| Mutex::new(HashMap::new())); + let mut map = locks.lock().unwrap_or_else(|e| e.into_inner()); + map.entry(model_id.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() +} + +#[cfg(test)] +struct DownloadTrackerGuard { + model_id: String, +} + +#[cfg(test)] +fn download_tracker_state() -> &'static Mutex> { + static TRACKER: OnceLock>> = OnceLock::new(); + TRACKER.get_or_init(|| Mutex::new(HashMap::new())) +} + +#[cfg(test)] +fn download_tracker_enter(model_id: &str) -> DownloadTrackerGuard { + let mut map = download_tracker_state() + .lock() + .unwrap_or_else(|e| e.into_inner()); + let entry = map.entry(model_id.to_string()).or_insert((0, 0)); + entry.0 += 1; + if entry.0 > entry.1 { + entry.1 = entry.0; + } + DownloadTrackerGuard { + model_id: model_id.to_string(), + } +} + +#[cfg(test)] +impl Drop for DownloadTrackerGuard { + fn drop(&mut self) { + let mut map = download_tracker_state() + .lock() + .unwrap_or_else(|e| e.into_inner()); + if let Some(entry) = map.get_mut(&self.model_id) { + if entry.0 > 0 { + entry.0 -= 1; + } + } + } +} + +#[cfg(test)] +pub(crate) fn reset_download_tracker(model_id: &str) { + let mut map = download_tracker_state() + .lock() + .unwrap_or_else(|e| e.into_inner()); + map.insert(model_id.to_string(), (0, 0)); +} + +#[cfg(test)] +pub(crate) fn download_max_for(model_id: &str) -> usize { + let map = download_tracker_state() + .lock() + .unwrap_or_else(|e| e.into_inner()); + map.get(model_id).map(|entry| entry.1).unwrap_or(0) +} + /// Download and cache model files from HuggingFace pub fn build_model_info( cache_path: PathBuf, model_id: &str, revision: &str, ) -> Result> { + let download_lock = model_download_lock(model_id); + let _download_guard = download_lock + .lock() + .map_err(|_| LibError::ModelLoadFailed)?; + #[cfg(test)] + let _download_tracker = download_tracker_enter(model_id); + let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string()); let api = ApiBuilder::new() .with_cache_dir(cache_path) diff --git a/embeddings/src/model/local_test.rs b/embeddings/src/model/local_test.rs index 558743c8..f27fa948 100644 --- a/embeddings/src/model/local_test.rs +++ b/embeddings/src/model/local_test.rs @@ -1,4 +1,4 @@ -use super::local::{build_model_info, LocalModel}; +use super::local::{build_model_info, download_max_for, reset_download_tracker, LocalModel}; #[cfg(test)] mod tests { @@ -7,6 +7,8 @@ mod tests { use crate::utils::{get_hidden_size, get_max_input_length}; use approx::assert_abs_diff_eq; use std::path::PathBuf; + use std::sync::{Arc, Barrier}; + use std::thread; fn check_embedding_properties(embedding: &[f32], expected_len: usize) { assert_eq!(embedding.len(), expected_len); @@ -105,6 +107,33 @@ mod tests { } } + #[test] + fn test_concurrent_model_init_serialized() { + let cache_path = test_cache_path(); + let model_id = "Qwen/Qwen3-Embedding-0.6B"; + reset_download_tracker(model_id); + + let start = Arc::new(Barrier::new(3)); + let mut handles = Vec::new(); + + for _ in 0..2 { + let start = Arc::clone(&start); + let cache_path = cache_path.clone(); + let model_id = model_id.to_string(); + handles.push(thread::spawn(move || { + start.wait(); + build_model_info(cache_path, &model_id, "main").expect("model init failed"); + })); + } + + start.wait(); + for handle in handles { + handle.join().expect("model init thread panicked"); + } + + assert!(download_max_for(model_id) <= 1); + } + #[test] fn test_model_id_variations() { let cache_path = PathBuf::from("/tmp/test_cache"); From c1df774558662035ae97e2145156c0c0c23d8c84 Mon Sep 17 00:00:00 2001 From: Don Hardman Date: Mon, 9 Feb 2026 18:43:29 +0700 Subject: [PATCH 4/5] fix(causal): get rid of mutex and use separate cache instead --- embeddings/Cargo.lock | 6 +-- embeddings/Cargo.toml | 6 +-- embeddings/src/model/local.rs | 69 ++++++++++++++++------------------- 3 files changed, 37 insertions(+), 44 deletions(-) diff --git a/embeddings/Cargo.lock b/embeddings/Cargo.lock index d08e2944..15a230b6 100644 --- a/embeddings/Cargo.lock +++ b/embeddings/Cargo.lock @@ -191,7 +191,7 @@ checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "candle-core" version = "0.9.2" -source = "git+https://github.com/manticoresoftware/candle.git?rev=fe707e10fb22599f8124de632d605a68266ab8e8#fe707e10fb22599f8124de632d605a68266ab8e8" +source = "git+https://github.com/manticoresoftware/candle.git?rev=196118024c1377f5fd132ca0db038d16901cd71f#196118024c1377f5fd132ca0db038d16901cd71f" dependencies = [ "byteorder", "float8", @@ -213,7 +213,7 @@ dependencies = [ [[package]] name = "candle-nn" version = "0.9.2" -source = "git+https://github.com/manticoresoftware/candle.git?rev=fe707e10fb22599f8124de632d605a68266ab8e8#fe707e10fb22599f8124de632d605a68266ab8e8" +source = "git+https://github.com/manticoresoftware/candle.git?rev=196118024c1377f5fd132ca0db038d16901cd71f#196118024c1377f5fd132ca0db038d16901cd71f" dependencies = [ "candle-core", "half", @@ -228,7 +228,7 @@ dependencies = [ [[package]] name = "candle-transformers" version = "0.9.2" -source = "git+https://github.com/manticoresoftware/candle.git?rev=fe707e10fb22599f8124de632d605a68266ab8e8#fe707e10fb22599f8124de632d605a68266ab8e8" +source = "git+https://github.com/manticoresoftware/candle.git?rev=196118024c1377f5fd132ca0db038d16901cd71f#196118024c1377f5fd132ca0db038d16901cd71f" dependencies = [ "byteorder", "candle-core", diff --git a/embeddings/Cargo.toml b/embeddings/Cargo.toml index 26364a8e..89b6b94d 100644 --- a/embeddings/Cargo.toml +++ b/embeddings/Cargo.toml @@ -40,6 +40,6 @@ approx = "0.5.1" # candle-nn = { path = "../../candle/candle-nn" } # candle-transformers = { path = "../../candle/candle-transformers" } -candle-core = { git = "https://github.com/manticoresoftware/candle.git", rev = "fe707e10fb22599f8124de632d605a68266ab8e8" } -candle-nn = { git = "https://github.com/manticoresoftware/candle.git", rev = "fe707e10fb22599f8124de632d605a68266ab8e8" } -candle-transformers = { git = "https://github.com/manticoresoftware/candle.git", rev = "fe707e10fb22599f8124de632d605a68266ab8e8" } +candle-core = { git = "https://github.com/manticoresoftware/candle.git", rev = "196118024c1377f5fd132ca0db038d16901cd71f" } +candle-nn = { git = "https://github.com/manticoresoftware/candle.git", rev = "196118024c1377f5fd132ca0db038d16901cd71f" } +candle-transformers = { git = "https://github.com/manticoresoftware/candle.git", rev = "196118024c1377f5fd132ca0db038d16901cd71f" } diff --git a/embeddings/src/model/local.rs b/embeddings/src/model/local.rs index b61f75a5..2b823178 100644 --- a/embeddings/src/model/local.rs +++ b/embeddings/src/model/local.rs @@ -8,16 +8,17 @@ use candle_nn::VarBuilder; use candle_transformers::models::bert::{ BertModel, Config as BertConfig, HiddenAct, DTYPE as BERT_DTYPE, }; -use candle_transformers::models::gemma::{Config as GemmaConfig, Model as GemmaModel}; +use candle_transformers::models::gemma::{Config as GemmaConfig, GemmaCache, Model as GemmaModel}; use candle_transformers::models::llama::{ Cache as LlamaCache, Config as LlamaConfig, Llama as LlamaModel, LlamaConfig as LlamaConfigSerde, }; -use candle_transformers::models::mistral::{Config as MistralConfig, Model as MistralModel}; -use candle_transformers::models::qwen3::{Config as QwenConfig, Model as QwenModel}; +use candle_transformers::models::mistral::{ + Config as MistralConfig, MistralCache, Model as MistralModel, +}; +use candle_transformers::models::qwen3::{Config as QwenConfig, Model as QwenModel, Qwen3Cache}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use serde_json::Value; -use std::cell::RefCell; use std::collections::HashMap; use std::error::Error; use std::path::PathBuf; @@ -279,18 +280,21 @@ impl BertEmbeddingModel { pub enum CausalEmbeddingKind { Qwen { - model: RefCell, + model: Arc, + config: QwenConfig, }, Llama { - model: RefCell, + model: Arc, config: LlamaConfig, dtype: DType, }, Mistral { - model: RefCell, + model: Arc, + config: MistralConfig, }, Gemma { - model: RefCell, + model: Arc, + config: GemmaConfig, }, } @@ -301,7 +305,6 @@ pub struct CausalEmbeddingModel { max_input_len: usize, hidden_size: usize, device: Device, - predict_lock: Mutex<()>, } impl CausalEmbeddingModel { @@ -343,7 +346,8 @@ impl CausalEmbeddingModel { serde_json::from_str(&config).map_err(|_| LibError::ModelConfigParseFailed)?; let model = QwenModel::new(&cfg, vb).map_err(|_| LibError::ModelLoadFailed)?; CausalEmbeddingKind::Qwen { - model: RefCell::new(model), + model: Arc::new(model), + config: cfg, } } "llama" => { @@ -352,7 +356,7 @@ impl CausalEmbeddingModel { let cfg = cfg_serde.into_config(false); let model = LlamaModel::load(vb, &cfg).map_err(|_| LibError::ModelLoadFailed)?; CausalEmbeddingKind::Llama { - model: RefCell::new(model), + model: Arc::new(model), config: cfg, dtype, } @@ -362,7 +366,8 @@ impl CausalEmbeddingModel { serde_json::from_str(&config).map_err(|_| LibError::ModelConfigParseFailed)?; let model = MistralModel::new(&cfg, vb).map_err(|_| LibError::ModelLoadFailed)?; CausalEmbeddingKind::Mistral { - model: RefCell::new(model), + model: Arc::new(model), + config: cfg, } } "gemma" | "gemma2" | "gemma3" | "gemma3_text" => { @@ -371,7 +376,8 @@ impl CausalEmbeddingModel { let model = GemmaModel::new(false, &cfg, vb).map_err(|_| LibError::ModelLoadFailed)?; CausalEmbeddingKind::Gemma { - model: RefCell::new(model), + model: Arc::new(model), + config: cfg, } } _ => return Err(Box::new(LibError::ModelLoadFailed)), @@ -387,7 +393,6 @@ impl CausalEmbeddingModel { max_input_len, hidden_size, device, - predict_lock: Mutex::new(()), }) } } @@ -451,15 +456,6 @@ impl LocalModel { impl TextModel for LocalModel { fn predict(&self, texts: &[&str]) -> Result>, Box> { - let _predict_guard = match self { - LocalModel::Causal(m) => Some( - m.predict_lock - .lock() - .map_err(|_| LibError::ModelLoadFailed)?, - ), - LocalModel::Bert(_) => None, - }; - let (device, max_input_len) = match self { LocalModel::Bert(m) => (m.device.clone(), m.max_input_len), LocalModel::Causal(m) => (m.device.clone(), m.max_input_len), @@ -499,10 +495,9 @@ impl TextModel for LocalModel { divided.to_dtype(DType::F32)? } LocalModel::Causal(m) => match &m.kind { - CausalEmbeddingKind::Qwen { model } => { - let mut model = model.borrow_mut(); - model.clear_kv_cache(); - let emb = model.forward(&token_ids, 0)?; + CausalEmbeddingKind::Qwen { model, config } => { + let mut cache = Qwen3Cache::new(config.num_hidden_layers); + let emb = model.forward_with_cache(&token_ids, 0, &mut cache)?; let (_, n_tokens, _) = emb.dims3()?; let summed = emb.sum(1)?.to_dtype(DType::F32)?; let divisor = Tensor::new(n_tokens as f32, &device)?; @@ -515,29 +510,27 @@ impl TextModel for LocalModel { dtype, } => { let mut cache = LlamaCache::new(false, *dtype, config, &device)?; - let emb = model - .borrow() - .forward_hidden_states(&token_ids, 0, &mut cache)?; + let emb = model.forward_hidden_states(&token_ids, 0, &mut cache)?; let (_, n_tokens, _) = emb.dims3()?; let summed = emb.sum(1)?.to_dtype(DType::F32)?; let divisor = Tensor::new(n_tokens as f32, &device)?; let divided = summed.broadcast_div(&divisor)?; divided.to_dtype(DType::F32)? } - CausalEmbeddingKind::Mistral { model } => { - let mut model = model.borrow_mut(); - model.clear_kv_cache(); - let emb = model.forward_hidden_states(&token_ids, 0)?; + CausalEmbeddingKind::Mistral { model, config } => { + let mut cache = MistralCache::new(config.num_hidden_layers); + let emb = model + .forward_hidden_states_with_cache(&token_ids, 0, &mut cache)?; let (_, n_tokens, _) = emb.dims3()?; let summed = emb.sum(1)?.to_dtype(DType::F32)?; let divisor = Tensor::new(n_tokens as f32, &device)?; let divided = summed.broadcast_div(&divisor)?; divided.to_dtype(DType::F32)? } - CausalEmbeddingKind::Gemma { model } => { - let mut model = model.borrow_mut(); - model.clear_kv_cache(); - let emb = model.forward_hidden_states(&token_ids, 0)?; + CausalEmbeddingKind::Gemma { model, config } => { + let mut cache = GemmaCache::new(config.num_hidden_layers); + let emb = model + .forward_hidden_states_with_cache(&token_ids, 0, &mut cache)?; let (_, n_tokens, _) = emb.dims3()?; let summed = emb.sum(1)?.to_dtype(DType::F32)?; let divisor = Tensor::new(n_tokens as f32, &device)?; From b8efe7ed35becc1858fab41e4685b8613bbe2d92 Mon Sep 17 00:00:00 2001 From: Don Hardman Date: Mon, 9 Feb 2026 18:45:11 +0700 Subject: [PATCH 5/5] fix(tests): get rid of uselss line in test --- embeddings/src/model/ffi_test.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/embeddings/src/model/ffi_test.rs b/embeddings/src/model/ffi_test.rs index 4df125d7..5641a687 100644 --- a/embeddings/src/model/ffi_test.rs +++ b/embeddings/src/model/ffi_test.rs @@ -58,7 +58,6 @@ mod tests { let handles: Vec<_> = (0..3) .map(|i| { let start = Arc::clone(&start); - let model_ptr = model_ptr; let model_id = model_id.to_string(); thread::spawn(move || { start.wait();