-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
288 additions
and
185 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,202 +1,174 @@ | ||
import assert from "assert" | ||
import OpenAI from "openai" | ||
import { | ||
ChatCompletion, | ||
ChatCompletionCreateParams, | ||
ChatCompletionMessage | ||
} from "openai/resources/index.mjs" | ||
import { ZodSchema } from "zod" | ||
import { JsonSchema7Type, zodToJsonSchema } from "zod-to-json-schema" | ||
OAIBuildFunctionParams, | ||
OAIBuildMessageBasedParams, | ||
OAIBuildToolFunctionParams | ||
} from "@/oai/params" | ||
import { | ||
OAIResponseFnArgsParser, | ||
OAIResponseJSONStringParser, | ||
OAIResponseToolArgsParser | ||
} from "@/oai/parser" | ||
import OpenAI from "openai" | ||
import { ChatCompletion, ChatCompletionCreateParamsNonStreaming } from "openai/resources/index.mjs" | ||
import { ZodObject } from "zod" | ||
import zodToJsonSchema from "zod-to-json-schema" | ||
|
||
import { MODE } from "@/constants/modes" | ||
|
||
export class OpenAISchema { | ||
private response_model: ReturnType<typeof zodToJsonSchema> | ||
constructor(public zod_schema: ZodSchema) { | ||
this.response_model = zodToJsonSchema(zod_schema) | ||
} | ||
|
||
get definitions() { | ||
return this.response_model["definitions"] | ||
} | ||
|
||
get properties() { | ||
return this.response_model["properties"] | ||
} | ||
const MODE_TO_PARSER = { | ||
[MODE.FUNCTIONS]: OAIResponseFnArgsParser, | ||
[MODE.TOOLS]: OAIResponseToolArgsParser, | ||
[MODE.JSON]: OAIResponseJSONStringParser, | ||
[MODE.MD_JSON]: OAIResponseJSONStringParser, | ||
[MODE.JSON_SCHEMA]: OAIResponseJSONStringParser | ||
} | ||
|
||
get openai_schema() { | ||
return { | ||
name: this.response_model["title"] || "schema", | ||
description: | ||
this.response_model["description"] || | ||
`Correctly extracted \`${ | ||
this.response_model["title"] || "schema" | ||
}\` with all the required parameters with correct types`, | ||
parameters: Object.keys(this.response_model).reduce( | ||
(acc, curr) => { | ||
if ( | ||
curr.startsWith("$") || | ||
["title", "description", "additionalProperties"].includes(curr) | ||
) | ||
return acc | ||
acc[curr] = this.response_model[curr] | ||
return acc | ||
}, | ||
{} as { | ||
[key: string]: object | JsonSchema7Type | ||
} | ||
) | ||
} | ||
} | ||
const MODE_TO_PARAMS = { | ||
[MODE.FUNCTIONS]: OAIBuildFunctionParams, | ||
[MODE.TOOLS]: OAIBuildToolFunctionParams, | ||
[MODE.JSON]: OAIBuildMessageBasedParams, | ||
[MODE.MD_JSON]: OAIBuildMessageBasedParams, | ||
[MODE.JSON_SCHEMA]: OAIBuildMessageBasedParams | ||
} | ||
|
||
type PatchedChatCompletionCreateParams = ChatCompletionCreateParams & { | ||
response_model?: ZodSchema | OpenAISchema | ||
type PatchedChatCompletionCreateParams = ChatCompletionCreateParamsNonStreaming & { | ||
//eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
response_model?: ZodObject<any> | ||
max_retries?: number | ||
} | ||
|
||
function handleResponseModel( | ||
response_model: ZodSchema | OpenAISchema, | ||
args: PatchedChatCompletionCreateParams[], | ||
mode: MODE = "FUNCTIONS" | ||
): [OpenAISchema, PatchedChatCompletionCreateParams[], MODE] { | ||
if (!(response_model instanceof OpenAISchema)) { | ||
response_model = new OpenAISchema(response_model) | ||
export default class Instructor { | ||
private client: OpenAI | ||
private mode: MODE | ||
|
||
/** | ||
* Creates an instance of the `Instructor` class. | ||
* @param {OpenAI} client - The OpenAI client. | ||
* @param {string} mode - The mode of operation. | ||
*/ | ||
constructor({ client, mode }: { client: OpenAI; mode: MODE }) { | ||
this.client = client | ||
this.mode = mode | ||
} | ||
|
||
if (mode === MODE.FUNCTIONS) { | ||
args[0].functions = [response_model.openai_schema] | ||
args[0].function_call = { name: response_model.openai_schema.name } | ||
} else if (mode === MODE.TOOLS) { | ||
args[0].tools = [{ type: "function", function: response_model.openai_schema }] | ||
args[0].tool_choice = { | ||
type: "function", | ||
function: { name: response_model.openai_schema.name } | ||
/** | ||
* Handles chat completion with retries. | ||
* @param {PatchedChatCompletionCreateParams} params - The parameters for chat completion. | ||
* @returns {Promise<any>} The response from the chat completion. | ||
*/ | ||
private chatCompletion = async ({ | ||
max_retries = 3, | ||
...params | ||
}: PatchedChatCompletionCreateParams) => { | ||
let attempts = 0 | ||
let validationIssues = [] | ||
let lastMessage = null | ||
|
||
const completionParams = this.buildChatCompletionParams(params) | ||
|
||
const makeCompletionCall = async () => { | ||
let resolvedParams = completionParams | ||
|
||
try { | ||
if (validationIssues.length > 0) { | ||
resolvedParams = { | ||
...completionParams, | ||
messages: [ | ||
...completionParams.messages, | ||
...(lastMessage ? [lastMessage] : []), | ||
{ | ||
role: "system", | ||
content: `Your last response had the following validation issues, please try again: ${validationIssues.join( | ||
", " | ||
)}` | ||
} | ||
] | ||
} | ||
} | ||
|
||
const completion = await this.client.chat.completions.create(resolvedParams) | ||
const response = this.parseOAIResponse(completion) | ||
|
||
return response | ||
} catch (error) { | ||
throw error | ||
} | ||
} | ||
} else if ([MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA].includes(mode)) { | ||
let message: string = `As a genius expert, your task is to understand the content and provide the parsed objects in json that match the following json_schema: \n${JSON.stringify( | ||
response_model.properties | ||
)}` | ||
if (response_model["definitions"]) { | ||
message += `Here are some more definitions to adhere to: \n${JSON.stringify( | ||
response_model.definitions | ||
)}` | ||
|
||
const makeCompletionCallWithRetries = async () => { | ||
try { | ||
const data = await makeCompletionCall() | ||
const validation = params.response_model.safeParse(data) | ||
|
||
if (!validation.success) { | ||
if ("error" in validation) { | ||
lastMessage = { | ||
role: "assistant", | ||
content: JSON.stringify(data) | ||
} | ||
|
||
validationIssues = validation.error.issues.map(issue => issue.message) | ||
throw validation.error | ||
} else { | ||
throw new Error("Validation failed.") | ||
} | ||
} | ||
|
||
return data | ||
} catch (error) { | ||
if (attempts < max_retries) { | ||
attempts++ | ||
return await makeCompletionCallWithRetries() | ||
} else { | ||
throw error | ||
} | ||
} | ||
} | ||
if (mode === MODE.JSON) { | ||
args[0].response_format = { type: "json_object" } | ||
} else if (mode == MODE.JSON_SCHEMA) { | ||
args[0].response_format = { type: "json_object" } | ||
} else if (mode === MODE.MD_JSON) { | ||
args[0].messages.push({ | ||
role: "assistant", | ||
content: "```json" | ||
}) | ||
args[0].stop = "```" | ||
|
||
return await makeCompletionCallWithRetries() | ||
} | ||
|
||
/** | ||
* Builds the chat completion parameters. | ||
* @param {PatchedChatCompletionCreateParams} params - The parameters for chat completion. | ||
* @returns {ChatCompletionCreateParamsNonStreaming} The chat completion parameters. | ||
*/ | ||
private buildChatCompletionParams = ({ | ||
response_model, | ||
...params | ||
}: PatchedChatCompletionCreateParams): ChatCompletionCreateParamsNonStreaming => { | ||
const jsonSchema = zodToJsonSchema(response_model, "response_model") | ||
|
||
const definition = { | ||
name: "response_model", | ||
...jsonSchema.definitions.response_model | ||
} | ||
if (args[0].messages[0].role != "system") { | ||
args[0].messages.unshift({ role: "system", content: message }) | ||
} else { | ||
args[0].messages[0].content += `\n${message}` | ||
|
||
const paramsForMode = MODE_TO_PARAMS[this.mode](definition, params, this.mode) | ||
|
||
return { | ||
stream: false, | ||
...paramsForMode | ||
} | ||
} else { | ||
console.error("unknown mode", mode) | ||
} | ||
return [response_model, args, mode] | ||
} | ||
|
||
function processResponse( | ||
response: OpenAI.Chat.Completions.ChatCompletion, | ||
response_model: OpenAISchema, | ||
mode: MODE = "FUNCTIONS" | ||
) { | ||
const message = response.choices[0].message | ||
if (mode === MODE.FUNCTIONS) { | ||
assert.equal( | ||
message.function_call!.name, | ||
response_model.openai_schema.name, | ||
"Function name does not match" | ||
) | ||
return response_model.zod_schema.parse(JSON.parse(message.function_call!.arguments)) | ||
} else if (mode === MODE.TOOLS) { | ||
const tool_call = message.tool_calls![0] | ||
assert.equal( | ||
tool_call.function.name, | ||
response_model.openai_schema.name, | ||
"Tool name does not match" | ||
) | ||
return response_model.zod_schema.parse(JSON.parse(tool_call.function.arguments)) | ||
} else if ([MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA].includes(mode)) { | ||
return response_model.zod_schema.parse(JSON.parse(message.content!)) | ||
} else { | ||
console.error("unknown mode", mode) | ||
} | ||
} | ||
/** | ||
* Parses the OAI response. | ||
* @param {ChatCompletion} response - The response from the chat completion. | ||
* @returns {any} The parsed response. | ||
*/ | ||
private parseOAIResponse = (response: ChatCompletion) => { | ||
const parser = MODE_TO_PARSER[this.mode] | ||
|
||
function dumpMessage(message: ChatCompletionMessage) { | ||
const ret: ChatCompletionMessage = { | ||
role: message.role, | ||
content: message.content || "" | ||
} | ||
if (message.tool_calls) { | ||
ret["content"] += JSON.stringify(message.tool_calls) | ||
} | ||
if (message.function_call) { | ||
ret["content"] += JSON.stringify(message.function_call) | ||
return parser(response) | ||
} | ||
return ret | ||
} | ||
|
||
const patch = ({ | ||
client, | ||
mode | ||
}: { | ||
client: OpenAI | ||
response_model?: ZodSchema | OpenAISchema | ||
max_retries?: number | ||
mode?: MODE | ||
}): OpenAI => { | ||
client.chat.completions.create = new Proxy(client.chat.completions.create, { | ||
async apply(target, ctx, args: PatchedChatCompletionCreateParams[]) { | ||
const max_retries = args[0].max_retries || 1 | ||
let retries = 0, | ||
response: ChatCompletion | undefined = undefined, | ||
response_model = args[0].response_model | ||
;[response_model, args, mode] = handleResponseModel(response_model!, args, mode) | ||
|
||
delete args[0].response_model | ||
delete args[0].max_retries | ||
|
||
while (retries < max_retries) { | ||
try { | ||
response = (await target.apply( | ||
ctx, | ||
args as [PatchedChatCompletionCreateParams] | ||
)) as ChatCompletion | ||
return processResponse(response, response_model as OpenAISchema, mode) | ||
} catch (error) { | ||
console.error(error.errors || error) | ||
if (!response) { | ||
break | ||
} | ||
args[0].messages.push(dumpMessage(response.choices[0].message)) | ||
args[0].messages.push({ | ||
role: "user", | ||
content: `Recall the function correctly, fix the errors, exceptions found\n${error}` | ||
}) | ||
if (mode == MODE.MD_JSON) { | ||
args[0].messages.push({ role: "assistant", content: "```json" }) | ||
} | ||
retries++ | ||
if (retries > max_retries) { | ||
throw error | ||
} | ||
} finally { | ||
response = undefined | ||
} | ||
} | ||
/** | ||
* Public chat interface. | ||
*/ | ||
public chat = { | ||
completions: { | ||
create: this.chatCompletion | ||
} | ||
}) | ||
return client | ||
} | ||
} | ||
|
||
export default patch |
Oops, something went wrong.