From e7b1d0385a9e2a2c211e73a78d7130aba12824e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien?= Date: Tue, 31 Oct 2023 12:05:19 +0100 Subject: [PATCH] add document check node, update variable parsing in llm completion node --- .../src/adapters/llm/llm_completion.ts | 18 ++-- .../document_check/document_check.ts | 67 +++++++++++++ .../ps-nodes-config/src/modifiers/index.ts | 3 + .../parse_document/parse_document.ts | 2 +- .../src/adapters/llm/llm_completion.test.ts | 42 ++++----- .../src/adapters/llm/llm_completion.ts | 7 +- .../document_check/document_check.test.ts | 91 ++++++++++++++++++ .../document_check/document_check.ts | 94 +++++++++++++++++++ .../@pufflig/ps-nodes/src/modifiers/index.ts | 2 + packages/@pufflig/ps-sdk/src/index.ts | 1 + .../@pufflig/ps-sdk/src/refineCompletion.ts | 5 +- 11 files changed, 297 insertions(+), 35 deletions(-) create mode 100644 packages/@pufflig/ps-nodes-config/src/modifiers/document_check/document_check.ts create mode 100644 packages/@pufflig/ps-nodes/src/modifiers/document_check/document_check.test.ts create mode 100644 packages/@pufflig/ps-nodes/src/modifiers/document_check/document_check.ts diff --git a/packages/@pufflig/ps-nodes-config/src/adapters/llm/llm_completion.ts b/packages/@pufflig/ps-nodes-config/src/adapters/llm/llm_completion.ts index 99b073b..30ea12b 100644 --- a/packages/@pufflig/ps-nodes-config/src/adapters/llm/llm_completion.ts +++ b/packages/@pufflig/ps-nodes-config/src/adapters/llm/llm_completion.ts @@ -5,7 +5,7 @@ export const llmCompletionNodeType = "adapter/llm_completion" as const; export const llmCompletionConfig: NodeConfig = { name: "Instruction", - description: "Generate a completion using an LLM.", + description: "Generate a completion using an LLM", status: "experimental", tags: ["adapter", "text"], globals: [], @@ -26,19 +26,12 @@ export const llmCompletionConfig: NodeConfig = { { id: "completion", name: "Completion", - description: "Text generated by the LLM.", + description: "Text generated by the LLM", type: "text", defaultValue: "", }, ], inputs: [ - { - id: "prompt", - name: "Prompt", - description: "The prompt to send to the LLM.", - type: "text", - defaultValue: "", - }, { id: "model", name: "Model", @@ -50,5 +43,12 @@ export const llmCompletionConfig: NodeConfig = { parameters: {}, }, }, + { + id: "prompt", + name: "Prompt", + description: "The prompt to send to the LLM", + type: "text", + defaultValue: "", + }, ], }; diff --git a/packages/@pufflig/ps-nodes-config/src/modifiers/document_check/document_check.ts b/packages/@pufflig/ps-nodes-config/src/modifiers/document_check/document_check.ts new file mode 100644 index 0000000..a19a5bb --- /dev/null +++ b/packages/@pufflig/ps-nodes-config/src/modifiers/document_check/document_check.ts @@ -0,0 +1,67 @@ +import { chat_models, completion_models, default_completion_model } from "@pufflig/ps-models"; +import { NodeConfig } from "@pufflig/ps-types"; + +export const documentCheckNodeType = "modifier/document_check" as const; + +export const documentCheck: NodeConfig = { + name: "Document Check", + description: "Run a checklist or extract information from a document.", + tags: ["modifier", "document", "text"], + status: "stable", + execution: { + inputs: [ + { + id: "exec:input", + }, + ], + outputs: [ + { + id: "exec:output", + name: "Completed", + }, + ], + }, + outputs: [ + { + id: "list", + name: "List", + description: "A list, checklist or other information about the document", + type: "text", + defaultValue: "", + }, + ], + inputs: [ + { + id: "model", + name: "Model", + description: "The model to use", + type: "model", + definition: { ...completion_models, ...chat_models }, + defaultValue: { + modelId: default_completion_model, + parameters: {}, + }, + }, + { + id: "prompt", + name: "Prompt", + description: "Prompt to check the document with", + type: "text", + defaultValue: `Extract information in the document below and insert them in the csv table, don't overwrite existing values and keep things empty if you cannot find information in the document:\n\nTABLE EXAMPLE:\ncharacters, age\nmickey mouse, 10\ndonald duck, -\n\nTABLE:\n[[table]]\n\nDOCUMENT:\n[[document]]\n\nTABLE:\n`, + }, + { + id: "table", + name: "Table", + description: "The list, table or checklist to parse the document with.", + type: "text", + defaultValue: "", + }, + { + id: "document", + name: "Document", + description: "Document to be processed", + type: "text", + defaultValue: "", + }, + ], +}; diff --git a/packages/@pufflig/ps-nodes-config/src/modifiers/index.ts b/packages/@pufflig/ps-nodes-config/src/modifiers/index.ts index dffe66b..287e9dc 100644 --- a/packages/@pufflig/ps-nodes-config/src/modifiers/index.ts +++ b/packages/@pufflig/ps-nodes-config/src/modifiers/index.ts @@ -1,5 +1,6 @@ import { addMessage, addMessageNodeType } from "./add_message/add_message"; import { addText, addTextNodeType } from "./add_text/add_text"; +import { documentCheck, documentCheckNodeType } from "./document_check/document_check"; import { parseDocument, parseDocumentNodeType } from "./parse_document/parse_document"; import { splitText, splitTextNodeType } from "./split_text/split_text"; import { templateChat, templateChatNodeType } from "./template/template_chat"; @@ -12,6 +13,7 @@ export const modifierNodes = { [splitTextNodeType]: splitText, [templateChatNodeType]: templateChat, [templateTextNodeType]: templateText, + [documentCheckNodeType]: documentCheck, }; export const modifierNodeTypes = { @@ -21,4 +23,5 @@ export const modifierNodeTypes = { splitTextNodeType, templateChatNodeType, templateTextNodeType, + documentCheckNodeType, }; diff --git a/packages/@pufflig/ps-nodes-config/src/modifiers/parse_document/parse_document.ts b/packages/@pufflig/ps-nodes-config/src/modifiers/parse_document/parse_document.ts index 1257e65..b1e800a 100644 --- a/packages/@pufflig/ps-nodes-config/src/modifiers/parse_document/parse_document.ts +++ b/packages/@pufflig/ps-nodes-config/src/modifiers/parse_document/parse_document.ts @@ -6,7 +6,7 @@ export const parseDocument: NodeConfig = { name: "Parse Document", description: "Run a prompt over a document", tags: ["modifier", "document", "text"], - status: "experimental", + status: "deprecated", execution: { inputs: [ { diff --git a/packages/@pufflig/ps-nodes/src/adapters/llm/llm_completion.test.ts b/packages/@pufflig/ps-nodes/src/adapters/llm/llm_completion.test.ts index 9d472c5..3d49e93 100644 --- a/packages/@pufflig/ps-nodes/src/adapters/llm/llm_completion.test.ts +++ b/packages/@pufflig/ps-nodes/src/adapters/llm/llm_completion.test.ts @@ -62,13 +62,6 @@ test("getInputDefinition - no variables", () => { }); expect(variables).toMatchInlineSnapshot(` [ - { - "defaultValue": "summarize {{longText}}", - "description": "The prompt to send to the LLM.", - "id": "prompt", - "name": "Prompt", - "type": "text", - }, { "defaultValue": { "modelId": "gpt-3.5-turbo-instruct", @@ -403,6 +396,13 @@ test("getInputDefinition - no variables", () => { "name": "Model", "type": "model", }, + { + "defaultValue": "summarize {{longText}}", + "description": "The prompt to send to the LLM", + "id": "prompt", + "name": "Prompt", + "type": "text", + }, { "defaultValue": "", "description": "", @@ -425,13 +425,6 @@ test("getInputDefinition - if you pass a template and a variable, take value of }); expect(variables).toMatchInlineSnapshot(` [ - { - "defaultValue": "summarize {{longText}}", - "description": "The prompt to send to the LLM.", - "id": "prompt", - "name": "Prompt", - "type": "text", - }, { "defaultValue": { "modelId": "gpt-3.5-turbo-instruct", @@ -766,6 +759,13 @@ test("getInputDefinition - if you pass a template and a variable, take value of "name": "Model", "type": "model", }, + { + "defaultValue": "summarize {{longText}}", + "description": "The prompt to send to the LLM", + "id": "prompt", + "name": "Prompt", + "type": "text", + }, { "defaultValue": "some long text", "description": "", @@ -788,13 +788,6 @@ test("getInputDefinition - ignores non existing variables", () => { }); expect(variables).toMatchInlineSnapshot(` [ - { - "defaultValue": "summarize {{longText}}", - "description": "The prompt to send to the LLM.", - "id": "prompt", - "name": "Prompt", - "type": "text", - }, { "defaultValue": { "modelId": "gpt-3.5-turbo-instruct", @@ -1129,6 +1122,13 @@ test("getInputDefinition - ignores non existing variables", () => { "name": "Model", "type": "model", }, + { + "defaultValue": "summarize {{longText}}", + "description": "The prompt to send to the LLM", + "id": "prompt", + "name": "Prompt", + "type": "text", + }, { "defaultValue": "", "description": "", diff --git a/packages/@pufflig/ps-nodes/src/adapters/llm/llm_completion.ts b/packages/@pufflig/ps-nodes/src/adapters/llm/llm_completion.ts index f30d694..9652ff4 100644 --- a/packages/@pufflig/ps-nodes/src/adapters/llm/llm_completion.ts +++ b/packages/@pufflig/ps-nodes/src/adapters/llm/llm_completion.ts @@ -3,6 +3,7 @@ import { createCompletion } from "@pufflig/ps-sdk"; import { Execute, GetInputDefinition, ModelValue, Node, Param } from "@pufflig/ps-types"; import { getPromptStudioKey } from "../../utils/getPromptStudioKey"; import { extractVariables } from "../../utils/extractVariables"; +import Mustache from "mustache"; export interface LLMCompletionInput { prompt: string; @@ -15,14 +16,16 @@ export interface LLMCompletionOutput { } export const execute: Execute = async (input, options = {}) => { - const { prompt, model } = input; + const { prompt, model, ...variables } = input; const { modelId, parameters } = model; const { globals } = options; + const renderedTemplate = Mustache.render(prompt, variables); + const result = await createCompletion({ apiKey: getPromptStudioKey(globals || {}), modelId, - prompt, + prompt: renderedTemplate, parameters, config: globals, options: { diff --git a/packages/@pufflig/ps-nodes/src/modifiers/document_check/document_check.test.ts b/packages/@pufflig/ps-nodes/src/modifiers/document_check/document_check.test.ts new file mode 100644 index 0000000..add9310 --- /dev/null +++ b/packages/@pufflig/ps-nodes/src/modifiers/document_check/document_check.test.ts @@ -0,0 +1,91 @@ +import { execute, LLMCompletionInput } from "./document_check"; +import axios from "axios"; + +jest.mock("axios"); + +describe("documentCheck", () => { + beforeEach(() => { + jest.resetAllMocks(); + }); + + it("should return the completion string", async () => { + const input: LLMCompletionInput = { + prompt: "Hello, world!", + model: { + modelId: "test_model", + parameters: {}, + }, + document: "This is a test document.", + table: "test_table", + }; + + const expectedOutput = { result: "This is a test completion." }; + const mockedAxiosResponse = { data: expectedOutput }; + (axios.post as jest.MockedFunction).mockResolvedValueOnce(mockedAxiosResponse); + + const output = await execute(input); + + expect(output).toEqual({ completion: "This is a test completion." }); + expect(axios.post).toHaveBeenCalledTimes(1); + }); + + it("should parse input variables", async () => { + const input: LLMCompletionInput = { + prompt: "Hello, {{myVariable}}!", + model: { + modelId: "test_model", + parameters: {}, + }, + document: "This is a test document.", + table: "test_table", + myVariable: "myValue", + }; + + const expectedOutput = { result: "This is a test completion." }; + const mockedAxiosResponse = { data: expectedOutput }; + (axios.post as jest.MockedFunction).mockResolvedValueOnce(mockedAxiosResponse); + + const output = await execute(input); + + expect(output).toEqual({ completion: "This is a test completion." }); + expect(axios.post).toHaveBeenCalledTimes(1); + expect(axios.post).toHaveBeenCalledWith( + expect.any(String), + { + document: "This is a test document.", + format: "test_table", + modelId: "test_model", + options: { + cache: true, + track: true, + }, + parameters: {}, + prompt: "Hello, myValue!", + }, + { + headers: { + Authorization: "Bearer undefined", + "Content-Type": "application/json", + }, + } + ); + }); + + it("should throw an error if the API call fails", async () => { + const input: LLMCompletionInput = { + prompt: "Hello, world!", + model: { + modelId: "test_model", + parameters: {}, + }, + document: "This is a test document.", + table: "test_table", + }; + + const expectedError = new Error("API call failed."); + (axios.post as jest.MockedFunction).mockRejectedValueOnce(expectedError); + + await expect(execute(input)).rejects.toThrow(expectedError); + expect(axios.post).toHaveBeenCalledTimes(1); + }); +}); diff --git a/packages/@pufflig/ps-nodes/src/modifiers/document_check/document_check.ts b/packages/@pufflig/ps-nodes/src/modifiers/document_check/document_check.ts new file mode 100644 index 0000000..d303742 --- /dev/null +++ b/packages/@pufflig/ps-nodes/src/modifiers/document_check/document_check.ts @@ -0,0 +1,94 @@ +import { nodeTypes, nodes } from "@pufflig/ps-nodes-config"; +import { refineCompletion } from "@pufflig/ps-sdk"; +import { Execute, GetInputDefinition, ModelValue, Node, Param } from "@pufflig/ps-types"; +import { getPromptStudioKey } from "../../utils/getPromptStudioKey"; +import { extractVariables } from "../../utils/extractVariables"; +import Mustache from "mustache"; + +export interface LLMCompletionInput { + prompt: string; + model: ModelValue; + document: string; + table: string; + [key: string]: any; +} + +export interface LLMCompletionOutput { + completion: string; +} + +export const execute: Execute = async (input, options = {}) => { + const { prompt, model, document, table, ...variables } = input; + const { modelId, parameters } = model; + const { globals } = options; + + const renderedPrompt = Mustache.render(prompt, variables); + + const { result } = await refineCompletion({ + apiKey: getPromptStudioKey(globals || {}), + modelId, + prompt: renderedPrompt, + document: document, + format: table, + parameters, + config: globals, + options: { + cache: true, + track: true, + }, + }); + + return { + completion: result || "", + }; +}; + +/** + * Returns a new input definition given variables extracted from the template. + * + * @param input + * @param prev + * @returns + */ +export const getInputDefinition: GetInputDefinition = (input) => { + const { prompt, ...rest } = input; + + if (prompt === undefined) { + return nodes[nodeTypes.documentCheckNodeType].inputs; + } + + const definitionsWithDefaults = nodes[nodeTypes.documentCheckNodeType].inputs.map((input) => { + if (input.id === "prompt") { + return { + ...input, + defaultValue: prompt, + } as Param; + } + return input; + }); + + const extractedVariables = extractVariables(prompt); + + if (extractedVariables) { + const extractedVariablesWithDefaults = extractedVariables + .filter((param) => { + return ["document", "table"].includes(param.id); + }) + .map((variable) => { + return { + ...variable, + defaultValue: rest[variable.id] || "", + } as Param; + }); + + return [...definitionsWithDefaults, ...extractedVariablesWithDefaults]; + } + + return definitionsWithDefaults; +}; + +export const documentCheck: Node = { + ...nodes[nodeTypes.documentCheckNodeType], + execute, + getInputDefinition, +}; diff --git a/packages/@pufflig/ps-nodes/src/modifiers/index.ts b/packages/@pufflig/ps-nodes/src/modifiers/index.ts index f298cee..84866ee 100644 --- a/packages/@pufflig/ps-nodes/src/modifiers/index.ts +++ b/packages/@pufflig/ps-nodes/src/modifiers/index.ts @@ -1,6 +1,7 @@ import { nodeTypes } from "@pufflig/ps-nodes-config"; import { addMessage } from "./add_message/add_message"; import { addText } from "./add_text/add_text"; +import { documentCheck } from "./document_check/document_check"; import { parseDocument } from "./parse_document/parse_document"; import { splitText } from "./split_text/split_text"; import { templateChat } from "./template/template_chat"; @@ -13,4 +14,5 @@ export const modifierNodes = { [nodeTypes.templateChatNodeType]: templateChat, [nodeTypes.templateTextNodeType]: templateText, [nodeTypes.parseDocumentNodeType]: parseDocument, + [nodeTypes.documentCheckNodeType]: documentCheck, }; diff --git a/packages/@pufflig/ps-sdk/src/index.ts b/packages/@pufflig/ps-sdk/src/index.ts index e8ab367..ca57f67 100644 --- a/packages/@pufflig/ps-sdk/src/index.ts +++ b/packages/@pufflig/ps-sdk/src/index.ts @@ -1,2 +1,3 @@ export { createCompletion } from "./createCompletion"; export { refineCompletion } from "./refineCompletion"; +export { mapCompletion } from "./mapCompletion"; diff --git a/packages/@pufflig/ps-sdk/src/refineCompletion.ts b/packages/@pufflig/ps-sdk/src/refineCompletion.ts index 5e1bf22..b90670d 100644 --- a/packages/@pufflig/ps-sdk/src/refineCompletion.ts +++ b/packages/@pufflig/ps-sdk/src/refineCompletion.ts @@ -16,11 +16,12 @@ interface RefineCompletionInput { } interface Completion { - datapoint?: { + result: string; + datapoints?: { model_output: string; model_input: string; model_id: string; - }; + }[]; } interface RefineCompletionPayload {