Skip to content

Commit

Permalink
feat: add anthropic support (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed Mar 9, 2024
1 parent 32ee8b0 commit b1706e1
Show file tree
Hide file tree
Showing 5 changed files with 562 additions and 0 deletions.
28 changes: 28 additions & 0 deletions examples/llm/antropic/multimodal/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package main

import (
"context"
"fmt"

"github.com/henomis/lingoose/llm/antropic"
"github.com/henomis/lingoose/thread"
)

func main() {
antropicllm := antropic.New().WithModel("claude-3-opus-20240229")

t := thread.New().AddMessage(
thread.NewUserMessage().AddContent(
thread.NewTextContent("Can you describe the image?"),
).AddContent(
thread.NewImageContentFromURL("https://upload.wikimedia.org/wikipedia/commons/thumb/3/34/Anser_anser_1_%28Piotr_Kuczynski%29.jpg/1280px-Anser_anser_1_%28Piotr_Kuczynski%29.jpg"),
),
)

err := antropicllm.Generate(context.Background(), t)
if err != nil {
panic(err)
}

fmt.Println(t)
}
34 changes: 34 additions & 0 deletions examples/llm/antropic/stream/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package main

import (
"context"
"fmt"

"github.com/henomis/lingoose/llm/antropic"
"github.com/henomis/lingoose/thread"
)

func main() {
antropicllm := antropic.New().WithModel("claude-3-opus-20240229").WithStream(
func(response string) {
if response != antropic.EOS {
fmt.Print(response)
} else {
fmt.Println()
}
},
)

t := thread.New().AddMessage(
thread.NewUserMessage().AddContent(
thread.NewTextContent("How are you?"),
),
)

err := antropicllm.Generate(context.Background(), t)
if err != nil {
panic(err)
}

fmt.Println(t)
}
251 changes: 251 additions & 0 deletions llm/antropic/antropic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
package antropic

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strings"

"github.com/henomis/lingoose/llm/cache"
"github.com/henomis/lingoose/thread"
"github.com/henomis/restclientgo"
)

const (
defaultModel = "claude-3-opus-20240229"
eventStreamContentType = "text/event-stream"
jsonContentType = "application/json"
defaultEndpoint = "https://api.anthropic.com/v1"
)

var (
ErrAnthropicChat = fmt.Errorf("anthropic chat error")
)

var threadRoleToAnthropicRole = map[thread.Role]string{
thread.RoleSystem: "system",
thread.RoleUser: "user",
thread.RoleAssistant: "assistant",
}

const (
defaultAPIVersion = "2023-06-01"
defaultMaxTokens = 1024
EOS = "\x00"
)

type StreamCallbackFn func(string)

type Antropic struct {
model string
temperature float64
restClient *restclientgo.RestClient
streamCallbackFn StreamCallbackFn
cache *cache.Cache
apiVersion string
apiKey string
maxTokens int
}

func New() *Antropic {
apiKey := os.Getenv("ANTHROPIC_API_KEY")

return &Antropic{
restClient: restclientgo.New(defaultEndpoint).WithRequestModifier(
func(req *http.Request) *http.Request {
req.Header.Set("x-api-key", apiKey)
req.Header.Set("anthropic-version", defaultAPIVersion)
return req
},
),
model: defaultModel,
apiVersion: defaultAPIVersion,
apiKey: apiKey,
maxTokens: defaultMaxTokens,
}
}

func (o *Antropic) WithModel(model string) *Antropic {
o.model = model
return o
}

func (o *Antropic) WithStream(callbackFn StreamCallbackFn) *Antropic {
o.streamCallbackFn = callbackFn
return o
}

func (o *Antropic) WithCache(cache *cache.Cache) *Antropic {
o.cache = cache
return o
}

func (o *Antropic) WithTemperature(temperature float64) *Antropic {
o.temperature = temperature
return o
}

func (o *Antropic) WithMaxTokens(maxTokens int) *Antropic {
o.maxTokens = maxTokens
return o
}

