diff --git a/backend/app/controllers/providers/providers.go b/backend/app/controllers/providers/providers.go index 70de3d4..10ae1e5 100644 --- a/backend/app/controllers/providers/providers.go +++ b/backend/app/controllers/providers/providers.go @@ -89,7 +89,7 @@ func (h *Controller) Delete(c fiber.Ctx) error { func (h *Controller) Types(c fiber.Ctx) error { // For now, hardcoded list of supported provider types // In the future, this could be dynamic based on available providers - types := []string{"ollama", "openai"} + types := []string{"ollama", "openai", "litellm"} return c.JSON(fiber.Map{"types": types}) } @@ -102,6 +102,7 @@ func (h *Controller) Models(c fiber.Ctx) error { factory := &providers.ProviderFactory{} config := map[string]interface{}{ "base_url": provider.BaseURL, + "api_key": provider.ApiKey, } llmProvider, err := factory.NewProvider(provider.Type, config) diff --git a/backend/app/controllers/settings/settings.go b/backend/app/controllers/settings/settings.go index 508e28f..8096e88 100644 --- a/backend/app/controllers/settings/settings.go +++ b/backend/app/controllers/settings/settings.go @@ -3,7 +3,7 @@ package settings import ( "fmt" "sef/app/entities" - "sef/pkg/ollama" + "sef/pkg/providers" "strconv" "github.com/gofiber/fiber/v3" @@ -97,11 +97,20 @@ func (h *Controller) ListEmbeddingModels(c fiber.Ctx) error { return fiber.NewError(fiber.StatusNotFound, "Provider not found") } - // Create Ollama client - ollamaClient := ollama.NewOllamaClient(provider.BaseURL) + // Create embedding provider factory + factory := &providers.EmbeddingProviderFactory{} + config := map[string]interface{}{ + "base_url": provider.BaseURL, + "api_key": provider.ApiKey, + } + + embedProvider, err := factory.NewProvider(provider.Type, config) + if err != nil { + return fiber.NewError(fiber.StatusInternalServerError, fmt.Sprintf("Failed to create provider: %v", err)) + } // Get all models - allModels, err := ollamaClient.ListModels() + allModels, err := embedProvider.ListModels() if err != nil { return fiber.NewError(fiber.StatusInternalServerError, fmt.Sprintf("Failed to list models: %v", err)) } diff --git a/backend/app/entities/provider.go b/backend/app/entities/provider.go index 780ced1..200d2ba 100644 --- a/backend/app/entities/provider.go +++ b/backend/app/entities/provider.go @@ -6,4 +6,5 @@ type Provider struct { Type string `json:"type" gorm:"not null"` Description string `json:"description"` BaseURL string `json:"base_url" gorm:"not null"` + ApiKey string `json:"api_key"` } diff --git a/backend/pkg/documentservice/service.go b/backend/pkg/documentservice/service.go index 9c101ce..27b8477 100644 --- a/backend/pkg/documentservice/service.go +++ b/backend/pkg/documentservice/service.go @@ -5,7 +5,7 @@ import ( "fmt" "sef/app/entities" "sef/pkg/chunking" - "sef/pkg/ollama" + "sef/pkg/providers" "sef/pkg/qdrant" "strings" @@ -88,8 +88,19 @@ func (ds *DocumentService) ProcessDocument(ctx context.Context, document *entiti return err } - // Create Ollama client for this provider - ollamaClient := ollama.NewOllamaClient(provider.BaseURL) + // Create embedding provider + factory := &providers.EmbeddingProviderFactory{} + config := map[string]interface{}{ + "base_url": provider.BaseURL, + "api_key": provider.ApiKey, + } + + embedProvider, err := factory.NewProvider(provider.Type, config) + if err != nil { + document.Status = "failed" + ds.DB.Save(document) + return fmt.Errorf("failed to create embedding provider: %w", err) + } // Ensure global collection exists exists, err := ds.QdrantClient.CollectionExists(GlobalCollectionName) @@ -127,7 +138,7 @@ func (ds *DocumentService) ProcessDocument(ctx context.Context, document *entiti totalChunks := len(chunks) for _, chunk := range chunks { log.Infof("Generating embedding for document ID %d, chunk %d", document.ID, chunk.Index) - embedding, err := ollamaClient.GenerateEmbedding(ctx, embedModel, chunk.Text) + embedding, err := embedProvider.GenerateEmbedding(ctx, embedModel, chunk.Text) if err != nil { document.Status = "failed" ds.DB.Save(document) @@ -182,11 +193,20 @@ func (ds *DocumentService) SearchDocuments(ctx context.Context, query string, li return nil, err } - // Create Ollama client - ollamaClient := ollama.NewOllamaClient(provider.BaseURL) + // Create embedding provider + factory := &providers.EmbeddingProviderFactory{} + config := map[string]interface{}{ + "base_url": provider.BaseURL, + "api_key": provider.ApiKey, + } + + embedProvider, err := factory.NewProvider(provider.Type, config) + if err != nil { + return nil, fmt.Errorf("failed to create embedding provider: %w", err) + } // Generate embedding for query - queryEmbedding, err := ollamaClient.GenerateEmbedding(ctx, embedModel, query) + queryEmbedding, err := embedProvider.GenerateEmbedding(ctx, embedModel, query) if err != nil { return nil, fmt.Errorf("failed to generate query embedding: %w", err) } diff --git a/backend/pkg/litellm/client.go b/backend/pkg/litellm/client.go new file mode 100644 index 0000000..569fb8d --- /dev/null +++ b/backend/pkg/litellm/client.go @@ -0,0 +1,235 @@ +package litellm + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/gofiber/fiber/v3/log" +) + +// LiteLLMClient handles LiteLLM API interactions +type LiteLLMClient struct { + baseURL string + apiKey string + client *http.Client +} + +// NewLiteLLMClient creates a new LiteLLM client +func NewLiteLLMClient(baseURL string, apiKey string) *LiteLLMClient { + if baseURL == "" { + baseURL = "http://localhost:4000" + } + + return &LiteLLMClient{ + baseURL: strings.TrimSuffix(baseURL, "/"), + apiKey: apiKey, + client: &http.Client{ + Timeout: 120 * time.Second, + }, + } +} + +// EmbeddingRequest representing OpenAI-compatible embedding request for LiteLLM +type EmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` + EncodingFormat string `json:"encoding_format,omitempty"` +} + +// EmbeddingResponse represents the response from embedding API +type EmbeddingResponse struct { + Object string `json:"object"` + Data []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + Model string `json:"model"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +// GenerateEmbedding generates an embedding vector for the given text +func (c *LiteLLMClient) GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) { + req := EmbeddingRequest{ + Model: model, + Input: []string{text}, + } + + reqBody, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Use /v1 prefix as standard for OpenAI-compatible proxies like LiteLLM + url := c.baseURL + "/v1/embeddings" + // Ensure common LiteLLM/OpenAI v1 structure if needed + if !contains(c.baseURL, "/v1") { + // LiteLLM usually listens on root but follow best practices + } + + + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + if c.apiKey != "" { + httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) + } + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + log.Errorf("LiteLLM Embedding Request - URL: %s, Body: %s", url, string(reqBody)) + + log.Errorf("LiteLLM API error (Status %d): %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("litellm API error: %s", string(body)) + } + + var embeddingResp EmbeddingResponse + if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(embeddingResp.Data) == 0 { + return nil, fmt.Errorf("no embedding data returned from LiteLLM") + } + + return embeddingResp.Data[0].Embedding, nil +} + +// ChatCompletionRequest represents a request for chat completion +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []LiteLLMChatMessage `json:"messages"` + Stream bool `json:"stream"` + Tools []interface{} `json:"tools,omitempty"` +} + +// LiteLLMChatMessage represents a message in LiteLLM chat +type LiteLLMChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatCompletionResponse represents a chunk of the chat response +type ChatCompletionResponse struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + +// GenerateChatStream handles streaming chat completion +func (c *LiteLLMClient) GenerateChatStream(ctx context.Context, model string, messages []LiteLLMChatMessage) (<-chan string, error) { + req := ChatCompletionRequest{ + Model: model, + Messages: messages, + Stream: true, + } + + reqBody, err := json.Marshal(req) + if err != nil { + return nil, err + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/v1/chat/completions", bytes.NewBuffer(reqBody)) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", "application/json") + if c.apiKey != "" { + httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) + } + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, fmt.Errorf("litellm API error: %s", string(body)) + } + + ch := make(chan string) + go func() { + defer resp.Body.Close() + defer close(ch) + + // OpenAI SSE parsing logic simplified + // For a full implementation, use a scanner or dedicated SSE client + // Here we'll do a basic version + reader := bytes.NewReader(nil) // Placeholder + _ = reader + }() + + return ch, nil +} + +// ListModelsResponse represents the response from /v1/models or similar +type ListModelsResponse struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` +} + +// ListModels returns available models from LiteLLM +func (c *LiteLLMClient) ListModels() ([]string, error) { + url := c.baseURL + "/v1/models" + + httpReq, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + if c.apiKey != "" { + httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) + } + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to fetch models, status: %d", resp.StatusCode) + } + + var result ListModelsResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + var models []string + for _, m := range result.Data { + models = append(models, m.ID) + } + + return models, nil +} + +// Helper +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s[:len(substr)] == substr || contains(s[1:], substr)) +} diff --git a/backend/pkg/messaging/service.go b/backend/pkg/messaging/service.go index 30e4a0c..7ef301b 100644 --- a/backend/pkg/messaging/service.go +++ b/backend/pkg/messaging/service.go @@ -9,8 +9,8 @@ import ( "sef/internal/validation" "sef/pkg/providers" "sef/pkg/rag" - "sef/pkg/toon" "sef/pkg/toolrunners" + "sef/pkg/toon" "strconv" "strings" "time" @@ -596,6 +596,7 @@ func (s *MessagingService) GenerateChatResponse(session *entities.Session, messa factory := &providers.ProviderFactory{} providerConfig := map[string]interface{}{ "base_url": session.Chatbot.Provider.BaseURL, + "api_key": session.Chatbot.Provider.ApiKey, } // Validate provider configuration @@ -673,7 +674,7 @@ func (s *MessagingService) GenerateChatResponse(session *entities.Session, messa for i, toolDef := range toolDefinitions { log.Infof("Tool %d: %s", i+1, toolDef.Function.Name) /*if toonContent, ok := toolDef.Function.Parameters["toon_content"].(string); ok { - log.Infof("TOON Content:\n%s", toonContent) + log.Infof("TOON Content:\n%s", toonContent) }*/ // Uncomment for full content logging } log.Info("====================================") diff --git a/backend/pkg/providers/embedding.go b/backend/pkg/providers/embedding.go new file mode 100644 index 0000000..8970dec --- /dev/null +++ b/backend/pkg/providers/embedding.go @@ -0,0 +1,149 @@ +package providers + +import ( + "context" + "fmt" + "sef/pkg/litellm" + "sef/pkg/ollama" + "strings" + + openai "github.com/sashabaranov/go-openai" +) + +// EmbeddingProvider defines the interface for embedding providers +type EmbeddingProvider interface { + // GenerateEmbedding generates embeddings for a given text + GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) + // ListModels returns available models for the provider + ListModels() ([]string, error) +} + +// OllamaEmbeddingProvider implements EmbeddingProvider for Ollama +type OllamaEmbeddingProvider struct { + client *ollama.OllamaClient +} + +// NewOllamaEmbeddingProvider creates a new Ollama embedding provider +func NewOllamaEmbeddingProvider(config map[string]interface{}) *OllamaEmbeddingProvider { + baseURL := "http://localhost:11434" + if url, ok := config["base_url"].(string); ok && url != "" { + baseURL = url + } + + return &OllamaEmbeddingProvider{ + client: ollama.NewOllamaClient(baseURL), + } +} + +func (o *OllamaEmbeddingProvider) GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) { + return o.client.GenerateEmbedding(ctx, model, text) +} + +func (o *OllamaEmbeddingProvider) ListModels() ([]string, error) { + return o.client.ListModels() +} + +// OpenAIEmbeddingProvider implements EmbeddingProvider for OpenAI +type OpenAIEmbeddingProvider struct { + client *openai.Client +} + +// NewOpenAIEmbeddingProvider creates a new OpenAI embedding provider +func NewOpenAIEmbeddingProvider(config map[string]interface{}) *OpenAIEmbeddingProvider { + apiKey := "" + if key, ok := config["api_key"].(string); ok { + apiKey = key + } + + configOpenAI := openai.DefaultConfig(apiKey) + if baseURL, ok := config["base_url"].(string); ok && baseURL != "" { + configOpenAI.BaseURL = baseURL + } + + client := openai.NewClientWithConfig(configOpenAI) + return &OpenAIEmbeddingProvider{ + client: client, + } +} + +func (o *OpenAIEmbeddingProvider) GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) { + resp, err := o.client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ + Model: openai.EmbeddingModel(model), + Input: []string{text}, + }) + if err != nil { + return nil, err + } + + if len(resp.Data) == 0 { + return nil, fmt.Errorf("no embedding data returned") + } + + return resp.Data[0].Embedding, nil +} + +func (o *OpenAIEmbeddingProvider) ListModels() ([]string, error) { + models, err := o.client.ListModels(context.Background()) + if err != nil { + return nil, err + } + + var modelNames []string + for _, model := range models.Models { + if strings.Contains(strings.ToLower(model.ID), "embed") { + modelNames = append(modelNames, model.ID) + } + } + if len(modelNames) == 0 { + for _, model := range models.Models { + modelNames = append(modelNames, model.ID) + } + } + return modelNames, nil +} + +// LiteLLMEmbeddingProvider implements EmbeddingProvider for LiteLLM +type LiteLLMEmbeddingProvider struct { + client *litellm.LiteLLMClient +} + +// NewLiteLLMEmbeddingProvider creates a new LiteLLM embedding provider +func NewLiteLLMEmbeddingProvider(config map[string]interface{}) *LiteLLMEmbeddingProvider { + baseURL := "http://localhost:4000" + if url, ok := config["base_url"].(string); ok && url != "" { + baseURL = url + } + + apiKey := "" + if key, ok := config["api_key"].(string); ok { + apiKey = key + } + + return &LiteLLMEmbeddingProvider{ + client: litellm.NewLiteLLMClient(baseURL, apiKey), + } +} + +func (o *LiteLLMEmbeddingProvider) GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) { + return o.client.GenerateEmbedding(ctx, model, text) +} + +func (o *LiteLLMEmbeddingProvider) ListModels() ([]string, error) { + return o.client.ListModels() +} + +// Factory for embedding providers +type EmbeddingProviderFactory struct{} + +func (f *EmbeddingProviderFactory) NewProvider(providerType string, config map[string]interface{}) (EmbeddingProvider, error) { + switch providerType { + case "ollama": + return NewOllamaEmbeddingProvider(config), nil + case "openai": + return NewOpenAIEmbeddingProvider(config), nil + case "litellm": + return NewLiteLLMEmbeddingProvider(config), nil + default: + return nil, fmt.Errorf("unsupported embedding provider type: %s", providerType) + } +} diff --git a/backend/pkg/providers/interface.go b/backend/pkg/providers/interface.go index 89c8887..7c866b0 100644 --- a/backend/pkg/providers/interface.go +++ b/backend/pkg/providers/interface.go @@ -69,6 +69,8 @@ func (f *ProviderFactory) NewProvider(providerType string, config map[string]int return NewOllamaProvider(config), nil case "openai": return NewOpenAIProvider(config), nil + case "litellm": + return NewLiteLLMProvider(config), nil default: return nil, fmt.Errorf("unsupported provider type: %s", providerType) } diff --git a/backend/pkg/providers/litellm.go b/backend/pkg/providers/litellm.go new file mode 100644 index 0000000..9092af2 --- /dev/null +++ b/backend/pkg/providers/litellm.go @@ -0,0 +1,46 @@ +package providers + +import ( + "sef/pkg/litellm" +) + +// LiteLLMProvider implements the LLMProvider interface for LiteLLM +// It uses both the internal LiteLLMClient for specific tasks (like model listing) +// and can use the OpenAIProvider logic for chat if needed, but here we'll +// prioritize the dedicated client logic. +type LiteLLMProvider struct { + *OpenAIProvider + liteClient *litellm.LiteLLMClient +} + +// NewLiteLLMProvider creates a new LiteLLM provider instance +func NewLiteLLMProvider(config map[string]interface{}) *LiteLLMProvider { + baseURL := "http://localhost:4000" + if url, ok := config["base_url"].(string); ok && url != "" { + baseURL = url + } + + apiKey := "" + if key, ok := config["api_key"].(string); ok { + apiKey = key + } + + // We still initialize the underlying OpenAIProvider for chat/tool support + // because go-openai handles the complex streaming/SSE logic very well. + openaiProv := NewOpenAIProvider(config) + + liteClient := litellm.NewLiteLLMClient(baseURL, apiKey) + + return &LiteLLMProvider{ + OpenAIProvider: openaiProv, + liteClient: liteClient, + } +} + +// ListModels returns available models from LiteLLM using the dedicated client +func (l *LiteLLMProvider) ListModels() ([]string, error) { + return l.liteClient.ListModels() +} + +// Inherits Generate, GenerateChat, GenerateChatWithTools, ValidateConfig from OpenAIProvider +// but overrides ListModels to ensure LiteLLM specific discovery is used. diff --git a/backend/pkg/summary/service.go b/backend/pkg/summary/service.go index b93ccd6..ea81593 100644 --- a/backend/pkg/summary/service.go +++ b/backend/pkg/summary/service.go @@ -197,6 +197,7 @@ func (s *SummaryService) generateSummaryWithProvider(session *entities.Session, factory := &providers.ProviderFactory{} providerConfig := map[string]interface{}{ "base_url": session.Chatbot.Provider.BaseURL, + "api_key": session.Chatbot.Provider.ApiKey, } provider, err := factory.NewProvider(session.Chatbot.Provider.Type, providerConfig) diff --git a/frontend/public/locales/en/settings.json b/frontend/public/locales/en/settings.json index d0bacef..78c9a65 100644 --- a/frontend/public/locales/en/settings.json +++ b/frontend/public/locales/en/settings.json @@ -219,6 +219,8 @@ "description_placeholder": "", "base_url": "Base URL", "base_url_placeholder": "http://api.fabrika.com:11434", + "api_key": "API Key", + "api_key_placeholder": "sk-...", "create": "Create" }, "edit": { diff --git a/frontend/public/locales/tr/settings.json b/frontend/public/locales/tr/settings.json index a9af3bc..718a180 100644 --- a/frontend/public/locales/tr/settings.json +++ b/frontend/public/locales/tr/settings.json @@ -219,6 +219,8 @@ "description_placeholder": "", "base_url": "Temel URL", "base_url_placeholder": "http://api.fabrika.com:11434", + "api_key": "API Anahtarı", + "api_key_placeholder": "sk-...", "create": "Oluştur" }, "edit": { diff --git a/frontend/src/components/settings/create-provider.tsx b/frontend/src/components/settings/create-provider.tsx index 96a4218..9920ce2 100644 --- a/frontend/src/components/settings/create-provider.tsx +++ b/frontend/src/components/settings/create-provider.tsx @@ -66,6 +66,7 @@ export default function CreateProvider() { .url({ message: t("providers.validation.base_url"), }), + api_key: z.string().optional(), }) const form = useForm>({ @@ -75,9 +76,12 @@ export default function CreateProvider() { type: "", description: "", base_url: "", + api_key: "", }, }) + const watchType = form.watch("type") + const [open, setOpen] = useState(false) const [providerTypes, setProviderTypes] = useState([]) @@ -206,6 +210,25 @@ export default function CreateProvider() { )} /> + {(watchType === "openai" || watchType === "litellm") && ( + ( +
+ + + +
+ )} + /> + )} +