Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/improve simple reasoning #2

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions libs/kotaemon/kotaemon/contribs/promptui/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
BINARY_REMOTE_NAME = f"frpc_{platform.system().lower()}_{machine.lower()}"
EXTENSION = ".exe" if os.name == "nt" else ""
BINARY_URL = (
"some-endpoint.com"
f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
"some-endpoint.com" f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
)

BINARY_FILENAME = f"{BINARY_REMOTE_NAME}_v{VERSION}"
Expand Down
1 change: 0 additions & 1 deletion libs/kotaemon/kotaemon/llms/chats/langchain_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def _get_lc_class(self):


class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore

def __init__(
self,
azure_endpoint: str | None = None,
Expand Down
3 changes: 2 additions & 1 deletion libs/ktem/ktem/components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common components, some kind of config"""

import logging
from functools import cache
from pathlib import Path
Expand Down Expand Up @@ -71,7 +72,7 @@ def settings(self) -> dict:
}

def options(self) -> dict:
"""Present a list of models"""
"""Present a dict of models"""
return self._models

def get_random_name(self) -> str:
Expand Down
32 changes: 27 additions & 5 deletions libs/ktem/ktem/index/file/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import shutil
import warnings
from collections import defaultdict
Expand All @@ -8,7 +9,7 @@
from pathlib import Path
from typing import Optional

from ktem.components import embeddings, filestorage_path, llms
from ktem.components import embeddings, filestorage_path
from ktem.db.models import engine
from llama_index.vector_stores import (
FilterCondition,
Expand All @@ -25,10 +26,12 @@
from kotaemon.base import RetrievedDocument
from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.indices.ingests import DocumentIngestor
from kotaemon.indices.rankings import BaseReranking, LLMReranking
from kotaemon.indices.rankings import BaseReranking

from .base import BaseFileIndexIndexing, BaseFileIndexRetriever

logger = logging.getLogger(__name__)


@lru_cache
def dev_settings():
Expand Down Expand Up @@ -67,7 +70,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
embedding=embeddings.get_default(),
)
reranker: BaseReranking = LLMReranking.withx(llm=llms.get_lowest_cost())
reranker: BaseReranking
get_extra_table: bool = False

def run(
Expand Down Expand Up @@ -153,7 +156,23 @@ def run(

@classmethod
def get_user_settings(cls) -> dict:
from ktem.components import llms

try:
reranking_llm = llms.get_lowest_cost_name()
reranking_llm_choices = list(llms.options().keys())
except Exception as e:
logger.error(e)
reranking_llm = None
reranking_llm_choices = []

return {
"reranking_llm": {
"name": "LLM for reranking",
"value": reranking_llm,
"component": "dropdown",
"choices": reranking_llm_choices,
},
"separate_embedding": {
"name": "Use separate embedding",
"value": False,
Expand Down Expand Up @@ -185,7 +204,7 @@ def get_user_settings(cls) -> dict:
},
"use_reranking": {
"name": "Use reranking",
"value": True,
"value": False,
"choices": [True, False],
"component": "checkbox",
},
Expand All @@ -199,7 +218,10 @@ def get_pipeline(cls, user_settings, index_settings, selected):
settings: the settings of the app
kwargs: other arguments
"""
retriever = cls(get_extra_table=user_settings["prioritize_table"])
retriever = cls(
get_extra_table=user_settings["prioritize_table"],
reranker=user_settings["reranking_llm"],
)
if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore

Expand Down
31 changes: 31 additions & 0 deletions libs/ktem/ktem/pages/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,28 @@ def on_building_ui(self):
self.reasoning_tab()

def on_subscribe_public_events(self):
"""
Subscribes to public events related to user management.

This function is responsible for subscribing to the "onSignIn" event, which is
triggered when a user signs in. It registers two event handlers for this event.

The first event handler, "load_setting", is responsible for loading the user's
settings when they sign in. It takes the user ID as input and returns the
settings state and a list of component outputs. The progress indicator for this
event is set to "hidden".

The second event handler, "get_name", is responsible for retrieving the
username of the current user. It takes the user ID as input and returns the
username if it exists, otherwise it returns "___". The progress indicator for
this event is also set to "hidden".

Parameters:
self (object): The instance of the class.

Returns:
None
"""
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
Expand Down Expand Up @@ -290,3 +312,12 @@ def components(self) -> list:
def component_names(self):
"""Get the setting components"""
return self._settings_keys

def _on_app_created(self):
if not self._app.f_user_management:
self._app.app.load(
self.load_setting,
inputs=self._user_id,
outputs=[self._settings_state] + self.components(),
show_progress="hidden",
)
18 changes: 14 additions & 4 deletions libs/ktem/ktem/reasoning/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class AnswerWithContextPipeline(BaseComponent):
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT

enable_citation: bool = False
system_prompt: str = ""
lang: str = "English" # support English and Japanese

Expand Down Expand Up @@ -200,7 +201,8 @@ async def run( # type: ignore
lang=self.lang,
)

if evidence:
citation_task = None
if evidence and self.enable_citation:
citation_task = asyncio.create_task(
self.citation_pipeline.ainvoke(context=evidence, question=question)
)
Expand All @@ -226,7 +228,7 @@ async def run( # type: ignore

# retrieve the citation
print("Waiting for citation task")
if evidence:
if citation_task is not None:
citation = await citation_task
else:
citation = None
Expand Down Expand Up @@ -353,7 +355,15 @@ def get_pipeline(cls, settings, retrievers):
_id = cls.get_info()["id"]

pipeline = FullQAPipeline(retrievers=retrievers)
pipeline.answering_pipeline.llm = llms.get_highest_accuracy()
pipeline.answering_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.main_llm"]
]
pipeline.answering_pipeline.citation_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.citation_llm"]
]
pipeline.answering_pipeline.enable_citation = settings[
f"reasoning.options.{_id}.highlight_citation"
]
pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
Expand Down Expand Up @@ -384,7 +394,7 @@ def get_user_settings(cls) -> dict:
return {
"highlight_citation": {
"name": "Highlight Citation",
"value": True,
"value": False,
"component": "checkbox",
},
"citation_llm": {
Expand Down
Loading