Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide embedding manager #16

Merged
merged 7 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/pages/app/customize-flows.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ information panel.
You can access users' collections of LLMs and embedding models with:

```python
from ktem.components import embeddings
from ktem.embeddings.manager import embeddings
from ktem.llms.manager import llms


Expand Down
9 changes: 5 additions & 4 deletions libs/kotaemon/kotaemon/embeddings/fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ class FastEmbedEmbeddings(BaseEmbeddings):
model_name: str = Param(
"BAAI/bge-small-en-v1.5",
help=(
"Model name for fastembed. "
"Supported model: "
"https://qdrant.github.io/fastembed/examples/Supported_Models/"
"Model name for fastembed. Please refer "
"[here](https://qdrant.github.io/fastembed/examples/Supported_Models/) "
"for the list of supported models."
),
required=True,
)
batch_size: int = Param(
256,
Expand All @@ -34,7 +35,7 @@ class FastEmbedEmbeddings(BaseEmbeddings):
"If > 1, data-parallel encoding will be used. "
"If 0, use all available CPUs. "
"If None, use default onnxruntime threading. "
"Defaults to None"
"Defaults to None."
),
)

Expand Down
14 changes: 10 additions & 4 deletions libs/ktem/flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
KH_EMBEDDINGS["azure"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCAzureOpenAIEmbeddings",
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"api_key": config("AZURE_OPENAI_API_KEY", default=""),
"api_version": config("OPENAI_API_VERSION", default="")
Expand All @@ -68,8 +68,6 @@
"timeout": 10,
},
"default": False,
"accuracy": 5,
"cost": 5,
}

if config("OPENAI_API_KEY", default=""):
Expand All @@ -88,7 +86,7 @@
if len(KH_EMBEDDINGS) < 1:
KH_EMBEDDINGS["openai"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCOpenAIEmbeddings",
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": config("OPENAI_API_BASE", default="")
or "https://api.openai.com/v1",
"api_key": config("OPENAI_API_KEY", default=""),
Expand Down Expand Up @@ -120,6 +118,14 @@
"cost": 0,
}

if len(KH_EMBEDDINGS) < 1:
KH_EMBEDDINGS["local-mxbai-large-v1"] = {
"spec": {
"__type__": "kotaemon.embeddings.FastEmbedEmbeddings",
"model_name": "mixedbread-ai/mxbai-embed-large-v1",
},
"default": True,
}

KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
Expand Down
2 changes: 0 additions & 2 deletions libs/ktem/ktem/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,5 @@ def get_lowest_cost(self) -> BaseComponent:
return self._models[self._cost[0]]


llms = ModelPool("LLMs", settings.KH_LLMS)
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
reasonings: dict = {}
tools = ModelPool("Tools", {})
Empty file.
36 changes: 36 additions & 0 deletions libs/ktem/ktem/embeddings/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Type

from ktem.db.engine import engine
from sqlalchemy import JSON, Boolean, Column, String
from sqlalchemy.orm import DeclarativeBase
from theflow.settings import settings as flowsettings
from theflow.utils.modules import import_dotted_string


class Base(DeclarativeBase):
pass


class BaseEmbeddingTable(Base):
"""Base table to store language model"""

__abstract__ = True

name = Column(String, primary_key=True, unique=True)
spec = Column(JSON, default={})
default = Column(Boolean, default=False)


_base_llm: Type[BaseEmbeddingTable] = (
import_dotted_string(flowsettings.KH_EMBEDDING_LLM, safe=False)
if hasattr(flowsettings, "KH_EMBEDDING_LLM")
else BaseEmbeddingTable
)


class EmbeddingTable(_base_llm): # type: ignore
__tablename__ = "embedding"


if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
EmbeddingTable.metadata.create_all(engine)
199 changes: 199 additions & 0 deletions libs/ktem/ktem/embeddings/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from typing import Optional, Type

from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize

from kotaemon.embeddings.base import BaseEmbeddings

from .db import EmbeddingTable, engine


class EmbeddingManager:
"""Represent a pool of models"""

def __init__(self):
self._models: dict[str, BaseEmbeddings] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []

# populate the pool if empty
if hasattr(flowsettings, "KH_EMBEDDINGS"):
with Session(engine) as sess:
count = sess.query(EmbeddingTable).count()
if not count:
for name, model in flowsettings.KH_EMBEDDINGS.items():
self.add(
name=name,
spec=model["spec"],
default=model.get("default", False),
)

self.load()
self.load_vendors()

def load(self):
"""Load the model pool from database"""
self._models, self._info, self._defaut = {}, {}, ""
with Session(engine) as sess:
stmt = select(EmbeddingTable)
items = sess.execute(stmt)

for (item,) in items:
self._models[item.name] = deserialize(item.spec, safe=False)
self._info[item.name] = {
"name": item.name,
"spec": item.spec,
"default": item.default,
}
if item.default:
self._default = item.name

def load_vendors(self):
from kotaemon.embeddings import (
AzureOpenAIEmbeddings,
FastEmbedEmbeddings,
OpenAIEmbeddings,
)

self._vendors = [AzureOpenAIEmbeddings, OpenAIEmbeddings, FastEmbedEmbeddings]

def __getitem__(self, key: str) -> BaseEmbeddings:
"""Get model by name"""
return self._models[key]

def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models

def get(
self, key: str, default: Optional[BaseEmbeddings] = None
) -> Optional[BaseEmbeddings]:
"""Get model by name with default value"""
return self._models.get(key, default)

def settings(self) -> dict:
"""Present model pools option for gradio"""
return {
"label": "Embedding",
"choices": list(self._models.keys()),
"value": self.get_default_name(),
}

def options(self) -> dict:
"""Present a dict of models"""
return self._models

def get_random_name(self) -> str:
"""Get the name of random model

Returns:
str: random model name in the pool
"""
import random

if not self._models:
raise ValueError("No models in pool")

return random.choice(list(self._models.keys()))

def get_default_name(self) -> str:
"""Get the name of default model

In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.

Returns:
str: model name
"""
if not self._models:
raise ValueError("No models in pool")

if not self._default:
return self.get_random_name()

return self._default

def get_random(self) -> BaseEmbeddings:
"""Get random model"""
return self._models[self.get_random_name()]

def get_default(self) -> BaseEmbeddings:
"""Get default model

In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.

Returns:
BaseEmbeddings: model
"""
return self._models[self.get_default_name()]

def info(self) -> dict:
"""List all models"""
return self._info

def add(self, name: str, spec: dict, default: bool):
"""Add a new model to the pool"""
if not name:
raise ValueError("Name must not be empty")

try:
with Session(engine) as sess:
if default:
# turn all models to non-default
sess.query(EmbeddingTable).update({"default": False})
sess.commit()

item = EmbeddingTable(name=name, spec=spec, default=default)
sess.add(item)
sess.commit()
except Exception as e:
raise ValueError(f"Failed to add model {name}: {e}")

self.load()

def delete(self, name: str):
"""Delete a model from the pool"""
try:
with Session(engine) as sess:
item = sess.query(EmbeddingTable).filter_by(name=name).first()
sess.delete(item)
sess.commit()
except Exception as e:
raise ValueError(f"Failed to delete model {name}: {e}")

self.load()

def update(self, name: str, spec: dict, default: bool):
"""Update a model in the pool"""
if not name:
raise ValueError("Name must not be empty")

try:
with Session(engine) as sess:

if default:
# turn all models to non-default
sess.query(EmbeddingTable).update({"default": False})
sess.commit()

item = sess.query(EmbeddingTable).filter_by(name=name).first()
if not item:
raise ValueError(f"Model {name} not found")
item.spec = spec
item.default = default
sess.commit()
except Exception as e:
raise ValueError(f"Failed to update model {name}: {e}")

self.load()

def vendors(self) -> dict:
"""Return list of vendors"""
return {vendor.__qualname__: vendor for vendor in self._vendors}


embeddings = EmbeddingManager()
Loading
Loading