Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ type StreamUsage struct {

// StreamDelta represents a delta in the chat completion stream.
type StreamDelta struct {
Role string `json:"role,omitempty"` // Role of the message.
Content string `json:"content"` // Content of the message.
ReasoningContent string `json:"reasoning_content,omitempty"` // Reasoning content of the message.
Role string `json:"role,omitempty"` // Role of the message.
Content string `json:"content"` // Content of the message.
ReasoningContent string `json:"reasoning_content,omitempty"` // Reasoning content of the message.
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // Optional tool calls related to the message.
}

// StreamChoices represents a choice in the chat completion stream.
Expand Down
100 changes: 100 additions & 0 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package deepseek_test
import (
"context"
"errors"
"fmt"
"io"
"testing"

Expand Down Expand Up @@ -155,3 +156,102 @@ func streamChatCompletion(
assert.NotEmpty(t, fullMessage, "should accumulate message content")
return fullMessage, nil
}

// TestStreamingWithToolCalls tests the streaming of a tool call.
func TestStreamingWithToolCalls(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
config := testutil.LoadTestConfig(t)
client := deepseek.NewClient(config.APIKey)
ctx, cancel := context.WithTimeout(context.Background(), config.TestTimeout)
defer cancel()

var message = []deepseek.ChatCompletionMessage{
{
Role: constants.ChatMessageRoleUser,
Content: "Find restaurants near me. I'm currently in France, Paris.",
},
}

req := &deepseek.StreamChatCompletionRequest{
Model: deepseek.DeepSeekChat,
Messages: message,
Stream: true,
Tools: []deepseek.Tool{
{
Type: "function",
Function: deepseek.Function{
Name: "get_user_location",
Description: "Get the user's exact location coordinates",
Parameters: &deepseek.FunctionParameters{
Type: "object",
Properties: map[string]interface{}{
"country": map[string]interface{}{
"type": "string",
"description": "Country name",
},
"city": map[string]interface{}{
"type": "string",
"description": "City name",
},
},
},
},
},
},
}

stream, err := client.CreateChatCompletionStream(ctx, req)
require.NoError(t, err)
defer stream.Close()

var fullMessage string
var toolCallCount int
var fullToolCall struct {
id string
name string
arguments string
}

for {
resp, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
require.NoError(t, err, "stream should not error")

if len(resp.Choices) > 0 {
chunk := resp.Choices[0]

// Track regular content
if chunk.Delta.Content != "" {
fullMessage += chunk.Delta.Content
}

if len(chunk.Delta.ToolCalls) > 0 {
toolCall := chunk.Delta.ToolCalls[0]

// If we have a new ID, start tracking a new tool call
if toolCall.ID != "" {
toolCallCount++
}

// Update function name if provided
if toolCall.Function.Name != "" {
fullToolCall.name = toolCall.Function.Name
}

// Append to arguments if provided
if toolCall.Function.Arguments != "" {
fullToolCall.arguments += toolCall.Function.Arguments
}
}
}
}

assert.Equal(t, 1, toolCallCount, "should make exactly one tool call")
assert.Contains(t, fullToolCall.name, "get_user_location", "should call get_user_location")

fmt.Printf("Arguments: %s", fullToolCall.arguments)
}