From bfb5d2ad43cc42d9902657e2677cf5885e0c81e0 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Sun, 20 Aug 2023 17:53:55 +0200 Subject: [PATCH] chore: refactor and document (#109) * chore: refactor and document * chore: refactor index * fix --- chat/chat.go | 2 + embedder/cohere/cohere.go | 3 + embedder/huggingface/huggingface.go | 3 + embedder/llamacpp/llamacpp.go | 4 + embedder/openai/openai.go | 2 + examples/embeddings/knowledge_base/main.go | 9 +- examples/embeddings/pinecone/main.go | 15 +-- examples/embeddings/qdrant/main.go | 17 +-- examples/embeddings/simpleVector/main.go | 14 ++- examples/embeddings/simplekb/main.go | 7 +- examples/pipeline/summarize/main.go | 2 +- index/index.go | 15 ++- index/option/option.go | 20 ++++ index/options.go | 20 ---- index/{ => pinecone}/pinecone.go | 103 +++++++++--------- index/{ => qdrant}/qdrant.go | 94 ++++++++-------- .../simpleVectorIndex.go | 77 ++++++------- llm/cohere/cohere.go | 9 ++ llm/huggingface/huggingface.go | 11 ++ llm/openai/openai.go | 97 ++++------------- 20 files changed, 260 insertions(+), 264 deletions(-) create mode 100644 index/option/option.go delete mode 100644 index/options.go rename index/{ => pinecone}/pinecone.go (70%) rename index/{ => qdrant}/qdrant.go (65%) rename index/{ => simpleVectorIndex}/simpleVectorIndex.go (51%) diff --git a/chat/chat.go b/chat/chat.go index 9f1f754d..2f67442b 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -56,6 +56,7 @@ func New(promptMessages ...PromptMessage) *Chat { return chatPromptTemplate } +// AddPromptMessages adds a list of chat prompt templates to the chat prompt template. func (c *Chat) AddPromptMessages(messages []PromptMessage) { for _, message := range messages { c.addMessagePromptTemplate(message) @@ -92,6 +93,7 @@ func (p *Chat) ToMessages() (Messages, error) { return messages, nil } +// PromptMessages returns the chat prompt messages. func (c *Chat) PromptMessages() PromptMessages { return c.promptMessages } diff --git a/embedder/cohere/cohere.go b/embedder/cohere/cohere.go index c3c66839..6da3aa8b 100644 --- a/embedder/cohere/cohere.go +++ b/embedder/cohere/cohere.go @@ -39,16 +39,19 @@ func New() *Embedder { } } +// WithAPIKey sets the API key to use for the embedder func (e *Embedder) WithAPIKey(apiKey string) *Embedder { e.client = coherego.New(apiKey) return e } +// WithModel sets the model to use for the embedder func (e *Embedder) WithModel(model EmbedderModel) *Embedder { e.model = model return e } +// Embed returns the embeddings for the given texts func (h *Embedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) { resp := &response.Embed{} err := h.client.Embed( diff --git a/embedder/huggingface/huggingface.go b/embedder/huggingface/huggingface.go index 12ae8b4c..432c0a24 100644 --- a/embedder/huggingface/huggingface.go +++ b/embedder/huggingface/huggingface.go @@ -23,16 +23,19 @@ func New() *HuggingFaceEmbedder { } } +// WithToken sets the API key to use for the embedder func (h *HuggingFaceEmbedder) WithToken(token string) *HuggingFaceEmbedder { h.token = token return h } +// WithModel sets the model to use for the embedder func (h *HuggingFaceEmbedder) WithModel(model string) *HuggingFaceEmbedder { h.model = model return h } +// Embed returns the embeddings for the given texts func (h *HuggingFaceEmbedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) { return h.featureExtraction(ctx, texts) } diff --git a/embedder/llamacpp/llamacpp.go b/embedder/llamacpp/llamacpp.go index 249e717d..33598f73 100644 --- a/embedder/llamacpp/llamacpp.go +++ b/embedder/llamacpp/llamacpp.go @@ -24,21 +24,25 @@ func New() *LlamaCppEmbedder { } } +// WithLlamaCppPath sets the path to the llamacpp binary func (l *LlamaCppEmbedder) WithLlamaCppPath(llamacppPath string) *LlamaCppEmbedder { l.llamacppPath = llamacppPath return l } +// WithModel sets the model to use for the embedder func (l *LlamaCppEmbedder) WithModel(modelPath string) *LlamaCppEmbedder { l.modelPath = modelPath return l } +// WithArgs sets the args to pass to the llamacpp binary func (l *LlamaCppEmbedder) WithArgs(llamacppArgs []string) *LlamaCppEmbedder { l.llamacppArgs = llamacppArgs return l } +// Embed returns the embeddings for the given texts func (o *LlamaCppEmbedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) { embeddings := make([]embedder.Embedding, len(texts)) diff --git a/embedder/openai/openai.go b/embedder/openai/openai.go index 1aec7b25..17648c28 100644 --- a/embedder/openai/openai.go +++ b/embedder/openai/openai.go @@ -71,11 +71,13 @@ func New(model Model) *OpenAIEmbedder { } } +// WithClient sets the OpenAI client to use for the embedder func (o *OpenAIEmbedder) WithClient(client *openai.Client) *OpenAIEmbedder { o.openAIClient = client return o } +// Embed returns the embeddings for the given texts func (o *OpenAIEmbedder) Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) { maxTokens := o.getMaxTokens() diff --git a/examples/embeddings/knowledge_base/main.go b/examples/embeddings/knowledge_base/main.go index e0526387..1cf900c8 100644 --- a/examples/embeddings/knowledge_base/main.go +++ b/examples/embeddings/knowledge_base/main.go @@ -8,7 +8,8 @@ import ( "github.com/henomis/lingoose/chat" openaiembedder "github.com/henomis/lingoose/embedder/openai" - "github.com/henomis/lingoose/index" + indexoption "github.com/henomis/lingoose/index/option" + simplevectorindex "github.com/henomis/lingoose/index/simpleVectorIndex" "github.com/henomis/lingoose/llm/openai" "github.com/henomis/lingoose/loader" "github.com/henomis/lingoose/prompt" @@ -24,7 +25,7 @@ func main() { openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2) - docsVectorIndex := index.NewSimpleVectorIndex("db", ".", openaiEmbedder) + docsVectorIndex := simplevectorindex.New("db", ".", openaiEmbedder) indexIsEmpty, _ := docsVectorIndex.IsEmpty() if indexIsEmpty { @@ -48,7 +49,7 @@ func main() { break } - similarities, err := docsVectorIndex.SimilaritySearch(context.Background(), query, index.WithTopK(3)) + similarities, err := docsVectorIndex.SimilaritySearch(context.Background(), query, indexoption.WithTopK(3)) if err != nil { panic(err) } @@ -97,7 +98,7 @@ func main() { } -func ingestData(docsVectorIndex *index.SimpleVectorIndex) error { +func ingestData(docsVectorIndex *simplevectorindex.Index) error { fmt.Printf("Learning Knowledge Base...") diff --git a/examples/embeddings/pinecone/main.go b/examples/embeddings/pinecone/main.go index 1d601f3e..4fb4dd5e 100644 --- a/examples/embeddings/pinecone/main.go +++ b/examples/embeddings/pinecone/main.go @@ -5,25 +5,26 @@ import ( "fmt" openaiembedder "github.com/henomis/lingoose/embedder/openai" - "github.com/henomis/lingoose/index" + indexoption "github.com/henomis/lingoose/index/option" + pineconeindex "github.com/henomis/lingoose/index/pinecone" "github.com/henomis/lingoose/llm/openai" "github.com/henomis/lingoose/loader" "github.com/henomis/lingoose/prompt" "github.com/henomis/lingoose/textsplitter" ) -// download https://frontiernerds.com/files/state_of_the_union.txt +// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt func main() { openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2) - pineconeIndex := index.NewPinecone( - index.PineconeOptions{ + pineconeIndex := pineconeindex.New( + pineconeindex.Options{ IndexName: "test", Namespace: "test-namespace", IncludeContent: true, - CreateIndex: &index.PineconeCreateIndexOptions{ + CreateIndex: &pineconeindex.CreateIndexOptions{ Dimension: 1536, Replicas: 1, Metric: "cosine", @@ -49,7 +50,7 @@ func main() { similarities, err := pineconeIndex.SimilaritySearch( context.Background(), query, - index.WithTopK(3), + indexoption.WithTopK(3), ) if err != nil { panic(err) @@ -87,7 +88,7 @@ func main() { } -func ingestData(pineconeIndex *index.Pinecone) error { +func ingestData(pineconeIndex *pineconeindex.Index) error { documents, err := loader.NewDirectoryLoader(".", ".txt").Load(context.Background()) if err != nil { diff --git a/examples/embeddings/qdrant/main.go b/examples/embeddings/qdrant/main.go index 366984b4..c717f656 100644 --- a/examples/embeddings/qdrant/main.go +++ b/examples/embeddings/qdrant/main.go @@ -5,7 +5,8 @@ import ( "fmt" openaiembedder "github.com/henomis/lingoose/embedder/openai" - "github.com/henomis/lingoose/index" + indexoption "github.com/henomis/lingoose/index/option" + qdrantindex "github.com/henomis/lingoose/index/qdrant" "github.com/henomis/lingoose/llm/openai" "github.com/henomis/lingoose/loader" "github.com/henomis/lingoose/prompt" @@ -13,19 +14,19 @@ import ( ) // download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt -// run qdrant docker run -p 6333:6333 qdrant/qdrant +// run qdrant docker run --rm -p 6333:6333 qdrant/qdrant func main() { openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2) - qdrantIndex := index.NewQdrant( - index.QdrantOptions{ + qdrantIndex := qdrantindex.New( + qdrantindex.Options{ CollectionName: "test", IncludeContent: true, - CreateCollection: &index.QdrantCreateCollectionOptions{ + CreateCollection: &qdrantindex.CreateCollectionOptions{ Dimension: 1536, - Distance: index.QdrantDistanceCosine, + Distance: qdrantindex.DistanceCosine, }, }, openaiEmbedder, @@ -47,7 +48,7 @@ func main() { similarities, err := qdrantIndex.SimilaritySearch( context.Background(), query, - index.WithTopK(3), + indexoption.WithTopK(3), ) if err != nil { panic(err) @@ -85,7 +86,7 @@ func main() { } -func ingestData(qdrantIndex *index.Qdrant) error { +func ingestData(qdrantIndex *qdrantindex.Index) error { documents, err := loader.NewDirectoryLoader(".", ".txt").Load(context.Background()) if err != nil { diff --git a/examples/embeddings/simpleVector/main.go b/examples/embeddings/simpleVector/main.go index cc6e4969..d210a08f 100644 --- a/examples/embeddings/simpleVector/main.go +++ b/examples/embeddings/simpleVector/main.go @@ -6,23 +6,25 @@ import ( openaiembedder "github.com/henomis/lingoose/embedder/openai" "github.com/henomis/lingoose/index" + indexoption "github.com/henomis/lingoose/index/option" + simplevectorindex "github.com/henomis/lingoose/index/simpleVectorIndex" "github.com/henomis/lingoose/llm/openai" "github.com/henomis/lingoose/loader" "github.com/henomis/lingoose/prompt" "github.com/henomis/lingoose/textsplitter" ) -// download https://frontiernerds.com/files/state_of_the_union.txt +// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt func main() { openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2) - docsVectorIndex := index.NewSimpleVectorIndex("docs", ".", openaiEmbedder) + docsVectorIndex := simplevectorindex.New("docs", ".", openaiEmbedder) indexIsEmpty, _ := docsVectorIndex.IsEmpty() if indexIsEmpty { - err := ingestData(openaiEmbedder) + err := ingestData(docsVectorIndex, openaiEmbedder) if err != nil { panic(err) } @@ -32,7 +34,7 @@ func main() { similarities, err := docsVectorIndex.SimilaritySearch( context.Background(), query, - index.WithTopK(3), + indexoption.WithTopK(3), ) if err != nil { panic(err) @@ -72,7 +74,7 @@ func main() { fmt.Println(output) } -func ingestData(openaiEmbedder index.Embedder) error { +func ingestData(docsVectorIndex *simplevectorindex.Index, openaiEmbedder index.Embedder) error { fmt.Printf("Ingesting data...") @@ -85,7 +87,7 @@ func ingestData(openaiEmbedder index.Embedder) error { documentChunks := textSplitter.SplitDocuments(documents) - err = index.NewSimpleVectorIndex("docs", ".", openaiEmbedder).LoadFromDocuments(context.Background(), documentChunks) + err = docsVectorIndex.LoadFromDocuments(context.Background(), documentChunks) if err != nil { return err } diff --git a/examples/embeddings/simplekb/main.go b/examples/embeddings/simplekb/main.go index df1af84f..82668a82 100644 --- a/examples/embeddings/simplekb/main.go +++ b/examples/embeddings/simplekb/main.go @@ -4,7 +4,8 @@ import ( "context" openaiembedder "github.com/henomis/lingoose/embedder/openai" - "github.com/henomis/lingoose/index" + indexoption "github.com/henomis/lingoose/index/option" + simplevectorindex "github.com/henomis/lingoose/index/simpleVectorIndex" "github.com/henomis/lingoose/llm/openai" "github.com/henomis/lingoose/loader" qapipeline "github.com/henomis/lingoose/pipeline/qa" @@ -15,7 +16,7 @@ func main() { query := "What is the NATO purpose?" docs, _ := loader.NewPDFToTextLoader("./kb").WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 200)).Load(context.Background()) openaiEmbedder := openaiembedder.New(openaiembedder.AdaEmbeddingV2) - index.NewSimpleVectorIndex("db", ".", openaiEmbedder).LoadFromDocuments(context.Background(), docs) - similarities, _ := index.NewSimpleVectorIndex("db", ".", openaiEmbedder).SimilaritySearch(context.Background(), query, index.WithTopK(3)) + simplevectorindex.New("db", ".", openaiEmbedder).LoadFromDocuments(context.Background(), docs) + similarities, _ := simplevectorindex.New("db", ".", openaiEmbedder).SimilaritySearch(context.Background(), query, indexoption.WithTopK(3)) qapipeline.New(openai.NewChat().WithVerbose(true)).Run(context.Background(), query, similarities.ToDocuments()) } diff --git a/examples/pipeline/summarize/main.go b/examples/pipeline/summarize/main.go index 8909b799..06783b48 100644 --- a/examples/pipeline/summarize/main.go +++ b/examples/pipeline/summarize/main.go @@ -9,7 +9,7 @@ import ( "github.com/henomis/lingoose/textsplitter" ) -// download https://frontiernerds.com/files/state_of_the_union.txt +// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt func main() { diff --git a/index/index.go b/index/index.go index 15f49ee7..b99a0c4b 100644 --- a/index/index.go +++ b/index/index.go @@ -6,6 +6,7 @@ import ( "github.com/henomis/lingoose/document" "github.com/henomis/lingoose/embedder" + "github.com/henomis/lingoose/types" ) var ( @@ -13,8 +14,8 @@ var ( ) const ( - defaultKeyID = "id" - defaultKeyContent = "content" + DefaultKeyID = "id" + DefaultKeyContent = "content" ) type SearchResponse struct { @@ -37,7 +38,7 @@ type Embedder interface { Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error) } -func filterSearchResponses(searchResponses SearchResponses, topK int) SearchResponses { +func FilterSearchResponses(searchResponses SearchResponses, topK int) SearchResponses { //sort by similarity score sort.Slice(searchResponses, func(i, j int) bool { return searchResponses[i].Score > searchResponses[j].Score @@ -50,3 +51,11 @@ func filterSearchResponses(searchResponses SearchResponses, topK int) SearchResp return searchResponses[:maxTopK] } + +func DeepCopyMetadata(metadata types.Meta) types.Meta { + metadataCopy := make(types.Meta) + for k, v := range metadata { + metadataCopy[k] = v + } + return metadataCopy +} diff --git a/index/option/option.go b/index/option/option.go new file mode 100644 index 00000000..f47137b0 --- /dev/null +++ b/index/option/option.go @@ -0,0 +1,20 @@ +package option + +type Option func(*Options) + +type Options struct { + TopK int + Filter interface{} +} + +func WithTopK(topK int) Option { + return func(opts *Options) { + opts.TopK = topK + } +} + +func WithFilter(filter interface{}) Option { + return func(opts *Options) { + opts.Filter = filter + } +} diff --git a/index/options.go b/index/options.go deleted file mode 100644 index 59022bd9..00000000 --- a/index/options.go +++ /dev/null @@ -1,20 +0,0 @@ -package index - -type Option func(*options) - -type options struct { - topK int - filter interface{} -} - -func WithTopK(topK int) Option { - return func(opts *options) { - opts.topK = topK - } -} - -func WithFilter(filter interface{}) Option { - return func(opts *options) { - opts.filter = filter - } -} diff --git a/index/pinecone.go b/index/pinecone/pinecone.go similarity index 70% rename from index/pinecone.go rename to index/pinecone/pinecone.go index ecc680d9..ae11fed3 100644 --- a/index/pinecone.go +++ b/index/pinecone/pinecone.go @@ -1,4 +1,4 @@ -package index +package pinecone import ( "context" @@ -9,57 +9,58 @@ import ( "github.com/google/uuid" "github.com/henomis/lingoose/document" "github.com/henomis/lingoose/embedder" - "github.com/henomis/lingoose/types" + "github.com/henomis/lingoose/index" + "github.com/henomis/lingoose/index/option" pineconego "github.com/henomis/pinecone-go" pineconerequest "github.com/henomis/pinecone-go/request" pineconeresponse "github.com/henomis/pinecone-go/response" ) const ( - defaultPineconeTopK = 10 - defaultPineconeBatchUpsertSize = 32 + defaultTopK = 10 + defaultBatchUpsertSize = 32 ) -type Pinecone struct { +type Index struct { pineconeClient *pineconego.PineconeGo indexName string projectID *string namespace string - embedder Embedder + embedder index.Embedder includeContent bool batchUpsertSize int - createIndex *PineconeCreateIndexOptions + createIndex *CreateIndexOptions } -type PineconeCreateIndexOptions struct { +type CreateIndexOptions struct { Dimension int Replicas int Metric string PodType string } -type PineconeOptions struct { +type Options struct { IndexName string Namespace string IncludeContent bool BatchUpsertSize *int - CreateIndex *PineconeCreateIndexOptions + CreateIndex *CreateIndexOptions } -func NewPinecone(options PineconeOptions, embedder Embedder) *Pinecone { +func New(options Options, embedder index.Embedder) *Index { apiKey := os.Getenv("PINECONE_API_KEY") environment := os.Getenv("PINECONE_ENVIRONMENT") pineconeClient := pineconego.New(environment, apiKey) - batchUpsertSize := defaultPineconeBatchUpsertSize + batchUpsertSize := defaultBatchUpsertSize if options.BatchUpsertSize != nil { batchUpsertSize = *options.BatchUpsertSize } - return &Pinecone{ + return &Index{ pineconeClient: pineconeClient, indexName: options.IndexName, embedder: embedder, @@ -70,35 +71,35 @@ func NewPinecone(options PineconeOptions, embedder Embedder) *Pinecone { } } -func (p *Pinecone) WithAPIKeyAndEnvironment(apiKey, environment string) *Pinecone { +func (p *Index) WithAPIKeyAndEnvironment(apiKey, environment string) *Index { p.pineconeClient = pineconego.New(environment, apiKey) return p } -func (p *Pinecone) LoadFromDocuments(ctx context.Context, documents []document.Document) error { +func (p *Index) LoadFromDocuments(ctx context.Context, documents []document.Document) error { err := p.createIndexIfRequired(ctx) if err != nil { - return fmt.Errorf("%s: %w", ErrInternal, err) + return fmt.Errorf("%s: %w", index.ErrInternal, err) } err = p.batchUpsert(ctx, documents) if err != nil { - return fmt.Errorf("%s: %w", ErrInternal, err) + return fmt.Errorf("%s: %w", index.ErrInternal, err) } return nil } -func (p *Pinecone) IsEmpty(ctx context.Context) (bool, error) { +func (p *Index) IsEmpty(ctx context.Context) (bool, error) { err := p.createIndexIfRequired(ctx) if err != nil { - return true, fmt.Errorf("%s: %w", ErrInternal, err) + return true, fmt.Errorf("%s: %w", index.ErrInternal, err) } err = p.getProjectID(ctx) if err != nil { - return true, fmt.Errorf("%s: %w", ErrInternal, err) + return true, fmt.Errorf("%s: %w", index.ErrInternal, err) } req := &pineconerequest.VectorDescribeIndexStats{ @@ -109,7 +110,7 @@ func (p *Pinecone) IsEmpty(ctx context.Context) (bool, error) { err = p.pineconeClient.VectorDescribeIndexStats(ctx, req, res) if err != nil { - return true, fmt.Errorf("%s: %w", ErrInternal, err) + return true, fmt.Errorf("%s: %w", index.ErrInternal, err) } namespace, ok := res.Namespaces[p.namespace] @@ -118,38 +119,42 @@ func (p *Pinecone) IsEmpty(ctx context.Context) (bool, error) { } if namespace.VectorCount == nil { - return false, fmt.Errorf("%s: failed to get total index size", ErrInternal) + return false, fmt.Errorf("%s: failed to get total index size", index.ErrInternal) } return *namespace.VectorCount == 0, nil } -func (p *Pinecone) SimilaritySearch(ctx context.Context, query string, opts ...Option) (SearchResponses, error) { +func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResponses, error) { - pineconeOptions := &options{ - topK: defaultPineconeTopK, + pineconeOptions := &option.Options{ + TopK: defaultTopK, } for _, opt := range opts { opt(pineconeOptions) } + if pineconeOptions.Filter == nil { + pineconeOptions.Filter = map[string]string{} + } + matches, err := p.similaritySearch(ctx, query, pineconeOptions) if err != nil { - return nil, fmt.Errorf("%s: %w", ErrInternal, err) + return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) } searchResponses := buildSearchReponsesFromPineconeMatches(matches, p.includeContent) - return filterSearchResponses(searchResponses, pineconeOptions.topK), nil + return index.FilterSearchResponses(searchResponses, pineconeOptions.TopK), nil } -func (p *Pinecone) similaritySearch(ctx context.Context, query string, opts *options) ([]pineconeresponse.QueryMatch, error) { +func (p *Index) similaritySearch(ctx context.Context, query string, opts *option.Options) ([]pineconeresponse.QueryMatch, error) { err := p.getProjectID(ctx) if err != nil { - return nil, fmt.Errorf("%s: %w", ErrInternal, err) + return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) } embeddings, err := p.embedder.Embed(ctx, []string{query}) @@ -164,11 +169,11 @@ func (p *Pinecone) similaritySearch(ctx context.Context, query string, opts *opt &pineconerequest.VectorQuery{ IndexName: p.indexName, ProjectID: *p.projectID, - TopK: int32(opts.topK), + TopK: int32(opts.TopK), Vector: embeddings[0], IncludeMetadata: &includeMetadata, Namespace: &p.namespace, - Filter: opts.filter.(map[string]string), + Filter: opts.Filter.(map[string]string), }, res, ) @@ -179,7 +184,7 @@ func (p *Pinecone) similaritySearch(ctx context.Context, query string, opts *opt return res.Matches, nil } -func (p *Pinecone) getProjectID(ctx context.Context) error { +func (p *Index) getProjectID(ctx context.Context) error { if p.projectID != nil { return nil @@ -197,7 +202,7 @@ func (p *Pinecone) getProjectID(ctx context.Context) error { return nil } -func (p *Pinecone) createIndexIfRequired(ctx context.Context) error { +func (p *Index) createIndexIfRequired(ctx context.Context) error { if p.createIndex == nil { return nil @@ -252,7 +257,7 @@ func (p *Pinecone) createIndexIfRequired(ctx context.Context) error { } -func (p *Pinecone) batchUpsert(ctx context.Context, documents []document.Document) error { +func (p *Index) batchUpsert(ctx context.Context, documents []document.Document) error { for i := 0; i < len(documents); i += p.batchUpsertSize { @@ -285,11 +290,11 @@ func (p *Pinecone) batchUpsert(ctx context.Context, documents []document.Documen return nil } -func (p *Pinecone) vectorUpsert(ctx context.Context, vectors []pineconerequest.Vector) error { +func (p *Index) vectorUpsert(ctx context.Context, vectors []pineconerequest.Vector) error { err := p.getProjectID(ctx) if err != nil { - return fmt.Errorf("%s: %w", ErrInternal, err) + return fmt.Errorf("%s: %w", index.ErrInternal, err) } req := &pineconerequest.VectorUpsert{ @@ -312,14 +317,6 @@ func (p *Pinecone) vectorUpsert(ctx context.Context, vectors []pineconerequest.V return nil } -func deepCopyMetadata(metadata types.Meta) types.Meta { - metadataCopy := make(types.Meta) - for k, v := range metadata { - metadataCopy[k] = v - } - return metadataCopy -} - func buildPineconeVectorsFromEmbeddingsAndDocuments( embeddings []embedder.Embedding, documents []document.Document, @@ -331,11 +328,11 @@ func buildPineconeVectorsFromEmbeddingsAndDocuments( for i, embedding := range embeddings { - metadata := deepCopyMetadata(documents[startIndex+i].Metadata) + metadata := index.DeepCopyMetadata(documents[startIndex+i].Metadata) // inject document content into vector metadata if includeContent { - metadata[defaultKeyContent] = documents[startIndex+i].Content + metadata[index.DefaultKeyContent] = documents[startIndex+i].Content } vectorID, err := uuid.NewUUID() @@ -350,24 +347,24 @@ func buildPineconeVectorsFromEmbeddingsAndDocuments( }) // inject vector ID into document metadata - documents[startIndex+i].Metadata[defaultKeyID] = vectorID.String() + documents[startIndex+i].Metadata[index.DefaultKeyID] = vectorID.String() } return vectors, nil } -func buildSearchReponsesFromPineconeMatches(matches []pineconeresponse.QueryMatch, includeContent bool) SearchResponses { - searchResponses := make([]SearchResponse, len(matches)) +func buildSearchReponsesFromPineconeMatches(matches []pineconeresponse.QueryMatch, includeContent bool) index.SearchResponses { + searchResponses := make([]index.SearchResponse, len(matches)) for i, match := range matches { - metadata := deepCopyMetadata(match.Metadata) + metadata := index.DeepCopyMetadata(match.Metadata) content := "" // extract document content from vector metadata if includeContent { - content = metadata[defaultKeyContent].(string) - delete(metadata, defaultKeyContent) + content = metadata[index.DefaultKeyContent].(string) + delete(metadata, index.DefaultKeyContent) } id := "" @@ -380,7 +377,7 @@ func buildSearchReponsesFromPineconeMatches(matches []pineconeresponse.QueryMatc score = *match.Score } - searchResponses[i] = SearchResponse{ + searchResponses[i] = index.SearchResponse{ ID: id, Document: document.Document{ Metadata: metadata, diff --git a/index/qdrant.go b/index/qdrant/qdrant.go similarity index 65% rename from index/qdrant.go rename to index/qdrant/qdrant.go index d010201b..0c7a243a 100644 --- a/index/qdrant.go +++ b/index/qdrant/qdrant.go @@ -1,4 +1,4 @@ -package index +package qdrant import ( "context" @@ -8,60 +8,62 @@ import ( "github.com/google/uuid" "github.com/henomis/lingoose/document" "github.com/henomis/lingoose/embedder" + "github.com/henomis/lingoose/index" + "github.com/henomis/lingoose/index/option" qdrantgo "github.com/henomis/qdrant-go" qdrantrequest "github.com/henomis/qdrant-go/request" qdrantresponse "github.com/henomis/qdrant-go/response" ) const ( - defaultQdrantTopK = 10 - defaultQdrantBatchUpsertSize = 32 + defaultTopK = 10 + defaultBatchUpsertSize = 32 ) -type Qdrant struct { +type Index struct { qdrantClient *qdrantgo.Client collectionName string - embedder Embedder + embedder index.Embedder includeContent bool batchUpsertSize int - createCollection *QdrantCreateCollectionOptions + createCollection *CreateCollectionOptions } -type QdrantDistance string +type Distance string const ( - QdrantDistanceCosine QdrantDistance = QdrantDistance(qdrantrequest.DistanceCosine) - QdrantDistanceEuclidean QdrantDistance = QdrantDistance(qdrantrequest.DistanceEuclidean) - QdrantDistanceDot QdrantDistance = QdrantDistance(qdrantrequest.DistanceDot) + DistanceCosine Distance = Distance(qdrantrequest.DistanceCosine) + DistanceEuclidean Distance = Distance(qdrantrequest.DistanceEuclidean) + DistanceDot Distance = Distance(qdrantrequest.DistanceDot) ) -type QdrantCreateCollectionOptions struct { +type CreateCollectionOptions struct { Dimension uint64 - Distance QdrantDistance + Distance Distance OnDisk bool } -type QdrantOptions struct { +type Options struct { CollectionName string IncludeContent bool BatchUpsertSize *int - CreateCollection *QdrantCreateCollectionOptions + CreateCollection *CreateCollectionOptions } -func NewQdrant(options QdrantOptions, embedder Embedder) *Qdrant { +func New(options Options, embedder index.Embedder) *Index { apiKey := os.Getenv("QDRANT_API_KEY") endpoint := os.Getenv("QDRANT_ENDPOINT") qdrantClient := qdrantgo.New(endpoint, apiKey) - batchUpsertSize := defaultQdrantBatchUpsertSize + batchUpsertSize := defaultBatchUpsertSize if options.BatchUpsertSize != nil { batchUpsertSize = *options.BatchUpsertSize } - return &Qdrant{ + return &Index{ qdrantClient: qdrantClient, collectionName: options.CollectionName, embedder: embedder, @@ -71,30 +73,30 @@ func NewQdrant(options QdrantOptions, embedder Embedder) *Qdrant { } } -func (q *Qdrant) WithAPIKeyAndEdpoint(apiKey, endpoint string) *Qdrant { +func (q *Index) WithAPIKeyAndEdpoint(apiKey, endpoint string) *Index { q.qdrantClient = qdrantgo.New(endpoint, apiKey) return q } -func (q *Qdrant) LoadFromDocuments(ctx context.Context, documents []document.Document) error { +func (q *Index) LoadFromDocuments(ctx context.Context, documents []document.Document) error { err := q.createCollectionIfRequired(ctx) if err != nil { - return fmt.Errorf("%s: %w", ErrInternal, err) + return fmt.Errorf("%s: %w", index.ErrInternal, err) } err = q.batchUpsert(ctx, documents) if err != nil { - return fmt.Errorf("%s: %w", ErrInternal, err) + return fmt.Errorf("%s: %w", index.ErrInternal, err) } return nil } -func (p *Qdrant) IsEmpty(ctx context.Context) (bool, error) { +func (p *Index) IsEmpty(ctx context.Context) (bool, error) { err := p.createCollectionIfRequired(ctx) if err != nil { - return true, fmt.Errorf("%s: %w", ErrInternal, err) + return true, fmt.Errorf("%s: %w", index.ErrInternal, err) } res := &qdrantresponse.CollectionCollectInfo{} @@ -106,17 +108,17 @@ func (p *Qdrant) IsEmpty(ctx context.Context) (bool, error) { res, ) if err != nil { - return true, fmt.Errorf("%s: %w", ErrInternal, err) + return true, fmt.Errorf("%s: %w", index.ErrInternal, err) } return res.Result.VectorsCount == 0, nil } -func (q *Qdrant) SimilaritySearch(ctx context.Context, query string, opts ...Option) (SearchResponses, error) { +func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResponses, error) { - qdrantOptions := &options{ - topK: defaultQdrantTopK, + qdrantOptions := &option.Options{ + TopK: defaultTopK, } for _, opt := range opts { @@ -125,31 +127,35 @@ func (q *Qdrant) SimilaritySearch(ctx context.Context, query string, opts ...Opt matches, err := q.similaritySearch(ctx, query, qdrantOptions) if err != nil { - return nil, fmt.Errorf("%s: %w", ErrInternal, err) + return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) } searchResponses := buildSearchReponsesFromQdrantMatches(matches, q.includeContent) - return filterSearchResponses(searchResponses, qdrantOptions.topK), nil + return index.FilterSearchResponses(searchResponses, qdrantOptions.TopK), nil } -func (p *Qdrant) similaritySearch(ctx context.Context, query string, opts *options) ([]qdrantresponse.PointSearchResult, error) { +func (p *Index) similaritySearch(ctx context.Context, query string, opts *option.Options) ([]qdrantresponse.PointSearchResult, error) { embeddings, err := p.embedder.Embed(ctx, []string{query}) if err != nil { return nil, err } + if opts.Filter == nil { + opts.Filter = qdrantrequest.Filter{} + } + includeMetadata := true res := &qdrantresponse.PointSearch{} err = p.qdrantClient.PointSearch( ctx, &qdrantrequest.PointSearch{ CollectionName: p.collectionName, - Limit: opts.topK, + Limit: opts.TopK, Vector: embeddings[0], WithPayload: &includeMetadata, - Filter: opts.filter.(qdrantrequest.Filter), + Filter: opts.Filter.(qdrantrequest.Filter), }, res, ) @@ -160,7 +166,7 @@ func (p *Qdrant) similaritySearch(ctx context.Context, query string, opts *optio return res.Result, nil } -func (q *Qdrant) createCollectionIfRequired(ctx context.Context) error { +func (q *Index) createCollectionIfRequired(ctx context.Context) error { if q.createCollection == nil { return nil @@ -195,7 +201,7 @@ func (q *Qdrant) createCollectionIfRequired(ctx context.Context) error { return nil } -func (q *Qdrant) batchUpsert(ctx context.Context, documents []document.Document) error { +func (q *Index) batchUpsert(ctx context.Context, documents []document.Document) error { for i := 0; i < len(documents); i += q.batchUpsertSize { @@ -228,7 +234,7 @@ func (q *Qdrant) batchUpsert(ctx context.Context, documents []document.Document) return nil } -func (q *Qdrant) pointUpsert(ctx context.Context, points []qdrantrequest.Point) error { +func (q *Index) pointUpsert(ctx context.Context, points []qdrantrequest.Point) error { wait := true req := &qdrantrequest.PointUpsert{ @@ -257,11 +263,11 @@ func buildQdrantPointsFromEmbeddingsAndDocuments( for i, embedding := range embeddings { - metadata := deepCopyMetadata(documents[startIndex+i].Metadata) + metadata := index.DeepCopyMetadata(documents[startIndex+i].Metadata) // inject document content into vector metadata if includeContent { - metadata[defaultKeyContent] = documents[startIndex+i].Content + metadata[index.DefaultKeyContent] = documents[startIndex+i].Content } vectorID, err := uuid.NewUUID() @@ -276,27 +282,27 @@ func buildQdrantPointsFromEmbeddingsAndDocuments( }) // inject vector ID into document metadata - documents[startIndex+i].Metadata[defaultKeyID] = vectorID.String() + documents[startIndex+i].Metadata[index.DefaultKeyID] = vectorID.String() } return vectors, nil } -func buildSearchReponsesFromQdrantMatches(matches []qdrantresponse.PointSearchResult, includeContent bool) SearchResponses { - searchResponses := make([]SearchResponse, len(matches)) +func buildSearchReponsesFromQdrantMatches(matches []qdrantresponse.PointSearchResult, includeContent bool) index.SearchResponses { + searchResponses := make([]index.SearchResponse, len(matches)) for i, match := range matches { - metadata := deepCopyMetadata(match.Payload) + metadata := index.DeepCopyMetadata(match.Payload) content := "" // extract document content from vector metadata if includeContent { - content = metadata[defaultKeyContent].(string) - delete(metadata, defaultKeyContent) + content = metadata[index.DefaultKeyContent].(string) + delete(metadata, index.DefaultKeyContent) } - searchResponses[i] = SearchResponse{ + searchResponses[i] = index.SearchResponse{ ID: match.ID, Document: document.Document{ Metadata: metadata, diff --git a/index/simpleVectorIndex.go b/index/simpleVectorIndex/simpleVectorIndex.go similarity index 51% rename from index/simpleVectorIndex.go rename to index/simpleVectorIndex/simpleVectorIndex.go index 35a4fe1e..0a138626 100644 --- a/index/simpleVectorIndex.go +++ b/index/simpleVectorIndex/simpleVectorIndex.go @@ -1,4 +1,4 @@ -package index +package simplevectorindex import ( "context" @@ -10,31 +10,32 @@ import ( "github.com/henomis/lingoose/document" "github.com/henomis/lingoose/embedder" + "github.com/henomis/lingoose/index" + "github.com/henomis/lingoose/index/option" ) const ( - defaultSimpleVectorIndexBatchSize = 32 - - defaultSimpleVectorIndexTopK = 10 + defaultBatchSize = 32 + defaultTopK = 10 ) -type simpleVectorIndexData struct { +type data struct { Document document.Document `json:"document"` Embedding embedder.Embedding `json:"embedding"` } -type SimpleVectorIndex struct { - data []simpleVectorIndexData +type Index struct { + data []data outputPath string name string - embedder Embedder + embedder index.Embedder } -type SimpleVectorIndexFilterFn func([]SearchResponse) []SearchResponse +type SimpleVectorIndexFilterFn func([]index.SearchResponse) []index.SearchResponse -func NewSimpleVectorIndex(name string, outputPath string, embedder Embedder) *SimpleVectorIndex { - simpleVectorIndex := &SimpleVectorIndex{ - data: []simpleVectorIndexData{}, +func New(name string, outputPath string, embedder index.Embedder) *Index { + simpleVectorIndex := &Index{ + data: []data{}, outputPath: outputPath, name: name, embedder: embedder, @@ -43,14 +44,14 @@ func NewSimpleVectorIndex(name string, outputPath string, embedder Embedder) *Si return simpleVectorIndex } -func (s *SimpleVectorIndex) LoadFromDocuments(ctx context.Context, documents []document.Document) error { +func (s *Index) LoadFromDocuments(ctx context.Context, documents []document.Document) error { - s.data = []simpleVectorIndexData{} + s.data = []data{} documentIndex := 0 - for i := 0; i < len(documents); i += defaultSimpleVectorIndexBatchSize { + for i := 0; i < len(documents); i += defaultBatchSize { - end := i + defaultSimpleVectorIndexBatchSize + end := i + defaultBatchSize if end > len(documents) { end = len(documents) } @@ -62,16 +63,16 @@ func (s *SimpleVectorIndex) LoadFromDocuments(ctx context.Context, documents []d embeddings, err := s.embedder.Embed(ctx, texts) if err != nil { - return fmt.Errorf("%s: %w", ErrInternal, err) + return fmt.Errorf("%s: %w", index.ErrInternal, err) } for j, document := range documents[i:end] { - s.data = append(s.data, simpleVectorIndexData{ + s.data = append(s.data, data{ Document: document, Embedding: embeddings[j], }) - documents[documentIndex].Metadata[defaultKeyID] = fmt.Sprintf("%d", documentIndex) + documents[documentIndex].Metadata[index.DefaultKeyID] = fmt.Sprintf("%d", documentIndex) documentIndex++ } @@ -79,13 +80,13 @@ func (s *SimpleVectorIndex) LoadFromDocuments(ctx context.Context, documents []d err := s.save() if err != nil { - return fmt.Errorf("%s: %w", ErrInternal, err) + return fmt.Errorf("%s: %w", index.ErrInternal, err) } return nil } -func (s SimpleVectorIndex) save() error { +func (s Index) save() error { jsonContent, err := json.Marshal(s.data) if err != nil { @@ -95,7 +96,7 @@ func (s SimpleVectorIndex) save() error { return os.WriteFile(s.database(), jsonContent, 0644) } -func (s *SimpleVectorIndex) load() error { +func (s *Index) load() error { content, err := os.ReadFile(s.database()) if err != nil { @@ -105,24 +106,24 @@ func (s *SimpleVectorIndex) load() error { return json.Unmarshal(content, &s.data) } -func (s *SimpleVectorIndex) database() string { +func (s *Index) database() string { return strings.Join([]string{s.outputPath, s.name + ".json"}, string(os.PathSeparator)) } -func (s *SimpleVectorIndex) IsEmpty() (bool, error) { +func (s *Index) IsEmpty() (bool, error) { err := s.load() if err != nil { - return true, fmt.Errorf("%s: %w", ErrInternal, err) + return true, fmt.Errorf("%s: %w", index.ErrInternal, err) } return len(s.data) == 0, nil } -func (s *SimpleVectorIndex) SimilaritySearch(ctx context.Context, query string, opts ...Option) (SearchResponses, error) { +func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResponses, error) { - sviOptions := &options{ - topK: defaultSimpleVectorIndexTopK, + sviOptions := &option.Options{ + TopK: defaultTopK, } for _, opt := range opts { @@ -131,37 +132,37 @@ func (s *SimpleVectorIndex) SimilaritySearch(ctx context.Context, query string, err := s.load() if err != nil { - return nil, fmt.Errorf("%s: %w", ErrInternal, err) + return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) } embeddings, err := s.embedder.Embed(ctx, []string{query}) if err != nil { - return nil, fmt.Errorf("%s: %w", ErrInternal, err) + return nil, fmt.Errorf("%s: %w", index.ErrInternal, err) } scores := s.cosineSimilarityBatch(embeddings[0]) - searchResponses := make([]SearchResponse, len(scores)) + searchResponses := make([]index.SearchResponse, len(scores)) for i, score := range scores { - id := s.data[i].Document.Metadata[defaultKeyID].(string) + id := s.data[i].Document.Metadata[index.DefaultKeyID].(string) - searchResponses[i] = SearchResponse{ + searchResponses[i] = index.SearchResponse{ ID: id, Document: s.data[i].Document, Score: score, } } - if sviOptions.filter != nil { - searchResponses = sviOptions.filter.(SimpleVectorIndexFilterFn)(searchResponses) + if sviOptions.Filter != nil { + searchResponses = sviOptions.Filter.(SimpleVectorIndexFilterFn)(searchResponses) } - return filterSearchResponses(searchResponses, sviOptions.topK), nil + return index.FilterSearchResponses(searchResponses, sviOptions.TopK), nil } -func (s *SimpleVectorIndex) cosineSimilarity(a embedder.Embedding, b embedder.Embedding) float64 { +func (s *Index) cosineSimilarity(a embedder.Embedding, b embedder.Embedding) float64 { dotProduct := float64(0.0) normA := float64(0.0) normB := float64(0.0) @@ -179,7 +180,7 @@ func (s *SimpleVectorIndex) cosineSimilarity(a embedder.Embedding, b embedder.Em return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) } -func (s *SimpleVectorIndex) cosineSimilarityBatch(a embedder.Embedding) []float64 { +func (s *Index) cosineSimilarityBatch(a embedder.Embedding) []float64 { scores := make([]float64, len(s.data)) diff --git a/llm/cohere/cohere.go b/llm/cohere/cohere.go index c7e214a6..51ce488f 100644 --- a/llm/cohere/cohere.go +++ b/llm/cohere/cohere.go @@ -36,6 +36,7 @@ type Cohere struct { stop []string } +// NewCompletion returns a new completion LLM func NewCompletion() *Cohere { return &Cohere{ client: coherego.New(os.Getenv("COHERE_API_KEY")), @@ -45,36 +46,43 @@ func NewCompletion() *Cohere { } } +// WithModel sets the model to use for the LLM func (c *Cohere) WithModel(model Model) *Cohere { c.model = model return c } +// WithTemperature sets the temperature to use for the LLM func (c *Cohere) WithTemperature(temperature float64) *Cohere { c.temperature = temperature return c } +// WithMaxTokens sets the max tokens to use for the LLM func (c *Cohere) WithMaxTokens(maxTokens int) *Cohere { c.maxTokens = maxTokens return c } +// WithAPIKey sets the API key to use for the LLM func (c *Cohere) WithAPIKey(apiKey string) *Cohere { c.client = coherego.New(apiKey) return c } +// WithVerbose sets the verbosity of the LLM func (c *Cohere) WithVerbose(verbose bool) *Cohere { c.verbose = verbose return c } +// WithStop sets the stop sequences to use for the LLM func (o *Cohere) WithStop(stop []string) *Cohere { o.stop = stop return o } +// Completion returns the completion for the given prompt func (c *Cohere) Completion(ctx context.Context, prompt string) (string, error) { resp := &response.Generate{} @@ -107,6 +115,7 @@ func (c *Cohere) Completion(ctx context.Context, prompt string) (string, error) return output, nil } +// Chat is not implemented func (c *Cohere) Chat(ctx context.Context, prompt *chat.Chat) (string, error) { return "", fmt.Errorf("not implemented") } diff --git a/llm/huggingface/huggingface.go b/llm/huggingface/huggingface.go index 802c3044..14abe758 100644 --- a/llm/huggingface/huggingface.go +++ b/llm/huggingface/huggingface.go @@ -41,51 +41,61 @@ func New(model string, temperature float32, verbose bool) *HuggingFace { } } +// WithModel sets the model to use for the LLM func (h *HuggingFace) WithModel(model string) *HuggingFace { h.model = model return h } +// WithTemperature sets the temperature to use for the LLM func (h *HuggingFace) WithTemperature(temperature float32) *HuggingFace { h.temperature = temperature return h } +// WithMaxLength sets the maxLength to use for the LLM func (h *HuggingFace) WithMaxLength(maxLength int) *HuggingFace { h.maxLength = &maxLength return h } +// WithMinLength sets the minLength to use for the LLM func (h *HuggingFace) WithMinLength(minLength int) *HuggingFace { h.minLength = &minLength return h } +// WithToken sets the API key to use for the LLM func (h *HuggingFace) WithToken(token string) *HuggingFace { h.token = token return h } +// WithVerbose sets the verbose flag to use for the LLM func (h *HuggingFace) WithVerbose(verbose bool) *HuggingFace { h.verbose = verbose return h } +// WithTopK sets the topK to use for the LLM func (h *HuggingFace) WithTopK(topK int) *HuggingFace { h.topK = &topK return h } +// WithTopP sets the topP to use for the LLM func (h *HuggingFace) WithTopP(topP float32) *HuggingFace { h.topP = &topP return h } +// WithMode sets the mode to use for the LLM func (h *HuggingFace) WithMode(mode HuggingFaceMode) *HuggingFace { h.mode = mode return h } +// Completion returns the completion for the given prompt func (h *HuggingFace) Completion(ctx context.Context, prompt string) (string, error) { var output string @@ -110,6 +120,7 @@ func (h *HuggingFace) Completion(ctx context.Context, prompt string) (string, er return output, nil } +// BatchCompletion returns the completion for the given prompts func (h *HuggingFace) BatchCompletion(ctx context.Context, prompts []string) ([]string, error) { var outputs []string diff --git a/llm/openai/openai.go b/llm/openai/openai.go index e84353e4..7219cc7a 100644 --- a/llm/openai/openai.go +++ b/llm/openai/openai.go @@ -88,45 +88,54 @@ func New(model Model, temperature float32, maxTokens int, verbose bool) *OpenAI } } +// WithModel sets the model to use for the OpenAI instance. func (o *OpenAI) WithModel(model Model) *OpenAI { o.model = model return o } +// WithTemperature sets the temperature to use for the OpenAI instance. func (o *OpenAI) WithTemperature(temperature float32) *OpenAI { o.temperature = temperature return o } +// WithMaxTokens sets the max tokens to use for the OpenAI instance. func (o *OpenAI) WithMaxTokens(maxTokens int) *OpenAI { o.maxTokens = maxTokens return o } +// WithUsageCallback sets the usage callback to use for the OpenAI instance. func (o *OpenAI) WithCallback(callback OpenAIUsageCallback) *OpenAI { o.usageCallback = callback return o } +// WithStop sets the stop sequences to use for the OpenAI instance. func (o *OpenAI) WithStop(stop []string) *OpenAI { o.stop = stop return o } +// WithClient sets the client to use for the OpenAI instance. func (o *OpenAI) WithClient(client *openai.Client) *OpenAI { o.openAIClient = client return o } +// WithVerbose sets the verbose flag to use for the OpenAI instance. func (o *OpenAI) WithVerbose(verbose bool) *OpenAI { o.verbose = verbose return o } +// CalledFunctionName returns the name of the function that was called. func (o *OpenAI) CalledFunctionName() *string { return o.calledFunctionName } +// FinishReason returns the LLM finish reason. func (o *OpenAI) FinishReason() string { return o.finishReason } @@ -149,41 +158,17 @@ func NewChat() *OpenAI { ) } +// Completion returns a single completion for the given prompt. func (o *OpenAI) Completion(ctx context.Context, prompt string) (string, error) { - - response, err := o.openAIClient.CreateCompletion( - ctx, - openai.CompletionRequest{ - Model: string(o.model), - Prompt: prompt, - MaxTokens: o.maxTokens, - Temperature: o.temperature, - N: DefaultOpenAINumResults, - TopP: DefaultOpenAITopP, - Stop: o.stop, - }, - ) - + outputs, err := o.BatchCompletion(ctx, []string{prompt}) if err != nil { - return "", fmt.Errorf("%s: %w", ErrOpenAICompletion, err) - } - - if o.usageCallback != nil { - o.setUsageMetadata(response.Usage) + return "", err } - if len(response.Choices) == 0 { - return "", fmt.Errorf("%s: no choices returned", ErrOpenAICompletion) - } - - output := strings.TrimSpace(response.Choices[0].Text) - if o.verbose { - debugCompletion(prompt, output) - } - - return output, nil + return outputs[0], nil } +// BatchCompletion returns multiple completions for the given prompts. func (o *OpenAI) BatchCompletion(ctx context.Context, prompts []string) ([]string, error) { response, err := o.openAIClient.CreateCompletion( @@ -223,57 +208,12 @@ func (o *OpenAI) BatchCompletion(ctx context.Context, prompts []string) ([]strin return outputs, nil } +// CompletionStream returns a single completion stream for the given prompt. func (o *OpenAI) CompletionStream(ctx context.Context, callbackFn OpenAIStreamCallback, prompt string) error { - - stream, err := o.openAIClient.CreateCompletionStream( - ctx, - openai.CompletionRequest{ - Model: string(o.model), - Prompt: prompt, - MaxTokens: o.maxTokens, - Temperature: o.temperature, - N: DefaultOpenAINumResults, - TopP: DefaultOpenAITopP, - Stop: o.stop, - }, - ) - if err != nil { - return fmt.Errorf("%s: %w", ErrOpenAICompletion, err) - } - - defer stream.Close() - - for { - - response, err := stream.Recv() - if errors.Is(err, io.EOF) { - break - } - - if err != nil { - return fmt.Errorf("%s: %w", ErrOpenAICompletion, err) - } - - if o.usageCallback != nil { - o.setUsageMetadata(response.Usage) - } - - if len(response.Choices) == 0 { - return fmt.Errorf("%s: no choices returned", ErrOpenAICompletion) - } - - output := response.Choices[0].Text - if o.verbose { - debugCompletion(prompt, output) - } - - callbackFn(output) - - } - - return nil + return o.BatchCompletionStream(ctx, []OpenAIStreamCallback{callbackFn}, []string{prompt}) } +// BatchCompletionStream returns multiple completion streams for the given prompts. func (o *OpenAI) BatchCompletionStream(ctx context.Context, callbackFn []OpenAIStreamCallback, prompts []string) error { stream, err := o.openAIClient.CreateCompletionStream( @@ -328,6 +268,7 @@ func (o *OpenAI) BatchCompletionStream(ctx context.Context, callbackFn []OpenAIS return nil } +// Chat returns a single chat completion for the given prompt. func (o *OpenAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) { messages, err := buildMessages(prompt) @@ -389,6 +330,7 @@ func (o *OpenAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) { return content, nil } +// ChatStream returns a single chat stream for the given prompt. func (o *OpenAI) ChatStream(ctx context.Context, callbackFn OpenAIStreamCallback, prompt *chat.Chat) error { messages, err := buildMessages(prompt) @@ -440,6 +382,7 @@ func (o *OpenAI) ChatStream(ctx context.Context, callbackFn OpenAIStreamCallback return nil } +// SetStop sets the stop sequences for the completion. func (o *OpenAI) SetStop(stop []string) { o.stop = stop }