Skip to content

Commit

Permalink
chore: refactor called function name (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis authored Jul 10, 2023
1 parent 2e691a1 commit ee3df70
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
23 changes: 12 additions & 11 deletions examples/chat/functions/main.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
package main

import (
"bufio"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"

"github.com/henomis/lingoose/chat"
"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/prompt"
)

func main() {
// fmt.Printf("What's your name?\n> ")
// reader := bufio.NewReader(os.Stdin)
// name, _ := reader.ReadString('\n')

name := "simone"
fmt.Printf("What's your name?\n> ")
reader := bufio.NewReader(os.Stdin)
name, _ := reader.ReadString('\n')

llmChat := chat.New(

Expand All @@ -41,25 +41,26 @@ func main() {
if err != nil {
panic(err)
}
// fmt.Printf("\n%s", response)
_ = response

if llmOpenAI.CalledFunctionName() == nil {
fmt.Printf("expected called function name to be set")
return
}

llmChat.AddPromptMessages(
[]chat.PromptMessage{
{
Type: chat.MessageTypeFunction,
Prompt: prompt.New(response),
Name: llmOpenAI.LastFunctionCallName(),
Name: llmOpenAI.CalledFunctionName(),
},
},
)

response, err = llmOpenAI.Chat(context.Background(), llmChat)
_, err = llmOpenAI.Chat(context.Background(), llmChat)
if err != nil {
panic(err)
}
_ = response
// fmt.Printf("\n%s", response)

}

Expand Down
2 changes: 1 addition & 1 deletion llm/openai/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (o *openAI) functionCall(response openai.ChatCompletionResponse) (string, e
return "", fmt.Errorf("%s: %w", ErrOpenAIChat, err)
}

o.lastFunctionCalledName = fn.Name
o.calledFunctionName = &fn.Name

return resultAsJSON, nil
}
7 changes: 4 additions & 3 deletions llm/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type openAI struct {
usageCallback OpenAIUsageCallback
functions map[string]Function
functionsMaxIterations uint
lastFunctionCalledName string
calledFunctionName *string
}

func New(model Model, temperature float32, maxTokens int, verbose bool) *openAI {
Expand Down Expand Up @@ -127,8 +127,8 @@ func (o *openAI) WithFunctionCallMaxIterations(maxIterations uint) *openAI {
return o
}

func (o *openAI) LastFunctionCallName() *string {
return &o.lastFunctionCalledName
func (o *openAI) CalledFunctionName() *string {
return o.calledFunctionName
}

func NewCompletion() *openAI {
Expand Down Expand Up @@ -275,6 +275,7 @@ func (o *openAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) {

content := response.Choices[0].Message.Content

o.calledFunctionName = nil
if response.Choices[0].FinishReason == "function_call" && len(o.functions) > 0 {
content, err = o.functionCall(response)
if err != nil {
Expand Down

0 comments on commit ee3df70

Please sign in to comment.