@@ -9,6 +9,9 @@ use crate::{
9
9
Embedding , EmbeddingModel , EmbeddingOutput , ModelInfo , QuantizationMode , SingleBatchOutput ,
10
10
} ;
11
11
#[ cfg( feature = "online" ) ]
12
+ use anyhow:: Context ;
13
+ use anyhow:: Result ;
14
+ #[ cfg( feature = "online" ) ]
12
15
use hf_hub:: {
13
16
api:: sync:: { ApiBuilder , ApiRepo } ,
14
17
Cache ,
@@ -40,7 +43,7 @@ impl TextEmbedding {
40
43
///
41
44
/// Uses the total number of CPUs available as the number of intra-threads
42
45
#[ cfg( feature = "online" ) ]
43
- pub fn try_new ( options : InitOptions ) -> anyhow :: Result < Self > {
46
+ pub fn try_new ( options : InitOptions ) -> Result < Self > {
44
47
let InitOptions {
45
48
model_name,
46
49
execution_providers,
@@ -61,7 +64,7 @@ impl TextEmbedding {
61
64
let model_file_name = & model_info. model_file ;
62
65
let model_file_reference = model_repo
63
66
. get ( model_file_name)
64
- . unwrap_or_else ( |_| panic ! ( "Failed to retrieve {} " , model_file_name) ) ;
67
+ . context ( format ! ( "Failed to retrieve {}" , model_file_name) ) ? ;
65
68
66
69
// TODO: If more models need .onnx_data, implement a better way to handle this
67
70
// Probably by adding `additional_files` field in the `ModelInfo` struct
@@ -95,7 +98,7 @@ impl TextEmbedding {
95
98
pub fn try_new_from_user_defined (
96
99
model : UserDefinedEmbeddingModel ,
97
100
options : InitOptionsUserDefined ,
98
- ) -> anyhow :: Result < Self > {
101
+ ) -> Result < Self > {
99
102
let InitOptionsUserDefined {
100
103
execution_providers,
101
104
max_length,
@@ -147,8 +150,7 @@ impl TextEmbedding {
147
150
let cache = Cache :: new ( cache_dir) ;
148
151
let api = ApiBuilder :: from_cache ( cache)
149
152
. with_progress ( show_download_progress)
150
- . build ( )
151
- . unwrap ( ) ;
153
+ . build ( ) ?;
152
154
153
155
let repo = api. model ( model. to_string ( ) ) ;
154
156
Ok ( repo)
@@ -160,7 +162,7 @@ impl TextEmbedding {
160
162
}
161
163
162
164
/// Get ModelInfo from EmbeddingModel
163
- pub fn get_model_info ( model : & EmbeddingModel ) -> anyhow :: Result < & ModelInfo < EmbeddingModel > > {
165
+ pub fn get_model_info ( model : & EmbeddingModel ) -> Result < & ModelInfo < EmbeddingModel > > {
164
166
get_model_info ( model) . ok_or_else ( || {
165
167
anyhow:: Error :: msg ( format ! (
166
168
"Model {model:?} not found. Please check if the model is supported \
@@ -195,7 +197,7 @@ impl TextEmbedding {
195
197
& ' e self ,
196
198
texts : Vec < S > ,
197
199
batch_size : Option < usize > ,
198
- ) -> anyhow :: Result < EmbeddingOutput < ' r , ' s > >
200
+ ) -> Result < EmbeddingOutput < ' r , ' s > >
199
201
where
200
202
' e : ' r ,
201
203
' e : ' s ,
@@ -223,72 +225,70 @@ impl TextEmbedding {
223
225
_ => Ok ( batch_size. unwrap_or ( DEFAULT_BATCH_SIZE ) ) ,
224
226
} ?;
225
227
226
- let batches =
227
- anyhow:: Result :: < Vec < _ > > :: from_par_iter ( texts. par_chunks ( batch_size) . map ( |batch| {
228
- // Encode the texts in the batch
229
- let inputs = batch. iter ( ) . map ( |text| text. as_ref ( ) ) . collect ( ) ;
230
- let encodings = self . tokenizer . encode_batch ( inputs, true ) . map_err ( |e| {
231
- anyhow:: Error :: msg ( e. to_string ( ) ) . context ( "Failed to encode the batch." )
232
- } ) ?;
233
-
234
- // Extract the encoding length and batch size
235
- let encoding_length = encodings[ 0 ] . len ( ) ;
236
- let batch_size = batch. len ( ) ;
237
-
238
- let max_size = encoding_length * batch_size;
239
-
240
- // Preallocate arrays with the maximum size
241
- let mut ids_array = Vec :: with_capacity ( max_size) ;
242
- let mut mask_array = Vec :: with_capacity ( max_size) ;
243
- let mut typeids_array = Vec :: with_capacity ( max_size) ;
244
-
245
- // Not using par_iter because the closure needs to be FnMut
246
- encodings. iter ( ) . for_each ( |encoding| {
247
- let ids = encoding. get_ids ( ) ;
248
- let mask = encoding. get_attention_mask ( ) ;
249
- let typeids = encoding. get_type_ids ( ) ;
250
-
251
- // Extend the preallocated arrays with the current encoding
252
- // Requires the closure to be FnMut
253
- ids_array. extend ( ids. iter ( ) . map ( |x| * x as i64 ) ) ;
254
- mask_array. extend ( mask. iter ( ) . map ( |x| * x as i64 ) ) ;
255
- typeids_array. extend ( typeids. iter ( ) . map ( |x| * x as i64 ) ) ;
256
- } ) ;
257
-
258
- // Create CowArrays from vectors
259
- let inputs_ids_array =
260
- Array :: from_shape_vec ( ( batch_size, encoding_length) , ids_array) ?;
261
-
262
- let attention_mask_array =
263
- Array :: from_shape_vec ( ( batch_size, encoding_length) , mask_array) ?;
264
-
265
- let token_type_ids_array =
266
- Array :: from_shape_vec ( ( batch_size, encoding_length) , typeids_array) ?;
267
-
268
- let mut session_inputs = ort:: inputs![
269
- "input_ids" => Value :: from_array( inputs_ids_array) ?,
270
- "attention_mask" => Value :: from_array( attention_mask_array. view( ) ) ?,
271
- ] ?;
272
-
273
- if self . need_token_type_ids {
274
- session_inputs. push ( (
275
- "token_type_ids" . into ( ) ,
276
- Value :: from_array ( token_type_ids_array) ?. into ( ) ,
277
- ) ) ;
278
- }
228
+ let batches = Result :: < Vec < _ > > :: from_par_iter ( texts. par_chunks ( batch_size) . map ( |batch| {
229
+ // Encode the texts in the batch
230
+ let inputs = batch. iter ( ) . map ( |text| text. as_ref ( ) ) . collect ( ) ;
231
+ let encodings = self . tokenizer . encode_batch ( inputs, true ) . map_err ( |e| {
232
+ anyhow:: Error :: msg ( e. to_string ( ) ) . context ( "Failed to encode the batch." )
233
+ } ) ?;
234
+
235
+ // Extract the encoding length and batch size
236
+ let encoding_length = encodings[ 0 ] . len ( ) ;
237
+ let batch_size = batch. len ( ) ;
238
+
239
+ let max_size = encoding_length * batch_size;
240
+
241
+ // Preallocate arrays with the maximum size
242
+ let mut ids_array = Vec :: with_capacity ( max_size) ;
243
+ let mut mask_array = Vec :: with_capacity ( max_size) ;
244
+ let mut typeids_array = Vec :: with_capacity ( max_size) ;
245
+
246
+ // Not using par_iter because the closure needs to be FnMut
247
+ encodings. iter ( ) . for_each ( |encoding| {
248
+ let ids = encoding. get_ids ( ) ;
249
+ let mask = encoding. get_attention_mask ( ) ;
250
+ let typeids = encoding. get_type_ids ( ) ;
251
+
252
+ // Extend the preallocated arrays with the current encoding
253
+ // Requires the closure to be FnMut
254
+ ids_array. extend ( ids. iter ( ) . map ( |x| * x as i64 ) ) ;
255
+ mask_array. extend ( mask. iter ( ) . map ( |x| * x as i64 ) ) ;
256
+ typeids_array. extend ( typeids. iter ( ) . map ( |x| * x as i64 ) ) ;
257
+ } ) ;
258
+
259
+ // Create CowArrays from vectors
260
+ let inputs_ids_array = Array :: from_shape_vec ( ( batch_size, encoding_length) , ids_array) ?;
261
+
262
+ let attention_mask_array =
263
+ Array :: from_shape_vec ( ( batch_size, encoding_length) , mask_array) ?;
264
+
265
+ let token_type_ids_array =
266
+ Array :: from_shape_vec ( ( batch_size, encoding_length) , typeids_array) ?;
267
+
268
+ let mut session_inputs = ort:: inputs![
269
+ "input_ids" => Value :: from_array( inputs_ids_array) ?,
270
+ "attention_mask" => Value :: from_array( attention_mask_array. view( ) ) ?,
271
+ ] ?;
272
+
273
+ if self . need_token_type_ids {
274
+ session_inputs. push ( (
275
+ "token_type_ids" . into ( ) ,
276
+ Value :: from_array ( token_type_ids_array) ?. into ( ) ,
277
+ ) ) ;
278
+ }
279
279
280
- Ok (
281
- // Package all the data required for post-processing (e.g. pooling)
282
- // into a SingleBatchOutput struct.
283
- SingleBatchOutput {
284
- session_outputs : self
285
- . session
286
- . run ( session_inputs)
287
- . map_err ( anyhow:: Error :: new) ?,
288
- attention_mask_array,
289
- } ,
290
- )
291
- } ) ) ?;
280
+ Ok (
281
+ // Package all the data required for post-processing (e.g. pooling)
282
+ // into a SingleBatchOutput struct.
283
+ SingleBatchOutput {
284
+ session_outputs : self
285
+ . session
286
+ . run ( session_inputs)
287
+ . map_err ( anyhow:: Error :: new) ?,
288
+ attention_mask_array,
289
+ } ,
290
+ )
291
+ } ) ) ?;
292
292
293
293
Ok ( EmbeddingOutput :: new ( batches) )
294
294
}
@@ -308,7 +308,7 @@ impl TextEmbedding {
308
308
& self ,
309
309
texts : Vec < S > ,
310
310
batch_size : Option < usize > ,
311
- ) -> anyhow :: Result < Vec < Embedding > > {
311
+ ) -> Result < Vec < Embedding > > {
312
312
let batches = self . transform ( texts, batch_size) ?;
313
313
314
314
batches. export_with_transformer ( output:: transformer_with_precedence (
0 commit comments