diff --git a/packages/@pufflig/ps-sdk/package.json b/packages/@pufflig/ps-sdk/package.json index a1bae02..6c58ee0 100644 --- a/packages/@pufflig/ps-sdk/package.json +++ b/packages/@pufflig/ps-sdk/package.json @@ -13,7 +13,7 @@ "prepublishOnly": "yarn build", "storybook": "storybook dev -p 6006", "build-storybook": "storybook build", - "test": "echo \"no test specified\" " + "test": "jest" }, "devDependencies": { "typescript": "^5.1.6", diff --git a/packages/@pufflig/ps-sdk/src/constants.ts b/packages/@pufflig/ps-sdk/src/constants.ts index fb4a43a..7c7a8c5 100644 --- a/packages/@pufflig/ps-sdk/src/constants.ts +++ b/packages/@pufflig/ps-sdk/src/constants.ts @@ -1,5 +1,5 @@ -export const SERVICE_URL = "https://api.prompt.studio/api/v1/completion"; +export const SERVICE_URL = "https://api.prompt.studio"; export const getServiceUrl = () => { - return process.env.PROMPT_STUDIO_SERVICE_URL || SERVICE_URL; + return `${process.env.PROMPT_STUDIO_SERVICE_BASE_URL || SERVICE_URL}`; }; diff --git a/packages/@pufflig/ps-sdk/src/createCompletion.test.ts b/packages/@pufflig/ps-sdk/src/createCompletion.test.ts new file mode 100644 index 0000000..2c8d46e --- /dev/null +++ b/packages/@pufflig/ps-sdk/src/createCompletion.test.ts @@ -0,0 +1,61 @@ +import axios from "axios"; +import { createCompletion } from "./createCompletion"; + +jest.mock("axios"); + +describe("createCompletion", () => { + const mockAxios = axios as jest.Mocked; + + const mockResponse = { + datapoint: { + model_output: "Hello, world!", + model_input: "Hello", + model_id: "model_123", + }, + }; + + const mockInput = { + apiKey: "my-api-key", + modelId: "model_123", + prompt: "Hello", + parameters: { name: "John" }, + config: { temperature: 0.5 }, + options: { track: true }, + }; + + beforeEach(() => { + jest.resetAllMocks(); + }); + + it("should make a POST request to the correct URL with the correct data", async () => { + mockAxios.post.mockResolvedValueOnce({ data: mockResponse }); + + const result = await createCompletion(mockInput); + + expect(mockAxios.post).toHaveBeenCalledTimes(1); + expect(mockAxios.post).toHaveBeenCalledWith( + "https://api.prompt.studio/api/v1/completion/buffered", + { + modelId: "model_123", + prompt: "Hello", + parameters: { name: "John" }, + config: { temperature: 0.5 }, + options: { track: true }, + }, + { + headers: { + Authorization: "Bearer my-api-key", + "Content-Type": "application/json", + }, + } + ); + expect(result).toEqual(mockResponse); + }); + + it("should throw an error if the request fails", async () => { + const mockError = new Error("Request failed"); + mockAxios.post.mockRejectedValueOnce(mockError); + + await expect(createCompletion(mockInput)).rejects.toThrow(mockError); + }); +}); diff --git a/packages/@pufflig/ps-sdk/src/createCompletion.ts b/packages/@pufflig/ps-sdk/src/createCompletion.ts index 8d5a165..605a9e5 100644 --- a/packages/@pufflig/ps-sdk/src/createCompletion.ts +++ b/packages/@pufflig/ps-sdk/src/createCompletion.ts @@ -1,6 +1,5 @@ -import axios, { AxiosRequestConfig } from "axios"; +import axios from "axios"; import { getServiceUrl } from "./constants"; -import { Datapoint } from "./types"; interface CreateCompletionInput { apiKey: string; @@ -14,66 +13,50 @@ interface CreateCompletionInput { }; } -interface Callbacks { - onNewToken?: (token: string) => void; +interface Completion { + datapoint?: { + model_output: string; + model_input: string; + model_id: string; + }; } -interface Completion { - datapoint?: Datapoint; +interface CreateCompletionPayload { + modelId: string; + prompt: string; + parameters?: Record; + config?: Record; + options?: { + track?: boolean; + cache?: boolean; + }; } -export async function createCompletion(input: CreateCompletionInput, callbacks?: Callbacks): Promise { +export async function createCompletion(input: CreateCompletionInput): Promise { const { modelId, prompt, apiKey, config, options, parameters = {} } = input; - const payload: AxiosRequestConfig = { - method: "post", - url: getServiceUrl(), - responseType: "stream", + const payload: CreateCompletionPayload = { + modelId, + prompt, + parameters, + }; + + const requestConfig = { headers: { "Content-Type": "application/json", Authorization: `Bearer ${apiKey}`, }, - data: { - modelId, - prompt: prompt, - parameters, - }, }; if (config) { - payload.data.config = config; + payload.config = config; } if (options) { - payload.data.options = options; + payload.options = options; } - const response = await axios(payload); - - const stream = response.data; - - let result = {}; - - stream.on("data", (buffer: Buffer) => { - const chunk = buffer.toString("utf-8"); - const rows = chunk.split("\n\n"); - rows.forEach((row) => { - const match = row.match(/^data: (.+)/); - const data = JSON.parse(match?.[1] || "{}"); - if (data.datapoint?.model_output) { - callbacks?.onNewToken?.(data.datapoint.model_output); - result = data; - } - }); - }); - - return new Promise((resolve, reject) => { - stream.on("end", () => { - resolve(result); - }); + const response = await axios.post(`${getServiceUrl()}/api/v1/completion/buffered`, payload, requestConfig); - stream.on("error", (error: Error) => { - reject(error); - }); - }); + return response.data; } diff --git a/packages/@pufflig/ps-sdk/src/index.ts b/packages/@pufflig/ps-sdk/src/index.ts index 92fce7d..e8ab367 100644 --- a/packages/@pufflig/ps-sdk/src/index.ts +++ b/packages/@pufflig/ps-sdk/src/index.ts @@ -1 +1,2 @@ export { createCompletion } from "./createCompletion"; +export { refineCompletion } from "./refineCompletion"; diff --git a/packages/@pufflig/ps-sdk/src/mapCompletion.test.ts b/packages/@pufflig/ps-sdk/src/mapCompletion.test.ts new file mode 100644 index 0000000..66c79bd --- /dev/null +++ b/packages/@pufflig/ps-sdk/src/mapCompletion.test.ts @@ -0,0 +1,87 @@ +import { mapCompletion } from "./mapCompletion"; +import axios from "axios"; + +jest.mock("axios"); + +describe("mapCompletion", () => { + const mockAxios = axios as jest.Mocked; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it("should return completions for each chunk of the document", async () => { + const input = { + apiKey: "myApiKey", + modelId: "myModelId", + prompt: "myPrompt", + document: "myDocument", + parameters: { myParam: "myValue" }, + config: { myConfig: "myValue" }, + options: { track: true, cache: false }, + }; + + const mockResponse = { + data: { + completions: [ + { + datapoints: { + model_output: "output1", + model_input: "input1", + model_id: "id1", + }, + }, + { + datapoints: { + model_output: "output2", + model_input: "input2", + model_id: "id2", + }, + }, + ], + }, + }; + + mockAxios.post.mockResolvedValueOnce(mockResponse); + + const result = await mapCompletion(input); + + expect(mockAxios.post).toHaveBeenCalledTimes(1); + expect(mockAxios.post).toHaveBeenCalledWith( + "https://api.prompt.studio/api/v1/completion/mapped", + { + prompt: "myPrompt", + document: "myDocument", + modelId: "myModelId", + parameters: { myParam: "myValue" }, + config: { myConfig: "myValue" }, + options: { track: true, cache: false }, + }, + { + headers: { + Authorization: "Bearer myApiKey", + "Content-Type": "application/json", + }, + } + ); + + expect(result).toEqual({ + completions: [ + { + datapoints: { + model_output: "output1", + model_input: "input1", + model_id: "id1", + }, + }, + { + datapoints: { + model_output: "output2", + model_input: "input2", + model_id: "id2", + }, + }, + ], + }); + }); +}); diff --git a/packages/@pufflig/ps-sdk/src/mapCompletion.ts b/packages/@pufflig/ps-sdk/src/mapCompletion.ts new file mode 100644 index 0000000..40e95d5 --- /dev/null +++ b/packages/@pufflig/ps-sdk/src/mapCompletion.ts @@ -0,0 +1,74 @@ +import axios from "axios"; +import { getServiceUrl } from "./constants"; + +interface MapCompletionInput { + apiKey: string; + modelId: string; + prompt: string; + document: string; + parameters?: Record; + config?: Record; + options?: { + track?: boolean; + cache?: boolean; + }; +} + +interface Completion { + datapoints?: { + model_output: string; + model_input: string; + model_id: string; + }; +} + +interface MapCompletionPayload { + modelId: string; + prompt: string; + document: string; + parameters?: Record; + config?: Record; + options?: { + track?: boolean; + cache?: boolean; + }; +} + +/** + * Map a prompt over a document of variable length. Return a completion for each chunk. + * + * @param input.document - The document to be processed + * @param input.parameters - Parameters to be passed to the model + * @param input.modelId - Name of the LLM to be used + * + * @returns The completion + */ +export async function mapCompletion(input: MapCompletionInput): Promise<{ completions: Completion[] }> { + const { modelId, prompt, document, apiKey, config, options, parameters = {} } = input; + + const payload: MapCompletionPayload = { + modelId, + prompt, + document, + parameters, + }; + + const requestConfig = { + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + }, + }; + + if (config) { + payload.config = config; + } + + if (options) { + payload.options = options; + } + + const response = await axios.post(`${getServiceUrl()}/api/v1/completion/mapped`, payload, requestConfig); + + return response.data; +} diff --git a/packages/@pufflig/ps-sdk/src/refineCompletion.test.ts b/packages/@pufflig/ps-sdk/src/refineCompletion.test.ts new file mode 100644 index 0000000..4c84f33 --- /dev/null +++ b/packages/@pufflig/ps-sdk/src/refineCompletion.test.ts @@ -0,0 +1,72 @@ +import { refineCompletion } from "./refineCompletion"; +import axios from "axios"; + +jest.mock("axios"); + +describe("refineCompletion", () => { + const mockAxios = axios as jest.Mocked; + + afterEach(() => { + jest.resetAllMocks(); + }); + + it("should return the completion object", async () => { + const mockCompletion = { + datapoint: { + model_output: "output", + model_input: "input", + model_id: "id", + }, + }; + mockAxios.post.mockResolvedValue({ data: mockCompletion }); + + const input = { + apiKey: "api-key", + modelId: "model-id", + prompt: "prompt", + format: "format", + document: "document", + parameters: { param1: "value1", param2: "value2" }, + config: { config1: "value1", config2: "value2" }, + options: { track: true, cache: false }, + }; + + const result = await refineCompletion(input); + + expect(mockAxios.post).toHaveBeenCalledTimes(1); + expect(mockAxios.post).toHaveBeenCalledWith( + "https://api.prompt.studio/api/v1/completion/looped", + { + prompt: "prompt", + format: "format", + document: "document", + modelId: "model-id", + parameters: { param1: "value1", param2: "value2" }, + config: { config1: "value1", config2: "value2" }, + options: { track: true, cache: false }, + }, + { + headers: { + Authorization: "Bearer api-key", + "Content-Type": "application/json", + }, + } + ); + expect(result).toEqual(mockCompletion); + }); + + it("should throw an error if the API call fails", async () => { + const mockError = new Error("API call failed"); + mockAxios.post.mockRejectedValueOnce(mockError); + + const input = { + apiKey: "api-key", + modelId: "model-id", + prompt: "prompt", + format: "format", + document: "document", + }; + + await expect(refineCompletion(input)).rejects.toThrow(mockError); + }); +}); diff --git a/packages/@pufflig/ps-sdk/src/refineCompletion.ts b/packages/@pufflig/ps-sdk/src/refineCompletion.ts new file mode 100644 index 0000000..5e1bf22 --- /dev/null +++ b/packages/@pufflig/ps-sdk/src/refineCompletion.ts @@ -0,0 +1,78 @@ +import axios from "axios"; +import { getServiceUrl } from "./constants"; + +interface RefineCompletionInput { + apiKey: string; + modelId: string; + prompt: string; + format: string; + document: string; + parameters?: Record; + config?: Record; + options?: { + track?: boolean; + cache?: boolean; + }; +} + +interface Completion { + datapoint?: { + model_output: string; + model_input: string; + model_id: string; + }; +} + +interface RefineCompletionPayload { + modelId: string; + prompt: string; + format: string; + document: string; + parameters?: Record; + config?: Record; + options?: { + track?: boolean; + cache?: boolean; + }; +} + +/** + * Loop a prompt over documents of variable size, refining the outcome for each chunk. + * + * @param input.format - The value to be refined, e.g. a json or csv table + * @param input.document - The document to be processed + * @param input.parameters - Parameters to be passed to the model + * @param input.modelId - Name of the LLM to be used + * + * @returns The completion + */ +export async function refineCompletion(input: RefineCompletionInput): Promise { + const { modelId, prompt, format, document, apiKey, config, options, parameters = {} } = input; + + const payload: RefineCompletionPayload = { + modelId, + prompt, + format, + document, + parameters, + }; + + const requestConfig = { + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + }, + }; + + if (config) { + payload.config = config; + } + + if (options) { + payload.options = options; + } + + const response = await axios.post(`${getServiceUrl()}/api/v1/completion/looped`, payload, requestConfig); + + return response.data; +}