Skip to content

Commit ac944b0

Browse files
committed
Add embedding worker
Signed-off-by: Jay Wang <jay@zijie.wang>
1 parent 9054850 commit ac944b0

File tree

2 files changed

+87
-11
lines changed

2 files changed

+87
-11
lines changed

examples/rag-playground/src/components/prompt-panel/prompt-panel.ts

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import { LitElement, css, unsafeCSS, html, PropertyValues } from 'lit';
22
import { customElement, property, state, query } from 'lit/decorators.js';
33
import { unsafeHTML } from 'lit/directives/unsafe-html.js';
4-
import { pipeline } from '@xenova/transformers';
4+
import { EmbeddingModel } from '../../workers/embedding';
5+
6+
import type { EmbeddingWorkerMessage } from '../../workers/embedding';
57

68
import componentCSS from './prompt-panel.css?inline';
79
import EmbeddingWorkerInline from '../../workers/embedding?worker&inline';
@@ -17,12 +19,25 @@ export class MememoPromptPanel extends LitElement {
1719
//==========================================================================||
1820
embeddingWorker: Worker;
1921

22+
embeddingWorkerRequestCount = 0;
23+
24+
get embeddingWorkerRequestID() {
25+
this.embeddingWorkerRequestCount++;
26+
return `prompt-panel-${this.embeddingWorkerRequestCount}`;
27+
}
28+
2029
//==========================================================================||
2130
// Lifecycle Methods ||
2231
//==========================================================================||
2332
constructor() {
2433
super();
2534
this.embeddingWorker = new EmbeddingWorkerInline();
35+
this.embeddingWorker.addEventListener(
36+
'message',
37+
(e: MessageEvent<EmbeddingWorkerMessage>) => {
38+
this.embeddingWorkerMessageHandler(e);
39+
}
40+
);
2641
}
2742

2843
firstUpdated() {
@@ -40,12 +55,43 @@ export class MememoPromptPanel extends LitElement {
4055
//==========================================================================||
4156
async initData() {}
4257

43-
async getEmbedding() {}
58+
getEmbedding() {
59+
const message: EmbeddingWorkerMessage = {
60+
command: 'startExtractEmbedding',
61+
payload: {
62+
detail: '',
63+
requestID: this.embeddingWorkerRequestID,
64+
model: EmbeddingModel.gteSmall,
65+
sentences: ['Hello, how are you', 'yo']
66+
}
67+
};
68+
this.embeddingWorker.postMessage(message);
69+
}
4470

4571
//==========================================================================||
4672
// Event Handlers ||
4773
//==========================================================================||
4874

75+
embeddingWorkerMessageHandler(e: MessageEvent<EmbeddingWorkerMessage>) {
76+
switch (e.data.command) {
77+
case 'finishExtractEmbedding': {
78+
const embeddings = e.data.payload.embeddings;
79+
console.log(embeddings);
80+
break;
81+
}
82+
83+
case 'error': {
84+
console.error('Worker error: ', e.data.payload.message);
85+
break;
86+
}
87+
88+
default: {
89+
console.error('Worker: unknown message', e.data.command);
90+
break;
91+
}
92+
}
93+
}
94+
4995
//==========================================================================||
5096
// Private Helpers ||
5197
//==========================================================================||

examples/rag-playground/src/workers/embedding.ts

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export type EmbeddingWorkerMessage =
1010
command: 'startExtractEmbedding';
1111
payload: {
1212
requestID: string;
13-
text: string;
13+
sentences: string[];
1414
model: EmbeddingModel;
1515
detail: string;
1616
};
@@ -19,10 +19,10 @@ export type EmbeddingWorkerMessage =
1919
command: 'finishExtractEmbedding';
2020
payload: {
2121
requestID: string;
22-
text: string;
22+
sentences: string[];
2323
model: EmbeddingModel;
2424
detail: string;
25-
embedding: number[];
25+
embeddings: number[][];
2626
};
2727
}
2828
| {
@@ -39,26 +39,56 @@ const extractors: Record<EmbeddingModel, Promise<FeatureExtractionPipeline>> = {
3939
'gte-small': pipeline('feature-extraction', 'gte-small')
4040
};
4141

42+
/**
43+
* Helper function to handle calls from the main thread
44+
* @param e Message event
45+
*/
46+
self.onmessage = (e: MessageEvent<EmbeddingWorkerMessage>) => {
47+
switch (e.data.command) {
48+
case 'startExtractEmbedding': {
49+
const { model, sentences, requestID, detail } = e.data.payload;
50+
startExtractEmbedding(model, sentences, requestID, detail);
51+
break;
52+
}
53+
54+
default: {
55+
console.error('Worker: unknown message', e.data.command);
56+
break;
57+
}
58+
}
59+
};
60+
4261
/**
4362
* Extract embedding from the input text
4463
* @param model Embedding model
4564
* @param text Input text
4665
*/
47-
export const getEmbedding = async (
66+
export const startExtractEmbedding = async (
4867
model: EmbeddingModel,
49-
text: string,
68+
sentences: string[],
5069
requestID: string,
5170
detail: string
5271
) => {
5372
try {
5473
const extractor = await extractors[model];
55-
const sentences = [text];
5674
const output = await extractor(sentences, {
5775
pooling: 'mean',
5876
normalize: true
5977
});
6078

61-
const embedding = Array.from<number>(output.data as Float32Array);
79+
const embeddings: number[][] = [];
80+
const flattenEmbedding: number[] = Array.from<number>(
81+
output.data as Float32Array
82+
);
83+
84+
// Un-flatten the embedding output
85+
for (let i = 0; i < output.dims[0]; i++) {
86+
const curRow = flattenEmbedding.slice(
87+
i * output.dims[1],
88+
(i + 1) * output.dims[1]
89+
);
90+
embeddings.push(curRow);
91+
}
6292

6393
// Send result to the main thread
6494
const message: EmbeddingWorkerMessage = {
@@ -67,8 +97,8 @@ export const getEmbedding = async (
6797
model,
6898
requestID,
6999
detail,
70-
text,
71-
embedding
100+
sentences,
101+
embeddings
72102
}
73103
};
74104
postMessage(message);

0 commit comments

Comments
 (0)