Skip to content

Commit 86594fc

Browse files
llenckAnush008
andauthored
refactor: Turn most panics into recoverable errors (#128)
* for image embeddings, turn some panics into returned errors * turn more panics into recoverable errors * Simplify collection of image embeddings Co-authored-by: Anush <anushshetty90@gmail.com> * Pass errors in sparse text embedding and reranking batches onto the user instead of panicking * fix warnings: only import anyhow::Context if online feature enabled * cargo fmt --------- Co-authored-by: Anush <anushshetty90@gmail.com>
1 parent 49a2185 commit 86594fc

File tree

4 files changed

+101
-90
lines changed

4 files changed

+101
-90
lines changed

src/image_embedding/impl.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ use crate::{
1717
ModelInfo,
1818
};
1919
use anyhow::anyhow;
20+
#[cfg(feature = "online")]
21+
use anyhow::Context;
2022

2123
#[cfg(feature = "online")]
2224
use super::ImageInitOptions;
@@ -52,13 +54,13 @@ impl ImageEmbedding {
5254

5355
let preprocessor_file = model_repo
5456
.get("preprocessor_config.json")
55-
.unwrap_or_else(|_| panic!("Failed to retrieve preprocessor_config.json"));
57+
.context("Failed to retrieve preprocessor_config.json")?;
5658
let preprocessor = Compose::from_file(preprocessor_file)?;
5759

5860
let model_file_name = ImageEmbedding::get_model_info(&model_name).model_file;
5961
let model_file_reference = model_repo
6062
.get(&model_file_name)
61-
.unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name));
63+
.context(format!("Failed to retrieve {}", model_file_name))?;
6264

6365
let session = Session::builder()?
6466
.with_execution_providers(execution_providers)?
@@ -111,8 +113,7 @@ impl ImageEmbedding {
111113
let cache = Cache::new(cache_dir);
112114
let api = ApiBuilder::from_cache(cache)
113115
.with_progress(show_download_progress)
114-
.build()
115-
.unwrap();
116+
.build()?;
116117

117118
let repo = api.model(model.to_string());
118119
Ok(repo)
@@ -189,7 +190,9 @@ impl ImageEmbedding {
189190

190191
Ok(embeddings)
191192
})
192-
.flat_map(|result: Result<Vec<Vec<f32>>, anyhow::Error>| result.unwrap())
193+
.collect::<anyhow::Result<Vec<_>>>()?
194+
.into_iter()
195+
.flatten()
193196
.collect();
194197

195198
Ok(output)

