From abfe5acb1f91e7dc96784569fee73bf13ddcadd9 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 16 Jun 2023 19:14:45 +0200 Subject: [PATCH] Feat/81 implement qdrant support (#89) --- .github/FUNDING.yml | 3 - embedder/embedding.go | 10 + examples/embeddings/qdrant/main.go | 109 ++++++++++ go.mod | 1 + go.sum | 2 + index/pinecone.go | 18 +- index/qdrant.go | 309 +++++++++++++++++++++++++++++ 7 files changed, 440 insertions(+), 12 deletions(-) delete mode 100644 .github/FUNDING.yml create mode 100644 examples/embeddings/qdrant/main.go create mode 100644 index/qdrant.go diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml deleted file mode 100644 index c2bd53aa..00000000 --- a/.github/FUNDING.yml +++ /dev/null @@ -1,3 +0,0 @@ -# Supported funding model platforms - -github: [henomis] \ No newline at end of file diff --git a/embedder/embedding.go b/embedder/embedding.go index 68838839..b371a893 100644 --- a/embedder/embedding.go +++ b/embedder/embedding.go @@ -6,3 +6,13 @@ var ( // Embedding is the result of an embedding operation. type Embedding []float64 + +func (e Embedding) ToFloat32() []float32 { + + vect := make([]float32, len(e)) + for i, v := range e { + vect[i] = float32(v) + } + + return vect +} diff --git a/examples/embeddings/qdrant/main.go b/examples/embeddings/qdrant/main.go new file mode 100644 index 00000000..1b7a843f --- /dev/null +++ b/examples/embeddings/qdrant/main.go @@ -0,0 +1,109 @@ +package main + +import ( + "context" + "fmt" + + openaiembedder "github.com/henomis/lingoose/embedder/openai" + "github.com/henomis/lingoose/index" + "github.com/henomis/lingoose/llm/openai" + "github.com/henomis/lingoose/loader" + "github.com/henomis/lingoose/prompt" + "github.com/henomis/lingoose/textsplitter" +) + +// download https://frontiernerds.com/files/state_of_the_union.txt + +func main() { + + openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2) + + qdrantIndex := index.NewQdrant( + index.QdrantOptions{ + CollectionName: "test", + IncludeContent: true, + CreateCollection: &index.QdrantCreateCollectionOptions{ + Dimension: 1536, + Distance: index.QdrantDistanceCosine, + }, + }, + openaiEmbedder, + ).WithAPIKeyAndEdpoint("", "http://localhost:6333") + + indexIsEmpty, err := qdrantIndex.IsEmpty(context.Background()) + if err != nil { + panic(err) + } + + if indexIsEmpty { + err = ingestData(qdrantIndex) + if err != nil { + panic(err) + } + } + + query := "What is the purpose of the NATO Alliance?" + similarities, err := qdrantIndex.SimilaritySearch( + context.Background(), + query, + index.WithTopK(3), + ) + if err != nil { + panic(err) + } + + content := "" + for _, similarity := range similarities { + fmt.Printf("Similarity: %f\n", similarity.Score) + fmt.Printf("Document: %s\n", similarity.Document.Content) + fmt.Println("Metadata: ", similarity.Document.Metadata) + fmt.Println("ID: ", similarity.ID) + fmt.Println("----------") + content += similarity.Document.Content + "\n" + } + + llmOpenAI := openai.NewCompletion().WithVerbose(true) + + prompt1 := prompt.NewPromptTemplate( + "Based on the following context answer to the question.\n\nContext:\n{{.context}}\n\nQuestion: {{.query}}").WithInputs( + map[string]string{ + "query": query, + "context": content, + }, + ) + + err = prompt1.Format(nil) + if err != nil { + panic(err) + } + + _, err = llmOpenAI.Completion(context.Background(), prompt1.String()) + if err != nil { + panic(err) + } + +} + +func ingestData(qdrantIndex *index.Qdrant) error { + + documents, err := loader.NewDirectoryLoader(".", ".txt").Load(context.Background()) + if err != nil { + return err + } + + textSplitter := textsplitter.NewRecursiveCharacterTextSplitter(1000, 20) + + documentChunks := textSplitter.SplitDocuments(documents) + + for _, doc := range documentChunks { + fmt.Println(doc.Content) + fmt.Println("----------") + fmt.Println(doc.Metadata) + fmt.Println("----------") + fmt.Println() + + } + + return qdrantIndex.LoadFromDocuments(context.Background(), documentChunks) + +} diff --git a/go.mod b/go.mod index 316617b8..5c8eca9b 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require github.com/mitchellh/mapstructure v1.5.0 require ( github.com/google/uuid v1.3.0 github.com/henomis/pinecone-go v1.1.1 + github.com/henomis/qdrant-go v1.0.0 github.com/pkoukk/tiktoken-go v0.1.1 github.com/sashabaranov/go-openai v1.11.1 ) diff --git a/go.sum b/go.sum index ed72eff6..67222962 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/henomis/pinecone-go v1.1.1 h1:uZmRi1XD6J/fA02Nsbb5TuGggwLet9//t0LLX/dXOds= github.com/henomis/pinecone-go v1.1.1/go.mod h1:FsMMRjLyiJ9zHqGOlmGvjolqOp2kkbMsRm8oc85vykU= +github.com/henomis/qdrant-go v1.0.0 h1:KVd9aTvObVJgQFznM0FPMn3+zC4O1ekXgRcG61bW130= +github.com/henomis/qdrant-go v1.0.0/go.mod h1:CJ+imAe+WK3ntoIn7v7sSqimGu+/In/7ijhhT0MC5WU= github.com/henomis/restclientgo v1.0.3 h1:y5+ydfvWJ0/7crObdnCHSn7ya/h1whD+PV4Ir2dZ9Ig= github.com/henomis/restclientgo v1.0.3/go.mod h1:xIeTCu2ZstvRn0fCukNpzXLN3m/kRTU0i0RwAbv7Zug= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= diff --git a/index/pinecone.go b/index/pinecone.go index 3a31438c..cf025ddc 100644 --- a/index/pinecone.go +++ b/index/pinecone.go @@ -16,8 +16,8 @@ import ( ) const ( - defaultPineconeTopK = 10 - defaultBatchUpsertSize = 32 + defaultPineconeTopK = 10 + defaultPineconeBatchUpsertSize = 32 ) type Pinecone struct { @@ -54,7 +54,7 @@ func NewPinecone(options PineconeOptions, embedder Embedder) *Pinecone { pineconeClient := pineconego.New(environment, apiKey) - batchUpsertSize := defaultBatchUpsertSize + batchUpsertSize := defaultPineconeBatchUpsertSize if options.BatchUpsertSize != nil { batchUpsertSize = *options.BatchUpsertSize } @@ -140,7 +140,7 @@ func (p *Pinecone) SimilaritySearch(ctx context.Context, query string, opts ...O return nil, fmt.Errorf("%s: %w", ErrInternal, err) } - searchResponses := buildSearchReponsesFromMatches(matches, p.includeContent) + searchResponses := buildSearchReponsesFromPineconeMatches(matches, p.includeContent) return filterSearchResponses(searchResponses, pineconeOptions.topK), nil } @@ -253,9 +253,9 @@ func (p *Pinecone) createIndexIfRequired(ctx context.Context) error { func (p *Pinecone) batchUpsert(ctx context.Context, documents []document.Document) error { - for i := 0; i < len(documents); i += defaultBatchUpsertSize { + for i := 0; i < len(documents); i += defaultPineconeBatchUpsertSize { - batchEnd := i + defaultBatchUpsertSize + batchEnd := i + defaultPineconeBatchUpsertSize if batchEnd > len(documents) { batchEnd = len(documents) } @@ -270,7 +270,7 @@ func (p *Pinecone) batchUpsert(ctx context.Context, documents []document.Documen return err } - vectors, err := buildVectorsFromEmbeddingsAndDocuments(embeddings, documents, i, p.includeContent) + vectors, err := buildPineconeVectorsFromEmbeddingsAndDocuments(embeddings, documents, i, p.includeContent) if err != nil { return err } @@ -319,7 +319,7 @@ func deepCopyMetadata(metadata types.Meta) types.Meta { return metadataCopy } -func buildVectorsFromEmbeddingsAndDocuments( +func buildPineconeVectorsFromEmbeddingsAndDocuments( embeddings []embedder.Embedding, documents []document.Document, startIndex int, @@ -355,7 +355,7 @@ func buildVectorsFromEmbeddingsAndDocuments( return vectors, nil } -func buildSearchReponsesFromMatches(matches []pineconeresponse.QueryMatch, includeContent bool) SearchResponses { +func buildSearchReponsesFromPineconeMatches(matches []pineconeresponse.QueryMatch, includeContent bool) SearchResponses { searchResponses := make([]SearchResponse, len(matches)) for i, match := range matches { diff --git a/index/qdrant.go b/index/qdrant.go new file mode 100644 index 00000000..6d678ffe --- /dev/null +++ b/index/qdrant.go @@ -0,0 +1,309 @@ +package index + +import ( + "context" + "fmt" + "os" + + "github.com/google/uuid" + "github.com/henomis/lingoose/document" + "github.com/henomis/lingoose/embedder" + qdrantgo "github.com/henomis/qdrant-go" + qdrantrequest "github.com/henomis/qdrant-go/request" + qdrantresponse "github.com/henomis/qdrant-go/response" +) + +const ( + defaultQdrantTopK = 10 + defaultQdrantBatchUpsertSize = 32 +) + +type Qdrant struct { + qdrantClient *qdrantgo.Client + collectionName string + embedder Embedder + includeContent bool + batchUpsertSize int + + createCollection *QdrantCreateCollectionOptions +} + +type QdrantDistance string + +const ( + QdrantDistanceCosine QdrantDistance = QdrantDistance(qdrantrequest.DistanceCosine) + QdrantDistanceEuclidean QdrantDistance = QdrantDistance(qdrantrequest.DistanceEuclidean) + QdrantDistanceDot QdrantDistance = QdrantDistance(qdrantrequest.DistanceDot) +) + +type QdrantCreateCollectionOptions struct { + Dimension uint64 + Distance QdrantDistance + OnDisk bool +} + +type QdrantOptions struct { + CollectionName string + IncludeContent bool + BatchUpsertSize *int + CreateCollection *QdrantCreateCollectionOptions +} + +func NewQdrant(options QdrantOptions, embedder Embedder) *Qdrant { + + apiKey := os.Getenv("QDRANT_API_KEY") + endpoint := os.Getenv("QDRANT_ENDPOINT") + + qdrantClient := qdrantgo.New(endpoint, apiKey) + + batchUpsertSize := defaultQdrantBatchUpsertSize + if options.BatchUpsertSize != nil { + batchUpsertSize = *options.BatchUpsertSize + } + + return &Qdrant{ + qdrantClient: qdrantClient, + collectionName: options.CollectionName, + embedder: embedder, + includeContent: options.IncludeContent, + batchUpsertSize: batchUpsertSize, + createCollection: options.CreateCollection, + } +} + +func (q *Qdrant) WithAPIKeyAndEdpoint(apiKey, endpoint string) *Qdrant { + q.qdrantClient = qdrantgo.New(endpoint, apiKey) + return q +} + +func (q *Qdrant) LoadFromDocuments(ctx context.Context, documents []document.Document) error { + + err := q.createCollectionIfRequired(ctx) + if err != nil { + return fmt.Errorf("%s: %w", ErrInternal, err) + } + + err = q.batchUpsert(ctx, documents) + if err != nil { + return fmt.Errorf("%s: %w", ErrInternal, err) + } + return nil +} + +func (p *Qdrant) IsEmpty(ctx context.Context) (bool, error) { + + err := p.createCollectionIfRequired(ctx) + if err != nil { + return true, fmt.Errorf("%s: %w", ErrInternal, err) + } + + res := &qdrantresponse.CollectionCollectInfo{} + err = p.qdrantClient.CollectionCollectInfo( + ctx, + &qdrantrequest.CollectionCollectInfo{ + CollectionName: p.collectionName, + }, + res, + ) + if err != nil { + return true, fmt.Errorf("%s: %w", ErrInternal, err) + } + + return res.Result.VectorsCount == 0, nil + +} + +func (q *Qdrant) SimilaritySearch(ctx context.Context, query string, opts ...Option) (SearchResponses, error) { + + qdrantOptions := &options{ + topK: defaultQdrantTopK, + } + + for _, opt := range opts { + opt(qdrantOptions) + } + + matches, err := q.similaritySearch(ctx, query, qdrantOptions.topK) + if err != nil { + return nil, fmt.Errorf("%s: %w", ErrInternal, err) + } + + searchResponses := buildSearchReponsesFromQdrantMatches(matches, q.includeContent) + + return filterSearchResponses(searchResponses, qdrantOptions.topK), nil +} + +func (p *Qdrant) similaritySearch(ctx context.Context, query string, topK int) ([]qdrantresponse.PointSearchResult, error) { + + embeddings, err := p.embedder.Embed(ctx, []string{query}) + if err != nil { + return nil, err + } + + includeMetadata := true + res := &qdrantresponse.PointSearch{} + err = p.qdrantClient.PointSearch( + ctx, + &qdrantrequest.PointSearch{ + CollectionName: p.collectionName, + Limit: topK, + Vector: embeddings[0], + WithPayload: &includeMetadata, + }, + res, + ) + if err != nil { + return nil, err + } + + return res.Result, nil +} + +func (q *Qdrant) createCollectionIfRequired(ctx context.Context) error { + + if q.createCollection == nil { + return nil + } + + resp := &qdrantresponse.CollectionList{} + err := q.qdrantClient.CollectionList(ctx, &qdrantrequest.CollectionList{}, resp) + if err != nil { + return err + } + + for _, collection := range resp.Result.Collections { + if collection.Name == q.collectionName { + return nil + } + } + + req := &qdrantrequest.CollectionCreate{ + CollectionName: q.collectionName, + Vectors: qdrantrequest.VectorsParams{ + Size: q.createCollection.Dimension, + Distance: qdrantrequest.Distance(q.createCollection.Distance), + OnDisk: &q.createCollection.OnDisk, + }, + } + + err = q.qdrantClient.CollectionCreate(ctx, req, &qdrantresponse.CollectionCreate{}) + if err != nil { + return err + } + + return nil +} + +func (q *Qdrant) batchUpsert(ctx context.Context, documents []document.Document) error { + + for i := 0; i < len(documents); i += q.batchUpsertSize { + + batchEnd := i + q.batchUpsertSize + if batchEnd > len(documents) { + batchEnd = len(documents) + } + + texts := []string{} + for _, document := range documents[i:batchEnd] { + texts = append(texts, document.Content) + } + + embeddings, err := q.embedder.Embed(ctx, texts) + if err != nil { + return err + } + + points, err := buildQdrantPointsFromEmbeddingsAndDocuments(embeddings, documents, i, q.includeContent) + if err != nil { + return err + } + + err = q.pointUpsert(ctx, points) + if err != nil { + return err + } + } + + return nil +} + +func (q *Qdrant) pointUpsert(ctx context.Context, points []qdrantrequest.Point) error { + + wait := true + req := &qdrantrequest.PointUpsert{ + Wait: &wait, + CollectionName: q.collectionName, + Points: points, + } + res := &qdrantresponse.PointUpsert{} + + err := q.qdrantClient.PointUpsert(ctx, req, res) + if err != nil { + return err + } + + return nil +} + +func buildQdrantPointsFromEmbeddingsAndDocuments( + embeddings []embedder.Embedding, + documents []document.Document, + startIndex int, + includeContent bool, +) ([]qdrantrequest.Point, error) { + + var vectors []qdrantrequest.Point + + for i, embedding := range embeddings { + + metadata := deepCopyMetadata(documents[startIndex+i].Metadata) + + // inject document content into vector metadata + if includeContent { + metadata[defaultKeyContent] = documents[startIndex+i].Content + } + + vectorID, err := uuid.NewUUID() + if err != nil { + return nil, err + } + + vectors = append(vectors, qdrantrequest.Point{ + ID: vectorID.String(), + Vector: embedding, + Payload: metadata, + }) + + // inject vector ID into document metadata + documents[startIndex+i].Metadata[defaultKeyID] = vectorID.String() + } + + return vectors, nil +} + +func buildSearchReponsesFromQdrantMatches(matches []qdrantresponse.PointSearchResult, includeContent bool) SearchResponses { + searchResponses := make([]SearchResponse, len(matches)) + + for i, match := range matches { + + metadata := deepCopyMetadata(match.Payload) + + content := "" + // extract document content from vector metadata + if includeContent { + content = metadata[defaultKeyContent].(string) + delete(metadata, defaultKeyContent) + } + + searchResponses[i] = SearchResponse{ + ID: match.ID, + Document: document.Document{ + Metadata: metadata, + Content: content, + }, + Score: match.Score, + } + } + + return searchResponses +}