From 83195d23bb1a421f9f088eba8f51af6f0795b3c6 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 2 Nov 2023 14:14:16 +0100 Subject: [PATCH] Feat Add redis as vector database (#145) --- .github/workflows/checks.yml | 7 +- examples/embeddings/qdrant/main.go | 1 - examples/embeddings/redis/main.go | 107 +++++++++++++ go.mod | 2 + go.sum | 10 ++ index/vectordb/milvus/milvus.go | 2 - index/vectordb/qdrant/qdrant.go | 16 +- index/vectordb/redis/redis.go | 244 +++++++++++++++++++++++++++++ 8 files changed, 372 insertions(+), 17 deletions(-) create mode 100644 examples/embeddings/redis/main.go create mode 100644 index/vectordb/redis/redis.go diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index d95ab420..c299e963 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -33,4 +33,9 @@ jobs: cache: false - uses: actions/checkout@v3 - name: golangci-lint - uses: golangci/golangci-lint-action@v3 \ No newline at end of file + uses: golangci/golangci-lint-action@v3 + with: + # Require: The version of golangci-lint to use. + # When `install-mode` is `binary` (default) the value can be v1.2 or v1.2.3 or `latest` to use the latest version. + # When `install-mode` is `goinstall` the value can be v1.2.3, `latest`, or the hash of a commit. + version: v1.54.2 \ No newline at end of file diff --git a/examples/embeddings/qdrant/main.go b/examples/embeddings/qdrant/main.go index 6c24ad26..546aaf41 100644 --- a/examples/embeddings/qdrant/main.go +++ b/examples/embeddings/qdrant/main.go @@ -23,7 +23,6 @@ func main() { qdrantdb.New( qdrantdb.Options{ CollectionName: "test", - IncludeContent: true, CreateCollection: &qdrantdb.CreateCollectionOptions{ Dimension: 1536, Distance: qdrantdb.DistanceCosine, diff --git a/examples/embeddings/redis/main.go b/examples/embeddings/redis/main.go new file mode 100644 index 00000000..dc6eb295 --- /dev/null +++ b/examples/embeddings/redis/main.go @@ -0,0 +1,107 @@ +package main + +import ( + "context" + "fmt" + + "github.com/RediSearch/redisearch-go/v2/redisearch" + openaiembedder "github.com/henomis/lingoose/embedder/openai" + "github.com/henomis/lingoose/index" + indexoption "github.com/henomis/lingoose/index/option" + "github.com/henomis/lingoose/index/vectordb/redis" + "github.com/henomis/lingoose/llm/openai" + "github.com/henomis/lingoose/loader" + "github.com/henomis/lingoose/prompt" + "github.com/henomis/lingoose/textsplitter" +) + +// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt + +func main() { + index := index.New( + redis.New( + redis.Options{ + RedisearchClient: redisearch.NewClient("localhost:6379", "test"), + CreateIndex: &redis.CreateIndexOptions{ + Dimension: 1536, + Distance: redis.DistanceCosine, + }, + }, + ), + openaiembedder.New(openaiembedder.AdaEmbeddingV2), + ).WithIncludeContents(true) + + indexIsEmpty, err := index.IsEmpty(context.Background()) + if err != nil { + panic(err) + } + + if indexIsEmpty { + err = ingestData(index) + if err != nil { + panic(err) + } + } + + query := "What is the purpose of the NATO Alliance?" + similarities, err := index.Query( + context.Background(), + query, + indexoption.WithTopK(3), + ) + if err != nil { + panic(err) + } + + content := "" + for _, similarity := range similarities { + fmt.Printf("Similarity: %f\n", similarity.Score) + fmt.Printf("Document: %s\n", similarity.Content()) + fmt.Println("Metadata: ", similarity.Metadata) + fmt.Println("ID: ", similarity.ID) + fmt.Println("----------") + content += similarity.Content() + "\n" + } + + llmOpenAI := openai.NewCompletion().WithVerbose(true) + + prompt1 := prompt.NewPromptTemplate( + "Based on the following context answer to the question.\n\nContext:\n{{.context}}\n\nQuestion: {{.query}}"). + WithInputs( + map[string]string{ + "query": query, + "context": content, + }, + ) + + err = prompt1.Format(nil) + if err != nil { + panic(err) + } + + _, err = llmOpenAI.Completion(context.Background(), prompt1.String()) + if err != nil { + panic(err) + } +} + +func ingestData(redisIndex *index.Index) error { + documents, err := loader.NewDirectoryLoader(".", ".txt").Load(context.Background()) + if err != nil { + return err + } + + textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(1000, 20) + + documentChunks := textSplitter.SplitDocuments(documents) + + for _, doc := range documentChunks { + fmt.Println(doc.Content) + fmt.Println("----------") + fmt.Println(doc.Metadata) + fmt.Println("----------") + fmt.Println() + } + + return redisIndex.LoadFromDocuments(context.Background(), documentChunks) +} diff --git a/go.mod b/go.mod index e9f77a30..ae7f7359 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ toolchain go1.21.1 require github.com/mitchellh/mapstructure v1.5.0 require ( + github.com/RediSearch/redisearch-go/v2 v2.1.1 github.com/google/uuid v1.3.0 github.com/henomis/cohere-go v1.0.1 github.com/henomis/milvus-go v0.0.4 @@ -19,6 +20,7 @@ require ( require ( github.com/dlclark/regexp2 v1.8.1 // indirect + github.com/gomodule/redigo v1.8.9 // indirect github.com/henomis/restclientgo v1.0.6 // indirect github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0 // indirect ) diff --git a/go.sum b/go.sum index e6c67ff5..9a0bc354 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,12 @@ +github.com/RediSearch/redisearch-go/v2 v2.1.1 h1:cCn3i40uLsVD8cxwrdrGfhdAgbR5Cld9q11eYyVOwpM= +github.com/RediSearch/redisearch-go/v2 v2.1.1/go.mod h1:Uw93Wi97QqAsw1DwbQrhVd88dBorGTfSuCS42zfh1iA= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= +github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/henomis/cohere-go v1.0.1 h1:a47gIN29tqAl4yBTAT+BzQMjsWG94Fz07u9AE4Md+a8= @@ -28,8 +32,14 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/sashabaranov/go-openai v1.12.0 h1:aRNHH0gtVfrpIaEolD0sWrLLRnYQNK4cH/bIAHwL8Rk= github.com/sashabaranov/go-openai v1.12.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/index/vectordb/milvus/milvus.go b/index/vectordb/milvus/milvus.go index a64507bd..8c75136d 100644 --- a/index/vectordb/milvus/milvus.go +++ b/index/vectordb/milvus/milvus.go @@ -42,8 +42,6 @@ type CreateCollectionOptions struct { type Options struct { DatabaseName *string CollectionName string - IncludeContent bool - IncludeValues bool BatchUpsertSize *int CreateCollection *CreateCollectionOptions } diff --git a/index/vectordb/qdrant/qdrant.go b/index/vectordb/qdrant/qdrant.go index 1882e843..66d7c7d2 100644 --- a/index/vectordb/qdrant/qdrant.go +++ b/index/vectordb/qdrant/qdrant.go @@ -16,8 +16,6 @@ import ( type DB struct { qdrantClient *qdrantgo.Client collectionName string - includeContent bool - includeValues bool createCollection *CreateCollectionOptions } @@ -38,9 +36,6 @@ type CreateCollectionOptions struct { type Options struct { CollectionName string - IncludeContent bool - IncludeValues bool - BatchUpsertSize *int CreateCollection *CreateCollectionOptions } @@ -53,8 +48,6 @@ func New(options Options) *DB { return &DB{ qdrantClient: qdrantClient, collectionName: options.CollectionName, - includeContent: options.IncludeContent, - includeValues: options.IncludeValues, createCollection: options.CreateCollection, } } @@ -126,7 +119,7 @@ func (d *DB) Search(ctx context.Context, values []float64, options *option.Optio return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) } - return buildSearchResultsFromQdrantMatches(matches, d.includeContent), nil + return buildSearchResultsFromQdrantMatches(matches), nil } func (d *DB) similaritySearch( @@ -139,6 +132,7 @@ func (d *DB) similaritySearch( } includeMetadata := true + includeValues := true res := &qdrantresponse.PointSearch{} err := d.qdrantClient.PointSearch( ctx, @@ -147,7 +141,7 @@ func (d *DB) similaritySearch( Limit: opts.TopK, Vector: values, WithPayload: &includeMetadata, - WithVector: &d.includeValues, + WithVector: &includeValues, Filter: opts.Filter.(qdrantrequest.Filter), }, res, @@ -195,15 +189,11 @@ func (d *DB) createCollectionIfRequired(ctx context.Context) error { func buildSearchResultsFromQdrantMatches( matches []qdrantresponse.PointSearchResult, - includeContent bool, ) index.SearchResults { searchResults := make([]index.SearchResult, len(matches)) for i, match := range matches { metadata := index.DeepCopyMetadata(match.Payload) - if !includeContent { - delete(metadata, index.DefaultKeyContent) - } searchResults[i] = index.SearchResult{ Data: index.Data{ diff --git a/index/vectordb/redis/redis.go b/index/vectordb/redis/redis.go new file mode 100644 index 00000000..8a8c3041 --- /dev/null +++ b/index/vectordb/redis/redis.go @@ -0,0 +1,244 @@ +package redis + +import ( + "context" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/google/uuid" + "github.com/henomis/lingoose/index" + "github.com/henomis/lingoose/index/option" + + "github.com/RediSearch/redisearch-go/v2/redisearch" +) + +const ( + errUnknownIndexName = "Unknown index name" +) + +type DB struct { + redisearchClient *redisearch.Client + createIndex *CreateIndexOptions +} + +type Distance string + +const ( + DistanceCosine Distance = "COSINE" + DistanceEuclidean Distance = "IP" + DistanceDot Distance = "L2" + + defaultVectorFieldName = "vec" + defaultVectorScoreFieldName = "__vec_score" +) + +type CreateIndexOptions struct { + Dimension uint64 + Distance Distance +} + +type Options struct { + RedisearchClient *redisearch.Client + CreateIndex *CreateIndexOptions +} + +func New(options Options) *DB { + return &DB{ + redisearchClient: options.RedisearchClient, + createIndex: options.CreateIndex, + } +} + +func (d *DB) IsEmpty(ctx context.Context) (bool, error) { + err := d.createIndexIfRequired(ctx) + if err != nil { + return true, fmt.Errorf("%w: %w", index.ErrInternal, err) + } + + indexInfo, err := d.redisearchClient.Info() + if err != nil { + return true, fmt.Errorf("%w: %w", index.ErrInternal, err) + } + + return indexInfo.DocCount == 0, nil +} + +func (d *DB) Insert(ctx context.Context, datas []index.Data) error { + err := d.createIndexIfRequired(ctx) + if err != nil { + return fmt.Errorf("%w: %w", index.ErrInternal, err) + } + + var documents []redisearch.Document + for _, data := range datas { + if data.ID == "" { + id, errUUID := uuid.NewUUID() + if errUUID != nil { + return errUUID + } + data.ID = id.String() + } + + document := redisearch.NewDocument(data.ID, 1.0) + + for key, value := range data.Metadata { + document.Set(key, value) + } + + document.Set(defaultVectorFieldName, float64tobytes(data.Values)) + + documents = append(documents, document) + } + + if err = d.redisearchClient.Index(documents...); err != nil { + return fmt.Errorf("%w: %w", index.ErrInternal, err) + } + + return nil +} + +func (d *DB) Search(ctx context.Context, values []float64, options *option.Options) (index.SearchResults, error) { + matches, err := d.similaritySearch(ctx, values, options) + if err != nil { + return nil, fmt.Errorf("%w: %w", index.ErrInternal, err) + } + + return buildSearchResultsFromRedisDocuments(matches), nil +} + +func (d *DB) similaritySearch( + _ context.Context, + values []float64, + opts *option.Options, +) ([]redisearch.Document, error) { + if opts.Filter == nil { + opts.Filter = redisearch.Filter{} + } + + docs, _, err := d.redisearchClient.Search( + redisearch.NewQuery(fmt.Sprintf("(*)=>[KNN %d @vec $query_vector]", opts.TopK)). + SetSortBy(defaultVectorScoreFieldName, true). + SetFlags(redisearch.QueryWithPayloads). + SetDialect(2). + Limit(0, opts.TopK). + AddParam("query_vector", float64tobytes(values)). + AddFilter(opts.Filter.(redisearch.Filter)), + ) + + return docs, err +} + +func (d *DB) createIndexIfRequired(_ context.Context) error { + if d.createIndex == nil { + return nil + } + + indexName := "" + indexInfo, err := d.redisearchClient.Info() + if err != nil && (err.Error() != errUnknownIndexName) { + return err + } else if err == nil { + indexName = indexInfo.Name + } + + indexes, err := d.redisearchClient.List() + if err != nil { + return err + } + + if len(indexes) > 0 && len(indexName) > 0 { + for _, index := range indexes { + if index == indexInfo.Name { + return nil + } + } + } + + return d.redisearchClient.CreateIndex( + redisearch.NewSchema(redisearch.DefaultOptions). + AddField(redisearch.NewVectorFieldOptions( + defaultVectorFieldName, + redisearch.VectorFieldOptions{ + Algorithm: redisearch.Flat, + Attributes: map[string]interface{}{ + "TYPE": "FLOAT32", + "DIM": d.createIndex.Dimension, + "DISTANCE_METRIC": d.createIndex.Distance, + }})), + ) +} + +func buildSearchResultsFromRedisDocuments( + documents []redisearch.Document, +) index.SearchResults { + searchResults := make([]index.SearchResult, len(documents)) + + for i, match := range documents { + metadata := index.DeepCopyMetadata(match.Properties) + + score := 0.0 + scoreField, fieldExists := match.Properties[defaultVectorScoreFieldName] + if fieldExists { + scoreAsString, ok := scoreField.(string) + if ok { + score, _ = strconv.ParseFloat(scoreAsString, 64) + delete(metadata, defaultVectorScoreFieldName) + } + } + + values := []float64{} + vectorField, fieldExists := match.Properties[defaultVectorFieldName] + if fieldExists { + vectorAsString, ok := vectorField.(string) + if ok { + values = bytestofloat64([]byte(vectorAsString)) + delete(metadata, defaultVectorFieldName) + } + } + + searchResults[i] = index.SearchResult{ + Data: index.Data{ + ID: match.Id, + Metadata: metadata, + Values: values, + }, + Score: score, + } + } + + return searchResults +} + +func float64to32(floats []float64) []float32 { + floats32 := make([]float32, len(floats)) + for i, f := range floats { + floats32[i] = float32(f) + } + return floats32 +} + +func float64tobytes(floats64 []float64) []byte { + floats := float64to32(floats64) + + byteSlice := make([]byte, len(floats)*4) + for i, f := range floats { + bits := math.Float32bits(f) + binary.LittleEndian.PutUint32(byteSlice[i*4:], bits) + } + return byteSlice +} +func bytestofloat64(byteSlice []byte) []float64 { + floats := make([]float32, len(byteSlice)/4) + for i := 0; i < len(byteSlice); i += 4 { + bits := binary.LittleEndian.Uint32(byteSlice[i : i+4]) + floats[i/4] = math.Float32frombits(bits) + } + + floats64 := make([]float64, len(floats)) + for i, f := range floats { + floats64[i] = float64(f) + } + return floats64 +}