src/reranking/impl.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#[cfg(feature = "online")]
2+
use anyhow::Context;
13
use anyhow::Result;
24
use ort::{
35
session::{builder::GraphOptimizationLevel, Session},
@@ -70,15 +72,16 @@ impl TextRerank {
7072
let model_repo = api.model(model_name.to_string());
7173

7274
let model_file_name = TextRerank::get_model_info(&model_name).model_file;
73-
let model_file_reference = model_repo
74-
.get(&model_file_name)
75-
.unwrap_or_else(|_| panic!("Failed to retrieve model file: {}", model_file_name));
75+
let model_file_reference = model_repo.get(&model_file_name).context(format!(
76+
"Failed to retrieve model file: {}",
77+
model_file_name
78+
))?;
7679
let additional_files = TextRerank::get_model_info(&model_name).additional_files;
7780
for additional_file in additional_files {
78-
let _additional_file_reference =
79-
model_repo.get(&additional_file).unwrap_or_else(|_| {
80-
panic!("Failed to retrieve additional file: {}", additional_file)
81-
});
81+
let _additional_file_reference = model_repo.get(&additional_file).context(format!(
82+
"Failed to retrieve additional file: {}",
83+
additional_file
84+
))?;
8285
}
8386

8487
let session = Session::builder()?
@@ -196,7 +199,9 @@ impl TextRerank {
196199

197200
Ok(scores)
198201
})
199-
.flat_map(|result: Result<Vec<f32>, anyhow::Error>| result.unwrap())
202+
.collect::<Result<Vec<_>>>()?
203+
.into_iter()
204+
.flatten()
200205
.collect();
201206

202207
// Return top_n_result of type Vec<RerankResult> ordered by score in descending order, don't use binary heap

src/sparse_text_embedding/impl.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use crate::{
44
models::sparse::{models_list, SparseModel},
55
ModelInfo, SparseEmbedding,
66
};
7+
#[cfg(feature = "online")]
8+
use anyhow::Context;
79
use anyhow::Result;
810
#[cfg(feature = "online")]
911
use hf_hub::{
@@ -55,7 +57,7 @@ impl SparseTextEmbedding {
5557
let model_file_name = SparseTextEmbedding::get_model_info(&model_name).model_file;
5658
let model_file_reference = model_repo
5759
.get(&model_file_name)
58-
.unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name));
60+
.context(format!("Failed to retrieve {} ", model_file_name))?;
5961

6062
let session = Session::builder()?
6163
.with_execution_providers(execution_providers)?
@@ -91,8 +93,7 @@ impl SparseTextEmbedding {
9193
let cache = Cache::new(cache_dir);
9294
let api = ApiBuilder::from_cache(cache)
9395
.with_progress(show_download_progress)
94-
.build()
95-
.unwrap();
96+
.build()?;
9697

9798
let repo = api.model(model.to_string());
9899
Ok(repo)
@@ -189,7 +190,9 @@ impl SparseTextEmbedding {
189190

190191
Ok(embeddings)
191192
})
192-
.flat_map(|result: Result<Vec<SparseEmbedding>, anyhow::Error>| result.unwrap())
193+
.collect::<Result<Vec<_>>>()?
194+
.into_iter()
195+
.flatten()
193196
.collect();
194197

195198
Ok(output)

src/text_embedding/impl.rs

Lines changed: 73 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ use crate::{
99
Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, QuantizationMode, SingleBatchOutput,
1010
};
1111
#[cfg(feature = "online")]
12+
use anyhow::Context;
13+
use anyhow::Result;
14+
#[cfg(feature = "online")]
1215
use hf_hub::{
1316
api::sync::{ApiBuilder, ApiRepo},
1417
Cache,
@@ -40,7 +43,7 @@ impl TextEmbedding {
4043
///
4144
/// Uses the total number of CPUs available as the number of intra-threads
4245
#[cfg(feature = "online")]
43-
pub fn try_new(options: InitOptions) -> anyhow::Result<Self> {
46+
pub fn try_new(options: InitOptions) -> Result<Self> {
4447
let InitOptions {
4548
model_name,
4649
execution_providers,
@@ -61,7 +64,7 @@ impl TextEmbedding {
6164
let model_file_name = &model_info.model_file;
6265
let model_file_reference = model_repo
6366
.get(model_file_name)
64-
.unwrap_or_else(|_| panic!("Failed to retrieve {} ", model_file_name));
67+
.context(format!("Failed to retrieve {}", model_file_name))?;
6568

6669
// TODO: If more models need .onnx_data, implement a better way to handle this
6770
// Probably by adding `additional_files` field in the `ModelInfo` struct
@@ -95,7 +98,7 @@ impl TextEmbedding {
9598
pub fn try_new_from_user_defined(
9699
model: UserDefinedEmbeddingModel,
97100
options: InitOptionsUserDefined,
98-
) -> anyhow::Result<Self> {
101+
) -> Result<Self> {
99102
let InitOptionsUserDefined {
100103
execution_providers,
101104
max_length,
@@ -147,8 +150,7 @@ impl TextEmbedding {
147150
let cache = Cache::new(cache_dir);
148151
let api = ApiBuilder::from_cache(cache)
149152
.with_progress(show_download_progress)
150-
.build()
151-
.unwrap();
153+
.build()?;
152154

153155
let repo = api.model(model.to_string());
154156
Ok(repo)
@@ -160,7 +162,7 @@ impl TextEmbedding {
160162
}
161163

162164
/// Get ModelInfo from EmbeddingModel
163-
pub fn get_model_info(model: &EmbeddingModel) -> anyhow::Result<&ModelInfo<EmbeddingModel>> {
165+
pub fn get_model_info(model: &EmbeddingModel) -> Result<&ModelInfo<EmbeddingModel>> {
164166
get_model_info(model).ok_or_else(|| {
165167
anyhow::Error::msg(format!(
166168
"Model {model:?} not found. Please check if the model is supported \
@@ -195,7 +197,7 @@ impl TextEmbedding {
195197
&'e self,
196198
texts: Vec<S>,
197199
batch_size: Option<usize>,
198-
) -> anyhow::Result<EmbeddingOutput<'r, 's>>
200+
) -> Result<EmbeddingOutput<'r, 's>>
199201
where
200202
'e: 'r,
201203
'e: 's,
@@ -223,72 +225,70 @@ impl TextEmbedding {
223225
_ => Ok(batch_size.unwrap_or(DEFAULT_BATCH_SIZE)),
224226
}?;
225227

226-
let batches =
227-
anyhow::Result::<Vec<_>>::from_par_iter(texts.par_chunks(batch_size).map(|batch| {
228-
// Encode the texts in the batch
229-
let inputs = batch.iter().map(|text| text.as_ref()).collect();
230-
let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| {
231-
anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.")
232-
})?;
233-
234-
// Extract the encoding length and batch size
235-
let encoding_length = encodings[0].len();
236-
let batch_size = batch.len();
237-
238-
let max_size = encoding_length * batch_size;
239-
240-
// Preallocate arrays with the maximum size
241-
let mut ids_array = Vec::with_capacity(max_size);
242-
let mut mask_array = Vec::with_capacity(max_size);
243-
let mut typeids_array = Vec::with_capacity(max_size);
244-
245-
// Not using par_iter because the closure needs to be FnMut
246-
encodings.iter().for_each(|encoding| {
247-
let ids = encoding.get_ids();
248-
let mask = encoding.get_attention_mask();
249-
let typeids = encoding.get_type_ids();
250-
251-
// Extend the preallocated arrays with the current encoding
252-
// Requires the closure to be FnMut
253-
ids_array.extend(ids.iter().map(|x| *x as i64));
254-
mask_array.extend(mask.iter().map(|x| *x as i64));
255-
typeids_array.extend(typeids.iter().map(|x| *x as i64));
256-
});
257-
258-
// Create CowArrays from vectors
259-
let inputs_ids_array =
260-
Array::from_shape_vec((batch_size, encoding_length), ids_array)?;
261-
262-
let attention_mask_array =
263-
Array::from_shape_vec((batch_size, encoding_length), mask_array)?;
264-
265-
let token_type_ids_array =
266-
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;
267-
268-
let mut session_inputs = ort::inputs![
269-
"input_ids" => Value::from_array(inputs_ids_array)?,
270-
"attention_mask" => Value::from_array(attention_mask_array.view())?,
271-
]?;
272-
273-
if self.need_token_type_ids {
274-
session_inputs.push((
275-
"token_type_ids".into(),
276-
Value::from_array(token_type_ids_array)?.into(),
277-
));
278-
}
228+
let batches = Result::<Vec<_>>::from_par_iter(texts.par_chunks(batch_size).map(|batch| {
229+
// Encode the texts in the batch
230+
let inputs = batch.iter().map(|text| text.as_ref()).collect();
231+
let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| {
232+
anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.")
233+
})?;
234+
235+
// Extract the encoding length and batch size
236+
let encoding_length = encodings[0].len();
237+
let batch_size = batch.len();
238+
239+
let max_size = encoding_length * batch_size;
240+
241+
// Preallocate arrays with the maximum size
242+
let mut ids_array = Vec::with_capacity(max_size);
243+
let mut mask_array = Vec::with_capacity(max_size);
244+
let mut typeids_array = Vec::with_capacity(max_size);
245+
246+
// Not using par_iter because the closure needs to be FnMut
247+
encodings.iter().for_each(|encoding| {
248+
let ids = encoding.get_ids();
249+
let mask = encoding.get_attention_mask();
250+
let typeids = encoding.get_type_ids();
251+
252+
// Extend the preallocated arrays with the current encoding
253+
// Requires the closure to be FnMut
254+
ids_array.extend(ids.iter().map(|x| *x as i64));
255+
mask_array.extend(mask.iter().map(|x| *x as i64));
256+
typeids_array.extend(typeids.iter().map(|x| *x as i64));
257+
});
258+
259+
// Create CowArrays from vectors
260+
let inputs_ids_array = Array::from_shape_vec((batch_size, encoding_length), ids_array)?;
261+
262+
let attention_mask_array =
263+
Array::from_shape_vec((batch_size, encoding_length), mask_array)?;
264+
265+
let token_type_ids_array =
266+
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;
267+
268+
let mut session_inputs = ort::inputs![
269+
"input_ids" => Value::from_array(inputs_ids_array)?,
270+
"attention_mask" => Value::from_array(attention_mask_array.view())?,
271+
]?;
272+
273+
if self.need_token_type_ids {
274+
session_inputs.push((
275+
"token_type_ids".into(),
276+
Value::from_array(token_type_ids_array)?.into(),
277+
));
278+
}
279279

280-
Ok(
281-
// Package all the data required for post-processing (e.g. pooling)
282-
// into a SingleBatchOutput struct.
283-
SingleBatchOutput {
284-
session_outputs: self
285-
.session
286-
.run(session_inputs)
287-
.map_err(anyhow::Error::new)?,
288-
attention_mask_array,
289-
},
290-
)
291-
}))?;
280+
Ok(
281+
// Package all the data required for post-processing (e.g. pooling)
282+
// into a SingleBatchOutput struct.
283+
SingleBatchOutput {
284+
session_outputs: self
285+
.session
286+
.run(session_inputs)
287+
.map_err(anyhow::Error::new)?,
288+
attention_mask_array,
289+
},
290+
)
291+
}))?;
292292

293293
Ok(EmbeddingOutput::new(batches))
294294
}
@@ -308,7 +308,7 @@ impl TextEmbedding {
308308
&self,
309309
texts: Vec<S>,
310310
batch_size: Option<usize>,
311-
) -> anyhow::Result<Vec<Embedding>> {
311+
) -> Result<Vec<Embedding>> {
312312
let batches = self.transform(texts, batch_size)?;
313313

314314
batches.export_with_transformer(output::transformer_with_precedence(

0 commit comments

Comments
 (0)