Skip to content

Commit

Permalink
Merge pull request #9 from LyzrCore/wip-khush/refactor-chatbot
Browse files Browse the repository at this point in the history
fix: Modified default index naming in vector_store to use a unique identifier
  • Loading branch information
patel-lyzr authored Jan 27, 2024
2 parents 079d01f + 2ac04b1 commit 98e1cd8
Show file tree
Hide file tree
Showing 13 changed files with 169 additions and 45 deletions.
39 changes: 28 additions & 11 deletions build/lib/lyzr/base/vector_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Optional, Sequence

import os
import uuid
import weaviate
from weaviate.embedded import EmbeddedOptions
from llama_index import Document, ServiceContext, VectorStoreIndex, StorageContext
from llama_index.node_parser import SimpleNodeParser

Expand All @@ -13,30 +17,43 @@ def import_vector_store_class(vector_store_class_name: str):
class LyzrVectorStoreIndex:
@staticmethod
def from_defaults(
vector_store_type: str = "LanceDBVectorStore",
vector_store_type: str = "WeaviateVectorStore",
documents: Optional[Sequence[Document]] = None,
service_context: Optional[ServiceContext] = None,
**kwargs
) -> VectorStoreIndex:
if documents is None and vector_store_type == "SimpleVectorStore":
raise ValueError("documents must be provided for SimpleVectorStore")

vector_store_class = import_vector_store_class(vector_store_type)
VectorStoreClass = import_vector_store_class(vector_store_type)

if vector_store_type == "WeaviateVectorStore":
weaviate_client = weaviate.Client(
embedded_options=weaviate.embedded.EmbeddedOptions(),
additional_headers={"X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]},
)
kwargs["weaviate_client"] = (
weaviate_client
if "weaviate_client" not in kwargs
else kwargs["weaviate_client"]
)
kwargs["index_name"] = (
f"DB_{uuid.uuid4().hex}" if "index_name" not in kwargs else kwargs["index_name"]
)

vector_store = VectorStoreClass(**kwargs)
else:
vector_store = VectorStoreClass(**kwargs)

if documents is None:
vector_store = vector_store_class(**kwargs)
index = VectorStoreIndex.from_vector_store(
vector_store=vector_store, service_context=service_context
)
else:
if vector_store_type == "LanceDBVectorStore":
kwargs["uri"] = "./.lancedb" if "uri" not in kwargs else kwargs["uri"]
kwargs["table_name"] = (
"vectors" if "table_name" not in kwargs else kwargs["table_name"]
)
vector_store = vector_store_class(**kwargs)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
return index

storage_context = StorageContext.from_defaults(vector_store=vector_store)

if documents is not None:
index = VectorStoreIndex.from_documents(
documents=documents,
storage_context=storage_context,
Expand Down
12 changes: 12 additions & 0 deletions build/lib/lyzr/chatqa/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def pdf_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return pdf_chat_(
input_dir=input_dir,
Expand All @@ -57,6 +58,7 @@ def pdf_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -74,6 +76,7 @@ def docx_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return docx_chat_(
input_dir=input_dir,
Expand All @@ -89,6 +92,7 @@ def docx_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -106,6 +110,7 @@ def txt_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return txt_chat_(
input_dir=input_dir,
Expand All @@ -121,6 +126,7 @@ def txt_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -133,6 +139,7 @@ def webpage_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return webpage_chat_(
url=url,
Expand All @@ -143,6 +150,7 @@ def webpage_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -155,6 +163,7 @@ def website_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return website_chat_(
url=url,
Expand All @@ -165,6 +174,7 @@ def website_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -177,6 +187,7 @@ def youtube_chat(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
return youtube_chat_(
urls=urls,
Expand All @@ -187,4 +198,5 @@ def youtube_chat(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
chat_engine_params=chat_engine_params,
retriever_params=retriever_params,
)
12 changes: 12 additions & 0 deletions build/lib/lyzr/chatqa/qa_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def pdf_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return pdf_rag(
input_dir=input_dir,
Expand All @@ -56,6 +57,7 @@ def pdf_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -73,6 +75,7 @@ def docx_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return docx_rag(
input_dir=input_dir,
Expand All @@ -88,6 +91,7 @@ def docx_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -105,6 +109,7 @@ def txt_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return txt_rag(
input_dir=input_dir,
Expand All @@ -120,6 +125,7 @@ def txt_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -132,6 +138,7 @@ def webpage_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return webpage_rag(
url=url,
Expand All @@ -142,6 +149,7 @@ def webpage_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -154,6 +162,7 @@ def website_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return website_rag(
url=url,
Expand All @@ -164,6 +173,7 @@ def website_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)

@staticmethod
Expand All @@ -176,6 +186,7 @@ def youtube_qa(
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseQueryEngine:
return youtube_rag(
urls=urls,
Expand All @@ -186,4 +197,5 @@ def youtube_qa(
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
retriever_params=retriever_params,
)
17 changes: 11 additions & 6 deletions build/lib/lyzr/utils/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def pdf_chat_(

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down Expand Up @@ -106,6 +106,7 @@ def txt_chat_(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
documents = read_txt_as_documents(
input_dir=input_dir,
Expand All @@ -118,7 +119,7 @@ def txt_chat_(

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down Expand Up @@ -177,6 +178,7 @@ def docx_chat_(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
documents = read_docx_as_documents(
input_dir=input_dir,
Expand All @@ -189,7 +191,7 @@ def docx_chat_(

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down Expand Up @@ -243,14 +245,15 @@ def webpage_chat_(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
documents = read_webpage_as_documents(
url=url,
)

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down Expand Up @@ -304,14 +307,15 @@ def website_chat_(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
documents = read_website_as_documents(
url=url,
)

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down Expand Up @@ -365,14 +369,15 @@ def youtube_chat_(
vector_store_params: dict = None,
service_context_params: dict = None,
chat_engine_params: dict = None,
retriever_params: dict = None,
) -> BaseChatEngine:
documents = read_youtube_as_documents(
urls=urls,
)

llm_params = {} if llm_params is None else llm_params
vector_store_params = (
{"vector_store_type": "LanceDBVectorStore"}
{"vector_store_type": "WeaviateVectorStore"}
if vector_store_params is None
else vector_store_params
)
Expand Down
Loading

0 comments on commit 98e1cd8

Please sign in to comment.