From e2a44057477a8371cdfa89e392a586952d59808b Mon Sep 17 00:00:00 2001 From: Don Hardman Date: Fri, 6 Feb 2026 13:28:41 +0700 Subject: [PATCH 1/2] feat(local): support sharded safetensors --- embeddings/src/model/local.rs | 54 +++++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/embeddings/src/model/local.rs b/embeddings/src/model/local.rs index ce04c02c..e37440ab 100644 --- a/embeddings/src/model/local.rs +++ b/embeddings/src/model/local.rs @@ -59,7 +59,7 @@ fn model_type_from_config(config: &str) -> Option { pub struct LocalModelInfo { pub config_path: PathBuf, pub tokenizer_path: PathBuf, - pub weights_path: PathBuf, + pub weights_paths: Vec, } /// Download and cache model files from HuggingFace @@ -81,14 +81,48 @@ pub fn build_model_info( let tokenizer_path = api .get("tokenizer.json") .map_err(|_| LibError::ModelTokenizerFetchFailed)?; - let weights_path = api - .get("model.safetensors") - .map_err(|_| LibError::ModelWeightsFetchFailed)?; + let weights_paths = match api.get("model.safetensors") { + Ok(path) => vec![path], + Err(_) => { + // Support sharded safetensors via model.safetensors.index.json + let index_path = api + .get("model.safetensors.index.json") + .map_err(|_| LibError::ModelWeightsFetchFailed)?; + let index_contents = std::fs::read_to_string(&index_path) + .map_err(|_| LibError::ModelWeightsFetchFailed)?; + let index_json: Value = serde_json::from_str(&index_contents) + .map_err(|_| LibError::ModelWeightsFetchFailed)?; + let weight_map = index_json + .get("weight_map") + .and_then(Value::as_object) + .ok_or_else(|| LibError::ModelWeightsFetchFailed)?; + + let mut shards: Vec = weight_map + .values() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect(); + shards.sort(); + shards.dedup(); + + if shards.is_empty() { + return Err(Box::new(LibError::ModelWeightsFetchFailed)); + } + + let mut paths = Vec::with_capacity(shards.len()); + for shard in shards { + let p = api + .get(&shard) + .map_err(|_| LibError::ModelWeightsFetchFailed)?; + paths.push(p); + } + paths + } + }; Ok(LocalModelInfo { config_path, tokenizer_path, - weights_path, + weights_paths, }) } @@ -153,7 +187,7 @@ impl BertEmbeddingModel { let _ = tokenizer.with_truncation(None); let vb = unsafe { - VarBuilder::from_mmaped_safetensors(&[model_info.weights_path], BERT_DTYPE, &device) + VarBuilder::from_mmaped_safetensors(&model_info.weights_paths, BERT_DTYPE, &device) .map_err(|_| LibError::ModelWeightsLoadFailed)? }; @@ -216,12 +250,8 @@ impl CausalEmbeddingModel { let dtype = dtype_from_config(&config, &device); let vb = unsafe { - VarBuilder::from_mmaped_safetensors( - std::slice::from_ref(&model_info.weights_path), - dtype, - &device, - ) - .map_err(|_| LibError::ModelWeightsLoadFailed)? + VarBuilder::from_mmaped_safetensors(&model_info.weights_paths, dtype, &device) + .map_err(|_| LibError::ModelWeightsLoadFailed)? }; let vb = if vb.contains_tensor("model.embed_tokens.weight") { From 58cc94b0793b0c6e9e63092e647b1edb0cf50b1b Mon Sep 17 00:00:00 2001 From: Don Hardman Date: Fri, 6 Feb 2026 14:45:41 +0700 Subject: [PATCH 2/2] fix(model): remove unnecessary closure in error handling --- embeddings/src/model/local.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embeddings/src/model/local.rs b/embeddings/src/model/local.rs index e37440ab..7f4125f7 100644 --- a/embeddings/src/model/local.rs +++ b/embeddings/src/model/local.rs @@ -95,7 +95,7 @@ pub fn build_model_info( let weight_map = index_json .get("weight_map") .and_then(Value::as_object) - .ok_or_else(|| LibError::ModelWeightsFetchFailed)?; + .ok_or(LibError::ModelWeightsFetchFailed)?; let mut shards: Vec = weight_map .values()