Skip to content

Commit

Permalink
Bugfixes and stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Sep 28, 2024
1 parent 7153aaa commit 45de799
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 24 deletions.
27 changes: 17 additions & 10 deletions src/main/askai/core/askai_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
Copyright (c) 2024, HomeSetup
"""
import logging as log
import os
from pathlib import Path
from threading import Thread
from typing import List, Optional, TypeAlias

import nltk
import pause
from askai.core.askai import AskAi
from askai.core.askai_configs import configs
from askai.core.askai_events import *
Expand All @@ -27,19 +35,14 @@
from askai.core.support.shared_instances import shared
from askai.core.support.text_formatter import text_formatter
from askai.core.support.utilities import display_text
from askai.tui.app_icons import AppIcons
from clitt.core.term.cursor import cursor
from clitt.core.term.screen import screen
from clitt.core.tui.line_input.keyboard_input import KeyboardInput
from hspylib.core.enums.charset import Charset
from hspylib.core.zoned_datetime import now, TIME_FORMAT
from hspylib.modules.eventbus.event import Event
from pathlib import Path
from rich.progress import Progress
from threading import Thread
from typing import List, Optional, TypeAlias

import logging as log
import nltk
import os
import pause

QueryString: TypeAlias = str | List[str] | None

Expand Down Expand Up @@ -79,7 +82,7 @@ def run(self) -> None:
elif output:
cache.save_reply(question, output)
cache.save_input_history()
with open(self._console_path, "a+") as f_console:
with open(self.console_path, "a+", encoding=Charset.UTF_8.val) as f_console:
f_console.write(f"{shared.username_md}{question}\n\n")
f_console.write(f"{shared.nickname_md}{output}\n\n")
f_console.flush()
Expand Down Expand Up @@ -209,8 +212,12 @@ def _startup(self) -> None:
askai_bus.subscribe(DEVICE_CHANGED_EVENT, self._cb_device_changed_event)
askai_bus.subscribe(MODE_CHANGED_EVENT, self._cb_mode_changed_event)
display_text(str(self), markdown=False)
self._reply(AIReply.info(self.mode.welcome()))
elif configs.is_speak:
recorder.setup()
player.start_delay()
# Register the startup
with open(self.console_path, "a+", encoding=Charset.UTF_8.val) as f_console:
f_console.write(f"\n\n## {AppIcons.STARTED} {now(TIME_FORMAT)}\n\n")
f_console.flush()
self._reply(AIReply.info(self.mode.welcome()))
log.info("AskAI is ready to use!")
6 changes: 3 additions & 3 deletions src/main/askai/core/component/internet_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import List

import bs4
import openai
from askai.__classpath__ import API_KEYS
from askai.core.askai_configs import configs
from askai.core.askai_events import events
Expand All @@ -46,6 +45,7 @@
from langchain_core.tools import Tool
from langchain_google_community import GoogleSearchAPIWrapper
from langchain_text_splitters import RecursiveCharacterTextSplitter
from openai import APIError


class InternetService(metaclass=Singleton):
Expand Down Expand Up @@ -110,7 +110,7 @@ def wrap_response(cls, terms: str, output: str, result: SearchResult) -> str:
return (
f"Your {result.engine.title()} search has returned the following results:"
f"\n\n{output}\n\n---\n\n"
f"`{cls.CATEGORY_ICONS[result.category]:<2} {result.category}` **Sources:** {sources} "
f"`{cls.CATEGORY_ICONS[result.category]:<2}{result.category}` **Sources:** {sources} "
f"**Access:** {geo_location.location} - *{now('%B %d, %Y')}*\n\n"
f">  Terms: {terms}")
# fmt: on
Expand Down Expand Up @@ -181,7 +181,7 @@ def google_search(self, search: SearchResult) -> str:
lc_llm.create_chat_model(temperature=Temperature.COLDEST.temp), llm_prompt
)
output = chain.invoke({"question": question, "context": docs})
except (HttpError, openai.APIError) as err:
except (HttpError, APIError) as err:
return msg.fail_to_search(str(err))

return self.refine_search(terms, output, search)
Expand Down
18 changes: 11 additions & 7 deletions src/main/askai/core/features/processors/task_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,11 @@ def process(self, question: str, **_) -> Optional[str]:
shared.context.forget("EVALUATION") # Erase previous evaluation notes.
model: ModelResult = ModelResult.default() # Hard-coding the result model for now.
log.info("Router::[QUESTION] '%s'", question)
retries: int = 0

@retry(exceptions=self.RETRIABLE_ERRORS, tries=configs.max_router_retries, backoff=1, jitter=1)
def _splitter_wrapper_() -> Optional[str]:
def _splitter_wrapper_(retry_count: int) -> Optional[str]:
retry_count += 1
# Invoke the LLM to split the tasks and create an action plan.
runnable = self.template(question) | lc_llm.create_chat_model(Temperature.COLDEST.temp)
runnable = RunnableWithMessageHistory(
Expand All @@ -184,22 +186,23 @@ def _splitter_wrapper_() -> Optional[str]:
return response # Most of the times, indicates a failure.

try:
agent_output = self._process_tasks(task_list)
agent_output: str | None = self._process_tasks(task_list, retries)
acc_response: AccResponse = assert_accuracy(question, agent_output, AccColor.MODERATE)
except InterruptionRequest as err:
return str(err)
except self.RETRIABLE_ERRORS:
events.reply.emit(reply=AIReply.error(msg.sorry_retry()))
if retry_count <= 1:
events.reply.emit(reply=AIReply.error(msg.sorry_retry()))
raise

return self.wrap_answer(question, agent_output, plan.model, acc_response)

return _splitter_wrapper_()
return _splitter_wrapper_(retries)

@retry(exceptions=RETRIABLE_ERRORS, tries=configs.max_router_retries, backoff=1, jitter=1)
def _process_tasks(self, task_list: list[SimpleNamespace]) -> Optional[str]:
def _process_tasks(self, task_list: list[SimpleNamespace], retry_count: int) -> Optional[str]:
"""Wrapper to allow accuracy retries."""

retry_count += 1
resp_history: list[str] = list()

if not task_list:
Expand All @@ -222,7 +225,8 @@ def _process_tasks(self, task_list: list[SimpleNamespace]) -> Optional[str]:
except (InterruptionRequest, TerminatingQuery) as err:
return str(err)
except self.RETRIABLE_ERRORS:
events.reply.emit(reply=AIReply.error(msg.sorry_retry()))
if retry_count <= 1:
events.reply.emit(reply=AIReply.error(msg.sorry_retry()))
raise

return os.linesep.join(resp_history) if resp_history else msg.no_output("Task-Splitter")
Expand Down
12 changes: 10 additions & 2 deletions src/main/askai/core/features/router/task_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from askai.core.model.ai_reply import AIReply
from askai.core.support.langchain_support import lc_llm
from askai.core.support.shared_instances import shared
from askai.exception.exceptions import InaccurateResponse
from hspylib.core.config.path_object import PathObject
from hspylib.core.metaclass.singleton import Singleton
from langchain.agents import AgentExecutor, create_structured_chat_agent
from langchain.memory.chat_memory import BaseChatMemory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable
from langchain_core.runnables.utils import Output
from openai import APIError


class TaskAgent(metaclass=Singleton):
Expand Down Expand Up @@ -94,8 +96,14 @@ def _exec_task(self, task: AnyStr) -> Optional[Output]:
:return: An instance of Output containing the result of the task, or None if the task fails or produces
no output.
"""
lc_agent: Runnable = self._create_lc_agent()
return lc_agent.invoke({"input": task})
output: str | None = None
try:
lc_agent: Runnable = self._create_lc_agent()
output = lc_agent.invoke({"input": task})
except APIError as err:
raise InaccurateResponse(str(err))

