Skip to content

Commit

Permalink
Provide embedding manager (#16)
Browse files Browse the repository at this point in the history
* Provide the Embedding management UI

* Update Fastembed documentation

* Add validation when adding / updating embeddings

* Stop using the old ktem embeddings manager

* Set default local embedding models

* Move the local embeddings below in flowsettings

* Update flowsettings
  • Loading branch information
trducng authored Apr 10, 2024
1 parent ed10020 commit 7b3307e
Show file tree
Hide file tree
Showing 12 changed files with 607 additions and 29 deletions.
2 changes: 1 addition & 1 deletion docs/pages/app/customize-flows.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ information panel.
You can access users' collections of LLMs and embedding models with:

```python
from ktem.components import embeddings
from ktem.embeddings.manager import embeddings
from ktem.llms.manager import llms


Expand Down
9 changes: 5 additions & 4 deletions libs/kotaemon/kotaemon/embeddings/fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ class FastEmbedEmbeddings(BaseEmbeddings):
model_name: str = Param(
"BAAI/bge-small-en-v1.5",
help=(
"Model name for fastembed. "
"Supported model: "
"https://qdrant.github.io/fastembed/examples/Supported_Models/"
"Model name for fastembed. Please refer "
"[here](https://qdrant.github.io/fastembed/examples/Supported_Models/) "
"for the list of supported models."
),
required=True,
)
batch_size: int = Param(
256,
Expand All @@ -34,7 +35,7 @@ class FastEmbedEmbeddings(BaseEmbeddings):
"If > 1, data-parallel encoding will be used. "
"If 0, use all available CPUs. "
"If None, use default onnxruntime threading. "
"Defaults to None"
"Defaults to None."
),
)

Expand Down
14 changes: 10 additions & 4 deletions libs/ktem/flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
KH_EMBEDDINGS["azure"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCAzureOpenAIEmbeddings",
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"api_key": config("AZURE_OPENAI_API_KEY", default=""),
"api_version": config("OPENAI_API_VERSION", default="")
Expand All @@ -68,8 +68,6 @@
"timeout": 10,
},
"default": False,
"accuracy": 5,
"cost": 5,
}

if config("OPENAI_API_KEY", default=""):
Expand All @@ -88,7 +86,7 @@
if len(KH_EMBEDDINGS) < 1:
KH_EMBEDDINGS["openai"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCOpenAIEmbeddings",
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": config("OPENAI_API_BASE", default="")
or "https://api.openai.com/v1",
"api_key": config("OPENAI_API_KEY", default=""),
Expand Down Expand Up @@ -120,6 +118,14 @@
"cost": 0,
}

if len(KH_EMBEDDINGS) < 1:
KH_EMBEDDINGS["local-mxbai-large-v1"] = {
"spec": {
"__type__": "kotaemon.embeddings.FastEmbedEmbeddings",
"model_name": "mixedbread-ai/mxbai-embed-large-v1",
},
"default": True,
}

KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
Expand Down
2 changes: 0 additions & 2 deletions libs/ktem/ktem/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,5 @@ def get_lowest_cost(self) -> BaseComponent:
return self._models[self._cost[0]]


llms = ModelPool("LLMs", settings.KH_LLMS)
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
reasonings: dict = {}
tools = ModelPool("Tools", {})
Empty file.
36 changes: 36 additions & 0 deletions libs/ktem/ktem/embeddings/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Type

from ktem.db.engine import engine
from sqlalchemy import JSON, Boolean, Column, String
from sqlalchemy.orm import DeclarativeBase
from theflow.settings import settings as flowsettings
from theflow.utils.modules import import_dotted_string


class Base(DeclarativeBase):
pass


class BaseEmbeddingTable(Base):
"""Base table to store language model"""

__abstract__ = True

name = Column(String, primary_key=True, unique=True)
spec = Column(JSON, default={})
default = Column(Boolean, default=False)


_base_llm: Type[BaseEmbeddingTable] = (
import_dotted_string(flowsettings.KH_EMBEDDING_LLM, safe=False)
if hasattr(flowsettings, "KH_EMBEDDING_LLM")
else BaseEmbeddingTable
)


class EmbeddingTable(_base_llm): # type: ignore
__tablename__ = "embedding"


if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
EmbeddingTable.metadata.create_all(engine)
199 changes: 199 additions & 0 deletions libs/ktem/ktem/embeddings/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from typing import Optional, Type

from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize

from kotaemon.embeddings.base import BaseEmbeddings

from .db import EmbeddingTable, engine


class EmbeddingManager:
"""Represent a pool of models"""

def __init__(self):
self._models: dict[str, BaseEmbeddings] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []

# populate the pool if empty
if hasattr(flowsettings, "KH_EMBEDDINGS"):
with Session(engine) as sess:
count = sess.query(EmbeddingTable).count()
if not count:
for name, model in flowsettings.KH_EMBEDDINGS.items():
self.add(
name=name,
spec=model["spec"],
default=model.get("default", False),
)

self.load()
self.load_vendors()

def load(self):
"""Load the model pool from database"""
self._models, self._info, self._defaut = {}, {}, ""
with Session(engine) as sess:
stmt = select(EmbeddingTable)
items = sess.execute(stmt)

for (item,) in items:
self._models[item.name] = deserialize(item.spec, safe=False)
self._info[item.name] = {
"name": item.name,
"spec": item.spec,
"default": item.default,
}
if item.default:
self._default = item.name

def load_vendors(self):
from kotaemon.embeddings import (
AzureOpenAIEmbeddings,
FastEmbedEmbeddings,
OpenAIEmbeddings,
)

self._vendors = [AzureOpenAIEmbeddings, OpenAIEmbeddings, FastEmbedEmbeddings]

def __getitem__(self, key: str) -> BaseEmbeddings:
"""Get model by name"""
return self._models[key]

def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models

def get(
self, key: str, default: Optional[BaseEmbeddings] = None
) -> Optional[BaseEmbeddings]:
"""Get model by name with default value"""
return self._models.get(key, default)

def settings(self) -> dict:
"""Present model pools option for gradio"""
return {
"label": "Embedding",
"choices": list(self._models.keys()),
"value": self.get_default_name(),
}

def options(self) -> dict:
"""Present a dict of models"""
return self._models

def get_random_name(self) -> str:
"""Get the name of random model
Returns:
str: random model name in the pool
"""
import random

if not self._models:
raise ValueError("No models in pool")

return random.choice(list(self._models.keys()))

def get_default_name(self) -> str:
"""Get the name of default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
str: model name
"""
if not self._models:
raise ValueError("No models in pool")

if not self._default:
return self.get_random_name()

return self._default

def get_random(self) -> BaseEmbeddings:
"""Get random model"""
return self._models[self.get_random_name()]

def get_default(self) -> BaseEmbeddings:
"""Get default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
BaseEmbeddings: model
"""
return self._models[self.get_default_name()]

def info(self) -> dict:
"""List all models"""
return self._info

def add(self, name: str, spec: dict, default: bool):
"""Add a new model to the pool"""
if not name:
raise ValueError("Name must not be empty")

try:
with Session(engine) as sess:
if default:
# turn all models to non-default
sess.query(EmbeddingTable).update({"default": False})
sess.commit()

item = EmbeddingTable(name=name, spec=spec, default=default)
sess.add(item)
sess.commit()
except Exception as e:
raise ValueError(f"Failed to add model {name}: {e}")

self.load()

def delete(self, name: str):
"""Delete a model from the pool"""
try:
with Session(engine) as sess:
item = sess.query(EmbeddingTable).filter_by(name=name).first()
sess.delete(item)
sess.commit()
except Exception as e:
raise ValueError(f"Failed to delete model {name}: {e}")

self.load()

def update(self, name: str, spec: dict, default: bool):
"""Update a model in the pool"""
if not name:
raise ValueError("Name must not be empty")

try:
with Session(engine) as sess:

if default:
# turn all models to non-default
sess.query(EmbeddingTable).update({"default": False})
sess.commit()

item = sess.query(EmbeddingTable).filter_by(name=name).first()
if not item:
raise ValueError(f"Model {name} not found")
item.spec = spec
item.default = default
sess.commit()
except Exception as e:
raise ValueError(f"Failed to update model {name}: {e}")

self.load()

def vendors(self) -> dict:
"""Return list of vendors"""
return {vendor.__qualname__: vendor for vendor in self._vendors}


embeddings = EmbeddingManager()
Loading

0 comments on commit 7b3307e

Please sign in to comment.