Skip to content

Commit

Permalink
implement streaming support for anythingLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
rainu committed Jan 25, 2025
1 parent 67ed5e0 commit 2d239b9
Showing 1 changed file with 139 additions and 15 deletions.
154 changes: 139 additions & 15 deletions llms/anythingllm.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package llms

import (
"bufio"
"bytes"
"context"
"encoding/json"
Expand Down Expand Up @@ -36,11 +37,36 @@ type chatAttachment struct {
}

type chatResponse struct {
ID string `json:"id"`
Type string `json:"type"`
ID string `json:"id"`
ChatId int `json:"chatId"`
Type string `json:"type"`
TextResponse string `json:"textResponse"`
Close bool `json:"close"`
Error string `json:"error"`
Metrics chatMetrics `json:"metrics"`
}

const streamChatDataEventPrefix = "data: "

type streamChatResponse struct {
Uuid string `json:"uuid"`
Type string `json:"type"`

TextResponse string `json:"textResponse"`
Close bool `json:"close"`
Error string `json:"error"`

Metrics chatMetrics `json:"metrics"`
ChatId int `json:"chatId"`

Close bool `json:"close"`
Error bool `json:"error"`
}

type chatMetrics struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
OutputTokenPerSec float64 `json:"outputTps"`
Duration float64 `json:"duration"`
}

type threadRequest struct {
Expand Down Expand Up @@ -100,15 +126,28 @@ func (a *AnythingLLM) GenerateContent(ctx context.Context, messages []llms.Messa
req.Attachments = append(req.Attachments, attachment)
}
}
opts := getOpts(options)

result, err := a.doChat(ctx, req)
var result *chatResponse
if opts.StreamingFunc == nil {
result, err = a.doChat(ctx, req)
} else {
result, err = a.doStreamChat(ctx, req, opts.StreamingFunc)
}
if err != nil {
return nil, fmt.Errorf("error calling anythingllm: %w", err)
}

return &llms.ContentResponse{
Choices: []*llms.ContentChoice{
{Content: result},
{
Content: result.TextResponse,
GenerationInfo: map[string]any{
"id": result.ID,
"chatId": result.ChatId,
"metrics": result.Metrics,
},
},
},
}, nil
}
Expand All @@ -119,11 +158,33 @@ func (a *AnythingLLM) Call(ctx context.Context, prompt string, options ...llms.C
return "", err
}

return a.doChat(ctx, chatRequest{
opts := getOpts(options)
req := chatRequest{
Message: prompt,
Mode: "chat",
SessionID: a.threadSlug,
})
}

var result *chatResponse
if opts.StreamingFunc == nil {
result, err = a.doChat(ctx, req)
} else {
result, err = a.doStreamChat(ctx, req, opts.StreamingFunc)
}

if err != nil {
return "", fmt.Errorf("error calling anythingllm: %w", err)
}

return result.TextResponse, nil
}

func getOpts(options []llms.CallOption) llms.CallOptions {
opts := llms.CallOptions{}
for _, o := range options {
o(&opts)
}
return opts
}

func (a *AnythingLLM) ensureThread(ctx context.Context) error {
Expand All @@ -136,37 +197,100 @@ func (a *AnythingLLM) ensureThread(ctx context.Context) error {
return nil
}

func (a *AnythingLLM) doChat(ctx context.Context, request chatRequest) (string, error) {
func (a *AnythingLLM) doChat(ctx context.Context, request chatRequest) (*chatResponse, error) {
url := fmt.Sprintf("%s/api/v1/workspace/%s/thread/%s/chat", a.baseURL, a.workspace, a.threadSlug)
jsonPayload, err := json.Marshal(request)
if err != nil {
return "", fmt.Errorf("error marshalling payload: %w", err)
return nil, fmt.Errorf("error marshalling payload: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonPayload))
if err != nil {
return "", fmt.Errorf("error creating request: %w", err)
return nil, fmt.Errorf("error creating request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")

resp, err := a.client.Do(req)
if err != nil {
return "", fmt.Errorf("error making request: %w", err)
return nil, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

var result chatResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("error decoding response: %w", err)
return nil, fmt.Errorf("error decoding response: %w", err)
}

return result.TextResponse, nil
return &result, nil
}

func (a *AnythingLLM) doStreamChat(ctx context.Context, request chatRequest, streamFn func(ctx context.Context, chunk []byte) error) (*chatResponse, error) {
url := fmt.Sprintf("%s/api/v1/workspace/%s/thread/%s/stream-chat", a.baseURL, a.workspace, a.threadSlug)
jsonPayload, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshalling payload: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonPayload))
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")

resp, err := a.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

result := &chatResponse{}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, streamChatDataEventPrefix) {
continue
}

line = line[len(streamChatDataEventPrefix):]

var parsed streamChatResponse
if err := json.Unmarshal([]byte(line), &parsed); err != nil {
return nil, fmt.Errorf("error decoding response: %w", err)
}

if parsed.Error {
return nil, fmt.Errorf("error in response: %s", line)
}

switch parsed.Type {
case "textResponseChunk":
se := streamFn(ctx, []byte(parsed.TextResponse))
if se != nil {
return nil, fmt.Errorf("error in streaming function: %w", se)
}
result.TextResponse += parsed.TextResponse
case "finalizeResponseStream":
result.ChatId = parsed.ChatId
result.Metrics = parsed.Metrics
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading response: %w", err)
}

return result, nil
}

func (a *AnythingLLM) createNewThread(ctx context.Context) error {
Expand Down

0 comments on commit 2d239b9

Please sign in to comment.