diff --git a/chat_stream.go b/chat_stream.go index 5154371..9950471 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -44,9 +44,10 @@ type StreamUsage struct { // StreamDelta represents a delta in the chat completion stream. type StreamDelta struct { - Role string `json:"role,omitempty"` // Role of the message. - Content string `json:"content"` // Content of the message. - ReasoningContent string `json:"reasoning_content,omitempty"` // Reasoning content of the message. + Role string `json:"role,omitempty"` // Role of the message. + Content string `json:"content"` // Content of the message. + ReasoningContent string `json:"reasoning_content,omitempty"` // Reasoning content of the message. + ToolCalls []ToolCall `json:"tool_calls,omitempty"` // Optional tool calls related to the message. } // StreamChoices represents a choice in the chat completion stream. diff --git a/chat_stream_test.go b/chat_stream_test.go index bc65a7d..eb27e13 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -3,6 +3,7 @@ package deepseek_test import ( "context" "errors" + "fmt" "io" "testing" @@ -155,3 +156,102 @@ func streamChatCompletion( assert.NotEmpty(t, fullMessage, "should accumulate message content") return fullMessage, nil } + +// TestStreamingWithToolCalls tests the streaming of a tool call. +func TestStreamingWithToolCalls(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + config := testutil.LoadTestConfig(t) + client := deepseek.NewClient(config.APIKey) + ctx, cancel := context.WithTimeout(context.Background(), config.TestTimeout) + defer cancel() + + var message = []deepseek.ChatCompletionMessage{ + { + Role: constants.ChatMessageRoleUser, + Content: "Find restaurants near me. I'm currently in France, Paris.", + }, + } + + req := &deepseek.StreamChatCompletionRequest{ + Model: deepseek.DeepSeekChat, + Messages: message, + Stream: true, + Tools: []deepseek.Tool{ + { + Type: "function", + Function: deepseek.Function{ + Name: "get_user_location", + Description: "Get the user's exact location coordinates", + Parameters: &deepseek.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "country": map[string]interface{}{ + "type": "string", + "description": "Country name", + }, + "city": map[string]interface{}{ + "type": "string", + "description": "City name", + }, + }, + }, + }, + }, + }, + } + + stream, err := client.CreateChatCompletionStream(ctx, req) + require.NoError(t, err) + defer stream.Close() + + var fullMessage string + var toolCallCount int + var fullToolCall struct { + id string + name string + arguments string + } + + for { + resp, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err, "stream should not error") + + if len(resp.Choices) > 0 { + chunk := resp.Choices[0] + + // Track regular content + if chunk.Delta.Content != "" { + fullMessage += chunk.Delta.Content + } + + if len(chunk.Delta.ToolCalls) > 0 { + toolCall := chunk.Delta.ToolCalls[0] + + // If we have a new ID, start tracking a new tool call + if toolCall.ID != "" { + toolCallCount++ + } + + // Update function name if provided + if toolCall.Function.Name != "" { + fullToolCall.name = toolCall.Function.Name + } + + // Append to arguments if provided + if toolCall.Function.Arguments != "" { + fullToolCall.arguments += toolCall.Function.Arguments + } + } + } + } + + assert.Equal(t, 1, toolCallCount, "should make exactly one tool call") + assert.Contains(t, fullToolCall.name, "get_user_location", "should call get_user_location") + + fmt.Printf("Arguments: %s", fullToolCall.arguments) +}