Skip to content

Commit

Permalink
Add mode change
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Sep 27, 2024
1 parent e868085 commit e43972e
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 44 deletions.
7 changes: 4 additions & 3 deletions src/main/askai/core/askai.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
configs.model = model_name

self._session_id = now("%Y%m%d")[:8]
self._engine: AIEngine = shared.create_engine(engine_name, model_name, RouterMode.CHAT)
self._engine: AIEngine = shared.create_engine(engine_name, model_name, RouterMode.default())
self._context: ChatContext = shared.create_context(self._engine.ai_token_limit())
self._mode: RouterMode = shared.mode
self._console_path = Path(f"{CACHE_DIR}/askai-{self.session_id}.md")
Expand Down Expand Up @@ -165,8 +165,8 @@ def ask_and_reply(self, question: str) -> tuple[bool, Optional[str]]:
elif not (output := cache.read_reply(question)):
log.debug('Response not found for "%s" in cache. Querying from %s.', question, self.engine.nickname())
events.reply.emit(reply=AIReply.detailed(msg.wait()))
output = processor.process(question, context=read_stdin(), query_prompt=self._query_prompt)
events.reply.emit(reply=AIReply.info(output or msg.no_output("processor")))
if output := processor.process(question, context=read_stdin(), query_prompt=self._query_prompt):
events.reply.emit(reply=AIReply.info(output or msg.no_output("processor")))
else:
log.debug("Reply found for '%s' in cache.", question)
events.reply.emit(reply=AIReply.info(output))
Expand All @@ -184,6 +184,7 @@ def ask_and_reply(self, question: str) -> tuple[bool, Optional[str]]:
events.reply.emit(reply=AIReply.error(msg.quote_exceeded()))
status = False
except TerminatingQuery:
self._reply(AIReply.info(msg.goodbye()))
status = False
finally:
if output:
Expand Down
15 changes: 7 additions & 8 deletions src/main/askai/core/askai_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from askai.core.askai_configs import configs
from askai.core.askai_events import *
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.commander.commander import commands
from askai.core.component.audio_player import player
from askai.core.component.cache_service import cache, CACHE_DIR
Expand Down Expand Up @@ -138,13 +139,11 @@ def _cb_mode_changed_event(self, ev: Event) -> None:
"""
self.mode: RouterMode = RouterMode.of_name(ev.args.mode)
if self.mode == RouterMode.QNA:
sum_msg: str = (
f"{msg.enter_qna()} \n"
f"```\nContext:  {ev.args.sum_path},  {ev.args.glob} \n```\n"
f"`{msg.press_esc_enter()}` \n\n"
f"> {msg.qna_welcome()}"
)
events.reply.emit(reply=AIReply.info(sum_msg))
welcome_msg = self.mode.welcome(sum_path=ev.args.sum_path, sum_glob=ev.args.glob)
else:
welcome_msg = self.mode.welcome()

events.reply.emit(reply=AIReply.info(welcome_msg or msg.welcome(prompt.user.title())))

def _cb_mic_listening_event(self, ev: Event) -> None:
"""Callback to handle microphone listening events.
Expand Down Expand Up @@ -210,7 +209,7 @@ def _startup(self) -> None:
askai_bus.subscribe(DEVICE_CHANGED_EVENT, self._cb_device_changed_event)
askai_bus.subscribe(MODE_CHANGED_EVENT, self._cb_mode_changed_event)
display_text(str(self), markdown=False)
self._reply(AIReply.info(msg.welcome(os.getenv("USER", "you"))))
self._reply(AIReply.info(self.mode.welcome()))
elif configs.is_speak:
recorder.setup()
player.start_delay()
Expand Down
13 changes: 11 additions & 2 deletions src/main/askai/core/askai_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,23 @@ def summary_succeeded(self, path: str, glob: str) -> str:
def enter_qna(self) -> str:
return "You have *entered* the **Summarization Q & A**"

def qna_welcome(self) -> str:
return " What specific information are you seeking about this content ?"

def enter_rag(self) -> str:
return "You have *entered* the **RAG Mode**"

def enter_chat(self) -> str:
return "Welcome back, Sir! Ready for more epic adventures?"

