diff --git a/examples/pipeline/summarize/main.go b/examples/pipeline/summarize/main.go index 06783b48..ca979ca8 100644 --- a/examples/pipeline/summarize/main.go +++ b/examples/pipeline/summarize/main.go @@ -14,7 +14,7 @@ import ( func main() { summarize := summarizepipeline.New( - openai.NewCompletion().WithMaxTokens(1000).WithVerbose(true).WithModel(openai.GPT3TextDavinci002), + openai.NewCompletion().WithMaxTokens(1000).WithVerbose(true).WithModel(openai.GPT3Dot5TurboInstruct), loader.NewTextLoader("state_of_the_union.txt", nil). WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 0)), ) diff --git a/go.mod b/go.mod index 1b5bf4bd..62512320 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/henomis/qdrant-go v1.1.0 github.com/invopop/jsonschema v0.7.0 github.com/pkoukk/tiktoken-go v0.1.1 - github.com/sashabaranov/go-openai v1.12.0 + github.com/sashabaranov/go-openai v1.17.9 ) require ( diff --git a/go.sum b/go.sum index 379ee07b..d9524525 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6 github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/sashabaranov/go-openai v1.12.0 h1:aRNHH0gtVfrpIaEolD0sWrLLRnYQNK4cH/bIAHwL8Rk= -github.com/sashabaranov/go-openai v1.12.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.17.9 h1:QEoBiGKWW68W79YIfXWEFZ7l5cEgZBV4/Ow3uy+5hNY= +github.com/sashabaranov/go-openai v1.17.9/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/llm/openai/function.go b/llm/openai/function.go index c4115fc5..2f1d6c71 100644 --- a/llm/openai/function.go +++ b/llm/openai/function.go @@ -49,18 +49,21 @@ func (o *OpenAI) BindFunction( return nil } -func (o *OpenAI) getFunctions() []openai.FunctionDefinition { - functions := []openai.FunctionDefinition{} +func (o *OpenAI) getFunctions() []openai.Tool { + tools := []openai.Tool{} for _, function := range o.functions { - functions = append(functions, openai.FunctionDefinition{ - Name: function.Name, - Description: function.Description, - Parameters: function.Parameters, + tools = append(tools, openai.Tool{ + Type: "function", + Function: openai.FunctionDefinition{ + Name: function.Name, + Description: function.Description, + Parameters: function.Parameters, + }, }) } - return functions + return tools } func extractFunctionParameter(f interface{}) (map[string]interface{}, error) { @@ -170,12 +173,16 @@ func callFnWithArgumentAsJSON(fn interface{}, argumentAsJSON string) (string, er } func (o *OpenAI) functionCall(response openai.ChatCompletionResponse) (string, error) { - fn, ok := o.functions[response.Choices[0].Message.FunctionCall.Name] + fn, ok := o.functions[response.Choices[0].Message.ToolCalls[0].Function.Name] if !ok { - return "", fmt.Errorf("%w: unknown function %s", ErrOpenAIChat, response.Choices[0].Message.FunctionCall.Name) + return "", fmt.Errorf( + "%w: unknown function %s", + ErrOpenAIChat, + response.Choices[0].Message.ToolCalls[0].Function.Name, + ) } - resultAsJSON, err := callFnWithArgumentAsJSON(fn.Fn, response.Choices[0].Message.FunctionCall.Arguments) + resultAsJSON, err := callFnWithArgumentAsJSON(fn.Fn, response.Choices[0].Message.ToolCalls[0].Function.Arguments) if err != nil { return "", fmt.Errorf("%w: %w", ErrOpenAIChat, err) } diff --git a/llm/openai/openai.go b/llm/openai/openai.go index 9307976c..43e87e20 100644 --- a/llm/openai/openai.go +++ b/llm/openai/openai.go @@ -32,29 +32,29 @@ const ( type Model string const ( - GPT432K0613 Model = openai.GPT432K0613 - GPT432K0314 Model = openai.GPT432K0314 - GPT432K Model = openai.GPT432K - GPT40613 Model = openai.GPT40613 - GPT40314 Model = openai.GPT40314 - GPT4 Model = openai.GPT4 - GPT3Dot5Turbo0613 Model = openai.GPT3Dot5Turbo0613 - GPT3Dot5Turbo0301 Model = openai.GPT3Dot5Turbo0301 - GPT3Dot5Turbo16K Model = openai.GPT3Dot5Turbo16K - GPT3Dot5Turbo16K0613 Model = openai.GPT3Dot5Turbo16K0613 - GPT3Dot5Turbo Model = openai.GPT3Dot5Turbo - GPT3TextDavinci003 Model = openai.GPT3TextDavinci003 - GPT3TextDavinci002 Model = openai.GPT3TextDavinci002 - GPT3TextCurie001 Model = openai.GPT3TextCurie001 - GPT3TextBabbage001 Model = openai.GPT3TextBabbage001 - GPT3TextAda001 Model = openai.GPT3TextAda001 - GPT3TextDavinci001 Model = openai.GPT3TextDavinci001 - GPT3DavinciInstructBeta Model = openai.GPT3DavinciInstructBeta - GPT3Davinci Model = openai.GPT3Davinci - GPT3CurieInstructBeta Model = openai.GPT3CurieInstructBeta - GPT3Curie Model = openai.GPT3Curie - GPT3Ada Model = openai.GPT3Ada - GPT3Babbage Model = openai.GPT3Babbage + GPT432K0613 Model = openai.GPT432K0613 + GPT432K0314 Model = openai.GPT432K0314 + GPT432K Model = openai.GPT432K + GPT40613 Model = openai.GPT40613 + GPT40314 Model = openai.GPT40314 + GPT4TurboPreview Model = openai.GPT4TurboPreview + GPT4VisionPreview Model = openai.GPT4VisionPreview + GPT4 Model = openai.GPT4 + GPT3Dot5Turbo1106 Model = openai.GPT3Dot5Turbo1106 + GPT3Dot5Turbo0613 Model = openai.GPT3Dot5Turbo0613 + GPT3Dot5Turbo0301 Model = openai.GPT3Dot5Turbo0301 + GPT3Dot5Turbo16K Model = openai.GPT3Dot5Turbo16K + GPT3Dot5Turbo16K0613 Model = openai.GPT3Dot5Turbo16K0613 + GPT3Dot5Turbo Model = openai.GPT3Dot5Turbo + GPT3Dot5TurboInstruct Model = openai.GPT3Dot5TurboInstruct + GPT3Davinci Model = openai.GPT3Davinci + GPT3Davinci002 Model = openai.GPT3Davinci002 + GPT3Curie Model = openai.GPT3Curie + GPT3Curie002 Model = openai.GPT3Curie002 + GPT3Ada Model = openai.GPT3Ada + GPT3Ada002 Model = openai.GPT3Ada002 + GPT3Babbage Model = openai.GPT3Babbage + GPT3Babbage002 Model = openai.GPT3Babbage002 ) type UsageCallback func(types.Meta) @@ -70,6 +70,7 @@ type OpenAI struct { usageCallback UsageCallback functions map[string]Function functionsMaxIterations uint + toolChoice *string calledFunctionName *string finishReason string cache *cache.Cache @@ -137,6 +138,11 @@ func (o *OpenAI) WithCompletionCache(cache *cache.Cache) *OpenAI { return o } +func (o *OpenAI) WithToolChoice(toolChoice string) *OpenAI { + o.toolChoice = &toolChoice + return o +} + // CalledFunctionName returns the name of the function that was called. func (o *OpenAI) CalledFunctionName() *string { return o.calledFunctionName @@ -149,7 +155,7 @@ func (o *OpenAI) FinishReason() string { func NewCompletion() *OpenAI { return New( - GPT3TextDavinci003, + GPT3Dot5TurboInstruct, DefaultOpenAITemperature, DefaultOpenAIMaxTokens, false, @@ -308,7 +314,17 @@ func (o *OpenAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) { } if len(o.functions) > 0 { - chatCompletionRequest.Functions = o.getFunctions() + chatCompletionRequest.Tools = o.getFunctions() + if o.toolChoice != nil { + chatCompletionRequest.ToolChoice = openai.ToolChoice{ + Type: openai.ToolTypeFunction, + Function: openai.ToolFunction{ + Name: *o.toolChoice, + }, + } + } else { + chatCompletionRequest.ToolChoice = "auto" + } } response, err := o.openAIClient.CreateChatCompletion( @@ -332,10 +348,10 @@ func (o *OpenAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) { o.finishReason = string(response.Choices[0].FinishReason) o.calledFunctionName = nil - if response.Choices[0].FinishReason == "function_call" && len(o.functions) > 0 { + if len(response.Choices[0].Message.ToolCalls) > 0 && len(o.functions) > 0 { if o.verbose { - fmt.Printf("Calling function %s\n", response.Choices[0].Message.FunctionCall.Name) - fmt.Printf("Function call arguments: %s\n", response.Choices[0].Message.FunctionCall.Arguments) + fmt.Printf("Calling function %s\n", response.Choices[0].Message.ToolCalls[0].Function.Name) + fmt.Printf("Function call arguments: %s\n", response.Choices[0].Message.ToolCalls[0].Function.Arguments) } content, err = o.functionCall(response)