From f31e73e36bbb510422fd2928d38cf18df97995d9 Mon Sep 17 00:00:00 2001 From: Dimitri Kennedy Date: Tue, 2 Jan 2024 16:30:01 -0500 Subject: [PATCH] add back --- examples/extract_user/index.ts | 4 +- src/instructor.ts | 330 +++++++++++++++------------------ src/oai/params.ts | 54 ++++++ src/oai/parser.ts | 81 ++++++++ tests/functions.test.ts | 4 +- 5 files changed, 288 insertions(+), 185 deletions(-) create mode 100644 src/oai/params.ts create mode 100644 src/oai/parser.ts diff --git a/examples/extract_user/index.ts b/examples/extract_user/index.ts index 710df601..685ecebe 100644 --- a/examples/extract_user/index.ts +++ b/examples/extract_user/index.ts @@ -16,16 +16,14 @@ const oai = new OpenAI({ organization: process.env.OPENAI_ORG_ID ?? undefined }) -const client = Instructor({ +const client = new Instructor({ client: oai, mode: "FUNCTIONS" }) -//@ts-expect-error these types wont work since were using a proxy and just returning the OAI instance type const user: User = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], model: "gpt-3.5-turbo", - //@ts-expect-error same as above response_model: UserSchema, max_retries: 3 }) diff --git a/src/instructor.ts b/src/instructor.ts index 5fe3e214..418a721d 100644 --- a/src/instructor.ts +++ b/src/instructor.ts @@ -1,202 +1,174 @@ -import assert from "assert" -import OpenAI from "openai" import { - ChatCompletion, - ChatCompletionCreateParams, - ChatCompletionMessage -} from "openai/resources/index.mjs" -import { ZodSchema } from "zod" -import { JsonSchema7Type, zodToJsonSchema } from "zod-to-json-schema" + OAIBuildFunctionParams, + OAIBuildMessageBasedParams, + OAIBuildToolFunctionParams +} from "@/oai/params" +import { + OAIResponseFnArgsParser, + OAIResponseJSONStringParser, + OAIResponseToolArgsParser +} from "@/oai/parser" +import OpenAI from "openai" +import { ChatCompletion, ChatCompletionCreateParamsNonStreaming } from "openai/resources/index.mjs" +import { ZodObject } from "zod" +import zodToJsonSchema from "zod-to-json-schema" import { MODE } from "@/constants/modes" -export class OpenAISchema { - private response_model: ReturnType - constructor(public zod_schema: ZodSchema) { - this.response_model = zodToJsonSchema(zod_schema) - } - - get definitions() { - return this.response_model["definitions"] - } - - get properties() { - return this.response_model["properties"] - } +const MODE_TO_PARSER = { + [MODE.FUNCTIONS]: OAIResponseFnArgsParser, + [MODE.TOOLS]: OAIResponseToolArgsParser, + [MODE.JSON]: OAIResponseJSONStringParser, + [MODE.MD_JSON]: OAIResponseJSONStringParser, + [MODE.JSON_SCHEMA]: OAIResponseJSONStringParser +} - get openai_schema() { - return { - name: this.response_model["title"] || "schema", - description: - this.response_model["description"] || - `Correctly extracted \`${ - this.response_model["title"] || "schema" - }\` with all the required parameters with correct types`, - parameters: Object.keys(this.response_model).reduce( - (acc, curr) => { - if ( - curr.startsWith("$") || - ["title", "description", "additionalProperties"].includes(curr) - ) - return acc - acc[curr] = this.response_model[curr] - return acc - }, - {} as { - [key: string]: object | JsonSchema7Type - } - ) - } - } +const MODE_TO_PARAMS = { + [MODE.FUNCTIONS]: OAIBuildFunctionParams, + [MODE.TOOLS]: OAIBuildToolFunctionParams, + [MODE.JSON]: OAIBuildMessageBasedParams, + [MODE.MD_JSON]: OAIBuildMessageBasedParams, + [MODE.JSON_SCHEMA]: OAIBuildMessageBasedParams } -type PatchedChatCompletionCreateParams = ChatCompletionCreateParams & { - response_model?: ZodSchema | OpenAISchema +type PatchedChatCompletionCreateParams = ChatCompletionCreateParamsNonStreaming & { + //eslint-disable-next-line @typescript-eslint/no-explicit-any + response_model?: ZodObject max_retries?: number } -function handleResponseModel( - response_model: ZodSchema | OpenAISchema, - args: PatchedChatCompletionCreateParams[], - mode: MODE = "FUNCTIONS" -): [OpenAISchema, PatchedChatCompletionCreateParams[], MODE] { - if (!(response_model instanceof OpenAISchema)) { - response_model = new OpenAISchema(response_model) +export default class Instructor { + private client: OpenAI + private mode: MODE + + /** + * Creates an instance of the `Instructor` class. + * @param {OpenAI} client - The OpenAI client. + * @param {string} mode - The mode of operation. + */ + constructor({ client, mode }: { client: OpenAI; mode: MODE }) { + this.client = client + this.mode = mode } - if (mode === MODE.FUNCTIONS) { - args[0].functions = [response_model.openai_schema] - args[0].function_call = { name: response_model.openai_schema.name } - } else if (mode === MODE.TOOLS) { - args[0].tools = [{ type: "function", function: response_model.openai_schema }] - args[0].tool_choice = { - type: "function", - function: { name: response_model.openai_schema.name } + /** + * Handles chat completion with retries. + * @param {PatchedChatCompletionCreateParams} params - The parameters for chat completion. + * @returns {Promise} The response from the chat completion. + */ + private chatCompletion = async ({ + max_retries = 3, + ...params + }: PatchedChatCompletionCreateParams) => { + let attempts = 0 + let validationIssues = [] + let lastMessage = null + + const completionParams = this.buildChatCompletionParams(params) + + const makeCompletionCall = async () => { + let resolvedParams = completionParams + + try { + if (validationIssues.length > 0) { + resolvedParams = { + ...completionParams, + messages: [ + ...completionParams.messages, + ...(lastMessage ? [lastMessage] : []), + { + role: "system", + content: `Your last response had the following validation issues, please try again: ${validationIssues.join( + ", " + )}` + } + ] + } + } + + const completion = await this.client.chat.completions.create(resolvedParams) + const response = this.parseOAIResponse(completion) + + return response + } catch (error) { + throw error + } } - } else if ([MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA].includes(mode)) { - let message: string = `As a genius expert, your task is to understand the content and provide the parsed objects in json that match the following json_schema: \n${JSON.stringify( - response_model.properties - )}` - if (response_model["definitions"]) { - message += `Here are some more definitions to adhere to: \n${JSON.stringify( - response_model.definitions - )}` + + const makeCompletionCallWithRetries = async () => { + try { + const data = await makeCompletionCall() + const validation = params.response_model.safeParse(data) + + if (!validation.success) { + if ("error" in validation) { + lastMessage = { + role: "assistant", + content: JSON.stringify(data) + } + + validationIssues = validation.error.issues.map(issue => issue.message) + throw validation.error + } else { + throw new Error("Validation failed.") + } + } + + return data + } catch (error) { + if (attempts < max_retries) { + attempts++ + return await makeCompletionCallWithRetries() + } else { + throw error + } + } } - if (mode === MODE.JSON) { - args[0].response_format = { type: "json_object" } - } else if (mode == MODE.JSON_SCHEMA) { - args[0].response_format = { type: "json_object" } - } else if (mode === MODE.MD_JSON) { - args[0].messages.push({ - role: "assistant", - content: "```json" - }) - args[0].stop = "```" + + return await makeCompletionCallWithRetries() + } + + /** + * Builds the chat completion parameters. + * @param {PatchedChatCompletionCreateParams} params - The parameters for chat completion. + * @returns {ChatCompletionCreateParamsNonStreaming} The chat completion parameters. + */ + private buildChatCompletionParams = ({ + response_model, + ...params + }: PatchedChatCompletionCreateParams): ChatCompletionCreateParamsNonStreaming => { + const jsonSchema = zodToJsonSchema(response_model, "response_model") + + const definition = { + name: "response_model", + ...jsonSchema.definitions.response_model } - if (args[0].messages[0].role != "system") { - args[0].messages.unshift({ role: "system", content: message }) - } else { - args[0].messages[0].content += `\n${message}` + + const paramsForMode = MODE_TO_PARAMS[this.mode](definition, params, this.mode) + + return { + stream: false, + ...paramsForMode } - } else { - console.error("unknown mode", mode) } - return [response_model, args, mode] -} -function processResponse( - response: OpenAI.Chat.Completions.ChatCompletion, - response_model: OpenAISchema, - mode: MODE = "FUNCTIONS" -) { - const message = response.choices[0].message - if (mode === MODE.FUNCTIONS) { - assert.equal( - message.function_call!.name, - response_model.openai_schema.name, - "Function name does not match" - ) - return response_model.zod_schema.parse(JSON.parse(message.function_call!.arguments)) - } else if (mode === MODE.TOOLS) { - const tool_call = message.tool_calls![0] - assert.equal( - tool_call.function.name, - response_model.openai_schema.name, - "Tool name does not match" - ) - return response_model.zod_schema.parse(JSON.parse(tool_call.function.arguments)) - } else if ([MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA].includes(mode)) { - return response_model.zod_schema.parse(JSON.parse(message.content!)) - } else { - console.error("unknown mode", mode) - } -} + /** + * Parses the OAI response. + * @param {ChatCompletion} response - The response from the chat completion. + * @returns {any} The parsed response. + */ + private parseOAIResponse = (response: ChatCompletion) => { + const parser = MODE_TO_PARSER[this.mode] -function dumpMessage(message: ChatCompletionMessage) { - const ret: ChatCompletionMessage = { - role: message.role, - content: message.content || "" - } - if (message.tool_calls) { - ret["content"] += JSON.stringify(message.tool_calls) - } - if (message.function_call) { - ret["content"] += JSON.stringify(message.function_call) + return parser(response) } - return ret -} -const patch = ({ - client, - mode -}: { - client: OpenAI - response_model?: ZodSchema | OpenAISchema - max_retries?: number - mode?: MODE -}): OpenAI => { - client.chat.completions.create = new Proxy(client.chat.completions.create, { - async apply(target, ctx, args: PatchedChatCompletionCreateParams[]) { - const max_retries = args[0].max_retries || 1 - let retries = 0, - response: ChatCompletion | undefined = undefined, - response_model = args[0].response_model - ;[response_model, args, mode] = handleResponseModel(response_model!, args, mode) - - delete args[0].response_model - delete args[0].max_retries - - while (retries < max_retries) { - try { - response = (await target.apply( - ctx, - args as [PatchedChatCompletionCreateParams] - )) as ChatCompletion - return processResponse(response, response_model as OpenAISchema, mode) - } catch (error) { - console.error(error.errors || error) - if (!response) { - break - } - args[0].messages.push(dumpMessage(response.choices[0].message)) - args[0].messages.push({ - role: "user", - content: `Recall the function correctly, fix the errors, exceptions found\n${error}` - }) - if (mode == MODE.MD_JSON) { - args[0].messages.push({ role: "assistant", content: "```json" }) - } - retries++ - if (retries > max_retries) { - throw error - } - } finally { - response = undefined - } - } + /** + * Public chat interface. + */ + public chat = { + completions: { + create: this.chatCompletion } - }) - return client + } } - -export default patch diff --git a/src/oai/params.ts b/src/oai/params.ts new file mode 100644 index 00000000..ae1795ca --- /dev/null +++ b/src/oai/params.ts @@ -0,0 +1,54 @@ +import { MODE } from "@/constants/modes" + +export function OAIBuildFunctionParams(definition, params) { + return { + ...params, + function_call: { + name: definition.name + }, + functions: [...(params?.functions ?? []), definition] + } +} + +export function OAIBuildToolFunctionParams(definition, params) { + return { + ...params, + tool_choice: { + type: "function", + function: { name: definition.name } + }, + tools: [...(params?.tools ?? []), definition] + } +} + +export function OAIBuildMessageBasedParams(definition, params, mode) { + const MODE_SPECIFIC_CONFIGS = { + [MODE.JSON]: { + response_format: { type: "json_object" } + }, + [MODE.JSON_SCHEMA]: { + response_format: { type: "json_object", schema: definition } + } + } + + const modeConfig = MODE_SPECIFIC_CONFIGS[mode] ?? {} + + return { + ...params, + ...modeConfig, + messages: [ + ...(params?.messages ?? []), + { + role: "SYSTEM", + content: ` + Given a user prompt, you will return fully valid JSON based on the following description and schema. + You will return no other prose. You will take into account the descriptions for each paramater within the schema + and return a valid JSON object that matches the schema and those instructions. + + description: ${definition?.description} + json schema: ${JSON.stringify(definition)} + ` + } + ] + } +} diff --git a/src/oai/parser.ts b/src/oai/parser.ts new file mode 100644 index 00000000..e5a4276d --- /dev/null +++ b/src/oai/parser.ts @@ -0,0 +1,81 @@ +import OpenAI from "openai" +import { Stream } from "openai/streaming" + +/** + * `OAIResponseTextParser` parses a JSON string and extracts the text content. + * + * @param {string} data - The JSON string to parse. + * @returns {string} - The extracted text content. + * + */ +export function OAIResponseTextParser( + data: + | string + | Stream + | OpenAI.Chat.Completions.ChatCompletion +) { + const parsedData = typeof data === "string" ? JSON.parse(data) : data + + const text = parsedData?.choices[0]?.message?.content ?? "{}" + + return JSON.parse(text) +} + +/** + * `OAIResponseFnArgsParser` parses a JSON string and extracts the function call arguments. + * + * @param {string} data - The JSON string to parse. + * @returns {Object} - The extracted function call arguments. + * + */ +export function OAIResponseFnArgsParser( + data: + | string + | Stream + | OpenAI.Chat.Completions.ChatCompletion +) { + const parsedData = typeof data === "string" ? JSON.parse(data) : data + + const text = parsedData.choices?.[0]?.message?.function_call?.arguments ?? "{}" + + return JSON.parse(text) +} + +/** + * `OAIResponseToolArgsParser` parses a JSON string and extracts the tool call arguments. + * + * @param {string} data - The JSON string to parse. + * @returns {Object} - The extracted tool call arguments. + * + */ +export function OAIResponseToolArgsParser( + data: + | string + | Stream + | OpenAI.Chat.Completions.ChatCompletion +) { + const parsedData = typeof data === "string" ? JSON.parse(data) : data + + const text = parsedData.choices?.[0]?.message?.tool_call?.function?.arguments ?? "{}" + + return JSON.parse(text) +} + +/** + * `OAIResponseJSONParser` parses a JSON string and extracts the JSON content. + * + * @param {string} data - The JSON string to parse. + * @returns {Object} - The extracted JSON content. + * + */ +export function OAIResponseJSONStringParser( + data: + | string + | Stream + | OpenAI.Chat.Completions.ChatCompletion +) { + const parsedData = typeof data === "string" ? JSON.parse(data) : data + const text = parsedData?.choices[0]?.message?.content ?? "{}" + + return JSON.parse(text) +} diff --git a/tests/functions.test.ts b/tests/functions.test.ts index fa7044b9..1c003c55 100644 --- a/tests/functions.test.ts +++ b/tests/functions.test.ts @@ -18,16 +18,14 @@ async function extractUser() { organization: process.env.OPENAI_ORG_ID ?? undefined }) - const client = Instructor({ + const client = new Instructor({ client: oai, mode: "FUNCTIONS" }) - //@ts-expect-error these types wont work since were using a proxy and just returning the OAI instance type const user: User = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], model: "gpt-3.5-turbo", - //@ts-expect-error same as above response_model: UserSchema, max_retries: 3 })