Skip to content

Commit 6ade9eb

Browse files
authored
feat: Add jina v2 code support (#140)
1 parent cfec7d7 commit 6ade9eb

File tree

6 files changed

+57
-7
lines changed

6 files changed

+57
-7
lines changed

src/models/image_embedding.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,34 +24,39 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
2424
description: String::from("CLIP vision encoder based on ViT-B/32"),
2525
model_code: String::from("Qdrant/clip-ViT-B-32-vision"),
2626
model_file: String::from("model.onnx"),
27+
additional_files: Vec::new(),
2728
},
2829
ModelInfo {
2930
model: ImageEmbeddingModel::Resnet50,
3031
dim: 2048,
3132
description: String::from("ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__."),
3233
model_code: String::from("Qdrant/resnet50-onnx"),
3334
model_file: String::from("model.onnx"),
35+
additional_files: Vec::new(),
3436
},
3537
ModelInfo {
3638
model: ImageEmbeddingModel::UnicomVitB16,
3739
dim: 768,
3840
description: String::from("Unicom Unicom-ViT-B-16 from open-metric-learning"),
3941
model_code: String::from("Qdrant/Unicom-ViT-B-16"),
4042
model_file: String::from("model.onnx"),
43+
additional_files: Vec::new(),
4144
},
4245
ModelInfo {
4346
model: ImageEmbeddingModel::UnicomVitB32,
4447
dim: 512,
4548
description: String::from("Unicom Unicom-ViT-B-32 from open-metric-learning"),
4649
model_code: String::from("Qdrant/Unicom-ViT-B-32"),
4750
model_file: String::from("model.onnx"),
51+
additional_files: Vec::new(),
4852
},
4953
ModelInfo {
5054
model: ImageEmbeddingModel::NomicEmbedVisionV15,
5155
dim: 768,
5256
description: String::from("Nomic NomicEmbedVisionV15"),
5357
model_code: String::from("nomic-ai/nomic-embed-vision-v1.5"),
5458
model_file: String::from("onnx/model.onnx"),
59+
additional_files: Vec::new(),
5560
},
5661
];
5762

src/models/model_info.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ pub struct ModelInfo<T> {
66
pub description: String,
77
pub model_code: String,
88
pub model_file: String,
9+
pub additional_files: Vec<String>,
910
}

src/models/sparse.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pub fn models_list() -> Vec<ModelInfo<SparseModel>> {
1616
description: String::from("Splade sparse vector model for commercial use, v1"),
1717
model_code: String::from("Qdrant/Splade_PP_en_v1"),
1818
model_file: String::from("model.onnx"),
19+
additional_files: Vec::new(),
1920
}]
2021
}
2122

src/models/text_embedding.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ pub enum EmbeddingModel {
6565
GTELargeENV15Q,
6666
/// Qdrant/clip-ViT-B-32-text
6767
ClipVitB32,
68+
/// jinaai/jina-embeddings-v2-base-code
69+
JinaEmbeddingsV2BaseCode,
6870
}
6971

