From b1706e1b31aa29455a966add978731cffb7b1404 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Sat, 9 Mar 2024 16:26:20 +0100 Subject: [PATCH] feat: add anthropic support (#178) --- examples/llm/antropic/multimodal/main.go | 28 +++ examples/llm/antropic/stream/main.go | 34 +++ llm/antropic/antropic.go | 251 +++++++++++++++++++++++ llm/antropic/api.go | 169 +++++++++++++++ llm/antropic/formatter.go | 80 ++++++++ 5 files changed, 562 insertions(+) create mode 100644 examples/llm/antropic/multimodal/main.go create mode 100644 examples/llm/antropic/stream/main.go create mode 100644 llm/antropic/antropic.go create mode 100644 llm/antropic/api.go create mode 100644 llm/antropic/formatter.go diff --git a/examples/llm/antropic/multimodal/main.go b/examples/llm/antropic/multimodal/main.go new file mode 100644 index 00000000..2240c188 --- /dev/null +++ b/examples/llm/antropic/multimodal/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "context" + "fmt" + + "github.com/henomis/lingoose/llm/antropic" + "github.com/henomis/lingoose/thread" +) + +func main() { + antropicllm := antropic.New().WithModel("claude-3-opus-20240229") + + t := thread.New().AddMessage( + thread.NewUserMessage().AddContent( + thread.NewTextContent("Can you describe the image?"), + ).AddContent( + thread.NewImageContentFromURL("https://upload.wikimedia.org/wikipedia/commons/thumb/3/34/Anser_anser_1_%28Piotr_Kuczynski%29.jpg/1280px-Anser_anser_1_%28Piotr_Kuczynski%29.jpg"), + ), + ) + + err := antropicllm.Generate(context.Background(), t) + if err != nil { + panic(err) + } + + fmt.Println(t) +} diff --git a/examples/llm/antropic/stream/main.go b/examples/llm/antropic/stream/main.go new file mode 100644 index 00000000..a42b4a87 --- /dev/null +++ b/examples/llm/antropic/stream/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "context" + "fmt" + + "github.com/henomis/lingoose/llm/antropic" + "github.com/henomis/lingoose/thread" +) + +func main() { + antropicllm := antropic.New().WithModel("claude-3-opus-20240229").WithStream( + func(response string) { + if response != antropic.EOS { + fmt.Print(response) + } else { + fmt.Println() + } + }, + ) + + t := thread.New().AddMessage( + thread.NewUserMessage().AddContent( + thread.NewTextContent("How are you?"), + ), + ) + + err := antropicllm.Generate(context.Background(), t) + if err != nil { + panic(err) + } + + fmt.Println(t) +} diff --git a/llm/antropic/antropic.go b/llm/antropic/antropic.go new file mode 100644 index 00000000..3c75f0c7 --- /dev/null +++ b/llm/antropic/antropic.go @@ -0,0 +1,251 @@ +package antropic + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "strings" + + "github.com/henomis/lingoose/llm/cache" + "github.com/henomis/lingoose/thread" + "github.com/henomis/restclientgo" +) + +const ( + defaultModel = "claude-3-opus-20240229" + eventStreamContentType = "text/event-stream" + jsonContentType = "application/json" + defaultEndpoint = "https://api.anthropic.com/v1" +) + +var ( + ErrAnthropicChat = fmt.Errorf("anthropic chat error") +) + +var threadRoleToAnthropicRole = map[thread.Role]string{ + thread.RoleSystem: "system", + thread.RoleUser: "user", + thread.RoleAssistant: "assistant", +} + +const ( + defaultAPIVersion = "2023-06-01" + defaultMaxTokens = 1024 + EOS = "\x00" +) + +type StreamCallbackFn func(string) + +type Antropic struct { + model string + temperature float64 + restClient *restclientgo.RestClient + streamCallbackFn StreamCallbackFn + cache *cache.Cache + apiVersion string + apiKey string + maxTokens int +} + +func New() *Antropic { + apiKey := os.Getenv("ANTHROPIC_API_KEY") + + return &Antropic{ + restClient: restclientgo.New(defaultEndpoint).WithRequestModifier( + func(req *http.Request) *http.Request { + req.Header.Set("x-api-key", apiKey) + req.Header.Set("anthropic-version", defaultAPIVersion) + return req + }, + ), + model: defaultModel, + apiVersion: defaultAPIVersion, + apiKey: apiKey, + maxTokens: defaultMaxTokens, + } +} + +func (o *Antropic) WithModel(model string) *Antropic { + o.model = model + return o +} + +func (o *Antropic) WithStream(callbackFn StreamCallbackFn) *Antropic { + o.streamCallbackFn = callbackFn + return o +} + +func (o *Antropic) WithCache(cache *cache.Cache) *Antropic { + o.cache = cache + return o +} + +func (o *Antropic) WithTemperature(temperature float64) *Antropic { + o.temperature = temperature + return o +} + +func (o *Antropic) WithMaxTokens(maxTokens int) *Antropic { + o.maxTokens = maxTokens + return o +} + +func (o *Antropic) getCache(ctx context.Context, t *thread.Thread) (*cache.Result, error) { + messages := t.UserQuery() + cacheQuery := strings.Join(messages, "\n") + cacheResult, err := o.cache.Get(ctx, cacheQuery) + if err != nil { + return cacheResult, err + } + + t.AddMessage(thread.NewAssistantMessage().AddContent( + thread.NewTextContent(strings.Join(cacheResult.Answer, "\n")), + )) + + return cacheResult, nil +} + +func (o *Antropic) setCache(ctx context.Context, t *thread.Thread, cacheResult *cache.Result) error { + lastMessage := t.LastMessage() + + if lastMessage.Role != thread.RoleAssistant || len(lastMessage.Contents) == 0 { + return nil + } + + contents := make([]string, 0) + for _, content := range lastMessage.Contents { + if content.Type == thread.ContentTypeText { + contents = append(contents, content.Data.(string)) + } else { + contents = make([]string, 0) + break + } + } + + err := o.cache.Set(ctx, cacheResult.Embedding, strings.Join(contents, "\n")) + if err != nil { + return err + } + + return nil +} + +func (o *Antropic) Generate(ctx context.Context, t *thread.Thread) error { + if t == nil { + return nil + } + + var err error + var cacheResult *cache.Result + if o.cache != nil { + cacheResult, err = o.getCache(ctx, t) + if err == nil { + return nil + } else if !errors.Is(err, cache.ErrCacheMiss) { + return fmt.Errorf("%w: %w", ErrAnthropicChat, err) + } + } + + chatRequest := o.buildChatCompletionRequest(t) + + if o.streamCallbackFn != nil { + err = o.stream(ctx, t, chatRequest) + } else { + err = o.generate(ctx, t, chatRequest) + } + + if err != nil { + return err + } + + if o.cache != nil { + err = o.setCache(ctx, t, cacheResult) + if err != nil { + return fmt.Errorf("%w: %w", ErrAnthropicChat, err) + } + } + + return nil +} + +func (o *Antropic) generate(ctx context.Context, t *thread.Thread, chatRequest *request) error { + var resp response + + err := o.restClient.Post( + ctx, + chatRequest, + &resp, + ) + if err != nil { + return fmt.Errorf("%w: %w", ErrAnthropicChat, err) + } + + m := thread.NewAssistantMessage() + + for _, content := range resp.Content { + if content.Type == messageTypeText && content.Text != nil { + m.AddContent( + thread.NewTextContent(*content.Text), + ) + } + } + + t.AddMessage(m) + + return nil +} + +func (o *Antropic) stream(ctx context.Context, t *thread.Thread, chatRequest *request) error { + var resp response + var assistantMessage string + + resp.SetAcceptContentType(eventStreamContentType) + resp.SetStreamCallback( + func(data []byte) error { + dataAsString := string(data) + if !strings.HasPrefix(dataAsString, "data: ") { + return nil + } + + dataAsString = strings.Replace(dataAsString, "data: ", "", -1) + + var e event + _ = json.Unmarshal([]byte(dataAsString), &e) + + if e.Type == "content_block_delta" { + if e.Delta != nil { + assistantMessage += e.Delta.Text + o.streamCallbackFn(e.Delta.Text) + } + } else if e.Type == "message_stop" { + o.streamCallbackFn(EOS) + } + + return nil + }, + ) + + chatRequest.Stream = true + + err := o.restClient.Post( + ctx, + chatRequest, + &resp, + ) + if err != nil { + return fmt.Errorf("%w: %w", ErrAnthropicChat, err) + } + + if resp.HTTPStatusCode >= http.StatusBadRequest { + return fmt.Errorf("%w: %s", ErrAnthropicChat, resp.RawBody) + } + + t.AddMessage(thread.NewAssistantMessage().AddContent( + thread.NewTextContent(assistantMessage), + )) + + return nil +} diff --git a/llm/antropic/api.go b/llm/antropic/api.go new file mode 100644 index 00000000..51584088 --- /dev/null +++ b/llm/antropic/api.go @@ -0,0 +1,169 @@ +package antropic + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "os" + "strings" + + "github.com/henomis/restclientgo" +) + +type request struct { + Model string `json:"model"` + Messages []message `json:"messages"` + System string `json:"system"` + MaxTokens int `json:"max_tokens"` + Metadata metadata `json:"metadata"` + StopSequences []string `json:"stop_sequences"` + Stream bool `json:"stream"` + Temperature float64 `json:"temperature"` + TopP float64 `json:"top_p"` + TopK int `json:"top_k"` +} + +type metadata struct { + UserID string `json:"user_id"` +} + +func (r *request) Path() (string, error) { + return "/messages", nil +} + +func (r *request) Encode() (io.Reader, error) { + jsonBytes, err := json.Marshal(r) + if err != nil { + return nil, err + } + + return bytes.NewReader(jsonBytes), nil +} + +func (r *request) ContentType() string { + return jsonContentType +} + +type response struct { + HTTPStatusCode int `json:"-"` + acceptContentType string `json:"-"` + ID string `json:"id"` + Type string `json:"type"` + Error aerror `json:"error"` + Role string `json:"role"` + Content []content `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` + Usage usage `json:"usage"` + streamCallbackFn restclientgo.StreamCallback + RawBody []byte `json:"-"` +} + +type aerror struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type content struct { + Type contentType `json:"type"` + Text *string `json:"text,omitempty"` + Source *contentSource `json:"source,omitempty"` +} + +type contentSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +type usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +func (r *response) SetAcceptContentType(contentType string) { + r.acceptContentType = contentType +} + +func (r *response) Decode(body io.Reader) error { + return json.NewDecoder(body).Decode(r) +} + +func (r *response) SetBody(body io.Reader) error { + r.RawBody, _ = io.ReadAll(body) + return nil +} + +func (r *response) AcceptContentType() string { + if r.acceptContentType != "" { + return r.acceptContentType + } + return jsonContentType +} + +func (r *response) SetStatusCode(code int) error { + r.HTTPStatusCode = code + return nil +} + +func (r *response) SetHeaders(_ restclientgo.Headers) error { return nil } + +func (r *response) SetStreamCallback(fn restclientgo.StreamCallback) { + r.streamCallbackFn = fn +} + +func (r *response) StreamCallback() restclientgo.StreamCallback { + return r.streamCallbackFn +} + +type message struct { + Role string `json:"role"` + Content []content `json:"content"` +} + +type contentType string + +const ( + messageTypeText contentType = "text" + messageTypeImage contentType = "image" +) + +type event struct { + Type string `json:"type"` + Index *int `json:"index,omitempty"` + Delta *delta `json:"delta,omitempty"` +} + +type delta struct { + Type string `json:"type"` + Text string `json:"text"` +} + +func getImageDataAsBase64(imageURL string) (string, string, error) { + var imageData []byte + var err error + + if strings.HasPrefix(imageURL, "http://") || strings.HasPrefix(imageURL, "https://") { + //nolint:gosec + resp, fetchErr := http.Get(imageURL) + if fetchErr != nil { + return "", "", fetchErr + } + defer resp.Body.Close() + + imageData, err = io.ReadAll(resp.Body) + } else { + imageData, err = os.ReadFile(imageURL) + } + if err != nil { + return "", "", err + } + + // Detect image type + mimeType := http.DetectContentType(imageData) + + return base64.StdEncoding.EncodeToString(imageData), mimeType, nil +} diff --git a/llm/antropic/formatter.go b/llm/antropic/formatter.go new file mode 100644 index 00000000..014adb7e --- /dev/null +++ b/llm/antropic/formatter.go @@ -0,0 +1,80 @@ +package antropic + +import ( + "github.com/henomis/lingoose/thread" +) + +func (o *Antropic) buildChatCompletionRequest(t *thread.Thread) *request { + messages, systemPrompt := threadToChatMessages(t) + + return &request{ + Model: o.model, + Messages: messages, + System: systemPrompt, + MaxTokens: o.maxTokens, + Temperature: o.temperature, + } +} + +//nolint:gocognit +func threadToChatMessages(t *thread.Thread) ([]message, string) { + var systemPrompt string + var chatMessages []message + for _, m := range t.Messages { + switch m.Role { + case thread.RoleSystem: + for _, content := range m.Contents { + contentData, ok := content.Data.(string) + if !ok { + continue + } + + systemPrompt += contentData + } + case thread.RoleUser, thread.RoleAssistant: + chatMessage := message{ + Role: threadRoleToAnthropicRole[m.Role], + } + for _, c := range m.Contents { + contentData, ok := c.Data.(string) + if !ok { + continue + } + + if c.Type == thread.ContentTypeText { + chatMessage.Content = append( + chatMessage.Content, + content{ + Type: messageTypeText, + Text: &contentData, + }, + ) + } else if c.Type == thread.ContentTypeImage { + imageData, mimeType, err := getImageDataAsBase64(contentData) + if err != nil { + continue + } + + chatMessage.Content = append( + chatMessage.Content, + content{ + Type: messageTypeImage, + Source: &contentSource{ + Type: "base64", + Data: imageData, + MediaType: mimeType, + }, + }, + ) + } else { + continue + } + } + chatMessages = append(chatMessages, chatMessage) + case thread.RoleTool: + continue + } + } + + return chatMessages, systemPrompt +}