Skip to content

Commit

Permalink
[ServiceWorker] Reload model when service worker killed (#471)
Browse files Browse the repository at this point in the history
Service worker may be killed by the browser despite us sending
`keepAlive` messages. Upon killed, all states in the worker thread is
lost. When the frontend sends another request, expecting the model
loaded, the backend MLCEngine throws an error, essentially due to the
mismatch of frontend's expectation and backend's reality.

This PR adds `modelId` and `chatOpts` to `WebWorkerMLCEngine` (hence
ServiceWorker and ExtensionServiceWorker), with the semantics being: the
model and config that the frontend expects the backend to be currently
loaded with. When sending a request, we send the expectation as well.

For backend, upon receiving the chat completion request, it compares the
expectation with its current state. If not matched (due to service
worker killed), we reload before carrying out the request.

---------

Co-authored-by: Nestor Qin <imba.qxy@gmail.com>
  • Loading branch information
CharlieFRuan and Neet-Nestor authored Jun 12, 2024
1 parent 74ed3be commit e9b83b6
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 6 deletions.
78 changes: 75 additions & 3 deletions src/extension_service_worker.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import * as tvmjs from "tvmjs";
import log from "loglevel";
import { AppConfig, ChatOptions, MLCEngineConfig } from "./config";
import { ReloadParams, WorkerRequest } from "./message";
import { ChatOptions, MLCEngineConfig } from "./config";
import {
ReloadParams,
WorkerRequest,
ChatCompletionNonStreamingParams,
ChatCompletionStreamInitParams,
} from "./message";
import { MLCEngineInterface } from "./types";
import {
ChatWorker,
MLCEngineWorkerHandler,
WebWorkerMLCEngine,
} from "./web_worker";
import { areChatOptionsEqual } from "./utils";
import { ChatCompletionChunk } from "./openai_api_protocols/index";

/**
* Worker handler that can be used in a ServiceWorker.
Expand All @@ -27,9 +33,17 @@ import { areChatOptionsEqual } from "./utils";
* });
*/
export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
/**
* The modelId and chatOpts that the underlying engine (backend) is currently loaded with.
*
* TODO(webllm-team): This is always in-sync with `this.engine` unless device is lost due to
* unexpected reason. Therefore, we should get it from `this.engine` directly and make handler
* stateless. We should also perhaps make `engine` of type `MLCEngine` instead. Besides, consider
* if we should add appConfig, or use engine's API to find the corresponding model record rather
* than relying on just the modelId.
*/
modelId?: string;
chatOpts?: ChatOptions;
appConfig?: AppConfig;
port: chrome.runtime.Port | null;

constructor(engine: MLCEngineInterface, port: chrome.runtime.Port) {
Expand Down Expand Up @@ -93,6 +107,64 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
});
return;
}

// Unset modelId and chatOpts since backend unloads the model
if (msg.kind === "unload") {
this.handleTask(msg.uuid, async () => {
await this.engine.unload();
this.modelId = undefined;
this.chatOpts = undefined;
return null;
});
return;
}

if (msg.kind === "chatCompletionNonStreaming") {
// Directly return the ChatCompletion response
this.handleTask(msg.uuid, async () => {
const params = msg.content as ChatCompletionNonStreamingParams;
// Check whether frontend expectation matches with backend (modelId and chatOpts)
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in MLCEngineServiceWorkerHandler.");
await this.engine.reload(params.modelId, params.chatOpts);
}
const res = await this.engine.chatCompletion(params.request);
return res;
});
return;
}

if (msg.kind === "chatCompletionStreamInit") {
// One-time set up that instantiates the chunk generator in worker
this.handleTask(msg.uuid, async () => {
const params = msg.content as ChatCompletionStreamInitParams;
// Check whether frontend expectation matches with backend (modelId and chatOpts)
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in MLCEngineServiceWorkerHandler.");
await this.engine.reload(params.modelId, params.chatOpts);
}
this.chatCompletionAsyncChunkGenerator =
(await this.engine.chatCompletion(params.request)) as AsyncGenerator<
ChatCompletionChunk,
void,
void
>;
return null;
});
return;
}

// All rest of message handling are the same as MLCEngineWorkerHandler
super.onmessage(event);
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,19 @@ export interface ForwardTokensAndSampleParams {
}
export interface ChatCompletionNonStreamingParams {
request: ChatCompletionRequestNonStreaming;
// The model and chatOpts that the frontend engine expects the backend to be loaded with.
// If not loaded due to service worker unexpectedly killed, handler will call reload().
// TODO(webllm-team): should add appConfig here as well.
modelId: string;
chatOpts: ChatOptions;
}
export interface ChatCompletionStreamInitParams {
request: ChatCompletionRequestStreaming;
// The model and chatOpts that the frontend engine expects the backend to be loaded with.
// If not loaded due to service worker unexpectedly killed, handler will call reload().
// TODO(webllm-team): should add appConfig here as well.
modelId: string;
chatOpts: ChatOptions;
}

