From cee36515522910e03c8c6a0b9d1759f99844ed45 Mon Sep 17 00:00:00 2001 From: SachaMorard <2254275+SachaMorard@users.noreply.github.com> Date: Wed, 14 Jan 2026 07:26:15 +0100 Subject: [PATCH] feat: sdk alignement --- README.md | 226 ++------- edgee/edgee.go | 143 +++--- edgee/edgee_test.go | 1172 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 1302 insertions(+), 239 deletions(-) create mode 100644 edgee/edgee_test.go diff --git a/README.md b/README.md index ff747ba..3321548 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,9 @@ -# Edgee Gateway SDK for Go +# Edgee Go SDK -Lightweight Go SDK for Edgee AI Gateway. +Lightweight, type-safe Go SDK for the [Edgee AI Gateway](https://www.edgee.cloud). + +[![Go Reference](https://pkg.go.dev/badge/github.com/edgee-cloud/go-sdk.svg)](https://pkg.go.dev/github.com/edgee-cloud/go-sdk) +[![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE) ## Installation @@ -8,87 +11,52 @@ Lightweight Go SDK for Edgee AI Gateway. go get github.com/edgee-cloud/go-sdk/edgee ``` -## Usage +## Quick Start ```go -import "github.com/edgee-cloud/go-sdk/edgee" - -// Create client (uses EDGEE_API_KEY environment variable) -client, err := edgee.NewClient(nil) -if err != nil { - log.Fatal(err) -} - -// Or create with explicit API key -client, err := edgee.NewClient("your-api-key") - -// Or create with full config -client, err := edgee.NewClient(&edgee.Config{ - APIKey: "your-api-key", - BaseURL: "https://api.edgee.ai", -}) -``` +package main + +import ( + "fmt" + "log" + "github.com/edgee-cloud/go-sdk/edgee" +) + +func main() { + client, err := edgee.NewClient("your-api-key") + if err != nil { + log.Fatal(err) + } -### Simple Input + response, err := client.Send("gpt-4o", "What is the capital of France?") + if err != nil { + log.Fatal(err) + } -```go -response, err := client.ChatCompletion("gpt-4o", "What is the capital of France?") -if err != nil { - log.Fatal(err) + fmt.Println(response.Text()) + // "The capital of France is Paris." } - -fmt.Println(response.Text()) ``` -### Full Input with Messages +## Send Method + +The `Send()` method makes non-streaming chat completion requests: ```go -response, err := client.ChatCompletion("gpt-4o", map[string]interface{}{ - "messages": []map[string]string{ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello!"}, - }, -}) +response, err := client.Send("gpt-4o", "Hello, world!") if err != nil { log.Fatal(err) } -fmt.Println(response.Text()) +// Access response +fmt.Println(response.Text()) // Text content +fmt.Println(response.FinishReason()) // Finish reason +fmt.Println(response.ToolCalls()) // Tool calls (if any) ``` -### With Tools +## Stream Method -```go -response, err := client.ChatCompletion("gpt-4o", map[string]interface{}{ - "messages": []map[string]string{ - {"role": "user", "content": "What's the weather in Paris?"}, - }, - "tools": []map[string]interface{}{ - { - "type": "function", - "function": map[string]interface{}{ - "name": "get_weather", - "description": "Get weather for a location", - "parameters": map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "location": map[string]string{"type": "string"}, - }, - }, - }, - }, - }, - "tool_choice": "auto", -}) - -if toolCalls := response.ToolCalls(); len(toolCalls) > 0 { - fmt.Printf("Tool calls: %+v\n", toolCalls) -} -``` - -### Streaming - -Access chunk properties for streaming: +The `Stream()` method enables real-time streaming responses: ```go chunkChan, errChan := client.Stream("gpt-4o", "Tell me a story") @@ -102,35 +70,9 @@ for { if text := chunk.Text(); text != "" { fmt.Print(text) } - case err := <-errChan: - if err != nil { - log.Fatal(err) - } - } -} -``` - -#### Accessing Full Chunk Data - -When you need complete access to the streaming response: - -```go -chunkChan, errChan := client.Stream("gpt-4o", "Hello") - -for { - select { - case chunk, ok := <-chunkChan: - if !ok { - return - } - if role := chunk.Role(); role != "" { - fmt.Printf("Role: %s\n", role) - } - if text := chunk.Text(); text != "" { - fmt.Print(text) - } + if reason := chunk.FinishReason(); reason != "" { - fmt.Printf("\nFinish: %s\n", reason) + fmt.Printf("\nFinished: %s\n", reason) } case err := <-errChan: if err != nil { @@ -140,89 +82,27 @@ for { } ``` -## Response Types - -### SendResponse - -```go -type SendResponse struct { - ID string - Object string - Created int64 - Model string - Choices []ChatCompletionChoice - Usage *Usage -} +## Features -// Convenience methods for easy access -func (r *SendResponse) Text() string // Shortcut for Choices[0].Message.Content -func (r *SendResponse) MessageContent() *Message // Shortcut for Choices[0].Message -func (r *SendResponse) FinishReason() string // Shortcut for Choices[0].FinishReason -func (r *SendResponse) ToolCalls() []ToolCall // Shortcut for Choices[0].Message.ToolCalls -``` +- ✅ **Type-safe** - Strong typing with Go structs and interfaces +- ✅ **OpenAI-compatible** - Works with any model supported by Edgee +- ✅ **Streaming** - Real-time response streaming with channels +- ✅ **Tool calling** - Full support for function calling +- ✅ **Flexible input** - Accept strings, InputObject, or maps +- ✅ **Minimal dependencies** - Uses only standard library and essential packages -### ChatCompletionChoice +## Documentation -```go -type ChatCompletionChoice struct { - Index int - Message *Message // For non-streaming responses - Delta *ChatCompletionDelta // For streaming responses - FinishReason *string -} -``` - -### Message +For complete documentation, examples, and API reference, visit: -```go -type Message struct { - Role string - Content string - Name *string - ToolCalls []ToolCall - ToolCallID *string -} -``` - -### Usage - -```go -type Usage struct { - PromptTokens int - CompletionTokens int - TotalTokens int -} -``` - -### Streaming Response - -```go -type StreamChunk struct { - ID string - Object string - Created int64 - Model string - Choices []ChatCompletionChoice -} - -// Convenience methods for easy access -func (c *StreamChunk) Text() string // Shortcut for Choices[0].Delta.Content -func (c *StreamChunk) Role() string // Shortcut for Choices[0].Delta.Role -func (c *StreamChunk) FinishReason() string // Shortcut for Choices[0].FinishReason -``` - -### ChatCompletionDelta - -```go -type ChatCompletionDelta struct { - Role *string // Only present in first chunk - Content *string - ToolCalls []ToolCall -} -``` +**👉 [Official Go SDK Documentation](https://www.edgee.cloud/docs/sdk/go)** -To learn more about this SDK, please refer to the [dedicated documentation](https://www.edgee.cloud/docs/sdk/go). +The documentation includes: +- [Configuration guide](https://www.edgee.cloud/docs/sdk/go/configuration) - Multiple ways to configure the SDK +- [Send method](https://www.edgee.cloud/docs/sdk/go/send) - Complete guide to non-streaming requests +- [Stream method](https://www.edgee.cloud/docs/sdk/go/stream) - Streaming responses guide +- [Tools](https://www.edgee.cloud/docs/sdk/go/tools) - Function calling guide ## License -Apache-2.0 +Licensed under the Apache License, Version 2.0. See [LICENSE](LICENSE) for details. diff --git a/edgee/edgee.go b/edgee/edgee.go index 7c5c6f8..93a6585 100644 --- a/edgee/edgee.go +++ b/edgee/edgee.go @@ -49,40 +49,46 @@ type Tool struct { // FunctionDefinition defines a function tool type FunctionDefinition struct { - Name string `json:"name"` - Description *string `json:"description,omitempty"` - Parameters map[string]interface{} `json:"parameters,omitempty"` + Name string `json:"name"` + Description *string `json:"description,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` } // InputObject represents structured input for chat completion type InputObject struct { - Messages []Message `json:"messages"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` // string or object + Messages []Message `json:"messages"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` // string or object } -// ChatCompletionRequest represents the request body for chat completions -type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Stream bool `json:"stream,omitempty"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` +// Request represents the request body for chat completions +type Request struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream bool `json:"stream,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` } -// ChatCompletionDelta represents a streaming chunk delta -type ChatCompletionDelta struct { +// StreamDelta represents a streaming chunk delta +type StreamDelta struct { Role *string `json:"role,omitempty"` Content *string `json:"content,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` } -// ChatCompletionChoice represents a choice in the response -type ChatCompletionChoice struct { - Index int `json:"index"` - Message *Message `json:"message,omitempty"` - Delta *ChatCompletionDelta `json:"delta,omitempty"` - FinishReason *string `json:"finish_reason,omitempty"` +// Choice represents a choice in the response +type Choice struct { + Index int `json:"index"` + Message *Message `json:"message,omitempty"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +// StreamChoice represents a choice in the streaming response +type StreamChoice struct { + Index int `json:"index"` + Delta *StreamDelta `json:"delta,omitempty"` + FinishReason *string `json:"finish_reason,omitempty"` } // Usage represents token usage information @@ -94,12 +100,12 @@ type Usage struct { // SendResponse represents the response from a non-streaming request type SendResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` - Usage *Usage `json:"usage,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []Choice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` } // Text returns the text content from the first choice (convenience method) @@ -136,11 +142,11 @@ func (r *SendResponse) ToolCalls() []ToolCall { // StreamChunk represents a streaming response chunk type StreamChunk struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []StreamChoice `json:"choices"` } // Text returns the text content from the first choice (convenience method) @@ -183,7 +189,7 @@ type Client struct { // - Pass a string to set the API key directly // - Pass a *Config to set both API key and base URL // - Pass nil to use environment variables (EDGEE_API_KEY, EDGEE_BASE_URL) -func NewClient(config interface{}) (*Client, error) { +func NewClient(config any) (*Client, error) { var apiKey, baseURL string switch v := config.(type) { @@ -229,31 +235,28 @@ func NewClient(config interface{}) (*Client, error) { // Send sends a chat completion request with flexible input: // - Pass a string for simple user input // - Pass an InputObject for full control -// - Pass a map[string]interface{} with "messages", "tools", "tool_choice" keys -func (c *Client) Send(model string, input interface{}, stream bool) (interface{}, error) { - req, err := c.buildRequest(model, input, stream) +// - Pass a map[string]any with "messages", "tools", "tool_choice" keys +func (c *Client) Send(model string, input any) (response SendResponse, err error) { + req, err := c.buildRequest(model, input, false) if err != nil { - return nil, err - } - - if stream { - return c.handleStreamingResponse(req) + return } - return c.handleNonStreamingResponse(req) + response, err = c.handleNonStreamingResponse(req) + return } // ChatCompletion sends a non-streaming chat completion request (convenience method) -func (c *Client) ChatCompletion(model string, input interface{}) (*SendResponse, error) { - result, err := c.Send(model, input, false) +func (c *Client) ChatCompletion(model string, input any) (response SendResponse, err error) { + response, err = c.Send(model, input) if err != nil { - return nil, err + return } - return result.(*SendResponse), nil + return } // Stream sends a streaming chat completion request (convenience method) -func (c *Client) Stream(model string, input interface{}) (<-chan *StreamChunk, <-chan error) { - result, err := c.Send(model, input, true) +func (c *Client) Stream(model string, input any) (<-chan *StreamChunk, <-chan error) { + req, err := c.buildRequest(model, input, true) if err != nil { errChan := make(chan error, 1) errChan <- err @@ -263,15 +266,21 @@ func (c *Client) Stream(model string, input interface{}) (<-chan *StreamChunk, < return chunkChan, errChan } - streamResult := result.(struct { - ChunkChan <-chan *StreamChunk - ErrChan <-chan error - }) - return streamResult.ChunkChan, streamResult.ErrChan + result, err := c.handleStreamingResponse(req) + if err != nil { + errChan := make(chan error, 1) + errChan <- err + close(errChan) + chunkChan := make(chan *StreamChunk) + close(chunkChan) + return chunkChan, errChan + } + + return result.ChunkChan, result.ErrChan } -func (c *Client) buildRequest(model string, input interface{}, stream bool) (*ChatCompletionRequest, error) { - req := &ChatCompletionRequest{ +func (c *Client) buildRequest(model string, input any, stream bool) (*Request, error) { + req := &Request{ Model: model, Stream: stream, } @@ -288,7 +297,7 @@ func (c *Client) buildRequest(model string, input interface{}, stream bool) (*Ch req.Messages = v.Messages req.Tools = v.Tools req.ToolChoice = v.ToolChoice - case map[string]interface{}: + case map[string]any: // Map input if messages, ok := v["messages"]; ok { msgBytes, err := json.Marshal(messages) @@ -318,15 +327,15 @@ func (c *Client) buildRequest(model string, input interface{}, stream bool) (*Ch return req, nil } -func (c *Client) handleNonStreamingResponse(req *ChatCompletionRequest) (*SendResponse, error) { +func (c *Client) handleNonStreamingResponse(req *Request) (response SendResponse, err error) { body, err := json.Marshal(req) if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) + return response, fmt.Errorf("failed to marshal request: %w", err) } httpReq, err := http.NewRequest("POST", c.baseURL+APIEndpoint, bytes.NewReader(body)) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return response, fmt.Errorf("failed to create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") @@ -335,24 +344,26 @@ func (c *Client) handleNonStreamingResponse(req *ChatCompletionRequest) (*SendRe client := &http.Client{} resp, err := client.Do(httpReq) if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) + return response, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) + return response, fmt.Errorf("API error %d: %s", resp.StatusCode, string(bodyBytes)) } - var result SendResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return response, fmt.Errorf("failed to decode response: %w", err) } - return &result, nil + return } -func (c *Client) handleStreamingResponse(req *ChatCompletionRequest) (interface{}, error) { +func (c *Client) handleStreamingResponse(req *Request) (struct { + ChunkChan <-chan *StreamChunk + ErrChan <-chan error +}, error) { chunkChan := make(chan *StreamChunk, 10) errChan := make(chan error, 1) diff --git a/edgee/edgee_test.go b/edgee/edgee_test.go new file mode 100644 index 0000000..68317af --- /dev/null +++ b/edgee/edgee_test.go @@ -0,0 +1,1172 @@ +package edgee + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" +) + +func TestNewClient(t *testing.T) { + // Save original environment variables + originalAPIKey := os.Getenv("EDGEE_API_KEY") + originalBaseURL := os.Getenv("EDGEE_BASE_URL") + defer func() { + if originalAPIKey != "" { + os.Setenv("EDGEE_API_KEY", originalAPIKey) + } else { + os.Unsetenv("EDGEE_API_KEY") + } + if originalBaseURL != "" { + os.Setenv("EDGEE_BASE_URL", originalBaseURL) + } else { + os.Unsetenv("EDGEE_BASE_URL") + } + }() + + t.Run("with string API key", func(t *testing.T) { + os.Unsetenv("EDGEE_API_KEY") + os.Unsetenv("EDGEE_BASE_URL") + + client, err := NewClient("test-api-key") + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if client == nil { + t.Fatal("Expected client, got nil") + } + }) + + t.Run("with empty string API key", func(t *testing.T) { + os.Unsetenv("EDGEE_API_KEY") + os.Unsetenv("EDGEE_BASE_URL") + + _, err := NewClient("") + if err == nil { + t.Fatal("Expected error for empty API key") + } + if !strings.Contains(err.Error(), "EDGEE_API_KEY is not set") { + t.Errorf("Expected error about EDGEE_API_KEY, got %v", err) + } + }) + + t.Run("with Config struct", func(t *testing.T) { + os.Unsetenv("EDGEE_API_KEY") + os.Unsetenv("EDGEE_BASE_URL") + + config := &Config{ + APIKey: "test-key", + BaseURL: "https://custom.example.com", + } + client, err := NewClient(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if client == nil { + t.Fatal("Expected client, got nil") + } + }) + + t.Run("with Config struct and empty baseURL uses default", func(t *testing.T) { + os.Unsetenv("EDGEE_API_KEY") + os.Unsetenv("EDGEE_BASE_URL") + + config := &Config{ + APIKey: "test-key", + } + client, err := NewClient(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if client == nil { + t.Fatal("Expected client, got nil") + } + }) + + t.Run("with nil uses environment variables", func(t *testing.T) { + os.Setenv("EDGEE_API_KEY", "env-api-key") + os.Setenv("EDGEE_BASE_URL", "https://env-base-url.example.com") + + client, err := NewClient(nil) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if client == nil { + t.Fatal("Expected client, got nil") + } + }) + + t.Run("with nil and no env vars fails", func(t *testing.T) { + os.Unsetenv("EDGEE_API_KEY") + os.Unsetenv("EDGEE_BASE_URL") + + _, err := NewClient(nil) + if err == nil { + t.Fatal("Expected error when no API key provided") + } + if !strings.Contains(err.Error(), "EDGEE_API_KEY is not set") { + t.Errorf("Expected error about EDGEE_API_KEY, got %v", err) + } + }) + + t.Run("with nil and only EDGEE_API_KEY uses default baseURL", func(t *testing.T) { + os.Setenv("EDGEE_API_KEY", "env-api-key") + os.Unsetenv("EDGEE_BASE_URL") + + client, err := NewClient(nil) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if client == nil { + t.Fatal("Expected client, got nil") + } + }) + + t.Run("with unsupported config type", func(t *testing.T) { + _, err := NewClient(123) + if err == nil { + t.Fatal("Expected error for unsupported config type") + } + if !strings.Contains(err.Error(), "unsupported config type") { + t.Errorf("Expected error about unsupported config type, got %v", err) + } + }) + + t.Run("with Config struct and empty APIKey falls back to env", func(t *testing.T) { + os.Setenv("EDGEE_API_KEY", "env-api-key") + os.Unsetenv("EDGEE_BASE_URL") + + config := &Config{ + APIKey: "", + } + client, err := NewClient(config) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if client == nil { + t.Fatal("Expected client, got nil") + } + }) +} + +func TestClient_Send(t *testing.T) { + t.Run("with string input", func(t *testing.T) { + mockResponse := SendResponse{ + ID: "test-id", + Object: "chat.completion", + Created: 1234567890, + Model: "gpt-4", + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + Content: "Hello, world!", + }, + FinishReason: stringPtr("stop"), + }, + }, + Usage: &Usage{ + PromptTokens: 10, + CompletionTokens: 5, + TotalTokens: 15, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST, got %s", r.Method) + } + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("Expected /v1/chat/completions, got %s", r.URL.Path) + } + if r.Header.Get("Authorization") != "Bearer test-api-key" { + t.Errorf("Expected Bearer token, got %s", r.Header.Get("Authorization")) + } + + var req Request + body, _ := io.ReadAll(r.Body) + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to unmarshal request: %v", err) + } + + if req.Model != "gpt-4" { + t.Errorf("Expected model gpt-4, got %s", req.Model) + } + if len(req.Messages) != 1 { + t.Errorf("Expected 1 message, got %d", len(req.Messages)) + } + if req.Messages[0].Role != "user" || req.Messages[0].Content != "Hello" { + t.Errorf("Expected user message 'Hello', got %+v", req.Messages[0]) + } + if req.Stream { + t.Error("Expected stream=false") + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResponse) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + response, err := client.Send("gpt-4", "Hello") + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if response.Choices[0].Message.Content != "Hello, world!" { + t.Errorf("Expected 'Hello, world!', got %s", response.Choices[0].Message.Content) + } + if response.Usage.TotalTokens != 15 { + t.Errorf("Expected 15 total tokens, got %d", response.Usage.TotalTokens) + } + }) + + t.Run("with InputObject", func(t *testing.T) { + mockResponse := SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + Content: "Response", + }, + FinishReason: stringPtr("stop"), + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Request + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + if len(req.Messages) != 2 { + t.Errorf("Expected 2 messages, got %d", len(req.Messages)) + } + if req.Messages[0].Role != "system" { + t.Errorf("Expected first message role 'system', got %s", req.Messages[0].Role) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResponse) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + input := InputObject{ + Messages: []Message{ + {Role: "system", Content: "You are a helpful assistant"}, + {Role: "user", Content: "Hello"}, + }, + } + + response, err := client.Send("gpt-4", input) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if response.Choices[0].Message.Content != "Response" { + t.Errorf("Expected 'Response', got %s", response.Choices[0].Message.Content) + } + }) + + t.Run("with InputObject pointer", func(t *testing.T) { + mockResponse := SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + Content: "Response", + }, + FinishReason: stringPtr("stop"), + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResponse) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + input := &InputObject{ + Messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + } + + _, err := client.Send("gpt-4", input) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + }) + + t.Run("with map[string]any", func(t *testing.T) { + mockResponse := SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + Content: "Response", + }, + FinishReason: stringPtr("stop"), + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Request + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + if len(req.Messages) != 1 { + t.Errorf("Expected 1 message, got %d", len(req.Messages)) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResponse) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + input := map[string]any{ + "messages": []map[string]any{ + {"role": "user", "content": "Hello"}, + }, + } + + _, err := client.Send("gpt-4", input) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + }) + + t.Run("with tools", func(t *testing.T) { + mockResponse := SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_123", + Type: "function", + Function: FunctionCall{ + Name: "get_weather", + Arguments: `{"location": "San Francisco"}`, + }, + }, + }, + }, + FinishReason: stringPtr("tool_calls"), + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Request + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + if len(req.Tools) != 1 { + t.Errorf("Expected 1 tool, got %d", len(req.Tools)) + } + if req.Tools[0].Function.Name != "get_weather" { + t.Errorf("Expected tool name 'get_weather', got %s", req.Tools[0].Function.Name) + } + if req.ToolChoice != "auto" { + t.Errorf("Expected tool_choice 'auto', got %v", req.ToolChoice) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResponse) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + tools := []Tool{ + { + Type: "function", + Function: FunctionDefinition{ + Name: "get_weather", + Description: stringPtr("Get the weather for a location"), + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + }, + }, + }, + } + + input := InputObject{ + Messages: []Message{ + {Role: "user", Content: "What is the weather?"}, + }, + Tools: tools, + ToolChoice: "auto", + } + + response, err := client.Send("gpt-4", input) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if len(response.Choices[0].Message.ToolCalls) == 0 { + t.Error("Expected tool calls in response") + } + }) + + t.Run("with tool_choice object", func(t *testing.T) { + mockResponse := SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + Content: "", + }, + FinishReason: stringPtr("tool_calls"), + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Request + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + toolChoiceMap, ok := req.ToolChoice.(map[string]any) + if !ok { + t.Errorf("Expected tool_choice to be map, got %T", req.ToolChoice) + } + if toolChoiceMap["type"] != "function" { + t.Errorf("Expected tool_choice type 'function', got %v", toolChoiceMap["type"]) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResponse) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + input := InputObject{ + Messages: []Message{ + {Role: "user", Content: "Test"}, + }, + ToolChoice: map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "specific_function", + }, + }, + } + + _, err := client.Send("gpt-4", input) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + }) + + t.Run("without usage field", func(t *testing.T) { + mockResponse := SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + Content: "Response", + }, + FinishReason: stringPtr("stop"), + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResponse) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + response, err := client.Send("gpt-4", "Test") + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if response.Usage != nil { + t.Error("Expected nil usage") + } + if len(response.Choices) != 1 { + t.Errorf("Expected 1 choice, got %d", len(response.Choices)) + } + }) + + t.Run("with multiple choices", func(t *testing.T) { + mockResponse := SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + Content: "First response", + }, + FinishReason: stringPtr("stop"), + }, + { + Index: 1, + Message: &Message{ + Role: "assistant", + Content: "Second response", + }, + FinishReason: stringPtr("stop"), + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResponse) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + response, err := client.Send("gpt-4", "Test") + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if len(response.Choices) != 2 { + t.Errorf("Expected 2 choices, got %d", len(response.Choices)) + } + if response.Choices[0].Message.Content != "First response" { + t.Errorf("Expected 'First response', got %s", response.Choices[0].Message.Content) + } + if response.Choices[1].Message.Content != "Second response" { + t.Errorf("Expected 'Second response', got %s", response.Choices[1].Message.Content) + } + }) + + t.Run("with custom baseURL", func(t *testing.T) { + mockResponse := SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + Content: "Response", + }, + FinishReason: stringPtr("stop"), + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResponse) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + _, err := client.Send("gpt-4", "Test") + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + }) + + t.Run("with API error", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("Unauthorized")) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + _, err := client.Send("gpt-4", "Test") + if err == nil { + t.Fatal("Expected error for 401 status") + } + if !strings.Contains(err.Error(), "API error 401") { + t.Errorf("Expected error about API error 401, got %v", err) + } + }) + + t.Run("with 500 error", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal Server Error")) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + _, err := client.Send("gpt-4", "Test") + if err == nil { + t.Fatal("Expected error for 500 status") + } + if !strings.Contains(err.Error(), "API error 500") { + t.Errorf("Expected error about API error 500, got %v", err) + } + }) + + t.Run("with unsupported input type", func(t *testing.T) { + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + }) + + _, err := client.Send("gpt-4", 123) + if err == nil { + t.Fatal("Expected error for unsupported input type") + } + if !strings.Contains(err.Error(), "unsupported input type") { + t.Errorf("Expected error about unsupported input type, got %v", err) + } + }) +} + +func TestSendResponse_ConvenienceMethods(t *testing.T) { + t.Run("Text method", func(t *testing.T) { + response := &SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + Content: "Hello, world!", + }, + }, + }, + } + + if response.Text() != "Hello, world!" { + t.Errorf("Expected 'Hello, world!', got %s", response.Text()) + } + }) + + t.Run("Text method with empty choices", func(t *testing.T) { + response := &SendResponse{ + Choices: []Choice{}, + } + + if response.Text() != "" { + t.Errorf("Expected empty string, got %s", response.Text()) + } + }) + + t.Run("Text method with nil message", func(t *testing.T) { + response := &SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: nil, + }, + }, + } + + if response.Text() != "" { + t.Errorf("Expected empty string, got %s", response.Text()) + } + }) + + t.Run("MessageContent method", func(t *testing.T) { + msg := &Message{ + Role: "assistant", + Content: "Hello, world!", + } + response := &SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: msg, + }, + }, + } + + if response.MessageContent() != msg { + t.Error("Expected message to match") + } + }) + + t.Run("MessageContent method with empty choices", func(t *testing.T) { + response := &SendResponse{ + Choices: []Choice{}, + } + + if response.MessageContent() != nil { + t.Error("Expected nil message") + } + }) + + t.Run("FinishReason method", func(t *testing.T) { + response := &SendResponse{ + Choices: []Choice{ + { + Index: 0, + FinishReason: stringPtr("stop"), + }, + }, + } + + if response.FinishReason() != "stop" { + t.Errorf("Expected 'stop', got %s", response.FinishReason()) + } + }) + + t.Run("FinishReason method with nil", func(t *testing.T) { + response := &SendResponse{ + Choices: []Choice{ + { + Index: 0, + FinishReason: nil, + }, + }, + } + + if response.FinishReason() != "" { + t.Errorf("Expected empty string, got %s", response.FinishReason()) + } + }) + + t.Run("ToolCalls method", func(t *testing.T) { + toolCalls := []ToolCall{ + { + ID: "call_123", + Type: "function", + Function: FunctionCall{ + Name: "get_weather", + Arguments: `{"location": "San Francisco"}`, + }, + }, + } + response := &SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + ToolCalls: toolCalls, + }, + }, + }, + } + + result := response.ToolCalls() + if len(result) != 1 { + t.Errorf("Expected 1 tool call, got %d", len(result)) + } + if result[0].ID != "call_123" { + t.Errorf("Expected tool call ID 'call_123', got %s", result[0].ID) + } + }) + + t.Run("ToolCalls method with empty choices", func(t *testing.T) { + response := &SendResponse{ + Choices: []Choice{}, + } + + if response.ToolCalls() != nil { + t.Error("Expected nil tool calls") + } + }) +} + +func TestClient_ChatCompletion(t *testing.T) { + mockResponse := SendResponse{ + Choices: []Choice{ + { + Index: 0, + Message: &Message{ + Role: "assistant", + Content: "Hello, world!", + }, + FinishReason: stringPtr("stop"), + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResponse) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + result, err := client.ChatCompletion("gpt-4", "Hello") + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if result.Choices[0].Message.Content != "Hello, world!" { + t.Errorf("Expected 'Hello, world!', got %s", result.Choices[0].Message.Content) + } +} + +func TestClient_Stream(t *testing.T) { + t.Run("with string input", func(t *testing.T) { + mockChunks := []string{ + `{"id":"test","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`, + `{"id":"test","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`, + `{"id":"test","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}`, + `{"id":"test","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Request + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &req) + + if !req.Stream { + t.Error("Expected stream=true") + } + + w.Header().Set("Content-Type", "text/event-stream") + for _, chunk := range mockChunks { + fmt.Fprintf(w, "data: %s\n\n", chunk) + } + fmt.Fprintf(w, "data: [DONE]\n\n") + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + chunkChan, errChan := client.Stream("gpt-4", "Hello") + + chunks := []*StreamChunk{} + for chunk := range chunkChan { + chunks = append(chunks, chunk) + } + + // Check for errors + select { + case err := <-errChan: + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + default: + } + + if len(chunks) != 4 { + t.Errorf("Expected 4 chunks, got %d", len(chunks)) + return + } + if chunks[0].Role() != "assistant" { + t.Errorf("Expected role 'assistant', got %s", chunks[0].Role()) + } + if chunks[1].Text() != "Hello" { + t.Errorf("Expected 'Hello', got %s", chunks[1].Text()) + } + if chunks[2].Text() != " world" { + t.Errorf("Expected ' world', got %s", chunks[2].Text()) + } + if chunks[3].FinishReason() != "stop" { + t.Errorf("Expected finish_reason 'stop', got %s", chunks[3].FinishReason()) + } + }) + + t.Run("with InputObject", func(t *testing.T) { + mockChunk := `{"id":"test","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"Response"},"finish_reason":null}]}` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "data: %s\n\n", mockChunk) + fmt.Fprintf(w, "data: [DONE]\n\n") + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + input := InputObject{ + Messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + } + + chunkChan, errChan := client.Stream("gpt-4", input) + + chunks := []*StreamChunk{} + for chunk := range chunkChan { + chunks = append(chunks, chunk) + } + + // Check for errors + select { + case err := <-errChan: + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + default: + } + + if len(chunks) != 1 { + t.Errorf("Expected 1 chunk, got %d", len(chunks)) + } + if chunks[0].Text() != "Response" { + t.Errorf("Expected 'Response', got %s", chunks[0].Text()) + } + }) + + t.Run("with streaming error", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("Rate limit exceeded")) + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + chunkChan, errChan := client.Stream("gpt-4", "Hello") + + // Wait for error + err := <-errChan + if err == nil { + t.Fatal("Expected error") + } + if !strings.Contains(err.Error(), "API error 429") { + t.Errorf("Expected error about API error 429, got %v", err) + } + + // Channel should be closed + _, ok := <-chunkChan + if ok { + t.Error("Expected chunk channel to be closed") + } + }) + + t.Run("skips malformed JSON", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "data: {invalid json}\n\n") + fmt.Fprintf(w, `data: {"id":"test","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"Valid"},"finish_reason":null}]}`+"\n\n") + fmt.Fprintf(w, "data: [DONE]\n\n") + })) + defer server.Close() + + client, _ := NewClient(&Config{ + APIKey: "test-api-key", + BaseURL: server.URL, + }) + + chunkChan, errChan := client.Stream("gpt-4", "Hello") + + chunks := []*StreamChunk{} + for chunk := range chunkChan { + chunks = append(chunks, chunk) + } + + // Check for errors + select { + case err := <-errChan: + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + default: + } + + // Should skip the malformed JSON and only return the valid chunk + if len(chunks) != 1 { + t.Errorf("Expected 1 chunk, got %d", len(chunks)) + } + if len(chunks) > 0 && chunks[0].Text() != "Valid" { + t.Errorf("Expected 'Valid', got %s", chunks[0].Text()) + } + }) +} + +func TestStreamChunk_ConvenienceMethods(t *testing.T) { + t.Run("Text method", func(t *testing.T) { + content := "Hello" + chunk := &StreamChunk{ + Choices: []StreamChoice{ + { + Index: 0, + Delta: &StreamDelta{ + Content: &content, + }, + }, + }, + } + + if chunk.Text() != "Hello" { + t.Errorf("Expected 'Hello', got %s", chunk.Text()) + } + }) + + t.Run("Text method with nil delta", func(t *testing.T) { + chunk := &StreamChunk{ + Choices: []StreamChoice{ + { + Index: 0, + Delta: nil, + }, + }, + } + + if chunk.Text() != "" { + t.Errorf("Expected empty string, got %s", chunk.Text()) + } + }) + + t.Run("Text method with nil content", func(t *testing.T) { + chunk := &StreamChunk{ + Choices: []StreamChoice{ + { + Index: 0, + Delta: &StreamDelta{ + Content: nil, + }, + }, + }, + } + + if chunk.Text() != "" { + t.Errorf("Expected empty string, got %s", chunk.Text()) + } + }) + + t.Run("Role method", func(t *testing.T) { + role := "assistant" + chunk := &StreamChunk{ + Choices: []StreamChoice{ + { + Index: 0, + Delta: &StreamDelta{ + Role: &role, + }, + }, + }, + } + + if chunk.Role() != "assistant" { + t.Errorf("Expected 'assistant', got %s", chunk.Role()) + } + }) + + t.Run("Role method with nil", func(t *testing.T) { + chunk := &StreamChunk{ + Choices: []StreamChoice{ + { + Index: 0, + Delta: &StreamDelta{ + Role: nil, + }, + }, + }, + } + + if chunk.Role() != "" { + t.Errorf("Expected empty string, got %s", chunk.Role()) + } + }) + + t.Run("FinishReason method", func(t *testing.T) { + chunk := &StreamChunk{ + Choices: []StreamChoice{ + { + Index: 0, + FinishReason: stringPtr("stop"), + }, + }, + } + + if chunk.FinishReason() != "stop" { + t.Errorf("Expected 'stop', got %s", chunk.FinishReason()) + } + }) + + t.Run("FinishReason method with nil", func(t *testing.T) { + chunk := &StreamChunk{ + Choices: []StreamChoice{ + { + Index: 0, + FinishReason: nil, + }, + }, + } + + if chunk.FinishReason() != "" { + t.Errorf("Expected empty string, got %s", chunk.FinishReason()) + } + }) + + t.Run("FinishReason method with empty choices", func(t *testing.T) { + chunk := &StreamChunk{ + Choices: []StreamChoice{}, + } + + if chunk.FinishReason() != "" { + t.Errorf("Expected empty string, got %s", chunk.FinishReason()) + } + }) +} + +// Helper function to create string pointer +func stringPtr(s string) *string { + return &s +}