diff --git a/embeddings/Cargo.lock b/embeddings/Cargo.lock index d08e2944..15a230b6 100644 --- a/embeddings/Cargo.lock +++ b/embeddings/Cargo.lock @@ -191,7 +191,7 @@ checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "candle-core" version = "0.9.2" -source = "git+https://github.com/manticoresoftware/candle.git?rev=fe707e10fb22599f8124de632d605a68266ab8e8#fe707e10fb22599f8124de632d605a68266ab8e8" +source = "git+https://github.com/manticoresoftware/candle.git?rev=196118024c1377f5fd132ca0db038d16901cd71f#196118024c1377f5fd132ca0db038d16901cd71f" dependencies = [ "byteorder", "float8", @@ -213,7 +213,7 @@ dependencies = [ [[package]] name = "candle-nn" version = "0.9.2" -source = "git+https://github.com/manticoresoftware/candle.git?rev=fe707e10fb22599f8124de632d605a68266ab8e8#fe707e10fb22599f8124de632d605a68266ab8e8" +source = "git+https://github.com/manticoresoftware/candle.git?rev=196118024c1377f5fd132ca0db038d16901cd71f#196118024c1377f5fd132ca0db038d16901cd71f" dependencies = [ "candle-core", "half", @@ -228,7 +228,7 @@ dependencies = [ [[package]] name = "candle-transformers" version = "0.9.2" -source = "git+https://github.com/manticoresoftware/candle.git?rev=fe707e10fb22599f8124de632d605a68266ab8e8#fe707e10fb22599f8124de632d605a68266ab8e8" +source = "git+https://github.com/manticoresoftware/candle.git?rev=196118024c1377f5fd132ca0db038d16901cd71f#196118024c1377f5fd132ca0db038d16901cd71f" dependencies = [ "byteorder", "candle-core", diff --git a/embeddings/Cargo.toml b/embeddings/Cargo.toml index 26364a8e..89b6b94d 100644 --- a/embeddings/Cargo.toml +++ b/embeddings/Cargo.toml @@ -40,6 +40,6 @@ approx = "0.5.1" # candle-nn = { path = "../../candle/candle-nn" } # candle-transformers = { path = "../../candle/candle-transformers" } -candle-core = { git = "https://github.com/manticoresoftware/candle.git", rev = "fe707e10fb22599f8124de632d605a68266ab8e8" } -candle-nn = { git = "https://github.com/manticoresoftware/candle.git", rev = "fe707e10fb22599f8124de632d605a68266ab8e8" } -candle-transformers = { git = "https://github.com/manticoresoftware/candle.git", rev = "fe707e10fb22599f8124de632d605a68266ab8e8" } +candle-core = { git = "https://github.com/manticoresoftware/candle.git", rev = "196118024c1377f5fd132ca0db038d16901cd71f" } +candle-nn = { git = "https://github.com/manticoresoftware/candle.git", rev = "196118024c1377f5fd132ca0db038d16901cd71f" } +candle-transformers = { git = "https://github.com/manticoresoftware/candle.git", rev = "196118024c1377f5fd132ca0db038d16901cd71f" } diff --git a/embeddings/README.md b/embeddings/README.md index b280fac1..5a0a2c3f 100644 --- a/embeddings/README.md +++ b/embeddings/README.md @@ -13,4 +13,3 @@ cargo build --lib --release ```bash g++ -o test examples/test.cpp -Ltarget/release -lmanticore_knn_embeddings -I. -lpthread -ldl -std=c++17 ``` - diff --git a/embeddings/src/model/ffi_test.rs b/embeddings/src/model/ffi_test.rs index bd145cae..5641a687 100644 --- a/embeddings/src/model/ffi_test.rs +++ b/embeddings/src/model/ffi_test.rs @@ -20,6 +20,74 @@ mod tests { } } + fn run_concurrent_ffi_embeddings(model_id: &str) { + use std::sync::{Arc, Barrier}; + use std::thread; + + let model_name = to_c_string(model_id); + let cache_path = to_c_string(""); + let api_key = to_c_string(""); + + let result = TextModelWrapper::load_model( + model_name.as_ptr(), + model_name.as_bytes().len(), + cache_path.as_ptr(), + cache_path.as_bytes().len(), + api_key.as_ptr(), + api_key.as_bytes().len(), + false, + ); + + if result.model.is_null() { + let error_message = if result.error.is_null() { + "unknown error".to_string() + } else { + unsafe { + CStr::from_ptr(result.error) + .to_str() + .unwrap_or("unknown error") + .to_string() + } + }; + TextModelWrapper::free_model_result(result); + panic!("failed to load model {}: {}", model_id, error_message); + } + + let model_ptr = result.model as usize; + let start = Arc::new(Barrier::new(4)); + let handles: Vec<_> = (0..3) + .map(|i| { + let start = Arc::clone(&start); + let model_id = model_id.to_string(); + thread::spawn(move || { + start.wait(); + let text = format!("Concurrent embedding test {} - {}", model_id, i); + let item = create_string_item(&text); + let items = [item]; + + // Safety: emulate FFI callers that share a model pointer across threads. + let wrapper = unsafe { + std::mem::transmute::<*mut std::ffi::c_void, TextModelWrapper>( + model_ptr as *mut std::ffi::c_void, + ) + }; + let vec_result = + TextModelWrapper::make_vect_embeddings(&wrapper, items.as_ptr(), 1); + assert!(vec_result.error.is_null()); + assert_eq!(vec_result.len, 1); + TextModelWrapper::free_vec_result(vec_result); + }) + }) + .collect(); + + start.wait(); + for handle in handles { + handle.join().unwrap(); + } + + TextModelWrapper::free_model_result(result); + } + #[test] fn test_text_model_result_structure() { // Test that TextModelResult has the expected structure @@ -367,4 +435,9 @@ mod tests { assert_eq!(options2.api_key, Some("sk-test456".to_string())); assert_eq!(options2.use_gpu, None); } + + #[test] + fn test_concurrent_qwen_embeddings_via_ffi() { + run_concurrent_ffi_embeddings("Qwen/Qwen3-Embedding-0.6B"); + } } diff --git a/embeddings/src/model/local.rs b/embeddings/src/model/local.rs index 7f4125f7..2b823178 100644 --- a/embeddings/src/model/local.rs +++ b/embeddings/src/model/local.rs @@ -8,18 +8,21 @@ use candle_nn::VarBuilder; use candle_transformers::models::bert::{ BertModel, Config as BertConfig, HiddenAct, DTYPE as BERT_DTYPE, }; -use candle_transformers::models::gemma::{Config as GemmaConfig, Model as GemmaModel}; +use candle_transformers::models::gemma::{Config as GemmaConfig, GemmaCache, Model as GemmaModel}; use candle_transformers::models::llama::{ Cache as LlamaCache, Config as LlamaConfig, Llama as LlamaModel, LlamaConfig as LlamaConfigSerde, }; -use candle_transformers::models::mistral::{Config as MistralConfig, Model as MistralModel}; -use candle_transformers::models::qwen3::{Config as QwenConfig, Model as QwenModel}; +use candle_transformers::models::mistral::{ + Config as MistralConfig, MistralCache, Model as MistralModel, +}; +use candle_transformers::models::qwen3::{Config as QwenConfig, Model as QwenModel, Qwen3Cache}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use serde_json::Value; -use std::cell::RefCell; +use std::collections::HashMap; use std::error::Error; use std::path::PathBuf; +use std::sync::{Arc, Mutex, OnceLock}; use tokenizers::Tokenizer; /// Model architecture type - determines pooling strategy @@ -62,12 +65,84 @@ pub struct LocalModelInfo { pub weights_paths: Vec, } +fn model_download_lock(model_id: &str) -> Arc> { + static LOCKS: OnceLock>>>> = OnceLock::new(); + let locks = LOCKS.get_or_init(|| Mutex::new(HashMap::new())); + let mut map = locks.lock().unwrap_or_else(|e| e.into_inner()); + map.entry(model_id.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() +} + +#[cfg(test)] +struct DownloadTrackerGuard { + model_id: String, +} + +#[cfg(test)] +fn download_tracker_state() -> &'static Mutex> { + static TRACKER: OnceLock>> = OnceLock::new(); + TRACKER.get_or_init(|| Mutex::new(HashMap::new())) +} + +#[cfg(test)] +fn download_tracker_enter(model_id: &str) -> DownloadTrackerGuard { + let mut map = download_tracker_state() + .lock() + .unwrap_or_else(|e| e.into_inner()); + let entry = map.entry(model_id.to_string()).or_insert((0, 0)); + entry.0 += 1; + if entry.0 > entry.1 { + entry.1 = entry.0; + } + DownloadTrackerGuard { + model_id: model_id.to_string(), + } +} + +#[cfg(test)] +impl Drop for DownloadTrackerGuard { + fn drop(&mut self) { + let mut map = download_tracker_state() + .lock() + .unwrap_or_else(|e| e.into_inner()); + if let Some(entry) = map.get_mut(&self.model_id) { + if entry.0 > 0 { + entry.0 -= 1; + } + } + } +} + +#[cfg(test)] +pub(crate) fn reset_download_tracker(model_id: &str) { + let mut map = download_tracker_state() + .lock() + .unwrap_or_else(|e| e.into_inner()); + map.insert(model_id.to_string(), (0, 0)); +} + +#[cfg(test)] +pub(crate) fn download_max_for(model_id: &str) -> usize { + let map = download_tracker_state() + .lock() + .unwrap_or_else(|e| e.into_inner()); + map.get(model_id).map(|entry| entry.1).unwrap_or(0) +} + /// Download and cache model files from HuggingFace pub fn build_model_info( cache_path: PathBuf, model_id: &str, revision: &str, ) -> Result> { + let download_lock = model_download_lock(model_id); + let _download_guard = download_lock + .lock() + .map_err(|_| LibError::ModelLoadFailed)?; + #[cfg(test)] + let _download_tracker = download_tracker_enter(model_id); + let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string()); let api = ApiBuilder::new() .with_cache_dir(cache_path) @@ -205,18 +280,21 @@ impl BertEmbeddingModel { pub enum CausalEmbeddingKind { Qwen { - model: RefCell, + model: Arc, + config: QwenConfig, }, Llama { - model: RefCell, + model: Arc, config: LlamaConfig, dtype: DType, }, Mistral { - model: RefCell, + model: Arc, + config: MistralConfig, }, Gemma { - model: RefCell, + model: Arc, + config: GemmaConfig, }, } @@ -268,7 +346,8 @@ impl CausalEmbeddingModel { serde_json::from_str(&config).map_err(|_| LibError::ModelConfigParseFailed)?; let model = QwenModel::new(&cfg, vb).map_err(|_| LibError::ModelLoadFailed)?; CausalEmbeddingKind::Qwen { - model: RefCell::new(model), + model: Arc::new(model), + config: cfg, } } "llama" => { @@ -277,7 +356,7 @@ impl CausalEmbeddingModel { let cfg = cfg_serde.into_config(false); let model = LlamaModel::load(vb, &cfg).map_err(|_| LibError::ModelLoadFailed)?; CausalEmbeddingKind::Llama { - model: RefCell::new(model), + model: Arc::new(model), config: cfg, dtype, } @@ -287,7 +366,8 @@ impl CausalEmbeddingModel { serde_json::from_str(&config).map_err(|_| LibError::ModelConfigParseFailed)?; let model = MistralModel::new(&cfg, vb).map_err(|_| LibError::ModelLoadFailed)?; CausalEmbeddingKind::Mistral { - model: RefCell::new(model), + model: Arc::new(model), + config: cfg, } } "gemma" | "gemma2" | "gemma3" | "gemma3_text" => { @@ -296,7 +376,8 @@ impl CausalEmbeddingModel { let model = GemmaModel::new(false, &cfg, vb).map_err(|_| LibError::ModelLoadFailed)?; CausalEmbeddingKind::Gemma { - model: RefCell::new(model), + model: Arc::new(model), + config: cfg, } } _ => return Err(Box::new(LibError::ModelLoadFailed)), @@ -414,10 +495,9 @@ impl TextModel for LocalModel { divided.to_dtype(DType::F32)? } LocalModel::Causal(m) => match &m.kind { - CausalEmbeddingKind::Qwen { model } => { - let mut model = model.borrow_mut(); - model.clear_kv_cache(); - let emb = model.forward(&token_ids, 0)?; + CausalEmbeddingKind::Qwen { model, config } => { + let mut cache = Qwen3Cache::new(config.num_hidden_layers); + let emb = model.forward_with_cache(&token_ids, 0, &mut cache)?; let (_, n_tokens, _) = emb.dims3()?; let summed = emb.sum(1)?.to_dtype(DType::F32)?; let divisor = Tensor::new(n_tokens as f32, &device)?; @@ -430,29 +510,27 @@ impl TextModel for LocalModel { dtype, } => { let mut cache = LlamaCache::new(false, *dtype, config, &device)?; - let emb = model - .borrow() - .forward_hidden_states(&token_ids, 0, &mut cache)?; + let emb = model.forward_hidden_states(&token_ids, 0, &mut cache)?; let (_, n_tokens, _) = emb.dims3()?; let summed = emb.sum(1)?.to_dtype(DType::F32)?; let divisor = Tensor::new(n_tokens as f32, &device)?; let divided = summed.broadcast_div(&divisor)?; divided.to_dtype(DType::F32)? } - CausalEmbeddingKind::Mistral { model } => { - let mut model = model.borrow_mut(); - model.clear_kv_cache(); - let emb = model.forward_hidden_states(&token_ids, 0)?; + CausalEmbeddingKind::Mistral { model, config } => { + let mut cache = MistralCache::new(config.num_hidden_layers); + let emb = model + .forward_hidden_states_with_cache(&token_ids, 0, &mut cache)?; let (_, n_tokens, _) = emb.dims3()?; let summed = emb.sum(1)?.to_dtype(DType::F32)?; let divisor = Tensor::new(n_tokens as f32, &device)?; let divided = summed.broadcast_div(&divisor)?; divided.to_dtype(DType::F32)? } - CausalEmbeddingKind::Gemma { model } => { - let mut model = model.borrow_mut(); - model.clear_kv_cache(); - let emb = model.forward_hidden_states(&token_ids, 0)?; + CausalEmbeddingKind::Gemma { model, config } => { + let mut cache = GemmaCache::new(config.num_hidden_layers); + let emb = model + .forward_hidden_states_with_cache(&token_ids, 0, &mut cache)?; let (_, n_tokens, _) = emb.dims3()?; let summed = emb.sum(1)?.to_dtype(DType::F32)?; let divisor = Tensor::new(n_tokens as f32, &device)?; diff --git a/embeddings/src/model/local_test.rs b/embeddings/src/model/local_test.rs index 558743c8..f27fa948 100644 --- a/embeddings/src/model/local_test.rs +++ b/embeddings/src/model/local_test.rs @@ -1,4 +1,4 @@ -use super::local::{build_model_info, LocalModel}; +use super::local::{build_model_info, download_max_for, reset_download_tracker, LocalModel}; #[cfg(test)] mod tests { @@ -7,6 +7,8 @@ mod tests { use crate::utils::{get_hidden_size, get_max_input_length}; use approx::assert_abs_diff_eq; use std::path::PathBuf; + use std::sync::{Arc, Barrier}; + use std::thread; fn check_embedding_properties(embedding: &[f32], expected_len: usize) { assert_eq!(embedding.len(), expected_len); @@ -105,6 +107,33 @@ mod tests { } } + #[test] + fn test_concurrent_model_init_serialized() { + let cache_path = test_cache_path(); + let model_id = "Qwen/Qwen3-Embedding-0.6B"; + reset_download_tracker(model_id); + + let start = Arc::new(Barrier::new(3)); + let mut handles = Vec::new(); + + for _ in 0..2 { + let start = Arc::clone(&start); + let cache_path = cache_path.clone(); + let model_id = model_id.to_string(); + handles.push(thread::spawn(move || { + start.wait(); + build_model_info(cache_path, &model_id, "main").expect("model init failed"); + })); + } + + start.wait(); + for handle in handles { + handle.join().expect("model init thread panicked"); + } + + assert!(download_max_for(model_id) <= 1); + } + #[test] fn test_model_id_variations() { let cache_path = PathBuf::from("/tmp/test_cache");