diff --git a/llms/anythingllm.go b/llms/anythingllm.go index 637961a..4b473a6 100644 --- a/llms/anythingllm.go +++ b/llms/anythingllm.go @@ -1,6 +1,7 @@ package llms import ( + "bufio" "bytes" "context" "encoding/json" @@ -36,11 +37,36 @@ type chatAttachment struct { } type chatResponse struct { - ID string `json:"id"` - Type string `json:"type"` + ID string `json:"id"` + ChatId int `json:"chatId"` + Type string `json:"type"` + TextResponse string `json:"textResponse"` + Close bool `json:"close"` + Error string `json:"error"` + Metrics chatMetrics `json:"metrics"` +} + +const streamChatDataEventPrefix = "data: " + +type streamChatResponse struct { + Uuid string `json:"uuid"` + Type string `json:"type"` + TextResponse string `json:"textResponse"` - Close bool `json:"close"` - Error string `json:"error"` + + Metrics chatMetrics `json:"metrics"` + ChatId int `json:"chatId"` + + Close bool `json:"close"` + Error bool `json:"error"` +} + +type chatMetrics struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + OutputTokenPerSec float64 `json:"outputTps"` + Duration float64 `json:"duration"` } type threadRequest struct { @@ -100,15 +126,28 @@ func (a *AnythingLLM) GenerateContent(ctx context.Context, messages []llms.Messa req.Attachments = append(req.Attachments, attachment) } } + opts := getOpts(options) - result, err := a.doChat(ctx, req) + var result *chatResponse + if opts.StreamingFunc == nil { + result, err = a.doChat(ctx, req) + } else { + result, err = a.doStreamChat(ctx, req, opts.StreamingFunc) + } if err != nil { return nil, fmt.Errorf("error calling anythingllm: %w", err) } return &llms.ContentResponse{ Choices: []*llms.ContentChoice{ - {Content: result}, + { + Content: result.TextResponse, + GenerationInfo: map[string]any{ + "id": result.ID, + "chatId": result.ChatId, + "metrics": result.Metrics, + }, + }, }, }, nil } @@ -119,11 +158,33 @@ func (a *AnythingLLM) Call(ctx context.Context, prompt string, options ...llms.C return "", err } - return a.doChat(ctx, chatRequest{ + opts := getOpts(options) + req := chatRequest{ Message: prompt, Mode: "chat", SessionID: a.threadSlug, - }) + } + + var result *chatResponse + if opts.StreamingFunc == nil { + result, err = a.doChat(ctx, req) + } else { + result, err = a.doStreamChat(ctx, req, opts.StreamingFunc) + } + + if err != nil { + return "", fmt.Errorf("error calling anythingllm: %w", err) + } + + return result.TextResponse, nil +} + +func getOpts(options []llms.CallOption) llms.CallOptions { + opts := llms.CallOptions{} + for _, o := range options { + o(&opts) + } + return opts } func (a *AnythingLLM) ensureThread(ctx context.Context) error { @@ -136,16 +197,16 @@ func (a *AnythingLLM) ensureThread(ctx context.Context) error { return nil } -func (a *AnythingLLM) doChat(ctx context.Context, request chatRequest) (string, error) { +func (a *AnythingLLM) doChat(ctx context.Context, request chatRequest) (*chatResponse, error) { url := fmt.Sprintf("%s/api/v1/workspace/%s/thread/%s/chat", a.baseURL, a.workspace, a.threadSlug) jsonPayload, err := json.Marshal(request) if err != nil { - return "", fmt.Errorf("error marshalling payload: %w", err) + return nil, fmt.Errorf("error marshalling payload: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonPayload)) if err != nil { - return "", fmt.Errorf("error creating request: %w", err) + return nil, fmt.Errorf("error creating request: %w", err) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token)) req.Header.Set("Content-Type", "application/json") @@ -153,20 +214,83 @@ func (a *AnythingLLM) doChat(ctx context.Context, request chatRequest) (string, resp, err := a.client.Do(req) if err != nil { - return "", fmt.Errorf("error making request: %w", err) + return nil, fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) } var result chatResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", fmt.Errorf("error decoding response: %w", err) + return nil, fmt.Errorf("error decoding response: %w", err) } - return result.TextResponse, nil + return &result, nil +} + +func (a *AnythingLLM) doStreamChat(ctx context.Context, request chatRequest, streamFn func(ctx context.Context, chunk []byte) error) (*chatResponse, error) { + url := fmt.Sprintf("%s/api/v1/workspace/%s/thread/%s/stream-chat", a.baseURL, a.workspace, a.threadSlug) + jsonPayload, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("error marshalling payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonPayload)) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + result := &chatResponse{} + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, streamChatDataEventPrefix) { + continue + } + + line = line[len(streamChatDataEventPrefix):] + + var parsed streamChatResponse + if err := json.Unmarshal([]byte(line), &parsed); err != nil { + return nil, fmt.Errorf("error decoding response: %w", err) + } + + if parsed.Error { + return nil, fmt.Errorf("error in response: %s", line) + } + + switch parsed.Type { + case "textResponseChunk": + se := streamFn(ctx, []byte(parsed.TextResponse)) + if se != nil { + return nil, fmt.Errorf("error in streaming function: %w", se) + } + result.TextResponse += parsed.TextResponse + case "finalizeResponseStream": + result.ChatId = parsed.ChatId + result.Metrics = parsed.Metrics + } + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading response: %w", err) + } + + return result, nil } func (a *AnythingLLM) createNewThread(ctx context.Context) error {