Skip to content
Open
3 changes: 2 additions & 1 deletion backend/app/controllers/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (h *Controller) Delete(c fiber.Ctx) error {
func (h *Controller) Types(c fiber.Ctx) error {
// For now, hardcoded list of supported provider types
// In the future, this could be dynamic based on available providers
types := []string{"ollama", "openai"}
types := []string{"ollama", "openai", "litellm"}
return c.JSON(fiber.Map{"types": types})
}

Expand All @@ -102,6 +102,7 @@ func (h *Controller) Models(c fiber.Ctx) error {
factory := &providers.ProviderFactory{}
config := map[string]interface{}{
"base_url": provider.BaseURL,
"api_key": provider.ApiKey,
}

llmProvider, err := factory.NewProvider(provider.Type, config)
Expand Down
17 changes: 13 additions & 4 deletions backend/app/controllers/settings/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package settings
import (
"fmt"
"sef/app/entities"
"sef/pkg/ollama"
"sef/pkg/providers"
"strconv"

"github.com/gofiber/fiber/v3"
Expand Down Expand Up @@ -97,11 +97,20 @@ func (h *Controller) ListEmbeddingModels(c fiber.Ctx) error {
return fiber.NewError(fiber.StatusNotFound, "Provider not found")
}

// Create Ollama client
ollamaClient := ollama.NewOllamaClient(provider.BaseURL)
// Create embedding provider factory
factory := &providers.EmbeddingProviderFactory{}
config := map[string]interface{}{
"base_url": provider.BaseURL,
"api_key": provider.ApiKey,
}

embedProvider, err := factory.NewProvider(provider.Type, config)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, fmt.Sprintf("Failed to create provider: %v", err))
}

// Get all models
allModels, err := ollamaClient.ListModels()
allModels, err := embedProvider.ListModels()
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, fmt.Sprintf("Failed to list models: %v", err))
}
Expand Down
1 change: 1 addition & 0 deletions backend/app/entities/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ type Provider struct {
Type string `json:"type" gorm:"not null"`
Description string `json:"description"`
BaseURL string `json:"base_url" gorm:"not null"`
ApiKey string `json:"api_key"`
}
34 changes: 27 additions & 7 deletions backend/pkg/documentservice/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"fmt"
"sef/app/entities"
"sef/pkg/chunking"
"sef/pkg/ollama"
"sef/pkg/providers"
"sef/pkg/qdrant"
"strings"

Expand Down Expand Up @@ -88,8 +88,19 @@ func (ds *DocumentService) ProcessDocument(ctx context.Context, document *entiti
return err
}

// Create Ollama client for this provider
ollamaClient := ollama.NewOllamaClient(provider.BaseURL)
// Create embedding provider
factory := &providers.EmbeddingProviderFactory{}
config := map[string]interface{}{
"base_url": provider.BaseURL,
"api_key": provider.ApiKey,
}

embedProvider, err := factory.NewProvider(provider.Type, config)
if err != nil {
document.Status = "failed"
ds.DB.Save(document)
return fmt.Errorf("failed to create embedding provider: %w", err)
}