def leave_qna(self) -> str:
return "You have *left* the **Summarization Q & A**"

def leave_rag(self) -> str:
return "You have *left* the **RAG Mode**"

def qna_welcome(self) -> str:
return " What specific information are you seeking about this content ?"
def leave_chat(self) -> str:
return f"Bye, Sir! If you need anything else, **just let me rock**!"

def press_esc_enter(self) -> str:
return "Type [exit] to exit Q & A mode"
Expand Down
54 changes: 37 additions & 17 deletions src/main/askai/core/commander/commander.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@
Copyright (c) 2024, HomeSetup
"""
import os
import re
from functools import lru_cache
from os.path import dirname
from pathlib import Path
from string import Template
from textwrap import dedent

import click
from askai.core.askai_configs import configs
from askai.core.askai_events import ASKAI_BUS_NAME, AskAiEvents, REPLY_EVENT
from askai.core.askai_events import ASKAI_BUS_NAME, AskAiEvents, REPLY_EVENT, events
from askai.core.commander.commands.cache_cmd import CacheCmd
from askai.core.commander.commands.camera_cmd import CameraCmd
from askai.core.commander.commands.general_cmd import GeneralCmd
Expand All @@ -26,18 +35,9 @@
from askai.language.language import AnyLocale, Language
from click import Command, Group
from clitt.core.term.cursor import cursor
from functools import lru_cache
from hspylib.core.enums.charset import Charset
from hspylib.core.tools.commons import sysout, to_bool
from hspylib.modules.eventbus.event import Event
from os.path import dirname
from pathlib import Path
from string import Template
from textwrap import dedent

import click
import os
import re

COMMANDER_HELP_TPL = Template(
dedent(
Expand Down Expand Up @@ -207,7 +207,7 @@ def context(operation: str, name: str | None = None) -> None:
:param operation: The operation to perform on contexts. Options: [list | forget].
:param name: The name of the context to target (default is "ALL").
"""
match operation:
match operation.casefold():
case "list":
HistoryCmd.context_list()
case "forget":
Expand All @@ -223,7 +223,7 @@ def history(operation: str) -> None:
"""Manages the current input history.
:param operation: The operation to perform on contexts. Options: [list|forget].
"""
match operation:
match operation.casefold():
case "list":
HistoryCmd.history_list()
case "forget":
Expand All @@ -250,7 +250,7 @@ def devices(operation: str, name: str | None = None) -> None:
:param operation: Specifies the device operation. Options: [list|set].
:param name: The target device name for setting.
"""
match operation:
match operation.casefold():
case "list":
TtsSttCmd.device_list()
case "set":
Expand All @@ -270,7 +270,7 @@ def settings(operation: str, name: str | None = None, value: str | None = None)
:param name: The key for the setting to modify.
:param value: The new value for the specified setting.
"""
match operation:
match operation.casefold():
case "list":
SettingsCmd.list(name)
case "get":
Expand All @@ -292,7 +292,7 @@ def cache(operation: str, args: tuple[str, ...]) -> None:
:param operation: Specifies the cache operation. Options: [list|get|clear|files|enable|ttl]
:param args: Arguments relevant to the chosen operation.
"""
match operation:
match operation.casefold():
case "list":
CacheCmd.list()
case "get":
Expand Down Expand Up @@ -349,7 +349,7 @@ def voices(operation: str, name: str | int | None = None) -> None:
:param operation: The action to perform on voices. Options: [list/set/play]
:param name: The voice name.
"""
match operation:
match operation.casefold():
case "list":
TtsSttCmd.voice_list()
case "set":
Expand Down Expand Up @@ -426,7 +426,7 @@ def camera(operation: str, args: tuple[str, ...]) -> None:
:param operation: The camera operation to perform. Options: [capture|identify|import]
:param args: The arguments required for the operation.
"""
match operation:
match operation.casefold():
case "capture" | "photo":
CameraCmd.capture(*args)
case "identify" | "id":
Expand All @@ -436,3 +436,23 @@ def camera(operation: str, args: tuple[str, ...]) -> None:
case _:
err = str(click.BadParameter(f"Invalid camera operation: '{operation}'"))
text_formatter.commander_print(f"%RED%{err}%NC%")


@ask_commander.command()
@click.argument("router_mode", default="")
def mode(router_mode: str) -> None:
"""Change the AskAI routing mode.
:param router_mode: The routing mode. Options: [rag|chat|splitter|qstring]
"""
if not router_mode:
text_formatter.commander_print(f"Available routing modes: **[rag|chat|splitter]**. Current: `{shared.mode}`")
else:
match router_mode.casefold():
case "rag":
events.mode_changed.emit(mode="RAG")
case "chat":
events.mode_changed.emit(mode="CHAT")
case "splitter":
events.mode_changed.emit(mode="TASK_SPLIT")
case _:
events.mode_changed.emit(mode="DEFAULT")
28 changes: 28 additions & 0 deletions src/main/askai/core/enums/router_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
Copyright (c) 2024, HomeSetup
"""
import os
from functools import lru_cache