return output


assert (agent := TaskAgent().INSTANCE) is not None
4 changes: 2 additions & 2 deletions src/main/askai/core/support/text_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def beautify(self, text: Any) -> str:
text = re.sub(self.RE_TYPES[''], r" [\1](\1)", text)
text = re.sub(self.RE_TYPES['MD'], r"\n\1\n", text)
text = re.sub(r'```(.+)```\s+', r"\n```\1```\n", text)
text = re.sub(rf"\s+{os.getenv('USER', 'user')}", f'` {os.getenv("USER", "user")}`', text)
text = re.sub(r"\s+Taius", f' **Taius**', text)
text = re.sub(rf"(\s+)({os.getenv('USER', 'user')})", r'\1*\2*', text)
text = re.sub(r"(\s+)([Tt]aius)", r'\1**\2**', text)

# fmt: on

Expand Down
1 change: 1 addition & 0 deletions src/main/askai/resources/prompts/search-builder.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Your task is to respond to a user query following the steps below. You MUST foll

8. **Map Inquiries:** For map-related inquiries, add the filter: 'map:"<location>"' to your list.

9. **General Search:** For broad inquiries or searches where the nature of the query cannot be determined, avoid using restrictive filters. Instead, rely on general search engines such as "google.com", "bing.com", "duckduckgo.com", and "ask.com."

The response should follow this format:

Expand Down

0 comments on commit 45de799

Please sign in to comment.