From 7e11cf3f3a0cb1e295003e180ae15dee2329a5eb Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Wed, 12 Jun 2024 19:28:29 -0400 Subject: [PATCH] [WorkerHandler][Breaking] Create MLCEngine in worker handler internally (#472) This PR applies to all `WebWorkerMLCEngine`, `ServiceWorkerMLCEngine`, and `ExtensionServiceWorkerMLCEngine`. Prior to this PR, the worker thread script looked like the following: ```typescript import { MLCEngineServiceWorkerHandler, MLCEngine, } from "@mlc-ai/web-llm"; const engine = new MLCEngine(); let handler: MLCEngineServiceWorkerHandler; self.addEventListener("activate", function (event) { handler = new MLCEngineServiceWorkerHandler(engine); console.log("Service Worker is ready"); }); ``` This may confuse users because they need to instantiate an `MLCEngine` in the backend, and an `ServiceWorkerMLCEngine` in the frontend. After this PR, the script looks like the following: ```typescript import { ServiceWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; let handler: ServiceWorkerMLCEngineHandler; self.addEventListener("activate", function (event) { handler = new ServiceWorkerMLCEngineHandler(); console.log("Service Worker is ready"); }); ``` That is, `WorkerHandler` does not take any constructor (except `port` for ExtensionServiceWorker), and we will instantiate `MLCEngine` internally, making the flow more intuitive. For logit processor usage, we add `setLogitProcessor()` to the handler (see examples for the change). Besides, we rename: - `MLCEngineWorkerHandler` --> `WebWorkerMLCEngineHandler` - `MLCEngineServiceWorkerHandler` --> `ServiceWorkerMLCEngineHandler` - `MLCEngineExtensionServiceWorkerHandler` --> `ExtensionServiceWorkerMLCEngineHandler` --- README.md | 29 +++++++------------ .../src/background.ts | 8 ++--- examples/get-started-web-worker/src/worker.ts | 5 ++-- examples/logit-processor/src/worker.ts | 5 ++-- examples/service-worker/src/sw.ts | 11 ++----- examples/simple-chat-ts/src/worker.ts | 5 ++-- examples/simple-chat-upload/src/worker.ts | 5 ++-- src/extension_service_worker.ts | 20 ++++++------- src/index.ts | 6 ++-- src/service_worker.ts | 28 +++++++++--------- src/web_worker.ts | 21 +++++++++----- 11 files changed, 62 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index 1ca9c0f2..fa73db13 100644 --- a/README.md +++ b/README.md @@ -201,10 +201,8 @@ console.log(fullReply); You can put the heavy computation in a worker script to optimizing your application performance. To do so, you need to: -1. Create an MLCEngine in the worker thread for the actual inference. -2. Wrap the MLCEngine in the worker thread with a worker message handler to handle thread communications via messages under the hood. -3. Create a Worker Engine in your main application as a proxy to sending operations to the MLCEngine in the worker thread via sending messages. - +1. Create a handler in the worker thread that communicates with the frontend while handling the requests. +2. Create a Worker Engine in your main application, which under the hood sends message to the handler in worker thread. For detailed implementation for different kinds of Workers, check the following sections. #### Dedicated Web Worker @@ -213,16 +211,14 @@ WebLLM comes with API support for WebWorker so you can hook the generation process into a separate worker thread so that the computing in the worker thread won't disrupt the UI. -We will first create a worker script with a MLCEngine and -hook it up to a worker message handler. +We create a handler in the worker thread that communicates with the frontend while handling the requests. ```typescript // worker.ts -import { MLCEngineWorkerHandler, MLCEngine } from "@mlc-ai/web-llm"; +import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; -// Hookup an MLCEngine to a worker handler -const engine = new MLCEngine(); -const handler = new MLCEngineWorkerHandler(engine); +// A handler that resides in the worker thread +const handler = new WebWorkerMLCEngineHandler(); self.onmessage = (msg: MessageEvent) => { handler.onmessage(msg); }; @@ -260,22 +256,17 @@ your application's offline experience. (Note, Service Worker's life cycle is managed by the browser and can be killed any time without notifying the webapp. `ServiceWorkerMLCEngine` will try to keep the service worker thread alive by periodically sending heartbeat events, but your application should also include proper error handling. Check `keepAliveMs` and `missedHeatbeat` in [`ServiceWorkerMLCEngine`](https://github.com/mlc-ai/web-llm/blob/main/src/service_worker.ts#L234) for more details.) -We first create a service worker script with a MLCEngine and hook it up to a worker message handler -that handles requests when the service worker is ready. +We create a handler in the worker thread that communicates with the frontend while handling the requests. ```typescript // sw.ts -import { - MLCEngineServiceWorkerHandler, - MLCEngine, -} from "@mlc-ai/web-llm"; +import { ServiceWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; -const engine = new MLCEngine(); -let handler: MLCEngineServiceWorkerHandler; +let handler: ServiceWorkerMLCEngineHandler; self.addEventListener("activate", function (event) { - handler = new MLCEngineServiceWorkerHandler(engine); + handler = new ServiceWorkerMLCEngineHandler(); console.log("Service Worker is ready"); }); ``` diff --git a/examples/chrome-extension-webgpu-service-worker/src/background.ts b/examples/chrome-extension-webgpu-service-worker/src/background.ts index 610537d9..3e3dc4d9 100644 --- a/examples/chrome-extension-webgpu-service-worker/src/background.ts +++ b/examples/chrome-extension-webgpu-service-worker/src/background.ts @@ -1,16 +1,12 @@ -import { - MLCEngineExtensionServiceWorkerHandler, - MLCEngine, -} from "@mlc-ai/web-llm"; +import { ExtensionServiceWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; // Hookup an engine to a service worker handler -const engine = new MLCEngine(); let handler; chrome.runtime.onConnect.addListener(function (port) { console.assert(port.name === "web_llm_service_worker"); if (handler === undefined) { - handler = new MLCEngineExtensionServiceWorkerHandler(engine, port); + handler = new ExtensionServiceWorkerMLCEngineHandler(port); } else { handler.setPort(port); } diff --git a/examples/get-started-web-worker/src/worker.ts b/examples/get-started-web-worker/src/worker.ts index 06d9709a..6c62240b 100644 --- a/examples/get-started-web-worker/src/worker.ts +++ b/examples/get-started-web-worker/src/worker.ts @@ -1,8 +1,7 @@ -import { MLCEngineWorkerHandler, MLCEngine } from "@mlc-ai/web-llm"; +import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; // Hookup an engine to a worker handler -const engine = new MLCEngine(); -const handler = new MLCEngineWorkerHandler(engine); +const handler = new WebWorkerMLCEngineHandler(); self.onmessage = (msg: MessageEvent) => { handler.onmessage(msg); }; diff --git a/examples/logit-processor/src/worker.ts b/examples/logit-processor/src/worker.ts index ac0f9c05..aa023704 100644 --- a/examples/logit-processor/src/worker.ts +++ b/examples/logit-processor/src/worker.ts @@ -8,9 +8,8 @@ const myLogitProcessor = new MyLogitProcessor(); const logitProcessorRegistry = new Map(); logitProcessorRegistry.set("phi-2-q4f32_1-MLC", myLogitProcessor); -const engine = new webllm.MLCEngine(); -engine.setLogitProcessorRegistry(logitProcessorRegistry); -const handler = new webllm.MLCEngineWorkerHandler(engine); +const handler = new webllm.WebWorkerMLCEngineHandler(); +handler.setLogitProcessorRegistry(logitProcessorRegistry); self.onmessage = (msg: MessageEvent) => { handler.onmessage(msg); }; diff --git a/examples/service-worker/src/sw.ts b/examples/service-worker/src/sw.ts index 1dd4f3ac..ffd5ae3f 100644 --- a/examples/service-worker/src/sw.ts +++ b/examples/service-worker/src/sw.ts @@ -1,13 +1,8 @@ -import { - MLCEngineServiceWorkerHandler, - MLCEngineInterface, - MLCEngine, -} from "@mlc-ai/web-llm"; +import { ServiceWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; -const engine: MLCEngineInterface = new MLCEngine(); -let handler: MLCEngineServiceWorkerHandler; +let handler: ServiceWorkerMLCEngineHandler; self.addEventListener("activate", function (event) { - handler = new MLCEngineServiceWorkerHandler(engine); + handler = new ServiceWorkerMLCEngineHandler(); console.log("Web-LLM Service Worker Activated"); }); diff --git a/examples/simple-chat-ts/src/worker.ts b/examples/simple-chat-ts/src/worker.ts index a83ac98f..acca7a42 100644 --- a/examples/simple-chat-ts/src/worker.ts +++ b/examples/simple-chat-ts/src/worker.ts @@ -1,8 +1,7 @@ // Serve the engine workload through web worker -import { MLCEngineWorkerHandler, MLCEngine } from "@mlc-ai/web-llm"; +import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; -const engine = new MLCEngine(); -const handler = new MLCEngineWorkerHandler(engine); +const handler = new WebWorkerMLCEngineHandler(); self.onmessage = (msg: MessageEvent) => { handler.onmessage(msg); }; diff --git a/examples/simple-chat-upload/src/worker.ts b/examples/simple-chat-upload/src/worker.ts index a83ac98f..acca7a42 100644 --- a/examples/simple-chat-upload/src/worker.ts +++ b/examples/simple-chat-upload/src/worker.ts @@ -1,8 +1,7 @@ // Serve the engine workload through web worker -import { MLCEngineWorkerHandler, MLCEngine } from "@mlc-ai/web-llm"; +import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; -const engine = new MLCEngine(); -const handler = new MLCEngineWorkerHandler(engine); +const handler = new WebWorkerMLCEngineHandler(); self.onmessage = (msg: MessageEvent) => { handler.onmessage(msg); }; diff --git a/src/extension_service_worker.ts b/src/extension_service_worker.ts index b46c167d..7053acdb 100644 --- a/src/extension_service_worker.ts +++ b/src/extension_service_worker.ts @@ -10,7 +10,7 @@ import { import { MLCEngineInterface } from "./types"; import { ChatWorker, - MLCEngineWorkerHandler, + WebWorkerMLCEngineHandler, WebWorkerMLCEngine, } from "./web_worker"; import { areChatOptionsEqual } from "./utils"; @@ -25,14 +25,14 @@ import { ChatCompletionChunk } from "./openai_api_protocols/index"; * let handler; * chrome.runtime.onConnect.addListener(function (port) { * if (handler === undefined) { - * handler = new MLCEngineServiceWorkerHandler(engine, port); + * handler = new ServiceWorkerMLCEngineHandler(engine, port); * } else { * handler.setPort(port); * } * port.onMessage.addListener(handler.onmessage.bind(handler)); * }); */ -export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { +export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler { /** * The modelId and chatOpts that the underlying engine (backend) is currently loaded with. * @@ -46,8 +46,8 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { chatOpts?: ChatOptions; port: chrome.runtime.Port | null; - constructor(engine: MLCEngineInterface, port: chrome.runtime.Port) { - super(engine); + constructor(port: chrome.runtime.Port) { + super(); this.port = port; port.onDisconnect.addListener(() => this.onPortDisconnect(port)); } @@ -127,10 +127,10 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { // If not (due to possibly killed service worker), we reload here. if (this.modelId !== params.modelId) { log.warn( - "ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " + + "ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " + "but it is not. This may due to service worker is unexpectedly killed. ", ); - log.info("Reloading engine in MLCEngineServiceWorkerHandler."); + log.info("Reloading engine in ServiceWorkerMLCEngineHandler."); await this.engine.reload(params.modelId, params.chatOpts); } const res = await this.engine.chatCompletion(params.request); @@ -147,10 +147,10 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { // If not (due to possibly killed service worker), we reload here. if (this.modelId !== params.modelId) { log.warn( - "ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " + + "ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " + "but it is not. This may due to service worker is unexpectedly killed. ", ); - log.info("Reloading engine in MLCEngineServiceWorkerHandler."); + log.info("Reloading engine in ServiceWorkerMLCEngineHandler."); await this.engine.reload(params.modelId, params.chatOpts); } this.chatCompletionAsyncChunkGenerator = @@ -164,7 +164,7 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { return; } - // All rest of message handling are the same as MLCEngineWorkerHandler + // All rest of message handling are the same as WebWorkerMLCEngineHandler super.onmessage(event); } } diff --git a/src/index.ts b/src/index.ts index 99348b68..c4060183 100644 --- a/src/index.ts +++ b/src/index.ts @@ -28,7 +28,7 @@ export { } from "./cache_util"; export { - MLCEngineWorkerHandler, + WebWorkerMLCEngineHandler, WebWorkerMLCEngine, CreateWebWorkerMLCEngine, } from "./web_worker"; @@ -36,13 +36,13 @@ export { export { WorkerRequest, WorkerResponse, CustomRequestParams } from "./message"; export { - MLCEngineServiceWorkerHandler, + ServiceWorkerMLCEngineHandler, ServiceWorkerMLCEngine, CreateServiceWorkerMLCEngine, } from "./service_worker"; export { - MLCEngineServiceWorkerHandler as MLCEngineExtensionServiceWorkerHandler, + ServiceWorkerMLCEngineHandler as ExtensionServiceWorkerMLCEngineHandler, ServiceWorkerMLCEngine as ExtensionServiceWorkerMLCEngine, CreateServiceWorkerMLCEngine as CreateExtensionServiceWorkerMLCEngine, } from "./extension_service_worker"; diff --git a/src/service_worker.ts b/src/service_worker.ts index 9232ee6c..7ab980be 100644 --- a/src/service_worker.ts +++ b/src/service_worker.ts @@ -8,9 +8,9 @@ import { ChatCompletionNonStreamingParams, ChatCompletionStreamInitParams, } from "./message"; -import { MLCEngineInterface, InitProgressReport } from "./types"; +import { InitProgressReport } from "./types"; import { - MLCEngineWorkerHandler, + WebWorkerMLCEngineHandler, WebWorkerMLCEngine, ChatWorker, } from "./web_worker"; @@ -30,14 +30,14 @@ type IServiceWorker = globalThis.ServiceWorker; * let handler; * chrome.runtime.onConnect.addListener(function (port) { * if (handler === undefined) { - * handler = new MLCEngineServiceWorkerHandler(engine, port); + * handler = new ServiceWorkerMLCEngineHandler(engine, port); * } else { * handler.setPort(port); * } * port.onMessage.addListener(handler.onmessage.bind(handler)); * }); */ -export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { +export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler { /** * The modelId and chatOpts that the underlying engine (backend) is currently loaded with. * @@ -56,14 +56,13 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { >(); private initRequestUuid?: string; - constructor(engine: MLCEngineInterface) { + constructor() { if (!self || !("addEventListener" in self)) { throw new Error( - "MLCEngineServiceWorkerHandler must be created in the service worker script.", + "ServiceWorkerMLCEngineHandler must be created in the service worker script.", ); } - const customInitProgressCallback = engine.getInitProgressCallback(); - super(engine); + super(); const onmessage = this.onmessage.bind(this); this.engine.setInitProgressCallback((report: InitProgressReport) => { @@ -73,7 +72,6 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { content: report, }; this.postMessage(msg); - customInitProgressCallback?.(report); }); self.addEventListener("message", (event) => { @@ -112,7 +110,7 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { `ServiceWorker message: [${msg.kind}] ${JSON.stringify(msg.content)}`, ); - // Special case message handling different from MLCEngineWorkerHandler + // Special case message handling different from WebWorkerMLCEngineHandler if (msg.kind === "keepAlive") { const reply: WorkerResponse = { kind: "heartbeat", @@ -180,10 +178,10 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { // If not (due to possibly killed service worker), we reload here. if (this.modelId !== params.modelId) { log.warn( - "ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " + + "ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " + "but it is not. This may due to service worker is unexpectedly killed. ", ); - log.info("Reloading engine in MLCEngineServiceWorkerHandler."); + log.info("Reloading engine in ServiceWorkerMLCEngineHandler."); this.initRequestUuid = msg.uuid; await this.engine.reload(params.modelId, params.chatOpts); } @@ -202,10 +200,10 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { // If not (due to possibly killed service worker), we reload here. if (this.modelId !== params.modelId) { log.warn( - "ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " + + "ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " + "but it is not. This may due to service worker is unexpectedly killed. ", ); - log.info("Reloading engine in MLCEngineServiceWorkerHandler."); + log.info("Reloading engine in ServiceWorkerMLCEngineHandler."); this.initRequestUuid = msg.uuid; await this.engine.reload(params.modelId, params.chatOpts); } @@ -221,7 +219,7 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler { return; } - // All rest of message handling are the same as MLCEngineWorkerHandler + // All rest of message handling are the same as WebWorkerMLCEngineHandler super.onmessage(msg, onComplete, onError); } } diff --git a/src/web_worker.ts b/src/web_worker.ts index c6def6a3..de79ed19 100644 --- a/src/web_worker.ts +++ b/src/web_worker.ts @@ -10,6 +10,7 @@ import { InitProgressCallback, InitProgressReport, LogLevel, + LogitProcessor, } from "./types"; import { ChatCompletionRequest, @@ -33,6 +34,7 @@ import { WorkerRequest, } from "./message"; import log from "loglevel"; +import { MLCEngine } from "./engine"; /** * Worker handler that can be used in a WebWorker @@ -42,11 +44,11 @@ import log from "loglevel"; * // setup a chat worker handler that routes * // requests to the chat * const engine = new MLCEngine(); - * cont handler = new MLCEngineWorkerHandler(engine); + * cont handler = new WebWorkerMLCEngineHandler(engine); * onmessage = handler.onmessage; */ -export class MLCEngineWorkerHandler { - protected engine: MLCEngineInterface; +export class WebWorkerMLCEngineHandler { + protected engine: MLCEngine; protected chatCompletionAsyncChunkGenerator?: AsyncGenerator< ChatCompletionChunk, void, @@ -56,10 +58,8 @@ export class MLCEngineWorkerHandler { /** * @param engine A concrete implementation of MLCEngineInterface */ - constructor(engine: MLCEngineInterface) { - this.engine = engine; - - const customInitProgressCallback = engine.getInitProgressCallback(); + constructor() { + this.engine = new MLCEngine(); this.engine.setInitProgressCallback((report: InitProgressReport) => { const msg: WorkerResponse = { kind: "initProgressCallback", @@ -67,7 +67,6 @@ export class MLCEngineWorkerHandler { content: report, }; this.postMessage(msg); - customInitProgressCallback?.(report); }); } @@ -76,6 +75,12 @@ export class MLCEngineWorkerHandler { postMessage(msg); } + setLogitProcessorRegistry( + logitProcessorRegistry?: Map, + ) { + this.engine.setLogitProcessorRegistry(logitProcessorRegistry); + } + async handleTask( uuid: string, task: () => Promise,