7072
/// Centralized function to initialize the models map.
@@ -76,62 +78,71 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
7678
description: String::from("Sentence Transformer model, MiniLM-L6-v2"),
7779
model_code: String::from("Qdrant/all-MiniLM-L6-v2-onnx"),
7880
model_file: String::from("model.onnx"),
81+
additional_files: Vec::new(),
7982
},
8083
ModelInfo {
8184
model: EmbeddingModel::AllMiniLML6V2Q,
8285
dim: 384,
8386
description: String::from("Quantized Sentence Transformer model, MiniLM-L6-v2"),
8487
model_code: String::from("Xenova/all-MiniLM-L6-v2"),
8588
model_file: String::from("onnx/model_quantized.onnx"),
89+
additional_files: Vec::new(),
8690
},
8791
ModelInfo {
8892
model: EmbeddingModel::AllMiniLML12V2,
8993
dim: 384,
9094
description: String::from("Sentence Transformer model, MiniLM-L12-v2"),
9195
model_code: String::from("Xenova/all-MiniLM-L12-v2"),
9296
model_file: String::from("onnx/model.onnx"),
97+
additional_files: Vec::new(),
9398
},
9499
ModelInfo {
95100
model: EmbeddingModel::AllMiniLML12V2Q,
96101
dim: 384,
97102
description: String::from("Quantized Sentence Transformer model, MiniLM-L12-v2"),
98103
model_code: String::from("Xenova/all-MiniLM-L12-v2"),
99104
model_file: String::from("onnx/model_quantized.onnx"),
105+
additional_files: Vec::new(),
100106
},
101107
ModelInfo {
102108
model: EmbeddingModel::BGEBaseENV15,
103109
dim: 768,
104110
description: String::from("v1.5 release of the base English model"),
105111
model_code: String::from("Xenova/bge-base-en-v1.5"),
106112
model_file: String::from("onnx/model.onnx"),
113+
additional_files: Vec::new(),
107114
},
108115
ModelInfo {
109116
model: EmbeddingModel::BGEBaseENV15Q,
110117
dim: 768,
111118
description: String::from("Quantized v1.5 release of the large English model"),
112119
model_code: String::from("Qdrant/bge-base-en-v1.5-onnx-Q"),
113120
model_file: String::from("model_optimized.onnx"),
121+
additional_files: Vec::new(),
114122
},
115123
ModelInfo {
116124
model: EmbeddingModel::BGELargeENV15,
117125
dim: 1024,
118126
description: String::from("v1.5 release of the large English model"),
119127
model_code: String::from("Xenova/bge-large-en-v1.5"),
120128
model_file: String::from("onnx/model.onnx"),
129+
additional_files: Vec::new(),
121130
},
122131
ModelInfo {
123132
model: EmbeddingModel::BGELargeENV15Q,
124133
dim: 1024,
125134
description: String::from("Quantized v1.5 release of the large English model"),
126135
model_code: String::from("Qdrant/bge-large-en-v1.5-onnx-Q"),
127136
model_file: String::from("model_optimized.onnx"),
137+
additional_files: Vec::new(),
128138
},
129139
ModelInfo {
130140
model: EmbeddingModel::BGESmallENV15,
131141
dim: 384,
132142
description: String::from("v1.5 release of the fast and default English model"),
133143
model_code: String::from("Xenova/bge-small-en-v1.5"),
134144
model_file: String::from("onnx/model.onnx"),
145+
additional_files: Vec::new(),
135146
},
136147
ModelInfo {
137148
model: EmbeddingModel::BGESmallENV15Q,
@@ -141,20 +152,23 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
141152
),
142153
model_code: String::from("Qdrant/bge-small-en-v1.5-onnx-Q"),
143154
model_file: String::from("model_optimized.onnx"),
155+
additional_files: Vec::new(),
144156
},
145157
ModelInfo {
146158
model: EmbeddingModel::NomicEmbedTextV1,
147159
dim: 768,
148160
description: String::from("8192 context length english model"),
149161
model_code: String::from("nomic-ai/nomic-embed-text-v1"),
150162
model_file: String::from("onnx/model.onnx"),
163+
additional_files: Vec::new(),
151164
},
152165
ModelInfo {
153166
model: EmbeddingModel::NomicEmbedTextV15,
154167
dim: 768,
155168
description: String::from("v1.5 release of the 8192 context length english model"),
156169
model_code: String::from("nomic-ai/nomic-embed-text-v1.5"),
157170
model_file: String::from("onnx/model.onnx"),
171+
additional_files: Vec::new(),
158172
},
159173
ModelInfo {
160174
model: EmbeddingModel::NomicEmbedTextV15Q,
@@ -164,20 +178,23 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
164178
),
165179
model_code: String::from("nomic-ai/nomic-embed-text-v1.5"),
166180
model_file: String::from("onnx/model_quantized.onnx"),
181+
additional_files: Vec::new(),
167182
},
168183
ModelInfo {
169184
model: EmbeddingModel::ParaphraseMLMiniLML12V2Q,
170185
dim: 384,
171186
description: String::from("Quantized Multi-lingual model"),
172187
model_code: String::from("Qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q"),
173188
model_file: String::from("model_optimized.onnx"),
189+
additional_files: Vec::new(),
174190
},
175191
ModelInfo {
176192
model: EmbeddingModel::ParaphraseMLMiniLML12V2,
177193
dim: 384,
178194
description: String::from("Multi-lingual model"),
179195
model_code: String::from("Xenova/paraphrase-multilingual-MiniLM-L12-v2"),
180196
model_file: String::from("onnx/model.onnx"),
197+
additional_files: Vec::new(),
181198
},
182199
ModelInfo {
183200
model: EmbeddingModel::ParaphraseMLMpnetBaseV2,
@@ -187,83 +204,103 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
187204
),
188205
model_code: String::from("Xenova/paraphrase-multilingual-mpnet-base-v2"),
189206
model_file: String::from("onnx/model.onnx"),
207+
additional_files: Vec::new(),
190208
},
191209
ModelInfo {
192210
model: EmbeddingModel::BGESmallZHV15,
193211
dim: 512,
194212
description: String::from("v1.5 release of the small Chinese model"),
195213
model_code: String::from("Xenova/bge-small-zh-v1.5"),
196214
model_file: String::from("onnx/model.onnx"),
215+
additional_files: Vec::new(),
197216
},
198217
ModelInfo {
199218
model: EmbeddingModel::MultilingualE5Small,
200219
dim: 384,
201220
description: String::from("Small model of multilingual E5 Text Embeddings"),
202221
model_code: String::from("intfloat/multilingual-e5-small"),
203222
model_file: String::from("onnx/model.onnx"),
223+
additional_files: Vec::new(),
204224
},
205225
ModelInfo {
206226
model: EmbeddingModel::MultilingualE5Base,
207227
dim: 768,
208228
description: String::from("Base model of multilingual E5 Text Embeddings"),
209229
model_code: String::from("intfloat/multilingual-e5-base"),
210230
model_file: String::from("onnx/model.onnx"),
231+
additional_files: Vec::new(),
211232
},
212233
ModelInfo {
213234
model: EmbeddingModel::MultilingualE5Large,
214235
dim: 1024,
215236
description: String::from("Large model of multilingual E5 Text Embeddings"),
216237
model_code: String::from("Qdrant/multilingual-e5-large-onnx"),
217238
model_file: String::from("model.onnx"),
239+
additional_files: vec!["model.onnx_data".to_string()],
218240
},
219241
ModelInfo {
220242
model: EmbeddingModel::MxbaiEmbedLargeV1,
221243
dim: 1024,
222244
description: String::from("Large English embedding model from MixedBreed.ai"),
223245
model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"),
224246
model_file: String::from("onnx/model.onnx"),
247+
additional_files: Vec::new(),
225248
},
226249
ModelInfo {
227250
model: EmbeddingModel::MxbaiEmbedLargeV1Q,
228251
dim: 1024,
229252
description: String::from("Quantized Large English embedding model from MixedBreed.ai"),
230253
model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"),
231254
model_file: String::from("onnx/model_quantized.onnx"),
255+
additional_files: Vec::new(),
232256
},
233257
ModelInfo {
234258
model: EmbeddingModel::GTEBaseENV15,
235259
dim: 768,
236260
description: String::from("Large multilingual embedding model from Alibaba"),
237261
model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"),
238262
model_file: String::from("onnx/model.onnx"),
263+
additional_files: Vec::new(),
239264
},
240265
ModelInfo {
241266
model: EmbeddingModel::GTEBaseENV15Q,
242267
dim: 768,
243268
description: String::from("Quantized Large multilingual embedding model from Alibaba"),
244269
model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"),
245270
model_file: String::from("onnx/model_quantized.onnx"),
271+
additional_files: Vec::new(),
246272
},
247273
ModelInfo {
248274
model: EmbeddingModel::GTELargeENV15,
249275
dim: 1024,
250276
description: String::from("Large multilingual embedding model from Alibaba"),
251277
model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"),
252278
model_file: String::from("onnx/model.onnx"),
279+
additional_files: Vec::new(),
253280
},
254281
ModelInfo {
255282
model: EmbeddingModel::GTELargeENV15Q,
256283
dim: 1024,
257284
description: String::from("Quantized Large multilingual embedding model from Alibaba"),
258285
model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"),
259286
model_file: String::from("onnx/model_quantized.onnx"),
287+
additional_files: Vec::new(),
260288
},
261289
ModelInfo {
262290
model: EmbeddingModel::ClipVitB32,
263291
dim: 512,
264292
description: String::from("CLIP text encoder based on ViT-B/32"),
265293
model_code: String::from("Qdrant/clip-ViT-B-32-text"),
266294
model_file: String::from("model.onnx"),
295+
additional_files: Vec::new(),
296+
},
297+
ModelInfo {
298+
model: EmbeddingModel::JinaEmbeddingsV2BaseCode,
299+
dim: 768,
300+
description: String::from("Jina embeddings v2 base code"),
301+
model_code: String::from("jinaai/jina-embeddings-v2-base-code"),
302+
model_file: String::from("onnx/model.onnx"),
303+
additional_files: Vec::new(),
267304
},
268305
];
269306

