Skip to content

Commit

Permalink
Allow users to select reasoning pipeline. Fix small issues with user …
Browse files Browse the repository at this point in the history
…UI, cohere name (#50)

* Fix user page

* Allow changing LLM in reasoning pipeline

* Fix CohereEmbedding name
  • Loading branch information
trducng authored Apr 25, 2024
1 parent e29bec6 commit a872571
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 13 deletions.
4 changes: 2 additions & 2 deletions libs/kotaemon/kotaemon/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .fastembed import FastEmbedEmbeddings
from .langchain_based import (
LCAzureOpenAIEmbeddings,
LCCohereEmbdeddings,
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
LCOpenAIEmbeddings,
)
Expand All @@ -14,7 +14,7 @@
"EndpointEmbeddings",
"LCOpenAIEmbeddings",
"LCAzureOpenAIEmbeddings",
"LCCohereEmbdeddings",
"LCCohereEmbeddings",
"LCHuggingFaceEmbeddings",
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
Expand Down
2 changes: 1 addition & 1 deletion libs/kotaemon/kotaemon/embeddings/langchain_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _get_lc_class(self):
return AzureOpenAIEmbeddings


class LCCohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions libs/kotaemon/tests/test_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
AzureOpenAIEmbeddings,
FastEmbedEmbeddings,
LCAzureOpenAIEmbeddings,
LCCohereEmbdeddings,
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
OpenAIEmbeddings,
)
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_lchuggingface_embeddings(
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
)
def test_lccohere_embeddings(langchain_cohere_embedding_call):
model = LCCohereEmbdeddings(
model = LCCohereEmbeddings(
model="embed-english-light-v2.0", cohere_api_key="my-api-key"
)

Expand Down
11 changes: 10 additions & 1 deletion libs/ktem/ktem/llms/manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Type
from typing import Optional, Type, overload

from sqlalchemy import select
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -71,6 +71,14 @@ def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models

@overload
def get(self, key: str, default: None) -> Optional[ChatLLM]:
...

@overload
def get(self, key: str, default: ChatLLM) -> ChatLLM:
...

def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
"""Get model by name with default value"""
return self._models.get(key, default)
Expand Down Expand Up @@ -138,6 +146,7 @@ def info(self) -> dict:

def add(self, name: str, spec: dict, default: bool):
"""Add a new model to the pool"""
name = name.strip()
if not name:
raise ValueError("Name must not be empty")

Expand Down
4 changes: 2 additions & 2 deletions libs/ktem/ktem/pages/resources/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def on_building_ui(self):
)
self.admin_edit = gr.Checkbox(label="Admin")

with gr.Row() as self._selected_panel_btn:
with gr.Row(visible=False) as self._selected_panel_btn:
with gr.Column():
self.btn_edit_save = gr.Button("Save")
with gr.Column():
Expand Down Expand Up @@ -338,7 +338,7 @@ def select_user(self, user_list, ev: gr.SelectData):
if not ev.selected:
return -1

return user_list["id"][ev.index[0]]
return int(user_list["id"][ev.index[0]])

def on_selected_user_change(self, selected_user_id):
if selected_user_id == -1:
Expand Down
32 changes: 27 additions & 5 deletions libs/ktem/ktem/reasoning/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,12 +680,15 @@ def get_pipeline(cls, settings, states, retrievers):
retrievers: the retrievers to use
"""
prefix = f"reasoning.options.{cls.get_info()['id']}"
pipeline = FullQAPipeline(retrievers=retrievers)
pipeline = cls(retrievers=retrievers)

llm_name = settings.get(f"{prefix}.llm", None)
llm = llms.get(llm_name, llms.get_default())

# answering pipeline configuration
answer_pipeline = pipeline.answering_pipeline
answer_pipeline.llm = llms.get_default()
answer_pipeline.citation_pipeline.llm = llms.get_default()
answer_pipeline.llm = llm
answer_pipeline.citation_pipeline.llm = llm
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
Expand All @@ -694,22 +697,41 @@ def get_pipeline(cls, settings, states, retrievers):
settings["reasoning.lang"], "English"
)

pipeline.add_query_context.llm = llms.get_default()
pipeline.add_query_context.llm = llm
pipeline.add_query_context.n_last_interactions = settings[
f"{prefix}.n_last_interactions"
]

pipeline.trigger_context = settings[f"{prefix}.trigger_context"]
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
pipeline.rewrite_pipeline.llm = llms.get_default()
pipeline.rewrite_pipeline.llm = llm
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
return pipeline

@classmethod
def get_user_settings(cls) -> dict:
from ktem.llms.manager import llms

llm = ""
choices = [("(default)", "")]
try:
choices += [(_, _) for _ in llms.options().keys()]
except Exception as e:
logger.exception(f"Failed to get LLM options: {e}")

return {
"llm": {
"name": "Language model",
"value": llm,
"component": "dropdown",
"choices": choices,
"info": (
"The language model to use for generating the answer. If None, "
"the application default language model will be used."
),
},
"highlight_citation": {
"name": "Highlight Citation",
"value": False,
Expand Down

0 comments on commit a872571

Please sign in to comment.