Skip to content

Commit

Permalink
[WorkerHandler][Breaking] Create MLCEngine in worker handler internal…
Browse files Browse the repository at this point in the history
…ly (#472)

This PR applies to all `WebWorkerMLCEngine`, `ServiceWorkerMLCEngine`,
and `ExtensionServiceWorkerMLCEngine`.

Prior to this PR, the worker thread script looked like the following:

```typescript
import {
  MLCEngineServiceWorkerHandler,
  MLCEngine,
} from "@mlc-ai/web-llm";

const engine = new MLCEngine();
let handler: MLCEngineServiceWorkerHandler;

self.addEventListener("activate", function (event) {
  handler = new MLCEngineServiceWorkerHandler(engine);
  console.log("Service Worker is ready");
});
```

This may confuse users because they need to instantiate an `MLCEngine`
in the backend, and an `ServiceWorkerMLCEngine` in the frontend. After
this PR, the script looks like the following:

```typescript
import { ServiceWorkerMLCEngineHandler } from "@mlc-ai/web-llm";

let handler: ServiceWorkerMLCEngineHandler;

self.addEventListener("activate", function (event) {
  handler = new ServiceWorkerMLCEngineHandler();
  console.log("Service Worker is ready");
});
```

That is, `WorkerHandler` does not take any constructor (except `port`
for ExtensionServiceWorker), and we will instantiate `MLCEngine`
internally, making the flow more intuitive.

For logit processor usage, we add `setLogitProcessor()` to the handler
(see examples for the change).

Besides, we rename:
- `MLCEngineWorkerHandler` --> `WebWorkerMLCEngineHandler`
- `MLCEngineServiceWorkerHandler` --> `ServiceWorkerMLCEngineHandler`
- `MLCEngineExtensionServiceWorkerHandler` -->
`ExtensionServiceWorkerMLCEngineHandler`
  • Loading branch information
CharlieFRuan authored Jun 12, 2024
1 parent e9b83b6 commit 7e11cf3
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 81 deletions.
29 changes: 10 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,8 @@ console.log(fullReply);

You can put the heavy computation in a worker script to optimizing your application performance. To do so, you need to:

1. Create an MLCEngine in the worker thread for the actual inference.
2. Wrap the MLCEngine in the worker thread with a worker message handler to handle thread communications via messages under the hood.
3. Create a Worker Engine in your main application as a proxy to sending operations to the MLCEngine in the worker thread via sending messages.

1. Create a handler in the worker thread that communicates with the frontend while handling the requests.
2. Create a Worker Engine in your main application, which under the hood sends message to the handler in worker thread.
For detailed implementation for different kinds of Workers, check the following sections.

#### Dedicated Web Worker
Expand All @@ -213,16 +211,14 @@ WebLLM comes with API support for WebWorker so you can hook
the generation process into a separate worker thread so that
the computing in the worker thread won't disrupt the UI.

We will first create a worker script with a MLCEngine and
hook it up to a worker message handler.
We create a handler in the worker thread that communicates with the frontend while handling the requests.

```typescript
// worker.ts
import { MLCEngineWorkerHandler, MLCEngine } from "@mlc-ai/web-llm";
import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm";

// Hookup an MLCEngine to a worker handler
const engine = new MLCEngine();
const handler = new MLCEngineWorkerHandler(engine);
// A handler that resides in the worker thread
const handler = new WebWorkerMLCEngineHandler();
self.onmessage = (msg: MessageEvent) => {
handler.onmessage(msg);
};
Expand Down Expand Up @@ -260,22 +256,17 @@ your application's offline experience.

(Note, Service Worker's life cycle is managed by the browser and can be killed any time without notifying the webapp. `ServiceWorkerMLCEngine` will try to keep the service worker thread alive by periodically sending heartbeat events, but your application should also include proper error handling. Check `keepAliveMs` and `missedHeatbeat` in [`ServiceWorkerMLCEngine`](https://github.com/mlc-ai/web-llm/blob/main/src/service_worker.ts#L234) for more details.)

We first create a service worker script with a MLCEngine and hook it up to a worker message handler
that handles requests when the service worker is ready.
We create a handler in the worker thread that communicates with the frontend while handling the requests.


```typescript
// sw.ts
import {
MLCEngineServiceWorkerHandler,
MLCEngine,
} from "@mlc-ai/web-llm";
import { ServiceWorkerMLCEngineHandler } from "@mlc-ai/web-llm";

const engine = new MLCEngine();
let handler: MLCEngineServiceWorkerHandler;
let handler: ServiceWorkerMLCEngineHandler;

self.addEventListener("activate", function (event) {
handler = new MLCEngineServiceWorkerHandler(engine);
handler = new ServiceWorkerMLCEngineHandler();
console.log("Service Worker is ready");
});
```
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import {
MLCEngineExtensionServiceWorkerHandler,
MLCEngine,
} from "@mlc-ai/web-llm";
import { ExtensionServiceWorkerMLCEngineHandler } from "@mlc-ai/web-llm";

// Hookup an engine to a service worker handler
const engine = new MLCEngine();
let handler;

chrome.runtime.onConnect.addListener(function (port) {
console.assert(port.name === "web_llm_service_worker");
if (handler === undefined) {
handler = new MLCEngineExtensionServiceWorkerHandler(engine, port);
handler = new ExtensionServiceWorkerMLCEngineHandler(port);
} else {
handler.setPort(port);
}
Expand Down
5 changes: 2 additions & 3 deletions examples/get-started-web-worker/src/worker.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { MLCEngineWorkerHandler, MLCEngine } from "@mlc-ai/web-llm";
import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm";

// Hookup an engine to a worker handler
const engine = new MLCEngine();
const handler = new MLCEngineWorkerHandler(engine);
const handler = new WebWorkerMLCEngineHandler();
self.onmessage = (msg: MessageEvent) => {
handler.onmessage(msg);
};
5 changes: 2 additions & 3 deletions examples/logit-processor/src/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ const myLogitProcessor = new MyLogitProcessor();
const logitProcessorRegistry = new Map<string, webllm.LogitProcessor>();
logitProcessorRegistry.set("phi-2-q4f32_1-MLC", myLogitProcessor);

const engine = new webllm.MLCEngine();
engine.setLogitProcessorRegistry(logitProcessorRegistry);
const handler = new webllm.MLCEngineWorkerHandler(engine);
const handler = new webllm.WebWorkerMLCEngineHandler();
handler.setLogitProcessorRegistry(logitProcessorRegistry);
self.onmessage = (msg: MessageEvent) => {
handler.onmessage(msg);
};
11 changes: 3 additions & 8 deletions examples/service-worker/src/sw.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import {
MLCEngineServiceWorkerHandler,
MLCEngineInterface,
MLCEngine,
} from "@mlc-ai/web-llm";
import { ServiceWorkerMLCEngineHandler } from "@mlc-ai/web-llm";

const engine: MLCEngineInterface = new MLCEngine();
let handler: MLCEngineServiceWorkerHandler;
let handler: ServiceWorkerMLCEngineHandler;

self.addEventListener("activate", function (event) {
handler = new MLCEngineServiceWorkerHandler(engine);
handler = new ServiceWorkerMLCEngineHandler();
console.log("Web-LLM Service Worker Activated");
});
5 changes: 2 additions & 3 deletions examples/simple-chat-ts/src/worker.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// Serve the engine workload through web worker
import { MLCEngineWorkerHandler, MLCEngine } from "@mlc-ai/web-llm";
import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm";

const engine = new MLCEngine();
const handler = new MLCEngineWorkerHandler(engine);
const handler = new WebWorkerMLCEngineHandler();
self.onmessage = (msg: MessageEvent) => {
handler.onmessage(msg);
};
5 changes: 2 additions & 3 deletions examples/simple-chat-upload/src/worker.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// Serve the engine workload through web worker
import { MLCEngineWorkerHandler, MLCEngine } from "@mlc-ai/web-llm";
import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm";

const engine = new MLCEngine();
const handler = new MLCEngineWorkerHandler(engine);
const handler = new WebWorkerMLCEngineHandler();
self.onmessage = (msg: MessageEvent) => {
handler.onmessage(msg);
};
20 changes: 10 additions & 10 deletions src/extension_service_worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
import { MLCEngineInterface } from "./types";

Check warning on line 10 in src/extension_service_worker.ts

View workflow job for this annotation

GitHub Actions / lint

'MLCEngineInterface' is defined but never used
import {
ChatWorker,
MLCEngineWorkerHandler,
WebWorkerMLCEngineHandler,
WebWorkerMLCEngine,
} from "./web_worker";
import { areChatOptionsEqual } from "./utils";
Expand All @@ -25,14 +25,14 @@ import { ChatCompletionChunk } from "./openai_api_protocols/index";
* let handler;
* chrome.runtime.onConnect.addListener(function (port) {
* if (handler === undefined) {
* handler = new MLCEngineServiceWorkerHandler(engine, port);
* handler = new ServiceWorkerMLCEngineHandler(engine, port);
* } else {
* handler.setPort(port);
* }
* port.onMessage.addListener(handler.onmessage.bind(handler));
* });
*/
export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler {
/**
* The modelId and chatOpts that the underlying engine (backend) is currently loaded with.
*
Expand All @@ -46,8 +46,8 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
chatOpts?: ChatOptions;
port: chrome.runtime.Port | null;

constructor(engine: MLCEngineInterface, port: chrome.runtime.Port) {
super(engine);
constructor(port: chrome.runtime.Port) {
super();
this.port = port;
port.onDisconnect.addListener(() => this.onPortDisconnect(port));
}
Expand Down Expand Up @@ -127,10 +127,10 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " +
"ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in MLCEngineServiceWorkerHandler.");
log.info("Reloading engine in ServiceWorkerMLCEngineHandler.");
await this.engine.reload(params.modelId, params.chatOpts);
}
const res = await this.engine.chatCompletion(params.request);
Expand All @@ -147,10 +147,10 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " +
"ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in MLCEngineServiceWorkerHandler.");
log.info("Reloading engine in ServiceWorkerMLCEngineHandler.");
await this.engine.reload(params.modelId, params.chatOpts);
}
this.chatCompletionAsyncChunkGenerator =
Expand All @@ -164,7 +164,7 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
return;
}

// All rest of message handling are the same as MLCEngineWorkerHandler
// All rest of message handling are the same as WebWorkerMLCEngineHandler
super.onmessage(event);
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,21 @@ export {
} from "./cache_util";

export {
MLCEngineWorkerHandler,
WebWorkerMLCEngineHandler,
WebWorkerMLCEngine,
CreateWebWorkerMLCEngine,
} from "./web_worker";

export { WorkerRequest, WorkerResponse, CustomRequestParams } from "./message";

export {
MLCEngineServiceWorkerHandler,
ServiceWorkerMLCEngineHandler,
ServiceWorkerMLCEngine,
CreateServiceWorkerMLCEngine,
} from "./service_worker";

export {
MLCEngineServiceWorkerHandler as MLCEngineExtensionServiceWorkerHandler,
ServiceWorkerMLCEngineHandler as ExtensionServiceWorkerMLCEngineHandler,
ServiceWorkerMLCEngine as ExtensionServiceWorkerMLCEngine,
CreateServiceWorkerMLCEngine as CreateExtensionServiceWorkerMLCEngine,
} from "./extension_service_worker";
Expand Down
28 changes: 13 additions & 15 deletions src/service_worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import {
ChatCompletionNonStreamingParams,
ChatCompletionStreamInitParams,
} from "./message";
import { MLCEngineInterface, InitProgressReport } from "./types";
import { InitProgressReport } from "./types";
import {
MLCEngineWorkerHandler,
WebWorkerMLCEngineHandler,
WebWorkerMLCEngine,
ChatWorker,
} from "./web_worker";
Expand All @@ -30,14 +30,14 @@ type IServiceWorker = globalThis.ServiceWorker;
* let handler;
* chrome.runtime.onConnect.addListener(function (port) {
* if (handler === undefined) {
* handler = new MLCEngineServiceWorkerHandler(engine, port);
* handler = new ServiceWorkerMLCEngineHandler(engine, port);
* } else {
* handler.setPort(port);
* }
* port.onMessage.addListener(handler.onmessage.bind(handler));
* });
*/
export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler {
/**
* The modelId and chatOpts that the underlying engine (backend) is currently loaded with.
*
Expand All @@ -56,14 +56,13 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
>();
private initRequestUuid?: string;

constructor(engine: MLCEngineInterface) {
constructor() {
if (!self || !("addEventListener" in self)) {
throw new Error(
"MLCEngineServiceWorkerHandler must be created in the service worker script.",
"ServiceWorkerMLCEngineHandler must be created in the service worker script.",
);
}
const customInitProgressCallback = engine.getInitProgressCallback();
super(engine);
super();
const onmessage = this.onmessage.bind(this);

this.engine.setInitProgressCallback((report: InitProgressReport) => {
Expand All @@ -73,7 +72,6 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
content: report,
};
this.postMessage(msg);
customInitProgressCallback?.(report);
});

self.addEventListener("message", (event) => {
Expand Down Expand Up @@ -112,7 +110,7 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
`ServiceWorker message: [${msg.kind}] ${JSON.stringify(msg.content)}`,
);

// Special case message handling different from MLCEngineWorkerHandler
// Special case message handling different from WebWorkerMLCEngineHandler
if (msg.kind === "keepAlive") {
const reply: WorkerResponse = {
kind: "heartbeat",
Expand Down Expand Up @@ -180,10 +178,10 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " +
"ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in MLCEngineServiceWorkerHandler.");
log.info("Reloading engine in ServiceWorkerMLCEngineHandler.");
this.initRequestUuid = msg.uuid;
await this.engine.reload(params.modelId, params.chatOpts);
}
Expand All @@ -202,10 +200,10 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " +
"ServiceWorkerMLCEngine expects model is loaded in ServiceWorkerMLCEngineHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in MLCEngineServiceWorkerHandler.");
log.info("Reloading engine in ServiceWorkerMLCEngineHandler.");
this.initRequestUuid = msg.uuid;
await this.engine.reload(params.modelId, params.chatOpts);
}
Expand All @@ -221,7 +219,7 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
return;
}

// All rest of message handling are the same as MLCEngineWorkerHandler
// All rest of message handling are the same as WebWorkerMLCEngineHandler
super.onmessage(msg, onComplete, onError);
}
}
Expand Down
Loading

0 comments on commit 7e11cf3

Please sign in to comment.