from askai.core.askai_configs import configs
from askai.core.askai_messages import msg
from askai.core.features.processors.ai_processor import AIProcessor
from askai.core.features.processors.qna import qna
from askai.core.features.processors.qstring import qstring
Expand All @@ -22,6 +26,8 @@
from hspylib.core.enums.enumeration import Enumeration
from typing import Optional

from hspylib.core.tools.dict_tools import get_or_default_by_key


class RouterMode(Enumeration):
"""Enumeration of available router modes used to determine the type of response provided to the user. This class
Expand Down Expand Up @@ -50,6 +56,7 @@ def modes(cls) -> list[str]:
return RouterMode.names()

@staticmethod
@lru_cache
def default() -> "RouterMode":
"""Return the default routing mode.
:return: The default RouterMode instance.
Expand Down Expand Up @@ -87,6 +94,27 @@ def processor(self) -> AIProcessor:
def is_default(self) -> bool:
return self == RouterMode.default()

def welcome(self, **kwargs) -> Optional[str]:
"""TODO"""
match self:
case RouterMode.QNA:
sum_path: str = get_or_default_by_key(kwargs, "sum_path", None)
sum_glob: str = get_or_default_by_key(kwargs, "sum_glob", None)
welcome_msg = msg.t(
f"{msg.enter_qna()} \n"
f"```\nContext:  {sum_path},  {sum_glob} \n```\n"
f"`{msg.press_esc_enter()}` \n\n"
f"> {msg.qna_welcome()}"
)
case RouterMode.RAG:
welcome_msg = msg.enter_rag()
case RouterMode.CHAT:
welcome_msg = msg.enter_chat()
case _:
welcome_msg = msg.welcome(os.getenv("USER", "user"))

return welcome_msg

def process(self, question: str, **kwargs) -> Optional[str]:
"""Invoke the processor associated with the current mode to handle the given question.
:param question: The question to be processed.
Expand Down
11 changes: 11 additions & 0 deletions src/main/askai/core/features/processors/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
"""
from typing import Optional, Any

