From 7e725df1566ce3843860475edb4d62d4662ed5af Mon Sep 17 00:00:00 2001 From: Sugam Panthi Date: Tue, 25 Mar 2025 15:07:24 -0500 Subject: [PATCH 1/3] Added the ability to stream using tool calls --- chat_stream.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 5154371..9950471 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -44,9 +44,10 @@ type StreamUsage struct { // StreamDelta represents a delta in the chat completion stream. type StreamDelta struct { - Role string `json:"role,omitempty"` // Role of the message. - Content string `json:"content"` // Content of the message. - ReasoningContent string `json:"reasoning_content,omitempty"` // Reasoning content of the message. + Role string `json:"role,omitempty"` // Role of the message. + Content string `json:"content"` // Content of the message. + ReasoningContent string `json:"reasoning_content,omitempty"` // Reasoning content of the message. + ToolCalls []ToolCall `json:"tool_calls,omitempty"` // Optional tool calls related to the message. } // StreamChoices represents a choice in the chat completion stream. From 95813662d81e988eb9485184d2b72680420e3461 Mon Sep 17 00:00:00 2001 From: Sugam Panthi Date: Tue, 25 Mar 2025 15:07:56 -0500 Subject: [PATCH 2/3] Added a new test case for checking stream using tool calls --- chat_stream_test.go | 214 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) diff --git a/chat_stream_test.go b/chat_stream_test.go index bc65a7d..2e25d4d 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -3,6 +3,7 @@ package deepseek_test import ( "context" "errors" + "fmt" "io" "testing" @@ -155,3 +156,216 @@ func streamChatCompletion( assert.NotEmpty(t, fullMessage, "should accumulate message content") return fullMessage, nil } + +// TestStreamingWithToolCall tests the streaming of a tool call. +func TestStreamingWithToolCall(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + config := testutil.LoadTestConfig(t) + client := deepseek.NewClient(config.APIKey) + ctx, cancel := context.WithTimeout(context.Background(), config.TestTimeout) + defer cancel() + + var message = []deepseek.ChatCompletionMessage{ + { + Role: constants.ChatMessageRoleUser, + Content: "Find restaurants near me. I'm currently in Italy", + }, + } + + req := &deepseek.StreamChatCompletionRequest{ + Model: deepseek.DeepSeekChat, + Messages: message, + Stream: true, + Tools: []deepseek.Tool{ + { + Type: "function", + Function: deepseek.Function{ + Name: "get_user_location", + Description: "Get the user's exact location coordinates", + Parameters: &deepseek.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "country": map[string]interface{}{ + "type": "string", + "description": "Country name", + }, + "city": map[string]interface{}{ + "type": "string", + "description": "City name", + }, + }, + }, + }, + }, + }, + } + + stream, err := client.CreateChatCompletionStream(ctx, req) + require.NoError(t, err) + defer stream.Close() + + var fullMessage string + var toolCallCount int + + // Initialize the slice with enough capacity + fullToolCall := make([]struct { + id string + name string + arguments string + }, 2) // Expecting 2 tool calls + + // Track the current tool call index + currentToolCallIndex := -1 + + for { + resp, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err, "stream should not error") + + if len(resp.Choices) > 0 { + chunk := resp.Choices[0] + + // Track regular content + if chunk.Delta.Content != "" { + fullMessage += chunk.Delta.Content + } + + if len(chunk.Delta.ToolCalls) > 0 { + toolCall := chunk.Delta.ToolCalls[0] + + // If we have a new ID, start tracking a new tool call + if toolCall.ID != "" && (currentToolCallIndex == -1 || + fullToolCall[currentToolCallIndex].id != toolCall.ID) { + currentToolCallIndex++ + fullToolCall[currentToolCallIndex].id = toolCall.ID + toolCallCount++ + } + + // Update function name if provided + if toolCall.Function.Name != "" { + fullToolCall[currentToolCallIndex].name = toolCall.Function.Name + } + + // Append to arguments if provided + if toolCall.Function.Arguments != "" { + fullToolCall[currentToolCallIndex].arguments += toolCall.Function.Arguments + } + } + } + } + + t.Logf("Assistant's response: %s", fullMessage) + + // Generate toolCalls for assertions + toolCalls := make([]string, len(fullToolCall)) + for i, tc := range fullToolCall { + toolCalls[i] = tc.name + ": " + tc.arguments + } + + assert.NotEmpty(t, fullMessage, "should accumulate message content") + assert.Equal(t, 2, toolCallCount, "should make exactly two tool calls") + assert.Contains(t, toolCalls[0], "get_user_location", "first tool call should get location") + assert.Contains(t, toolCalls[1], "search_restaurants", "second tool call should search restaurants") +} + +func TestStreamingWithToolCalls(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + config := testutil.LoadTestConfig(t) + client := deepseek.NewClient(config.APIKey) + ctx, cancel := context.WithTimeout(context.Background(), config.TestTimeout) + defer cancel() + + var message = []deepseek.ChatCompletionMessage{ + { + Role: constants.ChatMessageRoleUser, + Content: "Find restaurants near me. I'm currently in France, Paris.", + }, + } + + req := &deepseek.StreamChatCompletionRequest{ + Model: deepseek.DeepSeekChat, + Messages: message, + Stream: true, + Tools: []deepseek.Tool{ + { + Type: "function", + Function: deepseek.Function{ + Name: "get_user_location", + Description: "Get the user's exact location coordinates", + Parameters: &deepseek.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "country": map[string]interface{}{ + "type": "string", + "description": "Country name", + }, + "city": map[string]interface{}{ + "type": "string", + "description": "City name", + }, + }, + }, + }, + }, + }, + } + + stream, err := client.CreateChatCompletionStream(ctx, req) + require.NoError(t, err) + defer stream.Close() + + var fullMessage string + var toolCallCount int + var fullToolCall struct { + id string + name string + arguments string + } + + for { + resp, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err, "stream should not error") + + if len(resp.Choices) > 0 { + chunk := resp.Choices[0] + + // Track regular content + if chunk.Delta.Content != "" { + fullMessage += chunk.Delta.Content + } + + if len(chunk.Delta.ToolCalls) > 0 { + toolCall := chunk.Delta.ToolCalls[0] + + // If we have a new ID, start tracking a new tool call + if toolCall.ID != "" { + toolCallCount++ + } + + // Update function name if provided + if toolCall.Function.Name != "" { + fullToolCall.name = toolCall.Function.Name + } + + // Append to arguments if provided + if toolCall.Function.Arguments != "" { + fullToolCall.arguments += toolCall.Function.Arguments + } + } + } + } + + assert.Equal(t, 1, toolCallCount, "should make exactly one tool call") + assert.Contains(t, fullToolCall.name, "get_user_location", "should call get_user_location") + + fmt.Printf("Arguments: %s", fullToolCall.arguments) +} From 4113cf15419f3a197c1e366fbaeccb300b4a42ac Mon Sep 17 00:00:00 2001 From: Sugam Panthi Date: Tue, 25 Mar 2025 15:18:13 -0500 Subject: [PATCH 3/3] Fixed the issue with 2 tests --- chat_stream_test.go | 116 +------------------------------------------- 1 file changed, 1 insertion(+), 115 deletions(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index 2e25d4d..eb27e13 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -157,121 +157,7 @@ func streamChatCompletion( return fullMessage, nil } -// TestStreamingWithToolCall tests the streaming of a tool call. -func TestStreamingWithToolCall(t *testing.T) { - if testing.Short() { - t.Skip("skipping test in short mode.") - } - config := testutil.LoadTestConfig(t) - client := deepseek.NewClient(config.APIKey) - ctx, cancel := context.WithTimeout(context.Background(), config.TestTimeout) - defer cancel() - - var message = []deepseek.ChatCompletionMessage{ - { - Role: constants.ChatMessageRoleUser, - Content: "Find restaurants near me. I'm currently in Italy", - }, - } - - req := &deepseek.StreamChatCompletionRequest{ - Model: deepseek.DeepSeekChat, - Messages: message, - Stream: true, - Tools: []deepseek.Tool{ - { - Type: "function", - Function: deepseek.Function{ - Name: "get_user_location", - Description: "Get the user's exact location coordinates", - Parameters: &deepseek.FunctionParameters{ - Type: "object", - Properties: map[string]interface{}{ - "country": map[string]interface{}{ - "type": "string", - "description": "Country name", - }, - "city": map[string]interface{}{ - "type": "string", - "description": "City name", - }, - }, - }, - }, - }, - }, - } - - stream, err := client.CreateChatCompletionStream(ctx, req) - require.NoError(t, err) - defer stream.Close() - - var fullMessage string - var toolCallCount int - - // Initialize the slice with enough capacity - fullToolCall := make([]struct { - id string - name string - arguments string - }, 2) // Expecting 2 tool calls - - // Track the current tool call index - currentToolCallIndex := -1 - - for { - resp, err := stream.Recv() - if errors.Is(err, io.EOF) { - break - } - require.NoError(t, err, "stream should not error") - - if len(resp.Choices) > 0 { - chunk := resp.Choices[0] - - // Track regular content - if chunk.Delta.Content != "" { - fullMessage += chunk.Delta.Content - } - - if len(chunk.Delta.ToolCalls) > 0 { - toolCall := chunk.Delta.ToolCalls[0] - - // If we have a new ID, start tracking a new tool call - if toolCall.ID != "" && (currentToolCallIndex == -1 || - fullToolCall[currentToolCallIndex].id != toolCall.ID) { - currentToolCallIndex++ - fullToolCall[currentToolCallIndex].id = toolCall.ID - toolCallCount++ - } - - // Update function name if provided - if toolCall.Function.Name != "" { - fullToolCall[currentToolCallIndex].name = toolCall.Function.Name - } - - // Append to arguments if provided - if toolCall.Function.Arguments != "" { - fullToolCall[currentToolCallIndex].arguments += toolCall.Function.Arguments - } - } - } - } - - t.Logf("Assistant's response: %s", fullMessage) - - // Generate toolCalls for assertions - toolCalls := make([]string, len(fullToolCall)) - for i, tc := range fullToolCall { - toolCalls[i] = tc.name + ": " + tc.arguments - } - - assert.NotEmpty(t, fullMessage, "should accumulate message content") - assert.Equal(t, 2, toolCallCount, "should make exactly two tool calls") - assert.Contains(t, toolCalls[0], "get_user_location", "first tool call should get location") - assert.Contains(t, toolCalls[1], "search_restaurants", "second tool call should search restaurants") -} - +// TestStreamingWithToolCalls tests the streaming of a tool call. func TestStreamingWithToolCalls(t *testing.T) { if testing.Short() { t.Skip("skipping test in short mode.")