diff --git a/docs/pages/app/customize-flows.md b/docs/pages/app/customize-flows.md index 3dd005e24..40d320fa2 100644 --- a/docs/pages/app/customize-flows.md +++ b/docs/pages/app/customize-flows.md @@ -193,7 +193,7 @@ information panel. You can access users' collections of LLMs and embedding models with: ```python -from ktem.components import embeddings +from ktem.embeddings.manager import embeddings from ktem.llms.manager import llms diff --git a/libs/ktem/flowsettings.py b/libs/ktem/flowsettings.py index d73b9cfb3..9246b7a69 100644 --- a/libs/ktem/flowsettings.py +++ b/libs/ktem/flowsettings.py @@ -57,7 +57,7 @@ if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""): KH_EMBEDDINGS["azure"] = { "spec": { - "__type__": "kotaemon.embeddings.LCAzureOpenAIEmbeddings", + "__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings", "azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""), "api_key": config("AZURE_OPENAI_API_KEY", default=""), "api_version": config("OPENAI_API_VERSION", default="") @@ -88,7 +88,7 @@ if len(KH_EMBEDDINGS) < 1: KH_EMBEDDINGS["openai"] = { "spec": { - "__type__": "kotaemon.embeddings.LCOpenAIEmbeddings", + "__type__": "kotaemon.embeddings.OpenAIEmbeddings", "base_url": config("OPENAI_API_BASE", default="") or "https://api.openai.com/v1", "api_key": config("OPENAI_API_KEY", default=""), diff --git a/libs/ktem/ktem/components.py b/libs/ktem/ktem/components.py index 182cb91d0..6674430c5 100644 --- a/libs/ktem/ktem/components.py +++ b/libs/ktem/ktem/components.py @@ -182,7 +182,5 @@ def get_lowest_cost(self) -> BaseComponent: return self._models[self._cost[0]] -llms = ModelPool("LLMs", settings.KH_LLMS) -embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS) reasonings: dict = {} tools = ModelPool("Tools", {}) diff --git a/libs/ktem/ktem/embeddings/manager.py b/libs/ktem/ktem/embeddings/manager.py index 2118fd7e2..db0aa4143 100644 --- a/libs/ktem/ktem/embeddings/manager.py +++ b/libs/ktem/ktem/embeddings/manager.py @@ -5,7 +5,7 @@ from theflow.settings import settings as flowsettings from theflow.utils.modules import deserialize -from kotaemon.base import BaseComponent +from kotaemon.embeddings.base import BaseEmbeddings from .db import EmbeddingTable, engine @@ -14,7 +14,7 @@ class EmbeddingManager: """Represent a pool of models""" def __init__(self): - self._models: dict[str, BaseComponent] = {} + self._models: dict[str, BaseEmbeddings] = {} self._info: dict[str, dict] = {} self._default: str = "" self._vendors: list[Type] = [] @@ -60,7 +60,7 @@ def load_vendors(self): self._vendors = [AzureOpenAIEmbeddings, OpenAIEmbeddings, FastEmbedEmbeddings] - def __getitem__(self, key: str) -> BaseComponent: + def __getitem__(self, key: str) -> BaseEmbeddings: """Get model by name""" return self._models[key] @@ -69,8 +69,8 @@ def __contains__(self, key: str) -> bool: return key in self._models def get( - self, key: str, default: Optional[BaseComponent] = None - ) -> Optional[BaseComponent]: + self, key: str, default: Optional[BaseEmbeddings] = None + ) -> Optional[BaseEmbeddings]: """Get model by name with default value""" return self._models.get(key, default) @@ -116,18 +116,18 @@ def get_default_name(self) -> str: return self._default - def get_random(self) -> BaseComponent: + def get_random(self) -> BaseEmbeddings: """Get random model""" return self._models[self.get_random_name()] - def get_default(self) -> BaseComponent: + def get_default(self) -> BaseEmbeddings: """Get default model In case there is no default model, choose random model from pool. In case there are multiple default models, choose random from them. Returns: - BaseComponent: model + BaseEmbeddings: model """ return self._models[self.get_default_name()] diff --git a/libs/ktem/ktem/index/file/index.py b/libs/ktem/ktem/index/file/index.py index 5fe395596..4725cd926 100644 --- a/libs/ktem/ktem/index/file/index.py +++ b/libs/ktem/ktem/index/file/index.py @@ -236,17 +236,26 @@ def on_create(self): """Create the index for the first time For the file index, this will: - 1. Create the index and the source table if not already exists - 2. Create the vectorstore - 3. Create the docstore + 1. Postprocess the config + 2. Create the index and the source table if not already exists + 3. Create the vectorstore + 4. Create the docstore """ - file_types_str = self.config.get( - "supported_file_types", - self.get_admin_settings()["supported_file_types"]["value"], - ) + # default user's value + config = {} + for key, value in self.get_admin_settings().items(): + config[key] = value["value"] + + # user's modification + config.update(self.config) + + # clean + file_types_str = config["supported_file_types"] file_types = [each.strip() for each in file_types_str.split(",")] - self.config["supported_file_types"] = file_types + config["supported_file_types"] = file_types + self.config = config + # create the resources self._resources["Source"].metadata.create_all(engine) # type: ignore self._resources["Index"].metadata.create_all(engine) # type: ignore self._fs_path.mkdir(parents=True, exist_ok=True) @@ -285,7 +294,7 @@ def get_user_settings(self): @classmethod def get_admin_settings(cls): - from ktem.components import embeddings + from ktem.embeddings.manager import embeddings embedding_default = embeddings.get_default_name() embedding_choices = list(embeddings.options().keys()) diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index 13036f33f..05df23c30 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -10,8 +10,9 @@ from typing import Optional import gradio as gr -from ktem.components import embeddings, filestorage_path +from ktem.components import filestorage_path from ktem.db.models import engine +from ktem.embeddings.manager import embeddings from llama_index.vector_stores import ( FilterCondition, FilterOperator, @@ -68,9 +69,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): for surrounding tables (e.g. within the page) """ - vector_retrieval: VectorRetrieval = VectorRetrieval.withx( - embedding=embeddings.get_default(), - ) + vector_retrieval: VectorRetrieval = VectorRetrieval.withx() reranker: BaseReranking get_extra_table: bool = False @@ -226,6 +225,7 @@ def get_pipeline(cls, user_settings, index_settings, selected): if not user_settings["use_reranking"]: retriever.reranker = None # type: ignore + retriever.vector_retrieval.embedding = embeddings[index_settings["embedding"]] kwargs = { ".top_k": int(user_settings["num_retrieval"]), ".mmr": user_settings["mmr"], @@ -248,9 +248,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing): file_ingestor: ingestor to ingest the documents """ - indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx( - embedding=embeddings.get_default(), - ) + indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx() file_ingestor: DocumentIngestor = DocumentIngestor.withx() def run( @@ -438,6 +436,8 @@ def get_pipeline(cls, user_settings, index_settings) -> "IndexDocumentPipeline": if chunk_overlap: obj.file_ingestor.text_splitter.chunk_overlap = chunk_overlap + obj.indexing_vector_pipeline.embedding = embeddings[index_settings["embedding"]] + return obj def set_resources(self, resources: dict):