export interface CustomRequestParams {
Expand Down
84 changes: 81 additions & 3 deletions src/service_worker.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import * as tvmjs from "tvmjs";
import log from "loglevel";
import { AppConfig, ChatOptions, MLCEngineConfig } from "./config";
import { ReloadParams, WorkerRequest, WorkerResponse } from "./message";
import { ChatOptions, MLCEngineConfig } from "./config";
import {
ReloadParams,
WorkerRequest,
WorkerResponse,
ChatCompletionNonStreamingParams,
ChatCompletionStreamInitParams,
} from "./message";
import { MLCEngineInterface, InitProgressReport } from "./types";
import {
MLCEngineWorkerHandler,
WebWorkerMLCEngine,
ChatWorker,
} from "./web_worker";
import { areChatOptionsEqual } from "./utils";
import { ChatCompletionChunk } from "./openai_api_protocols/index";

/* Service Worker Script */

Expand All @@ -31,9 +38,17 @@ type IServiceWorker = globalThis.ServiceWorker;
* });
*/
export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
/**
* The modelId and chatOpts that the underlying engine (backend) is currently loaded with.
*
* TODO(webllm-team): This is always in-sync with `this.engine` unless device is lost due to
* unexpected reason. Therefore, we should get it from `this.engine` directly and make handler
* stateless. We should also perhaps make `engine` of type `MLCEngine` instead. Besides, consider
* if we should add appConfig, or use engine's API to find the corresponding model record rather
* than relying on just the modelId.
*/
modelId?: string;
chatOpts?: ChatOptions;
appConfig?: AppConfig;

private clientRegistry = new Map<
string,
Expand Down Expand Up @@ -97,6 +112,7 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
`ServiceWorker message: [${msg.kind}] ${JSON.stringify(msg.content)}`,
);

// Special case message handling different from MLCEngineWorkerHandler
if (msg.kind === "keepAlive") {
const reply: WorkerResponse = {
kind: "heartbeat",
Expand Down Expand Up @@ -144,6 +160,68 @@ export class MLCEngineServiceWorkerHandler extends MLCEngineWorkerHandler {
});
return;
}

if (msg.kind === "unload") {
this.handleTask(msg.uuid, async () => {
await this.engine.unload();
onComplete?.(null);
this.modelId = undefined;
this.chatOpts = undefined;
return null;
});
return;
}

if (msg.kind === "chatCompletionNonStreaming") {
// Directly return the ChatCompletion response
this.handleTask(msg.uuid, async () => {
const params = msg.content as ChatCompletionNonStreamingParams;
// Check whether frontend expectation matches with backend (modelId and chatOpts)
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in MLCEngineServiceWorkerHandler.");
this.initRequestUuid = msg.uuid;
await this.engine.reload(params.modelId, params.chatOpts);
}
const res = await this.engine.chatCompletion(params.request);
onComplete?.(res);
return res;
});
return;
}

if (msg.kind === "chatCompletionStreamInit") {
// One-time set up that instantiates the chunk generator in worker
this.handleTask(msg.uuid, async () => {
const params = msg.content as ChatCompletionStreamInitParams;
// Check whether frontend expectation matches with backend (modelId and chatOpts)
// If not (due to possibly killed service worker), we reload here.
if (this.modelId !== params.modelId) {
log.warn(
"ServiceWorkerMLCEngine expects model is loaded in MLCEngineServiceWorkerHandler, " +
"but it is not. This may due to service worker is unexpectedly killed. ",
);
log.info("Reloading engine in MLCEngineServiceWorkerHandler.");
this.initRequestUuid = msg.uuid;
await this.engine.reload(params.modelId, params.chatOpts);
}
this.chatCompletionAsyncChunkGenerator =
(await this.engine.chatCompletion(params.request)) as AsyncGenerator<
ChatCompletionChunk,
void,
void
>;
onComplete?.(null);
return null;
});
return;
}

// All rest of message handling are the same as MLCEngineWorkerHandler
super.onmessage(msg, onComplete, onError);
}
}
Expand Down
22 changes: 22 additions & 0 deletions src/web_worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,14 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
public worker: ChatWorker;
public chat: API.Chat;

/**
* The modelId and chatOpts that the frontend expects the backend engine is currently loaded
* with. Needed for service worker. It is the backend and handler's job to match up with the
* expectation despite the service worker possibly being killed.
*/
modelId?: string;
chatOpts?: ChatOptions;

private initProgressCallback?: InitProgressCallback;
private generateCallbackRegistry = new Map<
string,
Expand Down Expand Up @@ -421,6 +429,8 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
},
};
await this.getPromise<null>(msg);
this.modelId = modelId;
this.chatOpts = chatOpts;
}

async getMaxStorageBufferBindingSize(): Promise<number> {
Expand Down Expand Up @@ -496,6 +506,8 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
content: null,
};
await this.getPromise<null>(msg);
this.modelId = undefined;
this.chatOpts = undefined;
}

async resetChat(keepStats = false): Promise<void> {
Expand Down Expand Up @@ -563,13 +575,21 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
async chatCompletion(
request: ChatCompletionRequest,
): Promise<AsyncIterable<ChatCompletionChunk> | ChatCompletion> {
if (this.modelId === undefined) {
throw new Error(
`${this.constructor.name} is not loaded with a model. Did you call \`engine.reload()\`?`,
);
}

if (request.stream) {
// First let worker instantiate a generator
const msg: WorkerRequest = {
kind: "chatCompletionStreamInit",
uuid: crypto.randomUUID(),
content: {
request: request,
modelId: this.modelId,
chatOpts: this.chatOpts,
},
};
await this.getPromise<null>(msg);
Expand All @@ -584,6 +604,8 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
uuid: crypto.randomUUID(),
content: {
request: request,
modelId: this.modelId,
chatOpts: this.chatOpts,
},
};
return await this.getPromise<ChatCompletion>(msg);
Expand Down

0 comments on commit e9b83b6

Please sign in to comment.