@@ -338,6 +375,8 @@ impl EmbeddingModel {
338375
EmbeddingModel::GTELargeENV15Q => Some(Pooling::Cls),
339376

340377
EmbeddingModel::ClipVitB32 => Some(Pooling::Mean),
378+
379+
EmbeddingModel::JinaEmbeddingsV2BaseCode => Some(Pooling::Mean),
341380
}
342381
}
343382

src/text_embedding/impl.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ impl TextEmbedding {
6666
.get(model_file_name)
6767
.context(format!("Failed to retrieve {}", model_file_name))?;
6868

69-
// TODO: If more models need .onnx_data, implement a better way to handle this
70-
// Probably by adding `additional_files` field in the `ModelInfo` struct
71-
if model_name == EmbeddingModel::MultilingualE5Large {
72-
model_repo
73-
.get("model.onnx_data")
74-
.expect("Failed to retrieve model.onnx_data.");
69+
if !model_info.additional_files.is_empty() {
70+
for file in &model_info.additional_files {
71+
model_repo
72+
.get(file)
73+
.context(format!("Failed to retrieve {}", file))?;
74+
}
7575
}
7676

7777
// prioritise loading pooling config if available, if not (thanks qdrant!), look for it in hardcoded
@@ -132,6 +132,7 @@ impl TextEmbedding {
132132
.inputs
133133
.iter()
134134
.any(|input| input.name == "token_type_ids");
135+
135136
Self {
136137
tokenizer,
137138
session,

tests/embeddings.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ fn verify_embeddings(model: &EmbeddingModel, embeddings: &[Embedding]) -> Result
6262
EmbeddingModel::ParaphraseMLMiniLML12V2Q => [-0.07749095, -0.058981877, -0.043487836, -0.18775631],
6363
EmbeddingModel::ParaphraseMLMpnetBaseV2 => [0.39132136, 0.49490625, 0.65497226, 0.34237382],
6464
EmbeddingModel::ClipVitB32 => [0.7057363, 1.3549932, 0.46823958, 0.52351093],
65+
EmbeddingModel::JinaEmbeddingsV2BaseCode => [-0.31383067, -0.3758629, -0.24878195, -0.35373706],
6566
_ => panic!("Model {model} not found. If you have just inserted this `EmbeddingModel` variant, please update the expected embeddings."),
6667
};
6768

@@ -189,7 +190,9 @@ fn test_sparse_embeddings() {
189190
});
190191

191192
// Clear the model cache to avoid running out of space on GitHub Actions.
192-
clean_cache(supported_model.model_code.clone())
193+
if std::env::var("CI").is_ok() {
194+
clean_cache(supported_model.model_code.clone())
195+
}
193196
});
194197
}
195198

0 commit comments

Comments
 (0)