From f6fd161dd742b7ba3a5d32cb3b5d88b56cbace9c Mon Sep 17 00:00:00 2001 From: JoeEmp <2057567309@qq.com> Date: Wed, 25 Feb 2026 13:59:03 +0800 Subject: [PATCH] feat: add coding plan support --- pkg/providers/anthropic_compat/provider.go | 262 ++++++++++ .../anthropic_compat/provider_test.go | 464 ++++++++++++++++++ pkg/providers/factory_provider.go | 3 +- pkg/providers/http_provider.go | 23 +- 4 files changed, 749 insertions(+), 3 deletions(-) create mode 100644 pkg/providers/anthropic_compat/provider.go create mode 100644 pkg/providers/anthropic_compat/provider_test.go diff --git a/pkg/providers/anthropic_compat/provider.go b/pkg/providers/anthropic_compat/provider.go new file mode 100644 index 000000000..a7beb05be --- /dev/null +++ b/pkg/providers/anthropic_compat/provider.go @@ -0,0 +1,262 @@ +package anthropic_compat + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ( + ToolCall = protocoltypes.ToolCall + FunctionCall = protocoltypes.FunctionCall + LLMResponse = protocoltypes.LLMResponse + UsageInfo = protocoltypes.UsageInfo + Message = protocoltypes.Message + ToolDefinition = protocoltypes.ToolDefinition + ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition +) + +type Provider struct { + apiKey string + apiBase string + httpClient *http.Client +} + +func NewProvider(apiKey, apiBase, proxy string) *Provider { + client := &http.Client{ + Timeout: 120 * time.Second, + } + + if proxy != "" { + parsed, err := url.Parse(proxy) + if err == nil { + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(parsed), + } + } else { + log.Printf("anthropic_compat: invalid proxy URL %q: %v", proxy, err) + } + } + + return &Provider{ + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + httpClient: client, + } +} + +func (p *Provider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + // logger.InfoC("LLM-Chat", p.apiBase) + + requestBody, err := buildRequestBody(messages, tools, model, options) + if err != nil { + return nil, err + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/messages", bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", p.apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) + } + + return parseResponse(body) +} + +func buildRequestBody( + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (map[string]any, error) { + var system []string + var anthropicMessages []map[string]any + + for _, msg := range messages { + switch msg.Role { + case "system": + system = append(system, msg.Content) + case "user": + msgMap := map[string]any{ + "role": "user", + "content": []map[string]any{ + {"type": "text", "text": msg.Content}, + }, + } + // Handle tool result messages + if msg.ToolCallID != "" { + msgMap["content"] = []map[string]any{ + {"type": "tool_result", "tool_use_id": msg.ToolCallID, "content": msg.Content}, + } + } + anthropicMessages = append(anthropicMessages, msgMap) + case "assistant": + content := make([]map[string]any, 0) + if msg.Content != "" { + content = append(content, map[string]any{"type": "text", "text": msg.Content}) + } + for _, tc := range msg.ToolCalls { + content = append(content, map[string]any{ + "type": "tool_use", + "id": tc.ID, + "name": tc.Name, + "input": tc.Arguments, + }) + } + anthropicMessages = append(anthropicMessages, map[string]any{ + "role": "assistant", + "content": content, + }) + case "tool": + anthropicMessages = append(anthropicMessages, map[string]any{ + "role": "user", + "content": []map[string]any{ + {"type": "tool_result", "tool_use_id": msg.ToolCallID, "content": msg.Content}, + }, + }) + } + } + + requestBody := map[string]any{ + "model": model, + "messages": anthropicMessages, + } + + if len(system) > 0 { + requestBody["system"] = system + } + + if len(tools) > 0 { + requestBody["tools"] = translateTools(tools) + } + + maxTokens := 4096 + if mt, ok := options["max_tokens"].(int); ok { + maxTokens = mt + } + requestBody["max_tokens"] = maxTokens + + if temperature, ok := options["temperature"].(float64); ok { + requestBody["temperature"] = temperature + } + + return requestBody, nil +} + +func translateTools(tools []ToolDefinition) []map[string]any { + result := make([]map[string]any, 0, len(tools)) + for _, t := range tools { + tool := map[string]any{ + "name": t.Function.Name, + "input_schema": t.Function.Parameters, + } + if desc := t.Function.Description; desc != "" { + tool["description"] = desc + } + result = append(result, tool) + } + return result +} + +func parseResponse(body []byte) (*LLMResponse, error) { + var apiResponse struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + ID string `json:"id"` + Name string `json:"name"` + Input json.RawMessage `json:"input"` + } `json:"content"` + StopReason string `json:"stop_reason"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + } + + if err := json.Unmarshal(body, &apiResponse); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + var content string + var toolCalls []ToolCall + + for _, block := range apiResponse.Content { + switch block.Type { + case "text": + content += block.Text + case "tool_use": + var args map[string]any + if err := json.Unmarshal(block.Input, &args); err != nil { + log.Printf("anthropic_compat: failed to decode tool call input for %q: %v", block.Name, err) + args = map[string]any{"raw": string(block.Input)} + } + toolCalls = append(toolCalls, ToolCall{ + ID: block.ID, + Name: block.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + switch apiResponse.StopReason { + case "tool_use": + finishReason = "tool_calls" + case "max_tokens": + finishReason = "length" + case "end_turn", "stop": + finishReason = "stop" + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: apiResponse.Usage.InputTokens, + CompletionTokens: apiResponse.Usage.OutputTokens, + TotalTokens: apiResponse.Usage.InputTokens + apiResponse.Usage.OutputTokens, + }, + }, nil +} diff --git a/pkg/providers/anthropic_compat/provider_test.go b/pkg/providers/anthropic_compat/provider_test.go new file mode 100644 index 000000000..cd69365a1 --- /dev/null +++ b/pkg/providers/anthropic_compat/provider_test.go @@ -0,0 +1,464 @@ +package anthropic_compat + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestProviderChat_BasicChat(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + + // Verify headers + if r.Header.Get("x-api-key") != "test-key" { + t.Fatalf("expected x-api-key header, got %q", r.Header.Get("x-api-key")) + } + if r.Header.Get("anthropic-version") != "2023-06-01" { + t.Fatalf("expected anthropic-version header, got %q", r.Header.Get("anthropic-version")) + } + + resp := map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "Hello!"}, + }, + "stop_reason": "end_turn", + "usage": map[string]any{ + "input_tokens": 10, + "output_tokens": 5, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("test-key", server.URL, "") + out, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "Hi"}}, + nil, + "claude-3-5-sonnet-20241022", + nil, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if out.Content != "Hello!" { + t.Fatalf("Content = %q, want %q", out.Content, "Hello!") + } +} + +func TestProviderChat_ParsesToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "content": []map[string]any{ + { + "type": "text", + "text": "I'll get the weather for you.", + }, + { + "type": "tool_use", + "id": "toolu_123", + "name": "get_weather", + "input": map[string]any{ + "city": "Beijing", + }, + }, + }, + "stop_reason": "tool_use", + "usage": map[string]any{ + "input_tokens": 10, + "output_tokens": 20, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "What's the weather?"}}, nil, "claude-3-5-sonnet", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if out.Content != "I'll get the weather for you." { + t.Fatalf("Content = %q, want %q", out.Content, "I'll get the weather for you.") + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } + if out.ToolCalls[0].Arguments["city"] != "Beijing" { + t.Fatalf("ToolCalls[0].Arguments[city] = %v, want Beijing", out.ToolCalls[0].Arguments["city"]) + } + if out.ToolCalls[0].ID != "toolu_123" { + t.Fatalf("ToolCalls[0].ID = %q, want toolu_123", out.ToolCalls[0].ID) + } + if out.FinishReason != "tool_calls" { + t.Fatalf("FinishReason = %q, want tool_calls", out.FinishReason) + } +} + +func TestProviderChat_HandlesSystemMessages(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "ok"}, + }, + "stop_reason": "end_turn", + "usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello"}, + }, + nil, + "claude-3-5-sonnet", + nil, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + system, ok := requestBody["system"].([]any) + if !ok { + t.Fatalf("expected system in request body") + } + if len(system) != 1 || system[0] != "You are a helpful assistant." { + t.Fatalf("system = %v, want [You are a helpful assistant.]", system) + } +} + +func TestProviderChat_PassesToolsToRequest(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "ok"}, + }, + "stop_reason": "end_turn", + "usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, + "required": []string{"city"}, + }, + }, + }, + } + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "Weather in Beijing?"}}, + tools, + "claude-3-5-sonnet", + nil, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + toolsInReq, ok := requestBody["tools"].([]any) + if !ok { + t.Fatalf("expected tools in request body") + } + if len(toolsInReq) != 1 { + t.Fatalf("len(tools) = %d, want 1", len(toolsInReq)) + } + tool := toolsInReq[0].(map[string]any) + if tool["name"] != "get_weather" { + t.Fatalf("tool name = %v, want get_weather", tool["name"]) + } +} + +func TestProviderChat_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "claude-3-5-sonnet", nil) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProviderChat_EmptyAPIBase(t *testing.T) { + p := NewProvider("key", "", "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "claude-3-5-sonnet", nil) + if err == nil { + t.Fatal("expected error for empty API base, got nil") + } + if err.Error() != "API base not configured" { + t.Fatalf("error = %q, want %q", err.Error(), "API base not configured") + } +} + +func TestProvider_ProxyConfigured(t *testing.T) { + proxyURL := "http://127.0.0.1:8080" + p := NewProvider("key", "https://example.com", proxyURL) + + transport, ok := p.httpClient.Transport.(*http.Transport) + if !ok || transport == nil { + t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport) + } + + req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}} + gotProxy, err := transport.Proxy(req) + if err != nil { + t.Fatalf("proxy function returned error: %v", err) + } + if gotProxy == nil || gotProxy.String() != proxyURL { + t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL) + } +} + +func TestProviderChat_AcceptsOptions(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "ok"}, + }, + "stop_reason": "end_turn", + "usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + // Note: current implementation only accepts int type for max_tokens + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "claude-3-5-sonnet", + map[string]any{"max_tokens": 2048, "temperature": 0.5}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["temperature"] != 0.5 { + t.Fatalf("temperature = %v, want 0.5", requestBody["temperature"]) + } +} + +func TestProviderChat_DefaultMaxTokens(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "ok"}, + }, + "stop_reason": "end_turn", + "usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "claude-3-5-sonnet", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + // JSON unmarshals numbers as float64 by default + maxTokens, ok := requestBody["max_tokens"].(float64) + if !ok || maxTokens != 4096 { + t.Fatalf("max_tokens = %v (type %T), want 4096", requestBody["max_tokens"], requestBody["max_tokens"]) + } +} + +func TestProviderChat_HandleToolResultMessages(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "The weather is sunny."}, + }, + "stop_reason": "end_turn", + "usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{ + {Role: "assistant", Content: "", ToolCalls: []ToolCall{ + {ID: "toolu_1", Name: "get_weather", Arguments: map[string]any{"city": "Beijing"}}, + }}, + {Role: "tool", Content: "Sunny, 25°C", ToolCallID: "toolu_1"}, + }, + nil, + "claude-3-5-sonnet", + nil, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + messages, ok := requestBody["messages"].([]any) + if !ok { + t.Fatalf("expected messages in request body") + } + // Should have 2 messages: assistant (with tool_use) and user (with tool_result) + if len(messages) != 2 { + t.Fatalf("len(messages) = %d, want 2", len(messages)) + } +} + +func TestProviderChat_StopReasonMaxTokens(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "partial..."}, + }, + "stop_reason": "max_tokens", + "usage": map[string]any{"input_tokens": 10, "output_tokens": 4096}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "claude-3-5-sonnet", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if out.FinishReason != "length" { + t.Fatalf("FinishReason = %q, want length", out.FinishReason) + } +} + +func TestProviderChat_UsageInfo(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "ok"}, + }, + "stop_reason": "end_turn", + "usage": map[string]any{ + "input_tokens": 100, + "output_tokens": 50, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "claude-3-5-sonnet", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if out.Usage.PromptTokens != 100 { + t.Fatalf("PromptTokens = %d, want 100", out.Usage.PromptTokens) + } + if out.Usage.CompletionTokens != 50 { + t.Fatalf("CompletionTokens = %d, want 50", out.Usage.CompletionTokens) + } + if out.Usage.TotalTokens != 150 { + t.Fatalf("TotalTokens = %d, want 150", out.Usage.TotalTokens) + } +} + +func TestBuildRequestBody_PassesModelDirectly(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var requestBody map[string]any + json.NewDecoder(r.Body).Decode(&requestBody) + + resp := map[string]any{ + "content": []map[string]any{{"type": "text", "text": "ok"}}, + "stop_reason": "end_turn", + "usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + + if requestBody["model"] != "claude-3-5-sonnet" { + t.Errorf("model = %q, want %q", requestBody["model"], "claude-3-5-sonnet") + } + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "claude-3-5-sonnet", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } +} + +func TestNewProvider_TrimsTrailingSlash(t *testing.T) { + p := NewProvider("key", "https://api.example.com/", "") + if p.apiBase != "https://api.example.com" { + t.Fatalf("apiBase = %q, want %q", p.apiBase, "https://api.example.com") + } +} diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 7d5566eef..406f36d0e 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -65,7 +65,6 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err } protocol, modelID := ExtractProtocol(cfg.Model) - switch protocol { case "openai": // OpenAI with OAuth/token auth (Codex-style) @@ -116,7 +115,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err if cfg.APIKey == "" { return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model) } - return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil + return NewHTTPProviderWithProtocol(cfg.APIKey, apiBase, cfg.Proxy, "anthropic"), modelID, nil case "antigravity": return NewAntigravityProvider(), modelID, nil diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index d0c4344f3..c967c85d2 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -9,11 +9,17 @@ package providers import ( "context" + "github.com/sipeed/picoclaw/pkg/providers/anthropic_compat" "github.com/sipeed/picoclaw/pkg/providers/openai_compat" ) +// httpDelegate is the interface that both openai_compat and anthropic_compat providers implement +type httpDelegate interface { + Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any) (*LLMResponse, error) +} + type HTTPProvider struct { - delegate *openai_compat.Provider + delegate httpDelegate } func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { @@ -28,6 +34,21 @@ func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField st } } +// NewHTTPProviderWithProtocol creates an HTTP provider with the specified protocol type. +// protocol should be "openai" for OpenAI-compatible APIs or "anthropic" for Anthropic-compatible APIs. +func NewHTTPProviderWithProtocol(apiKey, apiBase, proxy, protocol string) *HTTPProvider { + switch protocol { + case "anthropic": + return &HTTPProvider{ + delegate: anthropic_compat.NewProvider(apiKey, apiBase, proxy), + } + default: + return &HTTPProvider{ + delegate: openai_compat.NewProvider(apiKey, apiBase, proxy), + } + } +} + func (p *HTTPProvider) Chat( ctx context.Context, messages []Message,