from askai.core.askai_events import events
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.engine.openai.temperature import Temperature
from askai.core.model.ai_reply import AIReply
from askai.core.support.langchain_support import lc_llm
from askai.core.support.shared_instances import shared
from askai.exception.exceptions import TerminatingQuery
from hspylib.core.config.path_object import PathObject
from hspylib.core.metaclass.singleton import Singleton
from hspylib.core.tools.dict_tools import get_or_default_by_key
Expand Down Expand Up @@ -49,6 +53,13 @@ def process(self, question: str, **kwargs) -> Optional[str]:
:return: The final response after processing the question.
"""

if not question:
raise TerminatingQuery("The user wants to exit!")
if question.casefold() in ["exit", "leave", "quit", "q"]:
events.reply.emit(reply=AIReply.info(msg.leave_chat()))
events.mode_changed.emit(mode="DEFAULT")
return None

response = None
prompt_file: str = get_or_default_by_key(kwargs, "prompt_file", None)
history_ctx: Any | None = get_or_default_by_key(kwargs, "history_ctx", "HISTORY")
Expand Down
12 changes: 7 additions & 5 deletions src/main/askai/core/features/processors/qna.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from askai.core.askai_events import events
from askai.core.askai_messages import msg
from askai.core.component.summarizer import summarizer
from askai.core.model.ai_reply import AIReply
from askai.core.model.summary_result import SummaryResult
from hspylib.core.metaclass.singleton import Singleton
from hspylib.core.preconditions import check_state
Expand All @@ -16,12 +17,13 @@ def process(self, question: str, **_) -> Optional[str]:
"""Process the user question against a summarized context to retrieve answers.
:param question: The user question to process.
"""
if question.casefold() == "exit" or not (response := summarizer.query(question)):
if question.casefold() in ["exit", "leave", "quit", "q"] or not (response := summarizer.query(question)):
events.reply.emit(reply=AIReply.info(msg.leave_qna()))
events.mode_changed.emit(mode="DEFAULT")
output = msg.leave_qna()
else:
check_state(isinstance(response[0], SummaryResult))
output = response[0].answer
return None

check_state(isinstance(response[0], SummaryResult))
output = response[0].answer

return output

Expand Down
19 changes: 13 additions & 6 deletions src/main/askai/core/features/processors/rag.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging as log
import os
from pathlib import Path
from typing import Optional

from askai.core.askai_configs import configs
from askai.core.askai_events import events
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.component.cache_service import RAG_DIR, PERSIST_DIR
from askai.core.engine.openai.temperature import Temperature
from askai.core.model.ai_reply import AIReply
from askai.core.support.langchain_support import lc_llm
from askai.core.support.spinner import Spinner
from askai.exception.exceptions import DocumentsNotFound
from askai.exception.exceptions import DocumentsNotFound, TerminatingQuery
from hspylib.core.config.path_object import PathObject
from hspylib.core.metaclass.classpath import AnyPath
from hspylib.core.metaclass.singleton import Singleton
Expand Down Expand Up @@ -50,20 +52,25 @@ def rag_template(self) -> BasePromptTemplate:
# fmt: on

def persist_dir(self, file_glob: AnyPath) -> Path:
"""TODO"""
summary_hash = hash_text(file_glob)
return Path(f"{PERSIST_DIR}/{summary_hash}")

def process(self, question: str, **_) -> str:
def process(self, question: str, **_) -> Optional[str]:
"""Process the user question to retrieve the final response.
:param question: The user question to process.
:return: The final response after processing the question.
"""
if not question:
raise TerminatingQuery("The user wants to exit!")
if question.casefold() in ["exit", "leave", "quit", "q"]:
events.reply.emit(reply=AIReply.info(msg.leave_rag()))
events.mode_changed.emit(mode="DEFAULT")
return None

self.generate()

if question.casefold() == "exit":
events.mode_changed.emit(mode="DEFAULT")
output = msg.leave_rag()
elif not (output := self._rag_chain.invoke(question)):
if not (output := self._rag_chain.invoke(question)):
output = msg.invalid_response(output)

return output
Expand Down
3 changes: 3 additions & 0 deletions src/main/askai/core/features/processors/task_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def process(self, question: str, **_) -> Optional[str]:
:param question: The user question to process.
"""

if not question or question.casefold() in ["exit", "leave", "quit", "q"]:
raise TerminatingQuery("The user wants to exit!")

os.chdir(Path.home())
shared.context.forget("EVALUATION") # Erase previous evaluation notes.
model: ModelResult = ModelResult.default() # Hard-coding the result model for now.
Expand Down
2 changes: 1 addition & 1 deletion src/main/askai/core/support/shared_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def app_info(self) -> str:
f"{dtm.center(80, '=')} %EOL%"
f" Language: {configs.language} {translator} %EOL%"
f" Engine: {shared.engine} %EOL%"
f" Mode: {self.mode} %EOL%"
f" Mode: %CYAN%{self.mode}%GREEN% %EOL%"
f" Dir: {cur_dir} %EOL%"
f" OS: {prompt.os_type}/{prompt.shell} %EOL%"
f"{'-' * 80} %EOL%"
Expand Down
Loading

0 comments on commit e43972e

Please sign in to comment.