From 1f57b085f73bc069403b19545d39edc1fb7462c4 Mon Sep 17 00:00:00 2001 From: Talha SARI Date: Fri, 8 Dec 2023 18:36:31 +0300 Subject: [PATCH 1/3] fix AutoQueryEngine bug causing not use of qa_prompt_template its given (#177) --- autollm/__init__.py | 2 +- autollm/auto/query_engine.py | 19 ++++++++++++------- autollm/auto/service_context.py | 5 ++++- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/autollm/__init__.py b/autollm/__init__.py index 7c32b9d8..d7c61a62 100644 --- a/autollm/__init__.py +++ b/autollm/__init__.py @@ -4,7 +4,7 @@ and vector databases, along with various utility functions. """ -__version__ = '0.1.3' +__version__ = '0.1.4' __author__ = 'safevideo' __license__ = 'AGPL-3.0' diff --git a/autollm/auto/query_engine.py b/autollm/auto/query_engine.py index 959fed8e..964847b0 100644 --- a/autollm/auto/query_engine.py +++ b/autollm/auto/query_engine.py @@ -3,7 +3,7 @@ from llama_index import Document, ServiceContext, VectorStoreIndex from llama_index.embeddings.utils import EmbedType from llama_index.indices.query.base import BaseQueryEngine -from llama_index.prompts.base import PromptTemplate +from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.response_synthesizers import get_response_synthesizer from llama_index.schema import BaseNode @@ -24,11 +24,11 @@ def create_query_engine( llm_api_base: Optional[str] = None, # service_context_params system_prompt: str = None, - query_wrapper_prompt: str = None, + query_wrapper_prompt: Union[str, BasePromptTemplate] = None, enable_cost_calculator: bool = True, embed_model: Union[str, EmbedType] = "default", # ["default", "local"] chunk_size: Optional[int] = 512, - chunk_overlap: Optional[int] = 200, + chunk_overlap: Optional[int] = 100, context_window: Optional[int] = None, enable_title_extractor: bool = False, enable_summary_extractor: bool = False, @@ -61,7 +61,7 @@ def create_query_engine( llm_temperature (float): The temperature to use for the LLM. llm_api_base (str): The API base to use for the LLM. system_prompt (str): The system prompt to use for the query engine. - query_wrapper_prompt (str): The query wrapper prompt to use for the query engine. + query_wrapper_prompt (Union[str, BasePromptTemplate]): The query wrapper prompt to use for the query engine. enable_cost_calculator (bool): Flag to enable cost calculator logging. embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI, "local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large) @@ -133,10 +133,15 @@ def create_query_engine( refine_prompt_template = PromptTemplate(refine_prompt, prompt_type=PromptType.REFINE) else: refine_prompt_template = None + + # Convert query_wrapper_prompt to PromptTemplate if it is a string + if isinstance(query_wrapper_prompt, str): + query_wrapper_prompt = PromptTemplate(template=query_wrapper_prompt) response_synthesizer = get_response_synthesizer( service_context=service_context, - response_mode=response_mode, + text_qa_template=query_wrapper_prompt, refine_template=refine_prompt_template, + response_mode=response_mode, structured_answer_filtering=structured_answer_filtering) return vector_store_index.as_query_engine( @@ -213,7 +218,7 @@ def from_defaults( llm_temperature: float = 0.1, # service_context_params system_prompt: str = None, - query_wrapper_prompt: str = None, + query_wrapper_prompt: Union[str, BasePromptTemplate] = None, enable_cost_calculator: bool = True, embed_model: Union[str, EmbedType] = "default", # ["default", "local"] chunk_size: Optional[int] = 512, @@ -246,7 +251,7 @@ def from_defaults( llm_temperature (float): The temperature to use for the LLM. llm_api_base (str): The API base to use for the LLM. system_prompt (str): The system prompt to use for the query engine. - query_wrapper_prompt (str): The query wrapper prompt to use for the query engine. + query_wrapper_prompt (Union[str, BasePromptTemplate]): The query wrapper prompt to use for the query engine. enable_cost_calculator (bool): Flag to enable cost calculator logging. embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI, "local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large) diff --git a/autollm/auto/service_context.py b/autollm/auto/service_context.py index ff45e794..9aad5294 100644 --- a/autollm/auto/service_context.py +++ b/autollm/auto/service_context.py @@ -65,11 +65,14 @@ def from_defaults( """ if not system_prompt and not query_wrapper_prompt: system_prompt, query_wrapper_prompt = set_default_prompt_template() - # Convert system_prompt to ChatPromptTemplate if it is a string + # Convert query_wrapper_prompt to PromptTemplate if it is a string if isinstance(query_wrapper_prompt, str): query_wrapper_prompt = PromptTemplate(template=query_wrapper_prompt) callback_manager: CallbackManager = kwargs.get('callback_manager', CallbackManager()) + kwargs.pop( + 'callback_manager', None) # Make sure callback_manager is not passed to ServiceContext twice + if enable_cost_calculator: llm_model_name = llm.metadata.model_name if not "default" else "gpt-3.5-turbo" callback_manager.add_handler(CostCalculatingHandler(model_name=llm_model_name, verbose=True)) From 0407ac65fcab58c05f8a09876a06f703f6973be2 Mon Sep 17 00:00:00 2001 From: Talha SARI Date: Fri, 22 Dec 2023 15:14:12 +0300 Subject: [PATCH 2/3] feature: implement auto embedding (#181) * implement auto embedding * bump litellm * add autoEmbedding to autoQuery pipeline * fix tests * bump llama-index --- autollm/auto/embedding.py | 98 +++++++++++++++++++++++++++++++++ autollm/auto/llm.py | 4 +- autollm/auto/query_engine.py | 19 ++++--- autollm/auto/service_context.py | 2 +- requirements.txt | 4 +- tests/config.yaml | 2 +- tests/test_auto_lite_llm.py | 4 +- 7 files changed, 117 insertions(+), 16 deletions(-) create mode 100644 autollm/auto/embedding.py diff --git a/autollm/auto/embedding.py b/autollm/auto/embedding.py new file mode 100644 index 00000000..f95bc965 --- /dev/null +++ b/autollm/auto/embedding.py @@ -0,0 +1,98 @@ +import asyncio +from typing import Any, List + +from litellm import embedding as lite_embedding +from llama_index.bridge.pydantic import Field +from llama_index.embeddings.base import BaseEmbedding, Embedding + + +class AutoEmbedding(BaseEmbedding): + """ + Custom embedding class for flexible and efficient text embedding. + + This class interfaces with the LiteLLM library to use its embedding functionality, making it compatible + with a wide range of LLM models. + """ + + # Define the model attribute using Pydantic's Field + model: str = Field(default="text-embedding-ada-002", description="The name of the embedding model.") + + def __init__(self, model: str, **kwargs: Any) -> None: + """ + Initialize the AutoEmbedding with a specific model. + + Args: + model (str): ID of the embedding model to use. + **kwargs (Any): Additional keyword arguments. + """ + super().__init__(**kwargs) + self.model = model # Set the model ID for embedding + + def _get_query_embedding(self, query: str) -> Embedding: + """ + Synchronously get the embedding for a query string. + + Args: + query (str): The query text to embed. + + Returns: + Embedding: The embedding vector. + """ + response = lite_embedding(model=self.model, input=[query]) + return self._parse_embedding_response(response) + + async def _aget_query_embedding(self, query: str) -> Embedding: + """ + Asynchronously get the embedding for a query string. + + Args: + query (str): The query text to embed. + + Returns: + Embedding: The embedding vector. + """ + response = await asyncio.to_thread(lite_embedding, model=self.model, input=[query]) + return self._parse_embedding_response(response) + + def _get_text_embedding(self, text: str) -> Embedding: + """ + Synchronously get the embedding for a text string. + + Args: + text (str): The text to embed. + + Returns: + Embedding: The embedding vector. + """ + return self._get_query_embedding(text) + + async def _aget_text_embedding(self, text: str) -> Embedding: + """ + Asynchronously get the embedding for a text string. + + Args: + text (str): The text to embed. + + Returns: + Embedding: The embedding vector. + """ + return await self._aget_query_embedding(text) + + def _parse_embedding_response(self, response): + """ + Parse the embedding response from LiteLLM and extract the embedding data. + + Args: + response: The response object from LiteLLM's embedding function. + + Returns: + List[float]: The extracted embedding list. + """ + try: + if 'data' in response and len(response['data']) > 0 and 'embedding' in response['data'][0]: + return response['data'][0]['embedding'] + else: + raise ValueError("Invalid response structure from embedding function.") + except (TypeError, KeyError, IndexError) as e: + # Handle any parsing errors + raise ValueError(f"Error parsing embedding response: {e}") diff --git a/autollm/auto/llm.py b/autollm/auto/llm.py index 1616b501..36e323ed 100644 --- a/autollm/auto/llm.py +++ b/autollm/auto/llm.py @@ -1,7 +1,7 @@ from typing import Optional from llama_index.llms import LiteLLM -from llama_index.llms.base import LLM +from llama_index.llms.base import BaseLLM class AutoLiteLLM: @@ -14,7 +14,7 @@ def from_defaults( model: str = "gpt-3.5-turbo", max_tokens: Optional[int] = 256, temperature: float = 0.1, - api_base: Optional[str] = None) -> LLM: + api_base: Optional[str] = None) -> BaseLLM: """ Create any LLM by model name. Check https://docs.litellm.ai/docs/providers for a list of supported models. diff --git a/autollm/auto/query_engine.py b/autollm/auto/query_engine.py index 964847b0..70953037 100644 --- a/autollm/auto/query_engine.py +++ b/autollm/auto/query_engine.py @@ -8,6 +8,7 @@ from llama_index.response_synthesizers import get_response_synthesizer from llama_index.schema import BaseNode +from autollm.auto.embedding import AutoEmbedding from autollm.auto.llm import AutoLiteLLM from autollm.auto.service_context import AutoServiceContext from autollm.auto.vector_store_index import AutoVectorStoreIndex @@ -26,7 +27,7 @@ def create_query_engine( system_prompt: str = None, query_wrapper_prompt: Union[str, BasePromptTemplate] = None, enable_cost_calculator: bool = True, - embed_model: Union[str, EmbedType] = "default", # ["default", "local"] + embed_model: Optional[str] = "text-embedding-ada-002", chunk_size: Optional[int] = 512, chunk_overlap: Optional[int] = 100, context_window: Optional[int] = None, @@ -106,9 +107,12 @@ def create_query_engine( llm = AutoLiteLLM.from_defaults( model=llm_model, api_base=llm_api_base, max_tokens=llm_max_tokens, temperature=llm_temperature) + + embedding = AutoEmbedding(model=embed_model) + service_context = AutoServiceContext.from_defaults( llm=llm, - embed_model=embed_model, + embed_model=embedding, system_prompt=system_prompt, query_wrapper_prompt=query_wrapper_prompt, enable_cost_calculator=enable_cost_calculator, @@ -173,7 +177,7 @@ class AutoQueryEngine: system_prompt=None, query_wrapper_prompt=None, enable_cost_calculator=True, - embed_model="default", # ["default", "local"] + embed_model="text-embedding-ada-002", chunk_size=512, chunk_overlap=None, context_window=None, @@ -212,15 +216,15 @@ def from_defaults( documents: Optional[Sequence[Document]] = None, nodes: Optional[Sequence[BaseNode]] = None, # llm_params - llm_model: str = "gpt-3.5-turbo", + llm_model: Optional[str] = "gpt-3.5-turbo", llm_api_base: Optional[str] = None, llm_max_tokens: Optional[int] = None, - llm_temperature: float = 0.1, + llm_temperature: Optional[float] = 0.1, # service_context_params system_prompt: str = None, query_wrapper_prompt: Union[str, BasePromptTemplate] = None, enable_cost_calculator: bool = True, - embed_model: Union[str, EmbedType] = "default", # ["default", "local"] + embed_model: Optional[str] = "text-embedding-ada-002", chunk_size: Optional[int] = 512, chunk_overlap: Optional[int] = 200, context_window: Optional[int] = None, @@ -253,8 +257,7 @@ def from_defaults( system_prompt (str): The system prompt to use for the query engine. query_wrapper_prompt (Union[str, BasePromptTemplate]): The query wrapper prompt to use for the query engine. enable_cost_calculator (bool): Flag to enable cost calculator logging. - embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI, - "local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large) + embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. chunk_size (int): The token chunk size for each chunk. chunk_overlap (int): The token overlap between each chunk. context_window (int): The maximum context size that will get sent to the LLM. diff --git a/autollm/auto/service_context.py b/autollm/auto/service_context.py index 9aad5294..0a36db28 100644 --- a/autollm/auto/service_context.py +++ b/autollm/auto/service_context.py @@ -32,7 +32,7 @@ def from_defaults( query_wrapper_prompt: Union[str, BasePromptTemplate] = None, enable_cost_calculator: bool = False, chunk_size: Optional[int] = 512, - chunk_overlap: Optional[int] = 200, + chunk_overlap: Optional[int] = 100, context_window: Optional[int] = None, enable_title_extractor: bool = False, enable_summary_extractor: bool = False, diff --git a/requirements.txt b/requirements.txt index c57b1ddf..293458fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -llama-index==0.9.10 -litellm==1.1.1 +llama-index==0.9.21 +litellm==1.15.6 uvicorn fastapi python-dotenv diff --git a/tests/config.yaml b/tests/config.yaml index a939feb8..df3fd16e 100644 --- a/tests/config.yaml +++ b/tests/config.yaml @@ -6,7 +6,7 @@ tasks: llm_temperature: 0.1 system_prompt: "You are a friendly chatbot that can summarize documents.:" # System prompt for this task enable_cost_calculator: true - embed_model: "default" + embed_model: "text-embedding-ada-002" chunk_size: 512 chunk_overlap: 64 context_window: 2048 diff --git a/tests/test_auto_lite_llm.py b/tests/test_auto_lite_llm.py index 737d46e1..9c3461c1 100644 --- a/tests/test_auto_lite_llm.py +++ b/tests/test_auto_lite_llm.py @@ -1,5 +1,5 @@ from llama_index import Document, ServiceContext, VectorStoreIndex -from llama_index.llms.base import LLM +from llama_index.llms.base import BaseLLM from llama_index.query_engine import BaseQueryEngine from autollm.auto.llm import AutoLiteLLM @@ -11,7 +11,7 @@ def test_auto_lite_llm(): llm = AutoLiteLLM.from_defaults(model="gpt-3.5-turbo") # Check if the llm is an instance of LLM - assert isinstance(llm, LLM) + assert isinstance(llm, BaseLLM) service_context = ServiceContext.from_defaults(llm=llm) From 2ec1b21d32471da390f99cbe2409399185857ba9 Mon Sep 17 00:00:00 2001 From: Talha SARI Date: Fri, 22 Dec 2023 19:25:53 +0300 Subject: [PATCH 3/3] refactor and enhance Lancedb URI Handling in AutoVectorStoreIndex (#182) * bump autollm * add incerement method to lancedb * add validate and setup method for lancedb_uri * some fixes * fix functionality of exist_ok * remove accidental lancedb dir * more fixes * remove deprecated parameters & methods --- .gitignore | 1 + autollm/__init__.py | 2 +- autollm/auto/query_engine.py | 85 +++--------------------- autollm/auto/vector_store_index.py | 100 +++++++++++++++++++++++++++-- autollm/utils/document_reading.py | 19 +----- autollm/utils/env_utils.py | 16 +++++ 6 files changed, 123 insertions(+), 100 deletions(-) diff --git a/.gitignore b/.gitignore index 170f1a07..665bb5be 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,4 @@ storage/ # vscode settings .vscode .lancedb +lancedb/ diff --git a/autollm/__init__.py b/autollm/__init__.py index d7c61a62..763b64ba 100644 --- a/autollm/__init__.py +++ b/autollm/__init__.py @@ -4,7 +4,7 @@ and vector databases, along with various utility functions. """ -__version__ = '0.1.4' +__version__ = '0.1.5' __author__ = 'safevideo' __license__ = 'AGPL-3.0' diff --git a/autollm/auto/query_engine.py b/autollm/auto/query_engine.py index 70953037..eb3235c1 100644 --- a/autollm/auto/query_engine.py +++ b/autollm/auto/query_engine.py @@ -45,11 +45,8 @@ def create_query_engine( vector_store_type: str = "LanceDBVectorStore", lancedb_uri: str = "./.lancedb", lancedb_table_name: str = "vectors", - # Deprecated parameters - llm_params: dict = None, - vector_store_params: dict = None, - service_context_params: dict = None, - query_engine_params: dict = None, + exist_ok: bool = False, + overwrite_existing: bool = False, **vector_store_kwargs) -> BaseQueryEngine: """ Create a query engine from parameters. @@ -84,27 +81,6 @@ def create_query_engine( Returns: A llama_index.BaseQueryEngine instance. """ - # Check for deprecated parameters - if llm_params is not None: - raise ValueError( - "llm_params is deprecated. Instead of llm_params={'llm_model': 'model_name', ...}, " - "use llm_model='model_name', llm_api_base='api_base', llm_max_tokens=1028, llm_temperature=0.1 directly as arguments." - ) - if vector_store_params is not None: - raise ValueError( - "vector_store_params is deprecated. Instead of vector_store_params={'vector_store_type': 'type', ...}, " - "use vector_store_type='type', lancedb_uri='uri', lancedb_table_name='table', enable_metadata_extraction=True directly as arguments." - ) - if service_context_params is not None: - raise ValueError( - "service_context_params is deprecated. Use the explicit parameters like system_prompt='prompt', " - "query_wrapper_prompt='wrapper', enable_cost_calculator=True, embed_model='model', chunk_size=512, " - "chunk_overlap=..., context_window=... directly as arguments.") - if query_engine_params is not None: - raise ValueError( - "query_engine_params is deprecated. Instead of query_engine_params={'similarity_top_k': 5, ...}, " - "use similarity_top_k=5 directly as an argument.") - llm = AutoLiteLLM.from_defaults( model=llm_model, api_base=llm_api_base, max_tokens=llm_max_tokens, temperature=llm_temperature) @@ -132,6 +108,8 @@ def create_query_engine( documents=documents, nodes=nodes, service_context=service_context, + exist_ok=exist_ok, + overwrite_existing=overwrite_existing, **vector_store_kwargs) if refine_prompt is not None: refine_prompt_template = PromptTemplate(refine_prompt, prompt_type=PromptType.REFINE) @@ -187,7 +165,6 @@ class AutoQueryEngine: vector_store_type="LanceDBVectorStore", lancedb_uri="./.lancedb", lancedb_table_name="vectors", - enable_metadata_extraction=False, **vector_store_kwargs) ) ``` @@ -237,12 +214,8 @@ def from_defaults( vector_store_type: str = "LanceDBVectorStore", lancedb_uri: str = "./.lancedb", lancedb_table_name: str = "vectors", - enable_metadata_extraction: bool = False, - # Deprecated parameters - llm_params: dict = None, - vector_store_params: dict = None, - service_context_params: dict = None, - query_engine_params: dict = None, + exist_ok: bool = False, + overwrite_existing: bool = False, **vector_store_kwargs) -> BaseQueryEngine: """ Create an AutoQueryEngine from default parameters. @@ -272,6 +245,8 @@ def from_defaults( vector_store_type (str): The vector store type to use for the query engine. lancedb_uri (str): The URI to use for the LanceDB vector store. lancedb_table_name (str): The table name to use for the LanceDB vector store. + exist_ok (bool): Flag to allow overwriting an existing vector store. + overwrite_existing (bool): Flag to allow overwriting an existing vector store. Returns: A llama_index.BaseQueryEngine instance. @@ -302,50 +277,10 @@ def from_defaults( vector_store_type=vector_store_type, lancedb_uri=lancedb_uri, lancedb_table_name=lancedb_table_name, - enable_metadata_extraction=enable_metadata_extraction, - # Deprecated parameters - llm_params=llm_params, - vector_store_params=vector_store_params, - service_context_params=service_context_params, - query_engine_params=query_engine_params, + exist_ok=exist_ok, + overwrite_existing=overwrite_existing, **vector_store_kwargs) - @staticmethod - def from_parameters( - documents: Sequence[Document] = None, - system_prompt: str = None, - query_wrapper_prompt: str = None, - enable_cost_calculator: bool = True, - embed_model: Union[str, EmbedType] = "default", # ["default", "local"] - llm_params: dict = None, - vector_store_params: dict = None, - service_context_params: dict = None, - query_engine_params: dict = None) -> BaseQueryEngine: - """ - DEPRECATED. Use AutoQueryEngine.from_defaults instead. - - Create an AutoQueryEngine from parameters. - - Parameters: - documents (Sequence[Document]): Sequence of llama_index.Document instances. - system_prompt (str): The system prompt to use for the query engine. - query_wrapper_prompt (str): The query wrapper prompt to use for the query engine. - enable_cost_calculator (bool): Flag to enable cost calculator logging. - embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI, - "local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large) - llm_params (dict): Parameters for the LLM. - vector_store_params (dict): Parameters for the vector store. - service_context_params (dict): Parameters for the service context. - query_engine_params (dict): Parameters for the query engine. - - Returns: - A llama_index.BaseQueryEngine instance. - """ - - # TODO: Remove this method in the next release - raise ValueError( - "AutoQueryEngine.from_parameters is deprecated. Use AutoQueryEngine.from_defaults instead.") - @staticmethod def from_config( config_file_path: str, diff --git a/autollm/auto/vector_store_index.py b/autollm/auto/vector_store_index.py index cef7628f..28b3ab96 100644 --- a/autollm/auto/vector_store_index.py +++ b/autollm/auto/vector_store_index.py @@ -1,8 +1,13 @@ +import os +import shutil from typing import Optional, Sequence from llama_index import Document, ServiceContext, StorageContext, VectorStoreIndex from llama_index.schema import BaseNode +from autollm.utils.env_utils import on_rm_error +from autollm.utils.logging import logger + def import_vector_store_class(vector_store_class_name: str): """ @@ -25,24 +30,33 @@ class name and additional parameters. @staticmethod def from_defaults( vector_store_type: str = "LanceDBVectorStore", - lancedb_uri: str = "./.lancedb", + lancedb_uri: str = None, lancedb_table_name: str = "vectors", documents: Optional[Sequence[Document]] = None, nodes: Optional[Sequence[BaseNode]] = None, service_context: Optional[ServiceContext] = None, + exist_ok: bool = False, + overwrite_existing: bool = False, **kwargs) -> VectorStoreIndex: """ - Initializes a Vector Store index from Vector Store type and additional parameters. + Initializes a Vector Store index from Vector Store type and additional parameters. Handles lancedb + path and document management according to specified behaviors. Parameters: - vector_store_type (str): The class name of the vector store (e.g., 'LanceDBVectorStore', 'SimpleVectorStore'..) + vector_store_type (str): The class name of the vector store. + lancedb_uri (str): The URI for the LanceDB vector store. + lancedb_table_name (str): The table name for the LanceDB vector store. documents (Optional[Sequence[Document]]): Documents to initialize the vector store index from. - nodes (Optional[Sequence[BaseNode]]): Nodes to initialize the vector store index from. - service_context (Optional[ServiceContext]): Service context to initialize the vector store index from. - **kwargs: Additional parameters for initializing the vector store + service_context (Optional[ServiceContext]): Service context for initialization. + exist_ok (bool): If True, allows adding to an existing database. + overwrite_existing (bool): If True, allows overwriting an existing database. + **kwargs: Additional parameters for initialization. Returns: - index (VectorStoreIndex): The initialized Vector Store index instance for given vector store type and parameter set. + VectorStoreIndex: The initialized Vector Store index instance. + + Raises: + ValueError: For invalid parameter combinations or missing information. """ if documents is None and nodes is None and vector_store_type == "SimpleVectorStore": raise ValueError("documents or nodes must be provided for SimpleVectorStore") @@ -55,6 +69,12 @@ def from_defaults( # If LanceDBVectorStore, use lancedb_uri and lancedb_table_name if vector_store_type == "LanceDBVectorStore": + lancedb_uri = AutoVectorStoreIndex._validate_and_setup_lancedb_uri( + lancedb_uri=lancedb_uri, + documents=documents, + exist_ok=exist_ok, + overwrite_existing=overwrite_existing) + vector_store = VectorStoreClass(uri=lancedb_uri, table_name=lancedb_table_name, **kwargs) else: vector_store = VectorStoreClass(**kwargs) @@ -82,3 +102,69 @@ def from_defaults( show_progress=True) return index + + @staticmethod + def _validate_and_setup_lancedb_uri(lancedb_uri, documents, exist_ok, overwrite_existing): + """ + Validates and sets up the lancedb_uri based on the given parameters. + + Parameters: + lancedb_uri (str): The URI for the LanceDB vector store. + documents (Sequence[Document]): Documents to initialize the vector store index from. + exist_ok (bool): Flag to allow adding to an existing database. + overwrite_existing (bool): Flag to allow overwriting an existing database. + + Returns: + str: The validated and potentially modified lancedb_uri. + """ + default_lancedb_uri = "./lancedb/db" + + # Scenario 0: Handle no lancedb uri and no documents provided + if not documents and not lancedb_uri: + raise ValueError( + "A lancedb uri is required to connect to a database. Please provide a lancedb uri.") + + # Scenario 1: Handle lancedb_uri given but no documents provided + if not documents and lancedb_uri: + # Check if the database exists + db_exists = os.path.exists(lancedb_uri) + if not db_exists: + raise ValueError( + f"No existing database found at {lancedb_uri}. Please provide a valid lancedb uri.") + + # Scenario 2: Handle no lancedb uri but documents provided + if documents and not lancedb_uri: + lancedb_uri = default_lancedb_uri + lancedb_uri = AutoVectorStoreIndex._increment_lancedb_uri(lancedb_uri) + logger.info( + f"A new database is being created at {lancedb_uri}. Please provide a lancedb path to use an existing database." + ) + + # Scenario 3: Handle lancedb uri given and documents provided + if documents and lancedb_uri: + db_exists = os.path.exists(lancedb_uri) + if exist_ok and overwrite_existing: + if db_exists: + shutil.rmtree(lancedb_uri) + logger.info(f"Overwriting existing database at {lancedb_uri}.") + elif not exist_ok and overwrite_existing: + raise ValueError("Cannot overwrite existing database without exist_ok set to True.") + elif db_exists: + if not exist_ok: + lancedb_uri = AutoVectorStoreIndex._increment_lancedb_uri(lancedb_uri) + logger.info(f"Existing database found. Creating a new database at {lancedb_uri}.") + logger.info( + "Please use exist_ok=True to add to the existing database and overwrite_existing=True to overwrite the existing database." + ) + else: + logger.info(f"Adding documents to existing database at {lancedb_uri}.") + + return lancedb_uri + + @staticmethod + def _increment_lancedb_uri(base_uri: str) -> str: + """Increment the lancedb uri to create a new database.""" + i = 1 + while os.path.exists(f"{base_uri}_{i}"): + i += 1 + return f"{base_uri}_{i}" diff --git a/autollm/utils/document_reading.py b/autollm/utils/document_reading.py index f6685cb3..888899dc 100644 --- a/autollm/utils/document_reading.py +++ b/autollm/utils/document_reading.py @@ -1,12 +1,11 @@ -import os import shutil -import stat from pathlib import Path -from typing import Callable, List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence from llama_index.readers.file.base import SimpleDirectoryReader from llama_index.schema import Document +from autollm.utils.env_utils import on_rm_error from autollm.utils.git_utils import clone_or_pull_repository from autollm.utils.logging import logger from autollm.utils.markdown_reader import MarkdownReader @@ -65,20 +64,6 @@ def read_files_as_documents( return documents -# From http://stackoverflow.com/a/4829285/548792 -def on_rm_error(func: Callable, path: str, exc_info: Tuple): - """ - Error handler for `shutil.rmtree` to handle permission errors. - - Parameters: - func (Callable): The function that raised the error. - path (str): The path to the file or directory which couldn't be removed. - exc_info (Tuple): Exception information returned by sys.exc_info(). - """ - os.chmod(path, stat.S_IWRITE) - os.unlink(path) - - def read_github_repo_as_documents( git_repo_url: str, relative_folder_path: Optional[str] = None, diff --git a/autollm/utils/env_utils.py b/autollm/utils/env_utils.py index 5499c9a1..60f43c7a 100644 --- a/autollm/utils/env_utils.py +++ b/autollm/utils/env_utils.py @@ -1,5 +1,7 @@ import os +import stat from pathlib import Path +from typing import Callable, Tuple import yaml from dotenv import load_dotenv @@ -55,3 +57,17 @@ def load_config_and_dotenv(config_file_path: str, env_file_path: str = None) -> config = yaml.safe_load(f) return config + + +# From http://stackoverflow.com/a/4829285/548792 +def on_rm_error(func: Callable, path: str, exc_info: Tuple): + """ + Error handler for `shutil.rmtree` to handle permission errors. + + Parameters: + func (Callable): The function that raised the error. + path (str): The path to the file or directory which couldn't be removed. + exc_info (Tuple): Exception information returned by sys.exc_info(). + """ + os.chmod(path, stat.S_IWRITE) + os.unlink(path)