Skip to content

Commit

Permalink
add document check node, update variable parsing in llm completion node
Browse files Browse the repository at this point in the history
  • Loading branch information
au-re committed Oct 31, 2023
1 parent ea8cd88 commit e7b1d03
Show file tree
Hide file tree
Showing 11 changed files with 297 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export const llmCompletionNodeType = "adapter/llm_completion" as const;

export const llmCompletionConfig: NodeConfig = {
name: "Instruction",
description: "Generate a completion using an LLM.",
description: "Generate a completion using an LLM",
status: "experimental",
tags: ["adapter", "text"],
globals: [],
Expand All @@ -26,19 +26,12 @@ export const llmCompletionConfig: NodeConfig = {
{
id: "completion",
name: "Completion",
description: "Text generated by the LLM.",
description: "Text generated by the LLM",
type: "text",
defaultValue: "",
},
],
inputs: [
{
id: "prompt",
name: "Prompt",
description: "The prompt to send to the LLM.",
type: "text",
defaultValue: "",
},
{
id: "model",
name: "Model",
Expand All @@ -50,5 +43,12 @@ export const llmCompletionConfig: NodeConfig = {
parameters: {},
},
},
{
id: "prompt",
name: "Prompt",
description: "The prompt to send to the LLM",
type: "text",
defaultValue: "",
},
],
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { chat_models, completion_models, default_completion_model } from "@pufflig/ps-models";
import { NodeConfig } from "@pufflig/ps-types";

export const documentCheckNodeType = "modifier/document_check" as const;

export const documentCheck: NodeConfig = {
name: "Document Check",
description: "Run a checklist or extract information from a document.",
tags: ["modifier", "document", "text"],
status: "stable",
execution: {
inputs: [
{
id: "exec:input",
},
],
outputs: [
{
id: "exec:output",
name: "Completed",
},
],
},
outputs: [
{
id: "list",
name: "List",
description: "A list, checklist or other information about the document",
type: "text",
defaultValue: "",
},
],
inputs: [
{
id: "model",
name: "Model",
description: "The model to use",
type: "model",
definition: { ...completion_models, ...chat_models },
defaultValue: {
modelId: default_completion_model,
parameters: {},
},
},
{
id: "prompt",
name: "Prompt",
description: "Prompt to check the document with",
type: "text",
defaultValue: `Extract information in the document below and insert them in the csv table, don't overwrite existing values and keep things empty if you cannot find information in the document:\n\nTABLE EXAMPLE:\ncharacters, age\nmickey mouse, 10\ndonald duck, -\n\nTABLE:\n[[table]]\n\nDOCUMENT:\n[[document]]\n\nTABLE:\n`,
},
{
id: "table",
name: "Table",
description: "The list, table or checklist to parse the document with.",
type: "text",
defaultValue: "",
},
{
id: "document",
name: "Document",
description: "Document to be processed",
type: "text",
defaultValue: "",
},
],
};
3 changes: 3 additions & 0 deletions packages/@pufflig/ps-nodes-config/src/modifiers/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { addMessage, addMessageNodeType } from "./add_message/add_message";
import { addText, addTextNodeType } from "./add_text/add_text";
import { documentCheck, documentCheckNodeType } from "./document_check/document_check";
import { parseDocument, parseDocumentNodeType } from "./parse_document/parse_document";
import { splitText, splitTextNodeType } from "./split_text/split_text";
import { templateChat, templateChatNodeType } from "./template/template_chat";
Expand All @@ -12,6 +13,7 @@ export const modifierNodes = {
[splitTextNodeType]: splitText,
[templateChatNodeType]: templateChat,
[templateTextNodeType]: templateText,
[documentCheckNodeType]: documentCheck,
};

export const modifierNodeTypes = {
Expand All @@ -21,4 +23,5 @@ export const modifierNodeTypes = {
splitTextNodeType,
templateChatNodeType,
templateTextNodeType,
documentCheckNodeType,
};
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export const parseDocument: NodeConfig = {
name: "Parse Document",
description: "Run a prompt over a document",
tags: ["modifier", "document", "text"],
status: "experimental",
status: "deprecated",
execution: {
inputs: [
{
Expand Down
42 changes: 21 additions & 21 deletions packages/@pufflig/ps-nodes/src/adapters/llm/llm_completion.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,6 @@ test("getInputDefinition - no variables", () => {
});
expect(variables).toMatchInlineSnapshot(`
[
{
"defaultValue": "summarize {{longText}}",
"description": "The prompt to send to the LLM.",
"id": "prompt",
"name": "Prompt",
"type": "text",
},
{
"defaultValue": {
"modelId": "gpt-3.5-turbo-instruct",
Expand Down Expand Up @@ -403,6 +396,13 @@ test("getInputDefinition - no variables", () => {
"name": "Model",
"type": "model",
},
{
"defaultValue": "summarize {{longText}}",
"description": "The prompt to send to the LLM",
"id": "prompt",
"name": "Prompt",
"type": "text",
},
{
"defaultValue": "",
"description": "",
Expand All @@ -425,13 +425,6 @@ test("getInputDefinition - if you pass a template and a variable, take value of
});
expect(variables).toMatchInlineSnapshot(`
[
{
"defaultValue": "summarize {{longText}}",
"description": "The prompt to send to the LLM.",
"id": "prompt",
"name": "Prompt",
"type": "text",
},
{
"defaultValue": {
"modelId": "gpt-3.5-turbo-instruct",
Expand Down Expand Up @@ -766,6 +759,13 @@ test("getInputDefinition - if you pass a template and a variable, take value of
"name": "Model",
"type": "model",
},
{
"defaultValue": "summarize {{longText}}",
"description": "The prompt to send to the LLM",
"id": "prompt",
"name": "Prompt",
"type": "text",
},
{
"defaultValue": "some long text",
"description": "",
Expand All @@ -788,13 +788,6 @@ test("getInputDefinition - ignores non existing variables", () => {
});
expect(variables).toMatchInlineSnapshot(`
[
{
"defaultValue": "summarize {{longText}}",
"description": "The prompt to send to the LLM.",
"id": "prompt",
"name": "Prompt",
"type": "text",
},
{
"defaultValue": {
"modelId": "gpt-3.5-turbo-instruct",
Expand Down Expand Up @@ -1129,6 +1122,13 @@ test("getInputDefinition - ignores non existing variables", () => {
"name": "Model",
"type": "model",
},
{
"defaultValue": "summarize {{longText}}",
"description": "The prompt to send to the LLM",
"id": "prompt",
"name": "Prompt",
"type": "text",
},
{
"defaultValue": "",
"description": "",
Expand Down
7 changes: 5 additions & 2 deletions packages/@pufflig/ps-nodes/src/adapters/llm/llm_completion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { createCompletion } from "@pufflig/ps-sdk";
import { Execute, GetInputDefinition, ModelValue, Node, Param } from "@pufflig/ps-types";
import { getPromptStudioKey } from "../../utils/getPromptStudioKey";
import { extractVariables } from "../../utils/extractVariables";
import Mustache from "mustache";

export interface LLMCompletionInput {
prompt: string;
Expand All @@ -15,14 +16,16 @@ export interface LLMCompletionOutput {
}

export const execute: Execute<LLMCompletionInput, LLMCompletionOutput> = async (input, options = {}) => {
const { prompt, model } = input;
const { prompt, model, ...variables } = input;
const { modelId, parameters } = model;
const { globals } = options;

const renderedTemplate = Mustache.render(prompt, variables);

const result = await createCompletion({
apiKey: getPromptStudioKey(globals || {}),
modelId,
prompt,
prompt: renderedTemplate,
parameters,
config: globals,
options: {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import { execute, LLMCompletionInput } from "./document_check";
import axios from "axios";

jest.mock("axios");

describe("documentCheck", () => {
beforeEach(() => {
jest.resetAllMocks();
});

it("should return the completion string", async () => {
const input: LLMCompletionInput = {
prompt: "Hello, world!",
model: {
modelId: "test_model",
parameters: {},
},
document: "This is a test document.",
table: "test_table",
};

const expectedOutput = { result: "This is a test completion." };
const mockedAxiosResponse = { data: expectedOutput };
(axios.post as jest.MockedFunction<typeof axios.post>).mockResolvedValueOnce(mockedAxiosResponse);

const output = await execute(input);

expect(output).toEqual({ completion: "This is a test completion." });
expect(axios.post).toHaveBeenCalledTimes(1);
});

it("should parse input variables", async () => {
const input: LLMCompletionInput = {
prompt: "Hello, {{myVariable}}!",
model: {
modelId: "test_model",
parameters: {},
},
document: "This is a test document.",
table: "test_table",
myVariable: "myValue",
};

const expectedOutput = { result: "This is a test completion." };
const mockedAxiosResponse = { data: expectedOutput };
(axios.post as jest.MockedFunction<typeof axios.post>).mockResolvedValueOnce(mockedAxiosResponse);

const output = await execute(input);

expect(output).toEqual({ completion: "This is a test completion." });
expect(axios.post).toHaveBeenCalledTimes(1);
expect(axios.post).toHaveBeenCalledWith(
expect.any(String),
{
document: "This is a test document.",
format: "test_table",
modelId: "test_model",
options: {
cache: true,
track: true,
},
parameters: {},
prompt: "Hello, myValue!",
},
{
headers: {
Authorization: "Bearer undefined",
"Content-Type": "application/json",
},
}
);
});

it("should throw an error if the API call fails", async () => {
const input: LLMCompletionInput = {
prompt: "Hello, world!",
model: {
modelId: "test_model",
parameters: {},
},
document: "This is a test document.",
table: "test_table",
};

const expectedError = new Error("API call failed.");
(axios.post as jest.MockedFunction<typeof axios.post>).mockRejectedValueOnce(expectedError);

await expect(execute(input)).rejects.toThrow(expectedError);
expect(axios.post).toHaveBeenCalledTimes(1);
});
});
Loading

0 comments on commit e7b1d03

Please sign in to comment.