diff --git a/api_chat_completions_test.go b/api_chat_completions_test.go index 11ef6b5..67566e9 100644 --- a/api_chat_completions_test.go +++ b/api_chat_completions_test.go @@ -9,10 +9,11 @@ import ( "strings" "testing" + "github.com/stretchr/testify/require" + "github.com/northes/go-moonshot" "github.com/northes/go-moonshot/internal/httpx" "github.com/northes/go-moonshot/test" - "github.com/stretchr/testify/require" ) func TestChat(t *testing.T) { @@ -179,24 +180,20 @@ func TestUseTools(t *testing.T) { // check tool calls if len(resp.Choices) != 0 { if resp.Choices[0].FinishReason == moonshot.FinishReasonToolCalls { - for _, toolCall := range resp.Choices[0].Message.ToolCalls { - t.Logf("should tool calls: %v", test.MarshalJsonToStringX(toolCall)) - if strings.HasPrefix(toolCall.ID, functionName) { + for _, tool := range resp.Choices[0].Message.ToolCalls { + t.Logf("should tool calls: %v", test.MarshalJsonToStringX(tool)) + if strings.HasPrefix(tool.ID, functionName) { // tool calls ipInfo, err := IPLocate(ip) if err != nil { t.Fatal(err) } - b, err := json.Marshal(ipInfo) - if err != nil { - t.Fatal(err) - } builder.AddMessageFromChoices(resp.Choices) - t.Logf("tool calls result: %s", test.MarshalJsonToStringX(ipInfo)) + t.Logf("tool calls result: %s", ipInfo) - builder.AddToolContent(string(b), functionName, resp.Choices[0].Message.ToolCalls[0].ID) + builder.AddToolContent(ipInfo, functionName, tool.ID) } } } @@ -231,17 +228,77 @@ type IPLocateInfoResponse struct { Data *IPLocateInfo `json:"data"` } -func IPLocate(ip string) (*IPLocateInfo, error) { +func IPLocate(ip string) (string, error) { response, err := httpx.NewClient(fmt.Sprintf("https://apihut.co/ip/%s", ip)).Get(context.Background()) if err != nil { - return nil, err + return "", err + } + defer func() { + _ = response.Raw().Body.Close() + }() + + body, err := io.ReadAll(response.Raw().Body) + if err != nil { + return "", err + } + + return string(body), nil +} + +func TestBuiltinFunctionWebSearch(t *testing.T) { + if test.IsGithubActions() { + return + } + + cli, err := NewTestClient() + if err != nil { + t.Fatal(err) + } + ctx := context.Background() + + builder := moonshot.NewChatCompletionsBuilder() + builder.SetModel(moonshot.ModelMoonshotV1128K) + builder.AddUserContent("请搜索 Moonshot AI Context Caching 技术,并告诉我它是什么。") + builder.SetTool(&moonshot.ChatCompletionsTool{ + Type: moonshot.ChatCompletionsToolTypeBuiltinFunction, + Function: &moonshot.ChatCompletionsToolFunction{ + Name: moonshot.BuiltinFunctionWebSearch, + }, + }) + + resp, err := cli.Chat().Completions(ctx, builder.ToRequest()) + if err != nil { + t.Fatal(err) + } + + if len(resp.Choices) != 0 { + choice := resp.Choices[0] + if choice.FinishReason == moonshot.FinishReasonToolCalls { + for _, tool := range choice.Message.ToolCalls { + t.Logf("tool calls: %v", test.MarshalJsonToStringX(tool)) + if tool.Function.Name == moonshot.BuiltinFunctionWebSearch { + // web search + arguments := new(moonshot.ChatCompletionsToolBuiltinFunctionWebSearchArguments) + if err = json.Unmarshal([]byte(tool.Function.Arguments), arguments); err != nil { + t.Errorf("unmarshal tool arguments error: %v", err) + continue + } + + t.Logf("tool calls result: search_id: %s, total_tokens: %d", arguments.SearchResult.SearchId, arguments.Usage.TotalTokens) + + builder.AddMessageFromChoices(resp.Choices) + builder.AddToolContent(tool.Function.Arguments, tool.Function.Name, tool.ID) + } + } + } } - respData := new(IPLocateInfoResponse) - err = response.Unmarshal(respData) + t.Logf("builder: %v", test.MarshalJsonToStringX(builder.ToRequest())) + + resp, err = cli.Chat().Completions(ctx, builder.ToRequest()) if err != nil { - return nil, err + t.Fatal(err) } - return respData.Data, nil + t.Log(test.MarshalJsonToStringX(resp)) } diff --git a/api_chat_completions_tool.go b/api_chat_completions_tool.go index baf64cf..ebb806f 100644 --- a/api_chat_completions_tool.go +++ b/api_chat_completions_tool.go @@ -7,8 +7,8 @@ type ChatCompletionsTool struct { type ChatCompletionsToolFunction struct { Name string `json:"name"` - Description string `json:"description"` - Parameters *ChatCompletionsToolFunctionParameters `json:"parameters"` + Description string `json:"description,omitempty"` + Parameters *ChatCompletionsToolFunctionParameters `json:"parameters,omitempty"` } type ChatCompletionsToolFunctionParameters struct { @@ -34,3 +34,12 @@ type ChatCompletionsResponseToolCallsFunction struct { Name string `json:"name"` Arguments string `json:"arguments"` } + +type ChatCompletionsToolBuiltinFunctionWebSearchArguments struct { + SearchResult struct { + SearchId string `json:"search_id"` + } `json:"search_result"` + Usage struct { + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} diff --git a/api_context_cache_test.go b/api_context_cache_test.go index a5c2955..b7702eb 100644 --- a/api_context_cache_test.go +++ b/api_context_cache_test.go @@ -2,7 +2,6 @@ package moonshot_test import ( "context" - "os" "testing" "time" @@ -17,7 +16,7 @@ import ( // https://github.com/MoonshotAI/moonpalace func TestContextCache(t *testing.T) { - if isGithubActions() { + if test.IsGithubActions() { return } cli, err := NewTestClient() @@ -93,7 +92,7 @@ func TestContextCache(t *testing.T) { } func TestContextCache_Create(t *testing.T) { - if isGithubActions() { + if test.IsGithubActions() { return } cli, err := NewTestClient() @@ -124,7 +123,7 @@ func TestContextCache_Create(t *testing.T) { } func TestContextCache_Delete(t *testing.T) { - if isGithubActions() { + if test.IsGithubActions() { return } cli, err := NewTestClient() @@ -144,7 +143,7 @@ func TestContextCache_Delete(t *testing.T) { } func TestContextCache_List(t *testing.T) { - if isGithubActions() { + if test.IsGithubActions() { return } cli, err := NewTestClient() @@ -162,7 +161,7 @@ func TestContextCache_List(t *testing.T) { } func TestContextCache_CreateTag(t *testing.T) { - if isGithubActions() { + if test.IsGithubActions() { return } cli, err := NewTestClient() @@ -182,10 +181,3 @@ func TestContextCache_CreateTag(t *testing.T) { } assert.Equal(t, "MyCacheTag", createResponse.Tag) } - -func isGithubActions() bool { - if val, ok := os.LookupEnv("GITHUB_ACTIONS"); !ok || val != "true" { - return false - } - return true -} diff --git a/enum_builtin_function.go b/enum_builtin_function.go new file mode 100644 index 0000000..373b1d7 --- /dev/null +++ b/enum_builtin_function.go @@ -0,0 +1,5 @@ +package moonshot + +const ( + BuiltinFunctionWebSearch string = "$web_search" +) diff --git a/enum_chat_completions.go b/enum_chat_completions.go index f9731e9..021d12b 100644 --- a/enum_chat_completions.go +++ b/enum_chat_completions.go @@ -52,7 +52,8 @@ func (c ChatCompletionsFinishReason) String() string { type ChatCompletionsToolType string const ( - ChatCompletionsToolTypeFunction ChatCompletionsToolType = "function" + ChatCompletionsToolTypeFunction ChatCompletionsToolType = "function" + ChatCompletionsToolTypeBuiltinFunction ChatCompletionsToolType = "builtin_function" ) func (c ChatCompletionsToolType) String() string { diff --git a/test/env.go b/test/env.go new file mode 100644 index 0000000..8cf160c --- /dev/null +++ b/test/env.go @@ -0,0 +1,12 @@ +package test + +import ( + "os" +) + +func IsGithubActions() bool { + if val, ok := os.LookupEnv("GITHUB_ACTIONS"); !ok || val != "true" { + return false + } + return true +} diff --git a/test/files.go b/test/files.go index 75ecbed..f39f476 100644 --- a/test/files.go +++ b/test/files.go @@ -6,7 +6,7 @@ import ( ) func GenerateTestContent() []byte { - return []byte("夕阳无限好") + return []byte("夕阳无限好,麦当劳汉堡") } func GenerateTestFile(content []byte) (string, error) {