From f9cc40ca25878d780752b27780efad127d85163d Mon Sep 17 00:00:00 2001 From: ian Date: Thu, 28 Mar 2024 16:35:13 +0700 Subject: [PATCH 1/2] improve llm selection for simple reasoning pipeline --- .../kotaemon/contribs/promptui/tunnel.py | 3 +- .../kotaemon/llms/chats/langchain_based.py | 1 - libs/ktem/ktem/components.py | 3 +- libs/ktem/ktem/index/file/pipelines.py | 32 ++++++++++++++++--- libs/ktem/ktem/reasoning/simple.py | 18 ++++++++--- 5 files changed, 44 insertions(+), 13 deletions(-) diff --git a/libs/kotaemon/kotaemon/contribs/promptui/tunnel.py b/libs/kotaemon/kotaemon/contribs/promptui/tunnel.py index 897a438e8..711585da0 100644 --- a/libs/kotaemon/kotaemon/contribs/promptui/tunnel.py +++ b/libs/kotaemon/kotaemon/contribs/promptui/tunnel.py @@ -17,8 +17,7 @@ BINARY_REMOTE_NAME = f"frpc_{platform.system().lower()}_{machine.lower()}" EXTENSION = ".exe" if os.name == "nt" else "" BINARY_URL = ( - "some-endpoint.com" - f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}" + "some-endpoint.com" f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}" ) BINARY_FILENAME = f"{BINARY_REMOTE_NAME}_v{VERSION}" diff --git a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py index 6c87c720b..526eaf868 100644 --- a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py +++ b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py @@ -194,7 +194,6 @@ def _get_lc_class(self): class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore - def __init__( self, azure_endpoint: str | None = None, diff --git a/libs/ktem/ktem/components.py b/libs/ktem/ktem/components.py index e8343267f..6cfb2e345 100644 --- a/libs/ktem/ktem/components.py +++ b/libs/ktem/ktem/components.py @@ -1,4 +1,5 @@ """Common components, some kind of config""" + import logging from functools import cache from pathlib import Path @@ -71,7 +72,7 @@ def settings(self) -> dict: } def options(self) -> dict: - """Present a list of models""" + """Present a dict of models""" return self._models def get_random_name(self) -> str: diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index d15d2fb49..1d813f508 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import shutil import warnings from collections import defaultdict @@ -8,7 +9,7 @@ from pathlib import Path from typing import Optional -from ktem.components import embeddings, filestorage_path, llms +from ktem.components import embeddings, filestorage_path from ktem.db.models import engine from llama_index.vector_stores import ( FilterCondition, @@ -25,10 +26,12 @@ from kotaemon.base import RetrievedDocument from kotaemon.indices import VectorIndexing, VectorRetrieval from kotaemon.indices.ingests import DocumentIngestor -from kotaemon.indices.rankings import BaseReranking, LLMReranking +from kotaemon.indices.rankings import BaseReranking from .base import BaseFileIndexIndexing, BaseFileIndexRetriever +logger = logging.getLogger(__name__) + @lru_cache def dev_settings(): @@ -67,7 +70,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): vector_retrieval: VectorRetrieval = VectorRetrieval.withx( embedding=embeddings.get_default(), ) - reranker: BaseReranking = LLMReranking.withx(llm=llms.get_lowest_cost()) + reranker: BaseReranking get_extra_table: bool = False def run( @@ -153,7 +156,23 @@ def run( @classmethod def get_user_settings(cls) -> dict: + from ktem.components import llms + + try: + reranking_llm = llms.get_lowest_cost_name() + reranking_llm_choices = list(llms.options().keys()) + except Exception as e: + logger.error(e) + reranking_llm = None + reranking_llm_choices = [] + return { + "reranking_llm": { + "name": "LLM for reranking", + "value": reranking_llm, + "component": "dropdown", + "choices": reranking_llm_choices, + }, "separate_embedding": { "name": "Use separate embedding", "value": False, @@ -185,7 +204,7 @@ def get_user_settings(cls) -> dict: }, "use_reranking": { "name": "Use reranking", - "value": True, + "value": False, "choices": [True, False], "component": "checkbox", }, @@ -199,7 +218,10 @@ def get_pipeline(cls, user_settings, index_settings, selected): settings: the settings of the app kwargs: other arguments """ - retriever = cls(get_extra_table=user_settings["prioritize_table"]) + retriever = cls( + get_extra_table=user_settings["prioritize_table"], + reranker=user_settings["reranking_llm"], + ) if not user_settings["use_reranking"]: retriever.reranker = None # type: ignore diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 9bcc3c4a9..47f7c92d5 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -159,6 +159,7 @@ class AnswerWithContextPipeline(BaseComponent): qa_table_template: str = DEFAULT_QA_TABLE_PROMPT qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT + enable_citation: bool = False system_prompt: str = "" lang: str = "English" # support English and Japanese @@ -200,7 +201,8 @@ async def run( # type: ignore lang=self.lang, ) - if evidence: + citation_task = None + if evidence and self.enable_citation: citation_task = asyncio.create_task( self.citation_pipeline.ainvoke(context=evidence, question=question) ) @@ -226,7 +228,7 @@ async def run( # type: ignore # retrieve the citation print("Waiting for citation task") - if evidence: + if citation_task is not None: citation = await citation_task else: citation = None @@ -353,7 +355,15 @@ def get_pipeline(cls, settings, retrievers): _id = cls.get_info()["id"] pipeline = FullQAPipeline(retrievers=retrievers) - pipeline.answering_pipeline.llm = llms.get_highest_accuracy() + pipeline.answering_pipeline.llm = llms[ + settings[f"reasoning.options.{_id}.main_llm"] + ] + pipeline.answering_pipeline.citation_pipeline.llm = llms[ + settings[f"reasoning.options.{_id}.citation_llm"] + ] + pipeline.answering_pipeline.enable_citation = settings[ + f"reasoning.options.{_id}.highlight_citation" + ] pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get( settings["reasoning.lang"], "English" ) @@ -384,7 +394,7 @@ def get_user_settings(cls) -> dict: return { "highlight_citation": { "name": "Highlight Citation", - "value": True, + "value": False, "component": "checkbox", }, "citation_llm": { From 14482e9755cd5ae0c816cd006feda0cd47871e97 Mon Sep 17 00:00:00 2001 From: ian Date: Thu, 28 Mar 2024 16:36:05 +0700 Subject: [PATCH 2/2] bug fix: settings are not persistent --- libs/ktem/ktem/pages/settings.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/libs/ktem/ktem/pages/settings.py b/libs/ktem/ktem/pages/settings.py index c396e3386..0fce2e8f7 100644 --- a/libs/ktem/ktem/pages/settings.py +++ b/libs/ktem/ktem/pages/settings.py @@ -87,6 +87,28 @@ def on_building_ui(self): self.reasoning_tab() def on_subscribe_public_events(self): + """ + Subscribes to public events related to user management. + + This function is responsible for subscribing to the "onSignIn" event, which is + triggered when a user signs in. It registers two event handlers for this event. + + The first event handler, "load_setting", is responsible for loading the user's + settings when they sign in. It takes the user ID as input and returns the + settings state and a list of component outputs. The progress indicator for this + event is set to "hidden". + + The second event handler, "get_name", is responsible for retrieving the + username of the current user. It takes the user ID as input and returns the + username if it exists, otherwise it returns "___". The progress indicator for + this event is also set to "hidden". + + Parameters: + self (object): The instance of the class. + + Returns: + None + """ if self._app.f_user_management: self._app.subscribe_event( name="onSignIn", @@ -290,3 +312,12 @@ def components(self) -> list: def component_names(self): """Get the setting components""" return self._settings_keys + + def _on_app_created(self): + if not self._app.f_user_management: + self._app.app.load( + self.load_setting, + inputs=self._user_id, + outputs=[self._settings_state] + self.components(), + show_progress="hidden", + )