// Ensure global collection exists
exists, err := ds.QdrantClient.CollectionExists(GlobalCollectionName)
Expand Down Expand Up @@ -127,7 +138,7 @@ func (ds *DocumentService) ProcessDocument(ctx context.Context, document *entiti
totalChunks := len(chunks)
for _, chunk := range chunks {
log.Infof("Generating embedding for document ID %d, chunk %d", document.ID, chunk.Index)
embedding, err := ollamaClient.GenerateEmbedding(ctx, embedModel, chunk.Text)
embedding, err := embedProvider.GenerateEmbedding(ctx, embedModel, chunk.Text)
if err != nil {
document.Status = "failed"
ds.DB.Save(document)
Expand Down Expand Up @@ -182,11 +193,20 @@ func (ds *DocumentService) SearchDocuments(ctx context.Context, query string, li
return nil, err
}

// Create Ollama client
ollamaClient := ollama.NewOllamaClient(provider.BaseURL)
// Create embedding provider
factory := &providers.EmbeddingProviderFactory{}
config := map[string]interface{}{
"base_url": provider.BaseURL,
"api_key": provider.ApiKey,
}

embedProvider, err := factory.NewProvider(provider.Type, config)
if err != nil {
return nil, fmt.Errorf("failed to create embedding provider: %w", err)
}

// Generate embedding for query
queryEmbedding, err := ollamaClient.GenerateEmbedding(ctx, embedModel, query)
queryEmbedding, err := embedProvider.GenerateEmbedding(ctx, embedModel, query)
if err != nil {
return nil, fmt.Errorf("failed to generate query embedding: %w", err)
}
Expand Down
235 changes: 235 additions & 0 deletions backend/pkg/litellm/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
package litellm

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/gofiber/fiber/v3/log"
)

// LiteLLMClient handles LiteLLM API interactions
type LiteLLMClient struct {
baseURL string
apiKey string
client *http.Client
}

// NewLiteLLMClient creates a new LiteLLM client
func NewLiteLLMClient(baseURL string, apiKey string) *LiteLLMClient {
if baseURL == "" {
baseURL = "http://localhost:4000"
}

return &LiteLLMClient{
baseURL: strings.TrimSuffix(baseURL, "/"),
apiKey: apiKey,
client: &http.Client{
Timeout: 120 * time.Second,
},
}
}

// EmbeddingRequest representing OpenAI-compatible embedding request for LiteLLM
type EmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
EncodingFormat string `json:"encoding_format,omitempty"`
}

// EmbeddingResponse represents the response from embedding API
type EmbeddingResponse struct {
Object string `json:"object"`
Data []struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
} `json:"data"`
Model string `json:"model"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}

// GenerateEmbedding generates an embedding vector for the given text
func (c *LiteLLMClient) GenerateEmbedding(ctx context.Context, model string, text string) ([]float32, error) {
req := EmbeddingRequest{
Model: model,
Input: []string{text},
}

reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

// Use /v1 prefix as standard for OpenAI-compatible proxies like LiteLLM
url := c.baseURL + "/v1/embeddings"
// Ensure common LiteLLM/OpenAI v1 structure if needed
if !contains(c.baseURL, "/v1") {
// LiteLLM usually listens on root but follow best practices
}


httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

httpReq.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
}

resp, err := c.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Errorf("LiteLLM Embedding Request - URL: %s, Body: %s", url, string(reqBody))

log.Errorf("LiteLLM API error (Status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("litellm API error: %s", string(body))
}

var embeddingResp EmbeddingResponse
if err := json.NewDecoder(resp.Body).Decode(&embeddingResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}

if len(embeddingResp.Data) == 0 {
return nil, fmt.Errorf("no embedding data returned from LiteLLM")
}

return embeddingResp.Data[0].Embedding, nil
}

// ChatCompletionRequest represents a request for chat completion
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []LiteLLMChatMessage `json:"messages"`
Stream bool `json:"stream"`
Tools []interface{} `json:"tools,omitempty"`
}

// LiteLLMChatMessage represents a message in LiteLLM chat
type LiteLLMChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}

// ChatCompletionResponse represents a chunk of the chat response
type ChatCompletionResponse struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}

// GenerateChatStream handles streaming chat completion
func (c *LiteLLMClient) GenerateChatStream(ctx context.Context, model string, messages []LiteLLMChatMessage) (<-chan string, error) {
req := ChatCompletionRequest{
Model: model,
Messages: messages,
Stream: true,
}

reqBody, err := json.Marshal(req)
if err != nil {
return nil, err
}

httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/v1/chat/completions", bytes.NewBuffer(reqBody))
if err != nil {
return nil, err
}

httpReq.Header.Set("Content-Type", "application/json")
if c.apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
}

resp, err := c.client.Do(httpReq)
if err != nil {
return nil, err
}

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("litellm API error: %s", string(body))
}

ch := make(chan string)
go func() {
defer resp.Body.Close()
defer close(ch)

// OpenAI SSE parsing logic simplified
// For a full implementation, use a scanner or dedicated SSE client
// Here we'll do a basic version
reader := bytes.NewReader(nil) // Placeholder
_ = reader
}()

return ch, nil
}

// ListModelsResponse represents the response from /v1/models or similar
type ListModelsResponse struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}

// ListModels returns available models from LiteLLM
func (c *LiteLLMClient) ListModels() ([]string, error) {
url := c.baseURL + "/v1/models"

httpReq, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}

if c.apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
}

resp, err := c.client.Do(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch models, status: %d", resp.StatusCode)
}

var result ListModelsResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}

var models []string
for _, m := range result.Data {
models = append(models, m.ID)
}

return models, nil
}

// Helper
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s[:len(substr)] == substr || contains(s[1:], substr))
}
5 changes: 3 additions & 2 deletions backend/pkg/messaging/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"sef/internal/validation"
"sef/pkg/providers"
"sef/pkg/rag"
"sef/pkg/toon"
"sef/pkg/toolrunners"
"sef/pkg/toon"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -596,6 +596,7 @@ func (s *MessagingService) GenerateChatResponse(session *entities.Session, messa
factory := &providers.ProviderFactory{}
providerConfig := map[string]interface{}{
"base_url": session.Chatbot.Provider.BaseURL,
"api_key": session.Chatbot.Provider.ApiKey,
}

// Validate provider configuration
Expand Down Expand Up @@ -673,7 +674,7 @@ func (s *MessagingService) GenerateChatResponse(session *entities.Session, messa
for i, toolDef := range toolDefinitions {
log.Infof("Tool %d: %s", i+1, toolDef.Function.Name)
/*if toonContent, ok := toolDef.Function.Parameters["toon_content"].(string); ok {
log.Infof("TOON Content:\n%s", toonContent)
log.Infof("TOON Content:\n%s", toonContent)
}*/ // Uncomment for full content logging
}
log.Info("====================================")
Expand Down
Loading