From 6dde1f6cc0d8fcbdfe510d9f22a4e4e3a365c2b1 Mon Sep 17 00:00:00 2001 From: sam Date: Wed, 3 Jan 2024 15:20:42 -0800 Subject: [PATCH] feat: add type inference using zod model (#29) Co-authored-by: Jason Liu --- .../classification/multi_prediction/index.ts | 2 +- .../classification/simple_prediction/index.ts | 2 +- examples/extract_user/index.ts | 15 ++++- examples/passthrough/index.ts | 20 ++++++ .../exampleGraphMaker.ts | 2 +- src/instructor.ts | 63 ++++++++++--------- tsconfig.json | 1 + 7 files changed, 72 insertions(+), 33 deletions(-) create mode 100644 examples/passthrough/index.ts diff --git a/examples/classification/multi_prediction/index.ts b/examples/classification/multi_prediction/index.ts index 45bf367f..cb8d62fc 100644 --- a/examples/classification/multi_prediction/index.ts +++ b/examples/classification/multi_prediction/index.ts @@ -26,7 +26,7 @@ const client = Instructor({ }) const createClassification = async (data: string): Promise => { - const classification: MultiClassification = await client.chat.completions.create({ + const classification = await client.chat.completions.create({ messages: [{ role: "user", content: `"Classify the following support ticket: ${data}` }], model: "gpt-3.5-turbo", response_model: MultiClassificationSchema, diff --git a/examples/classification/simple_prediction/index.ts b/examples/classification/simple_prediction/index.ts index c8235b7e..92f1e8f4 100644 --- a/examples/classification/simple_prediction/index.ts +++ b/examples/classification/simple_prediction/index.ts @@ -25,7 +25,7 @@ const client = Instructor({ }) const createClassification = async (data: string): Promise => { - const classification: SimpleClassification = await client.chat.completions.create({ + const classification = await client.chat.completions.create({ messages: [{ role: "user", content: `"Classify the following text: ${data}` }], model: "gpt-3.5-turbo", response_model: SimpleClassificationSchema, diff --git a/examples/extract_user/index.ts b/examples/extract_user/index.ts index 864d2e42..9d50bd43 100644 --- a/examples/extract_user/index.ts +++ b/examples/extract_user/index.ts @@ -21,11 +21,22 @@ const client = Instructor({ mode: "FUNCTIONS" }) -const user: User = await client.chat.completions.create({ +const user = await client.chat.completions.create({ messages: [{ role: "user", content: "Jason Liu is 30 years old" }], model: "gpt-3.5-turbo", response_model: UserSchema, max_retries: 3 -}) +}); + + +// let's now verify that the response type is inferred correctly + +const age: number = user.age; +// @ts-expect-error - age is a number, not a string +const _age: string = user.age; +const name: string = user.name; + +// @ts-expect-error - this property does not exist +user.missing; console.log(user) diff --git a/examples/passthrough/index.ts b/examples/passthrough/index.ts new file mode 100644 index 00000000..a364572c --- /dev/null +++ b/examples/passthrough/index.ts @@ -0,0 +1,20 @@ +import Instructor from "@/instructor" +import OpenAI from "openai" + +const oai = new OpenAI({ + apiKey: process.env.OPENAI_API_KEY ?? undefined, + organization: process.env.OPENAI_ORG_ID ?? undefined +}) + +const client = Instructor({ + client: oai, + mode: "FUNCTIONS" +}) + +// ensures that when no `response_model` is provided, the response type is `ChatCompletion` +const completion = await client.chat.completions.create({ + messages: [{ role: "user", content: "Jason Liu is 30 years old" }], + model: "gpt-3.5-turbo", + max_retries: 3 +}) satisfies OpenAI.Chat.ChatCompletion; + diff --git a/examples/resolving-complex-entitities/exampleGraphMaker.ts b/examples/resolving-complex-entitities/exampleGraphMaker.ts index 5b28b7cb..b991085e 100644 --- a/examples/resolving-complex-entitities/exampleGraphMaker.ts +++ b/examples/resolving-complex-entitities/exampleGraphMaker.ts @@ -39,7 +39,7 @@ function createHtmlDocument(data) { }) .join(",\n") - const edgeDefs = [] + const edgeDefs: string[] = [] data.entities.forEach(entity => { entity.dependencies.forEach(depId => { // @ts-ignore diff --git a/src/instructor.ts b/src/instructor.ts index 65773b13..f2bd6456 100644 --- a/src/instructor.ts +++ b/src/instructor.ts @@ -14,7 +14,7 @@ import type { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from "openai/resources/index.mjs" -import { ZodObject } from "zod" +import type { ZodObject, z } from "zod" import zodToJsonSchema from "zod-to-json-schema" import { fromZodError } from "zod-validation-error" @@ -36,9 +36,9 @@ const MODE_TO_PARAMS = { [MODE.JSON_SCHEMA]: OAIBuildMessageBasedParams } -type PatchedChatCompletionCreateParams = ChatCompletionCreateParamsNonStreaming & { +interface PatchedChatCompletionCreateParams | undefined> extends ChatCompletionCreateParamsNonStreaming { //eslint-disable-next-line @typescript-eslint/no-explicit-any - response_model?: ZodObject + response_model?: Model max_retries?: number } @@ -57,11 +57,19 @@ class Instructor { } /** - * Handles chat completion with retries. - * @param {PatchedChatCompletionCreateParams} params - The parameters for chat completion. - * @returns {Promise} The response from the chat completion. + * Handles chat completion with retries and parses the response if a response model is provided. + * + * @param params - The parameters for chat completion. + * @returns The parsed response model if {@link PatchedChatCompletionCreateParams.response_model} is provided, otherwise the original chat completion. */ - chatCompletion = async ({ max_retries = 3, ...params }: PatchedChatCompletionCreateParams) => { + async chatCompletion | undefined = undefined>({ + max_retries = 3, + ...params + }: PatchedChatCompletionCreateParams): + Promise + ? z.infer + : OpenAI.Chat.Completions.ChatCompletion > { + let attempts = 0 let validationIssues = "" let lastMessage: ChatCompletionMessageParam | null = null @@ -87,8 +95,10 @@ class Instructor { } const completion = await this.client.chat.completions.create(resolvedParams) + if (params.response_model === undefined) { + return completion; + } const response = this.parseOAIResponse(completion) - return response } catch (error) { throw error @@ -98,22 +108,25 @@ class Instructor { const makeCompletionCallWithRetries = async () => { try { const data = await makeCompletionCall() - if (params.response_model === undefined) return data - const validation = params.response_model.safeParse(data) - if (!validation.success) { - if ("error" in validation) { - lastMessage = { - role: "assistant", - content: JSON.stringify(data) + if (params.response_model === undefined) { + return data; + } else { + const validation = params.response_model.safeParse(data) + if (!validation.success) { + if ("error" in validation) { + lastMessage = { + role: "assistant", + content: JSON.stringify(data) + } + + validationIssues = fromZodError(validation.error).message + throw validation.error + } else { + throw new Error("Validation failed.") } - - validationIssues = fromZodError(validation.error).message - throw validation.error - } else { - throw new Error("Validation failed.") } + return validation.data } - return validation.data } catch (error) { if (attempts < max_retries) { attempts++ @@ -135,13 +148,7 @@ class Instructor { private buildChatCompletionParams = ({ response_model, ...params - }: PatchedChatCompletionCreateParams): ChatCompletionCreateParamsNonStreaming => { - if (response_model === undefined) { - return { - stream: false, - ...params - } - } + }: PatchedChatCompletionCreateParams): ChatCompletionCreateParamsNonStreaming => { const jsonSchema = zodToJsonSchema(response_model, "response_model") const definition = { diff --git a/tsconfig.json b/tsconfig.json index a3dff1b2..f21ed520 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,6 +1,7 @@ { "compilerOptions": { "strict": false, + "strictNullChecks": true, "noEmit": true, "allowJs": true, "jsx": "preserve",