Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
au-re committed Mar 1, 2024
1 parent 2054758 commit 58e349d
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import { WebSocket } from "unws";
import { getApiServiceWebSocketUrl } from "../constants";
import { streamCompletion } from "./streamCompletion";

const mockWebSocketInstance = {
onopen: jest.fn(),
onerror: jest.fn(),
onmessage: jest.fn(),
send: jest.fn(),
close: jest.fn(),
};

jest.mock("unws", () => {
return {
WebSocket: jest.fn().mockImplementation(() => mockWebSocketInstance),
};
});

jest.mock("../constants");

describe("streamCompletion", () => {
const fakeUrl = "ws://fake-url";

beforeEach(() => {
(getApiServiceWebSocketUrl as jest.Mock).mockReturnValue(fakeUrl);
});

it.skip("should connect to WebSocket and send start message", async () => {
const input = {
apiKey: "fake-api-key",
modelId: "fake-model-id",
prompt: "fake-prompt",
};

const generator = streamCompletion(input);

generator.next();

expect(WebSocket).toHaveBeenCalledWith(`${fakeUrl}/api/v1/completion/streamed`);

// Simulate successful WebSocket connection
mockWebSocketInstance.onopen();

// Assert that a start message is sent
expect(mockWebSocketInstance.send).toHaveBeenCalledWith(expect.any(String));
});

it.skip("should handle WebSocket error", async () => {
const input = {
apiKey: "fake-api-key",
modelId: "fake-model-id",
prompt: "fake-prompt",
};

const errorMessage = "WebSocket error";
mockWebSocketInstance.onerror(new Error(errorMessage));

const generator = streamCompletion(input);

await generator.next();

await expect(generator.next()).resolves.toEqual({
done: true,
value: expect.objectContaining({
is_error: true,
error: errorMessage,
}),
});
});

it.skip("should handle received messages and close on end message", async () => {
const input = {
apiKey: "fake-api-key",
modelId: "fake-model-id",
prompt: "fake-prompt",
};

const generator = streamCompletion(input);
generator.next();

// Simulate receiving messages
const messageData1 = { datapoint: { model_output: "output1" } };
const messageData2 = { datapoint: { model_output: "output2" } };
mockWebSocketInstance.onmessage({ data: JSON.stringify({ type: "chunk", data: messageData1 }) });
mockWebSocketInstance.onmessage({ data: JSON.stringify({ type: "chunk", data: messageData2 }) });
mockWebSocketInstance.onmessage({ data: JSON.stringify({ type: "end" }) });

await expect(generator.next()).resolves.toEqual({ value: messageData1, done: false });
await expect(generator.next()).resolves.toEqual({ value: messageData2, done: false });
await expect(generator.next()).resolves.toEqual({ value: undefined, done: true });

expect(mockWebSocketInstance.close).toHaveBeenCalled();
});
});
13 changes: 10 additions & 3 deletions packages/@pufflig/ps-sdk/src/completionAPI/streamCompletion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { getApiServiceWebSocketUrl } from "../constants";
import { CreateCompletionInput, CreateCompletionPayload } from "../types";

const logger = pino();
const WAIT_FOR_CHUNKS_MS = 50;

interface MessageData {
datapoint: {
Expand Down Expand Up @@ -32,6 +33,7 @@ class StreamedCompletionResponseMessage {
this.data = data;
}
}

enum StreamedCompletionRequestType {
CompletionRequestStart = "CompletionRequestStart",
}
Expand Down Expand Up @@ -89,12 +91,13 @@ export async function* streamCompletion(input: CreateCompletionInput) {

websocket.onerror = (err: Error) => {
logger.error("SDK:CompletionStreamError", err.message);
completionChunks.push({
const errorMessage = {
datapoint: {
is_error: true,
error: err.message,
},
});
};
completionChunks.push(errorMessage);
streamEnded = true;
websocket.close();
};
Expand All @@ -109,11 +112,15 @@ export async function* streamCompletion(input: CreateCompletionInput) {
}
};

websocket.onclose = () => {
streamEnded = true;
};

while (completionChunks.length > 0 || !streamEnded) {
if (completionChunks.length > 0) {
yield completionChunks.shift();
} else {
await new Promise((resolve) => setTimeout(resolve, 50)); // Wait for chunks to arrive
await new Promise((resolve) => setTimeout(resolve, WAIT_FOR_CHUNKS_MS)); // Wait for chunks to arrive
}
}
}

0 comments on commit 58e349d

Please sign in to comment.