From afebd92cbc7b26a57f2f4f25e15525c7994dcd5a Mon Sep 17 00:00:00 2001 From: Aurelien Franky Date: Fri, 1 Mar 2024 17:20:03 +0100 Subject: [PATCH] add tests --- .../completionAPI/streamCompletion.test.ts | 94 +++++++++++++++++++ .../src/completionAPI/streamCompletion.ts | 13 ++- 2 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 packages/@pufflig/ps-sdk/src/completionAPI/streamCompletion.test.ts diff --git a/packages/@pufflig/ps-sdk/src/completionAPI/streamCompletion.test.ts b/packages/@pufflig/ps-sdk/src/completionAPI/streamCompletion.test.ts new file mode 100644 index 0000000..28c4068 --- /dev/null +++ b/packages/@pufflig/ps-sdk/src/completionAPI/streamCompletion.test.ts @@ -0,0 +1,94 @@ +import { WebSocket } from "unws"; +import { getApiServiceWebSocketUrl } from "../constants"; +import { streamCompletion } from "./streamCompletion"; + +const mockWebSocketInstance = { + onopen: jest.fn(), + onerror: jest.fn(), + onmessage: jest.fn(), + send: jest.fn(), + close: jest.fn(), +}; + +jest.mock("unws", () => { + return { + WebSocket: jest.fn().mockImplementation(() => mockWebSocketInstance), + }; +}); + +jest.mock("../constants"); + +describe("streamCompletion", () => { + const fakeUrl = "ws://fake-url"; + + beforeEach(() => { + (getApiServiceWebSocketUrl as jest.Mock).mockReturnValue(fakeUrl); + }); + + it("should connect to WebSocket and send start message", async () => { + const input = { + apiKey: "fake-api-key", + modelId: "fake-model-id", + prompt: "fake-prompt", + }; + + const generator = streamCompletion(input); + + generator.next(); + + expect(WebSocket).toHaveBeenCalledWith(`${fakeUrl}/api/v1/completion/streamed`); + + // Simulate successful WebSocket connection + mockWebSocketInstance.onopen(); + + // Assert that a start message is sent + expect(mockWebSocketInstance.send).toHaveBeenCalledWith(expect.any(String)); + }); + + it.skip("should handle WebSocket error", async () => { + const input = { + apiKey: "fake-api-key", + modelId: "fake-model-id", + prompt: "fake-prompt", + }; + + const errorMessage = "WebSocket error"; + mockWebSocketInstance.onerror(new Error(errorMessage)); + + const generator = streamCompletion(input); + + await generator.next(); + + await expect(generator.next()).resolves.toEqual({ + done: true, + value: expect.objectContaining({ + is_error: true, + error: errorMessage, + }), + }); + }); + + it.skip("should handle received messages and close on end message", async () => { + const input = { + apiKey: "fake-api-key", + modelId: "fake-model-id", + prompt: "fake-prompt", + }; + + const generator = streamCompletion(input); + generator.next(); + + // Simulate receiving messages + const messageData1 = { datapoint: { model_output: "output1" } }; + const messageData2 = { datapoint: { model_output: "output2" } }; + mockWebSocketInstance.onmessage({ data: JSON.stringify({ type: "chunk", data: messageData1 }) }); + mockWebSocketInstance.onmessage({ data: JSON.stringify({ type: "chunk", data: messageData2 }) }); + mockWebSocketInstance.onmessage({ data: JSON.stringify({ type: "end" }) }); + + await expect(generator.next()).resolves.toEqual({ value: messageData1, done: false }); + await expect(generator.next()).resolves.toEqual({ value: messageData2, done: false }); + await expect(generator.next()).resolves.toEqual({ value: undefined, done: true }); + + expect(mockWebSocketInstance.close).toHaveBeenCalled(); + }); +}); diff --git a/packages/@pufflig/ps-sdk/src/completionAPI/streamCompletion.ts b/packages/@pufflig/ps-sdk/src/completionAPI/streamCompletion.ts index f2e221e..0ac93c7 100644 --- a/packages/@pufflig/ps-sdk/src/completionAPI/streamCompletion.ts +++ b/packages/@pufflig/ps-sdk/src/completionAPI/streamCompletion.ts @@ -4,6 +4,7 @@ import { getApiServiceWebSocketUrl } from "../constants"; import { CreateCompletionInput, CreateCompletionPayload } from "../types"; const logger = pino(); +const WAIT_FOR_CHUNKS_MS = 50; interface MessageData { datapoint: { @@ -32,6 +33,7 @@ class StreamedCompletionResponseMessage { this.data = data; } } + enum StreamedCompletionRequestType { CompletionRequestStart = "CompletionRequestStart", } @@ -89,12 +91,13 @@ export async function* streamCompletion(input: CreateCompletionInput) { websocket.onerror = (err: Error) => { logger.error("SDK:CompletionStreamError", err.message); - completionChunks.push({ + const errorMessage = { datapoint: { is_error: true, error: err.message, }, - }); + }; + completionChunks.push(errorMessage); streamEnded = true; websocket.close(); }; @@ -109,11 +112,15 @@ export async function* streamCompletion(input: CreateCompletionInput) { } }; + websocket.onclose = () => { + streamEnded = true; + }; + while (completionChunks.length > 0 || !streamEnded) { if (completionChunks.length > 0) { yield completionChunks.shift(); } else { - await new Promise((resolve) => setTimeout(resolve, 50)); // Wait for chunks to arrive + await new Promise((resolve) => setTimeout(resolve, WAIT_FOR_CHUNKS_MS)); // Wait for chunks to arrive } } }