Skip to content

Commit

Permalink
feat: add voyage embedd and rerank (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed Mar 9, 2024
1 parent ba47574 commit 525cbb0
Show file tree
Hide file tree
Showing 5 changed files with 453 additions and 0 deletions.
79 changes: 79 additions & 0 deletions embedder/voyage/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package voyageembedder

import (
"bytes"
"encoding/json"
"io"

"github.com/henomis/lingoose/embedder"
"github.com/henomis/restclientgo"
)

type request struct {
Model string `json:"model"`
Input []string `json:"input"`
}

func (r *request) Path() (string, error) {
return "/embeddings", nil
}

func (r *request) Encode() (io.Reader, error) {
jsonBytes, err := json.Marshal(r)
if err != nil {
return nil, err
}

return bytes.NewReader(jsonBytes), nil
}

func (r *request) ContentType() string {
return "application/json"
}

type response struct {
HTTPStatusCode int `json:"-"`
acceptContentType string `json:"-"`
Object string `json:"object"`
Data []data `json:"data"`
Model string `json:"model"`
RawBody []byte `json:"-"`
}

type data struct {
Object string `json:"object"`
Embedding embedder.Embedding `json:"embedding"`
Index int `json:"index"`
}

func (r *response) SetAcceptContentType(contentType string) {
r.acceptContentType = contentType
}

func (r *response) Decode(body io.Reader) error {
return json.NewDecoder(body).Decode(r)
}

func (r *response) SetBody(body io.Reader) error {
b, err := io.ReadAll(body)
if err != nil {
return err
}

r.RawBody = b
return nil
}

func (r *response) AcceptContentType() string {
if r.acceptContentType != "" {
return r.acceptContentType
}
return "application/json"
}

func (r *response) SetStatusCode(code int) error {
r.HTTPStatusCode = code
return nil
}

func (r *response) SetHeaders(_ restclientgo.Headers) error { return nil }
66 changes: 66 additions & 0 deletions embedder/voyage/voyage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package voyageembedder

import (
"context"
"net/http"
"os"

"github.com/henomis/lingoose/embedder"
"github.com/henomis/restclientgo"
)

const (
defaultModel = "voyage-2"
defaultEndpoint = "https://api.voyageai.com/v1"
)

type Embedder struct {
model string
restClient *restclientgo.RestClient
}

func New() *Embedder {
apiKey := os.Getenv("VOYAGE_API_KEY")

return &Embedder{
restClient: restclientgo.New(defaultEndpoint).WithRequestModifier(
func(req *http.Request) *http.Request {
req.Header.Set("Authorization", "Bearer "+apiKey)
return req
}),
model: defaultModel,
}
}

func (e *Embedder) WithModel(model string) *Embedder {
e.model = model
return e
}

// Embed returns the embeddings for the given texts
func (e *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) {
return e.embed(ctx, texts)
}

// Embed returns the embeddings for the given texts
func (e *Embedder) embed(ctx context.Context, text []string) ([]embedder.Embedding, error) {
resp := &response{}
err := e.restClient.Post(
ctx,
&request{
Input: text,
Model: e.model,
},
resp,
)
if err != nil {
return nil, err
}

embeddings := make([]embedder.Embedding, len(resp.Data))
for i, data := range resp.Data {
embeddings[i] = data.Embedding
}

return embeddings, nil
}
102 changes: 102 additions & 0 deletions examples/embeddings/voyage/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package main

import (
"context"
"fmt"

voyageembedder "github.com/henomis/lingoose/embedder/voyage"
"github.com/henomis/lingoose/index"
indexoption "github.com/henomis/lingoose/index/option"
"github.com/henomis/lingoose/index/vectordb/jsondb"
"github.com/henomis/lingoose/llm/antropic"
"github.com/henomis/lingoose/loader"
"github.com/henomis/lingoose/textsplitter"
"github.com/henomis/lingoose/thread"
"github.com/henomis/lingoose/types"
)

// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt

func main() {

index := index.New(
jsondb.New().WithPersist("db.json"),
voyageembedder.New().WithModel("voyage-2"),
).WithIncludeContents(true).WithAddDataCallback(func(data *index.Data) error {
data.Metadata["contentLen"] = len(data.Metadata["content"].(string))
return nil
})

indexIsEmpty, _ := index.IsEmpty(context.Background())

if indexIsEmpty {
err := ingestData(index)
if err != nil {
panic(err)
}
}

query := "What is the purpose of the NATO Alliance?"
similarities, err := index.Query(
context.Background(),
query,
indexoption.WithTopK(3),
)
if err != nil {
panic(err)
}

for _, similarity := range similarities {
fmt.Printf("Similarity: %f\n", similarity.Score)
fmt.Printf("Document: %s\n", similarity.Content())
fmt.Println("Metadata: ", similarity.Metadata)
fmt.Println("----------")
}

documentContext := ""
for _, similarity := range similarities {
documentContext += similarity.Content() + "\n\n"
}

antropicllm := antropic.New().WithModel("claude-3-opus-20240229")
t := thread.New()
t.AddMessage(thread.NewUserMessage().AddContent(
thread.NewTextContent("Based on the following context answer to the" +
"question.\n\nContext:\n{{.context}}\n\nQuestion: {{.query}}").Format(
types.M{
"query": query,
"context": documentContext,
},
),
))

err = antropicllm.Generate(context.Background(), t)
if err != nil {
panic(err)
}

fmt.Println(t)
}

func ingestData(index *index.Index) error {

fmt.Printf("Ingesting data...")

documents, err := loader.NewDirectoryLoader(".", ".txt").Load(context.Background())
if err != nil {
return err
}

textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(1000, 20)

documentChunks := textSplitter.SplitDocuments(documents)

err = index.LoadFromDocuments(context.Background(), documentChunks)
if err != nil {
return err
}

fmt.Printf("Done!\n")

return nil
}
36 changes: 36 additions & 0 deletions examples/transformer/voyage-rerank/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package main

import (
"context"
"fmt"

"github.com/henomis/lingoose/document"
"github.com/henomis/lingoose/transformer"
)

func main() {

r := transformer.NewVoyageRerank()

documents, err := r.Rerank(
context.Background(),
"What is the capital of the United States?",
[]document.Document{
{
Content: "Carson City is the capital city of the American state of Nevada.",
}, {
Content: "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
}, {
Content: "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
},
},
)
if err != nil {
panic(err)
}

for _, doc := range documents {
fmt.Println(doc.GetEnrichedContent())
fmt.Println("-----")
}
}
Loading

0 comments on commit 525cbb0

Please sign in to comment.