func (o *Antropic) getCache(ctx context.Context, t *thread.Thread) (*cache.Result, error) {
messages := t.UserQuery()
cacheQuery := strings.Join(messages, "\n")
cacheResult, err := o.cache.Get(ctx, cacheQuery)
if err != nil {
return cacheResult, err
}

t.AddMessage(thread.NewAssistantMessage().AddContent(
thread.NewTextContent(strings.Join(cacheResult.Answer, "\n")),
))

return cacheResult, nil
}

func (o *Antropic) setCache(ctx context.Context, t *thread.Thread, cacheResult *cache.Result) error {
lastMessage := t.LastMessage()

if lastMessage.Role != thread.RoleAssistant || len(lastMessage.Contents) == 0 {
return nil
}

contents := make([]string, 0)
for _, content := range lastMessage.Contents {
if content.Type == thread.ContentTypeText {
contents = append(contents, content.Data.(string))
} else {
contents = make([]string, 0)
break
}
}

err := o.cache.Set(ctx, cacheResult.Embedding, strings.Join(contents, "\n"))
if err != nil {
return err
}

return nil
}

func (o *Antropic) Generate(ctx context.Context, t *thread.Thread) error {
if t == nil {
return nil
}

var err error
var cacheResult *cache.Result
if o.cache != nil {
cacheResult, err = o.getCache(ctx, t)
if err == nil {
return nil
} else if !errors.Is(err, cache.ErrCacheMiss) {
return fmt.Errorf("%w: %w", ErrAnthropicChat, err)
}
}

chatRequest := o.buildChatCompletionRequest(t)

if o.streamCallbackFn != nil {
err = o.stream(ctx, t, chatRequest)
} else {
err = o.generate(ctx, t, chatRequest)
}

if err != nil {
return err
}

if o.cache != nil {
err = o.setCache(ctx, t, cacheResult)
if err != nil {
return fmt.Errorf("%w: %w", ErrAnthropicChat, err)
}
}

return nil
}

func (o *Antropic) generate(ctx context.Context, t *thread.Thread, chatRequest *request) error {
var resp response

err := o.restClient.Post(
ctx,
chatRequest,
&resp,
)
if err != nil {
return fmt.Errorf("%w: %w", ErrAnthropicChat, err)
}

m := thread.NewAssistantMessage()

for _, content := range resp.Content {
if content.Type == messageTypeText && content.Text != nil {
m.AddContent(
thread.NewTextContent(*content.Text),
)
}
}

t.AddMessage(m)

return nil
}

func (o *Antropic) stream(ctx context.Context, t *thread.Thread, chatRequest *request) error {
var resp response
var assistantMessage string

resp.SetAcceptContentType(eventStreamContentType)
resp.SetStreamCallback(
func(data []byte) error {
dataAsString := string(data)
if !strings.HasPrefix(dataAsString, "data: ") {
return nil
}

dataAsString = strings.Replace(dataAsString, "data: ", "", -1)

var e event
_ = json.Unmarshal([]byte(dataAsString), &e)

if e.Type == "content_block_delta" {
if e.Delta != nil {
assistantMessage += e.Delta.Text
o.streamCallbackFn(e.Delta.Text)
}
} else if e.Type == "message_stop" {
o.streamCallbackFn(EOS)
}

return nil
},
)

chatRequest.Stream = true

err := o.restClient.Post(
ctx,
chatRequest,
&resp,
)
if err != nil {
return fmt.Errorf("%w: %w", ErrAnthropicChat, err)
}

if resp.HTTPStatusCode >= http.StatusBadRequest {
return fmt.Errorf("%w: %s", ErrAnthropicChat, resp.RawBody)
}

t.AddMessage(thread.NewAssistantMessage().AddContent(
thread.NewTextContent(assistantMessage),
))

return nil
}
Loading

0 comments on commit b1706e1

Please sign in to comment.