Skip to content

Commit

Permalink
initial stream handling
Browse files Browse the repository at this point in the history
  • Loading branch information
roodboi committed Jan 3, 2024
1 parent 3893e1a commit 4ad3fd4
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 48 deletions.
22 changes: 0 additions & 22 deletions .github/workflows/npm-publish.yml

This file was deleted.

Binary file modified bun.lockb
Binary file not shown.
31 changes: 31 additions & 0 deletions examples/extract_user copy/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import Instructor from "@/instructor"
import OpenAI from "openai"
import { z } from "zod"

const UserSchema = z.object({
age: z.number(),
name: z.string().refine(name => name.includes(" "), {
message: "Name must contain a space"
})
})

type User = z.infer<typeof UserSchema>

const oai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY ?? undefined,
organization: process.env.OPENAI_ORG_ID ?? undefined
})

const client = Instructor({
client: oai,
mode: "FUNCTIONS"
})

const user: User = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-3.5-turbo",
response_model: UserSchema,
max_retries: 3
})

console.log(user)
53 changes: 53 additions & 0 deletions examples/extract_user_stream/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import Instructor from "@/instructor"
import OpenAI from "openai"
import { z } from "zod"

const UserSchema = z.object({
age: z.number(),
name: z.string()
})

type User = z.infer<typeof UserSchema>

const oai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY ?? undefined,
organization: process.env.OPENAI_ORG_ID ?? undefined
})

const client = Instructor({
client: oai,
mode: "FUNCTIONS"
})

const userStream = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-3.5-turbo",
response_model: UserSchema,
max_retries: 3,
stream: true
})

const reader = userStream.readable.getReader()
const decoder = new TextDecoder()

let result = {}
let done = false

