Skip to content

Commit

Permalink
Stop using the old ktem embeddings manager
Browse files Browse the repository at this point in the history
  • Loading branch information
trducng committed Apr 10, 2024
1 parent 1445e8b commit 74109d6
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 29 deletions.
2 changes: 1 addition & 1 deletion docs/pages/app/customize-flows.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ information panel.
You can access users' collections of LLMs and embedding models with:

```python
from ktem.components import embeddings
from ktem.embeddings.manager import embeddings
from ktem.llms.manager import llms


Expand Down
4 changes: 2 additions & 2 deletions libs/ktem/flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
KH_EMBEDDINGS["azure"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCAzureOpenAIEmbeddings",
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"api_key": config("AZURE_OPENAI_API_KEY", default=""),
"api_version": config("OPENAI_API_VERSION", default="")
Expand Down Expand Up @@ -88,7 +88,7 @@
if len(KH_EMBEDDINGS) < 1:
KH_EMBEDDINGS["openai"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCOpenAIEmbeddings",
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": config("OPENAI_API_BASE", default="")
or "https://api.openai.com/v1",
"api_key": config("OPENAI_API_KEY", default=""),
Expand Down
2 changes: 0 additions & 2 deletions libs/ktem/ktem/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,5 @@ def get_lowest_cost(self) -> BaseComponent:
return self._models[self._cost[0]]


llms = ModelPool("LLMs", settings.KH_LLMS)
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
reasonings: dict = {}
tools = ModelPool("Tools", {})
16 changes: 8 additions & 8 deletions libs/ktem/ktem/embeddings/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize

from kotaemon.base import BaseComponent
from kotaemon.embeddings.base import BaseEmbeddings

from .db import EmbeddingTable, engine

Expand All @@ -14,7 +14,7 @@ class EmbeddingManager:
"""Represent a pool of models"""

def __init__(self):
self._models: dict[str, BaseComponent] = {}
self._models: dict[str, BaseEmbeddings] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []
Expand Down Expand Up @@ -60,7 +60,7 @@ def load_vendors(self):

self._vendors = [AzureOpenAIEmbeddings, OpenAIEmbeddings, FastEmbedEmbeddings]

def __getitem__(self, key: str) -> BaseComponent:
def __getitem__(self, key: str) -> BaseEmbeddings:
"""Get model by name"""
return self._models[key]

Expand All @@ -69,8 +69,8 @@ def __contains__(self, key: str) -> bool:
return key in self._models

def get(
self, key: str, default: Optional[BaseComponent] = None
) -> Optional[BaseComponent]:
self, key: str, default: Optional[BaseEmbeddings] = None
) -> Optional[BaseEmbeddings]:
"""Get model by name with default value"""
return self._models.get(key, default)

Expand Down Expand Up @@ -116,18 +116,18 @@ def get_default_name(self) -> str:

return self._default

def get_random(self) -> BaseComponent:
def get_random(self) -> BaseEmbeddings:
"""Get random model"""
return self._models[self.get_random_name()]

def get_default(self) -> BaseComponent:
def get_default(self) -> BaseEmbeddings:
"""Get default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
BaseComponent: model
BaseEmbeddings: model
"""
return self._models[self.get_default_name()]

Expand Down
27 changes: 18 additions & 9 deletions libs/ktem/ktem/index/file/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,26 @@ def on_create(self):
"""Create the index for the first time
For the file index, this will:
1. Create the index and the source table if not already exists
2. Create the vectorstore
3. Create the docstore
1. Postprocess the config
2. Create the index and the source table if not already exists
3. Create the vectorstore
4. Create the docstore
"""
file_types_str = self.config.get(
"supported_file_types",
self.get_admin_settings()["supported_file_types"]["value"],
)
# default user's value
config = {}
for key, value in self.get_admin_settings().items():
config[key] = value["value"]

# user's modification
config.update(self.config)

# clean
file_types_str = config["supported_file_types"]
file_types = [each.strip() for each in file_types_str.split(",")]
self.config["supported_file_types"] = file_types
config["supported_file_types"] = file_types
self.config = config

# create the resources
self._resources["Source"].metadata.create_all(engine) # type: ignore
self._resources["Index"].metadata.create_all(engine) # type: ignore
self._fs_path.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -285,7 +294,7 @@ def get_user_settings(self):

@classmethod
def get_admin_settings(cls):
from ktem.components import embeddings
from ktem.embeddings.manager import embeddings

embedding_default = embeddings.get_default_name()
embedding_choices = list(embeddings.options().keys())
Expand Down
14 changes: 7 additions & 7 deletions libs/ktem/ktem/index/file/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from typing import Optional

import gradio as gr
from ktem.components import embeddings, filestorage_path
from ktem.components import filestorage_path
from ktem.db.models import engine
from ktem.embeddings.manager import embeddings
from llama_index.vector_stores import (
FilterCondition,
FilterOperator,
Expand Down Expand Up @@ -68,9 +69,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
for surrounding tables (e.g. within the page)
"""

vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
embedding=embeddings.get_default(),
)
vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
reranker: BaseReranking
get_extra_table: bool = False

Expand Down Expand Up @@ -226,6 +225,7 @@ def get_pipeline(cls, user_settings, index_settings, selected):
if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore

retriever.vector_retrieval.embedding = embeddings[index_settings["embedding"]]
kwargs = {
".top_k": int(user_settings["num_retrieval"]),
".mmr": user_settings["mmr"],
Expand All @@ -248,9 +248,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
file_ingestor: ingestor to ingest the documents
"""

indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx(
embedding=embeddings.get_default(),
)
indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx()
file_ingestor: DocumentIngestor = DocumentIngestor.withx()

def run(
Expand Down Expand Up @@ -438,6 +436,8 @@ def get_pipeline(cls, user_settings, index_settings) -> "IndexDocumentPipeline":
if chunk_overlap:
obj.file_ingestor.text_splitter.chunk_overlap = chunk_overlap

obj.indexing_vector_pipeline.embedding = embeddings[index_settings["embedding"]]

return obj

def set_resources(self, resources: dict):
Expand Down

0 comments on commit 74109d6

Please sign in to comment.