Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions embeddings/src/model/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fn model_type_from_config(config: &str) -> Option<String> {
pub struct LocalModelInfo {
pub config_path: PathBuf,
pub tokenizer_path: PathBuf,
pub weights_path: PathBuf,
pub weights_paths: Vec<PathBuf>,
}

/// Download and cache model files from HuggingFace
Expand All @@ -81,14 +81,48 @@ pub fn build_model_info(
let tokenizer_path = api
.get("tokenizer.json")
.map_err(|_| LibError::ModelTokenizerFetchFailed)?;
let weights_path = api
.get("model.safetensors")
.map_err(|_| LibError::ModelWeightsFetchFailed)?;
let weights_paths = match api.get("model.safetensors") {
Ok(path) => vec![path],
Err(_) => {
// Support sharded safetensors via model.safetensors.index.json
let index_path = api
.get("model.safetensors.index.json")
.map_err(|_| LibError::ModelWeightsFetchFailed)?;
let index_contents = std::fs::read_to_string(&index_path)
.map_err(|_| LibError::ModelWeightsFetchFailed)?;
let index_json: Value = serde_json::from_str(&index_contents)
.map_err(|_| LibError::ModelWeightsFetchFailed)?;
let weight_map = index_json
.get("weight_map")
.and_then(Value::as_object)
.ok_or(LibError::ModelWeightsFetchFailed)?;

let mut shards: Vec<String> = weight_map
.values()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
shards.sort();
shards.dedup();

if shards.is_empty() {
return Err(Box::new(LibError::ModelWeightsFetchFailed));
}

let mut paths = Vec::with_capacity(shards.len());
for shard in shards {
let p = api
.get(&shard)
.map_err(|_| LibError::ModelWeightsFetchFailed)?;
paths.push(p);
}
paths
}
};

Ok(LocalModelInfo {
config_path,
tokenizer_path,
weights_path,
weights_paths,
})
}

Expand Down Expand Up @@ -153,7 +187,7 @@ impl BertEmbeddingModel {
let _ = tokenizer.with_truncation(None);

let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_info.weights_path], BERT_DTYPE, &device)
VarBuilder::from_mmaped_safetensors(&model_info.weights_paths, BERT_DTYPE, &device)
.map_err(|_| LibError::ModelWeightsLoadFailed)?
};

Expand Down Expand Up @@ -216,12 +250,8 @@ impl CausalEmbeddingModel {

let dtype = dtype_from_config(&config, &device);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
std::slice::from_ref(&model_info.weights_path),
dtype,
&device,
)
.map_err(|_| LibError::ModelWeightsLoadFailed)?
VarBuilder::from_mmaped_safetensors(&model_info.weights_paths, dtype, &device)
.map_err(|_| LibError::ModelWeightsLoadFailed)?
};

let vb = if vb.contains_tensor("model.embed_tokens.weight") {
Expand Down
Loading