while (!done) {
try {
const { value, done: doneReading } = await reader.read()
done = doneReading

if (done) {
break
}

const chunkValue = decoder.decode(value)
result = JSON.parse(chunkValue)
console.log(result)
} catch (e) {
done = true
console.log(e)
break
}
}
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
},
"homepage": "https://github.com/jxnl/instructor-js#readme",
"dependencies": {
"schema-stream": "^1.1.0",
"zod-to-json-schema": "^3.22.3"
},
"peerDependencies": {
Expand Down
56 changes: 38 additions & 18 deletions src/instructor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import {
OAIResponseJSONStringParser,
OAIResponseToolArgsParser
} from "@/oai/parser"
import { OAIStream } from "@/oai/stream"
import OpenAI from "openai"
import { ChatCompletion, ChatCompletionCreateParamsNonStreaming } from "openai/resources/index.mjs"
import { ZodObject } from "zod"
import { ChatCompletion, ChatCompletionCreateParams } from "openai/resources/index.mjs"

Check warning on line 13 in src/instructor.ts

View workflow job for this annotation

GitHub Actions / run-tests

'ChatCompletion' is defined but never used. Allowed unused vars must match /^_/u
import { SchemaStream } from "schema-stream"
import { z, ZodObject } from "zod"

Check warning on line 15 in src/instructor.ts

View workflow job for this annotation

GitHub Actions / run-tests

'z' is defined but never used. Allowed unused vars must match /^_/u
import zodToJsonSchema from "zod-to-json-schema"

import { MODE } from "@/constants/modes"
Expand All @@ -31,7 +33,7 @@ const MODE_TO_PARAMS = {
[MODE.JSON_SCHEMA]: OAIBuildMessageBasedParams
}

type PatchedChatCompletionCreateParams = ChatCompletionCreateParamsNonStreaming & {
type PatchedChatCompletionCreateParams = ChatCompletionCreateParams & {
//eslint-disable-next-line @typescript-eslint/no-explicit-any
response_model?: ZodObject<any>
max_retries?: number
Expand Down Expand Up @@ -84,9 +86,14 @@ class Instructor {
}

const completion = await this.client.chat.completions.create(resolvedParams)
const response = this.parseOAIResponse(completion)
const parser = MODE_TO_PARSER[this.mode]

return response
if ("choices" in completion) {
const parsedCompletion = parser(completion)
return JSON.parse(parsedCompletion)
} else {
return OAIStream({ res: completion, parser })
}
} catch (error) {
throw error
}
Expand All @@ -95,6 +102,15 @@ class Instructor {
const makeCompletionCallWithRetries = async () => {
try {
const data = await makeCompletionCall()

//short circuit if this is a stream for now
if (params.stream) {
return this.partialStreamResponse({
stream: data,
schema: params.response_model
})
}

const validation = params.response_model.safeParse(data)

if (!validation.success) {
Expand Down Expand Up @@ -125,15 +141,30 @@ class Instructor {
return await makeCompletionCallWithRetries()
}

private async partialStreamResponse({ stream, schema }) {
const streamParser = new SchemaStream(schema, {
onKeyComplete: console.log
})

const parser = streamParser.parse({
stringStreaming: true,
handleUnescapedNewLines: true
})

stream.pipeThrough(parser)

return parser
}

/**
* Builds the chat completion parameters.
* @param {PatchedChatCompletionCreateParams} params - The parameters for chat completion.
* @returns {ChatCompletionCreateParamsNonStreaming} The chat completion parameters.
* @returns {ChatCompletionCreateParams} The chat completion parameters.
*/
private buildChatCompletionParams = ({
response_model,
...params
}: PatchedChatCompletionCreateParams): ChatCompletionCreateParamsNonStreaming => {
}: PatchedChatCompletionCreateParams): ChatCompletionCreateParams => {
const jsonSchema = zodToJsonSchema(response_model, "response_model")

const definition = {
Expand All @@ -149,17 +180,6 @@ class Instructor {
}
}

/**
* Parses the OAI response.
* @param {ChatCompletion} response - The response from the chat completion.
* @returns {any} The parsed response.
*/
private parseOAIResponse = (response: ChatCompletion) => {
const parser = MODE_TO_PARSER[this.mode]

return parser(response)
}

/**
* Public chat interface.
*/
Expand Down
23 changes: 15 additions & 8 deletions src/oai/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ export function OAIResponseTextParser(
) {
const parsedData = typeof data === "string" ? JSON.parse(data) : data

const text = parsedData?.choices[0]?.message?.content ?? "{}"
const text = parsedData?.choices[0]?.message?.content ?? null

return JSON.parse(text)
return text
}

/**
Expand All @@ -36,9 +36,12 @@ export function OAIResponseFnArgsParser(
) {
const parsedData = typeof data === "string" ? JSON.parse(data) : data

const text = parsedData.choices?.[0]?.message?.function_call?.arguments ?? "{}"
const text =
parsedData.choices?.[0].delta?.function_call?.arguments ??
parsedData.choices?.[0]?.message?.function_call?.arguments ??
null

return JSON.parse(text)
return text
}

/**
Expand All @@ -56,9 +59,12 @@ export function OAIResponseToolArgsParser(
) {
const parsedData = typeof data === "string" ? JSON.parse(data) : data

const text = parsedData.choices?.[0]?.message?.tool_call?.function?.arguments ?? "{}"
const text =
parsedData.choices?.[0].delta?.tool_call?.function?.arguments ??
parsedData.choices?.[0]?.message?.tool_call?.function?.arguments ??
null

return JSON.parse(text)
return text
}

/**
Expand All @@ -75,7 +81,8 @@ export function OAIResponseJSONStringParser(
| OpenAI.Chat.Completions.ChatCompletion
) {
const parsedData = typeof data === "string" ? JSON.parse(data) : data
const text = parsedData?.choices[0]?.message?.content ?? "{}"
const text =
parsedData.choices?.[0].delta?.content ?? parsedData?.choices[0]?.message?.content ?? null

return JSON.parse(text)
return text
}
51 changes: 51 additions & 0 deletions src/oai/stream.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import OpenAI from "openai"
import { Stream } from "openai/streaming"

interface OaiStreamArgs {
res: Stream<OpenAI.Chat.Completions.ChatCompletionChunk>
parser
}

/**
* `OaiStream` creates a ReadableStream that parses the SSE response from OAI
* and returns a parsed string from the response.
*
* @param {OaiStreamArgs} args - The arguments for the function.
* @returns {ReadableStream<string>} - The created ReadableStream.
*/
export function OAIStream({ res, parser }: OaiStreamArgs): ReadableStream<Uint8Array> {
const encoder = new TextEncoder()
let cancelGenerator: () => void

async function* generateStream(res): AsyncGenerator<string> {
cancelGenerator = () => {
return
}

for await (const part of res) {
if (part?.choices?.[0]?.finish_reason === "stop") {
cancelGenerator()
break
}

yield parser(part)
}
}

const generator = generateStream(res)

return new ReadableStream({
async start(controller) {
for await (const parsedData of generator) {
controller.enqueue(encoder.encode(parsedData))
}

controller.close()
},
cancel() {
if (cancelGenerator) {
cancelGenerator()
}
}
})
}
1 change: 1 addition & 0 deletions tests/functions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async function extractUser() {
const user: User = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-3.5-turbo",

response_model: UserSchema,
max_retries: 3
})
Expand Down

0 comments on commit 4ad3fd4

Please sign in to comment.