Skip to content

Commit

Permalink
Response and agent fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Oct 23, 2024
1 parent 7f143df commit 8140951
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 49 deletions.
21 changes: 7 additions & 14 deletions src/main/askai/__classpath__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,30 @@
Copyright (c) 2024, HomeSetup
"""
from hspylib.modules.application.exit_status import ExitStatus

from askai.core.model.api_keys import ApiKeys
from clitt.core.term.commons import is_a_tty
from hspylib.core.metaclass.classpath import Classpath
from hspylib.core.tools.commons import is_debugging, parent_path, root_dir

import logging as log
import os
import pydantic
import sys
import warnings

import pydantic
from hspylib.core.metaclass.classpath import Classpath
from hspylib.core.tools.commons import is_debugging, parent_path, root_dir
from hspylib.modules.application.exit_status import ExitStatus

from askai.core.model.api_keys import ApiKeys

if not is_debugging():
warnings.simplefilter("ignore", category=FutureWarning)
warnings.simplefilter("ignore", category=UserWarning)
warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.simplefilter("ignore", category=UserWarning)
warnings.filterwarnings("ignore", module="chromadb.db.impl.sqlite")

if not is_a_tty():
log.getLogger().setLevel(log.ERROR)
else:
log.getLogger().setLevel(log.INFO)

if not os.environ.get("USER_AGENT"):
# The AskAI User Agent, required by the langchain framework
ASKAI_USER_AGENT: str = "AskAI-User-Agent"
os.environ["USER_AGENT"] = ASKAI_USER_AGENT


try:
API_KEYS: ApiKeys = ApiKeys()
except pydantic.v1.error_wrappers.ValidationError as err:
Expand Down
33 changes: 23 additions & 10 deletions src/main/askai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
Copyright (c) 2024, HomeSetup
"""
from askai.__classpath__ import classpath
from askai.core.askai_configs import configs
from askai.core.enums.run_modes import RunModes

import logging as log
import os
import re
import sys
from textwrap import dedent
from typing import Any, AnyStr, Optional

import click
from clitt.core.tui.tui_application import TUIApplication
from hspylib.core.enums.charset import Charset
from hspylib.core.tools.commons import syserr, to_bool
Expand All @@ -23,14 +29,11 @@
from hspylib.modules.application.argparse.parser_action import ParserAction
from hspylib.modules.application.exit_status import ExitStatus
from hspylib.modules.application.version import Version
from textwrap import dedent
from typing import Any, AnyStr, Optional

import click
import logging as log
import os
import re
import sys
from askai.__classpath__ import classpath
from askai.core.askai_configs import configs
from askai.core.enums.run_modes import RunModes
from askai.core.support.shared_instances import LOGGER_NAME


class Main(TUIApplication):
Expand All @@ -47,9 +50,19 @@ class Main(TUIApplication):

INSTANCE: "Main"

@staticmethod
def setup_logs() -> None:
"""TODO"""
# FIXME: Move this code to hspylib Application FW
log.basicConfig(level=log.WARNING)
logger = log.getLogger(LOGGER_NAME)
logger.setLevel(log.INFO)
logger.propagate = False

def __init__(self, app_name: str):
super().__init__(app_name, self.VERSION, self.DESCRIPTION.format(self.VERSION), resource_dir=self.RESOURCE_DIR)
self._askai: Any
Main.setup_logs()

@property
def askai(self) -> Any:
Expand Down
10 changes: 5 additions & 5 deletions src/main/askai/core/commander/commands/history_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ def context_list() -> None:
"""List the entries in the chat context window."""

if (all_context := shared.context) and (length := len(all_context)) > 0:
ln: str = os.linesep
display_text(f"### Listing ALL ({length}) Chat Contexts:\n\n---\n\n")
for c in all_context:
ctx, ctx_val = c[0], c[1]
display_text(
f"- {ctx} ({len(ctx_val)}/{all_context.max_context_size} "
f"tk [{all_context.length(ctx)}/{all_context.token_limit}]) \n"
+ indent(
"\n".join(
[
f'{i}. **{e.role.title()}:**\n\n{indent(e.content, " " * 4)}' + os.linesep
for i, e in enumerate(ctx_val, start=1)
]
ln.join([
f'{i}. **{e.role.title()}:**\n\n{indent(text_formatter.strip_format(e.content), " " * 4)}'
+ os.linesep
for i, e in enumerate(ctx_val, start=1)]
),
" " * 4,
)
Expand Down
5 changes: 3 additions & 2 deletions src/main/askai/core/processors/splitter/splitter_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def display(self, message: str) -> None:
if configs.is_debug:
self._console.print(message)

def run(self):
with self._console.status(msg.wait(), spinner="dots") as spinner:
def run(self) -> None:
with self._console.status(f"[green]{msg.wait()}[/green]", spinner="dots") as spinner:
while not self.pipeline.state == States.COMPLETE:
self.pipeline.track_previous()
if 1 < configs.max_router_retries < 1 + self.pipeline.failures[self.pipeline.state.value]:
Expand Down Expand Up @@ -89,6 +89,7 @@ def run(self):
case _:
self.display(f"[red] Error: Machine halted before complete!({self.pipeline.state})[/red]")
break

execution_status: bool = self.pipeline.previous != self.pipeline.state
execution_status_str: str = (
f"{'[green]√[/green]' if execution_status else '[red]X[/red]'}"
Expand Down
43 changes: 28 additions & 15 deletions src/main/askai/core/processors/splitter/splitter_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
"""
import logging as log
from collections import defaultdict
from typing import AnyStr
from typing import AnyStr, Optional

from hspylib.core.preconditions import check_state
from hspylib.core.tools.validator import Validator
from langchain_core.prompts import PromptTemplate
from transitions import Machine

Expand All @@ -31,7 +32,7 @@
from askai.core.processors.splitter.splitter_states import States
from askai.core.processors.splitter.splitter_transitions import Transition, TRANSITIONS
from askai.core.router.evaluation import eval_response, EVALUATION_GUIDE
from askai.core.support.shared_instances import shared
from askai.core.support.shared_instances import shared, LOGGER_NAME


class SplitterPipeline:
Expand All @@ -44,7 +45,7 @@ class SplitterPipeline:
def __init__(self, question: AnyStr):
self._transitions: list[Transition] = [t for t in TRANSITIONS]
self._machine: Machine = Machine(
name="Taius-Coder", model=self,
name=LOGGER_NAME, model=self,
initial=States.STARTUP, states=States, transitions=self._transitions,
auto_transitions=False
)
Expand Down Expand Up @@ -82,28 +83,28 @@ def question(self) -> str:
return self.result.question

@property
def last_query(self) -> str:
return self.responses[-1].query
def last_query(self) -> Optional[str]:
return self.responses[-1].query if self.responses else None

@last_query.setter
def last_query(self, value: str) -> None:
self.responses[-1].query = value
self.responses[-1].query = value if self.responses else None

@property
def last_answer(self) -> str:
return self.responses[-1].answer
def last_answer(self) -> Optional[str]:
return self.responses[-1].answer if self.responses else None

@last_answer.setter
def last_answer(self, value: str) -> None:
self.responses[-1].answer = value
self.responses[-1].answer = value if self.responses else None

@property
def last_accuracy(self) -> AccResponse:
return self.responses[-1].accuracy
def last_accuracy(self) -> Optional[AccResponse]:
return self.responses[-1].accuracy if self.responses else None

@last_accuracy.setter
def last_accuracy(self, value: AccResponse) -> None:
self.responses[-1].accuracy = value
self.responses[-1].accuracy = value if self.responses else None

@property
def plan(self) -> ActionPlan:
Expand Down Expand Up @@ -173,7 +174,7 @@ def st_execute_task(self) -> bool:
def st_accuracy_check(self) -> AccColor:
"""TODO"""

if self.last_query is None or self.last_answer is None:
if not Validator.has_no_nulls(self.last_query, self.last_answer):
return AccColor.BAD

# FIXME Hardcoded for now
Expand Down Expand Up @@ -204,7 +205,19 @@ def st_accuracy_check(self) -> AccColor:
return acc.acc_color

def st_refine_answer(self) -> bool:
return actions.refine_answer(self.question, self.final_answer, self.last_accuracy)
"""TODO"""
if refined := actions.refine_answer(self.question, self.final_answer, self.last_accuracy):
final_response: PipelineResponse = PipelineResponse(self.question, refined, self.last_accuracy)
self.responses.clear()
self.responses.append(final_response)
return True
return False

def st_final_answer(self) -> bool:
return actions.wrap_answer(self.question, self.final_answer, self.model)
"""TODO"""
if wrapped := actions.wrap_answer(self.question, self.final_answer, self.model):
final_response: PipelineResponse = PipelineResponse(self.question, wrapped, self.last_accuracy)
self.responses.clear()
self.responses.append(final_response)
return True
return False
4 changes: 2 additions & 2 deletions src/main/askai/core/processors/splitter/splitter_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class SplitterResult:
plan: ActionPlan | None = None
model: ModelResult | None = None

def final_response(self) -> str:
def final_response(self, acc_threshold: AccColor = AccColor.MODERATE) -> str:
"""TODO"""
return os.linesep.join(
list(map(lambda r: r.answer, filter(
lambda acc: acc.accuracy and acc.accuracy.acc_color.passed(AccColor.MODERATE), self.responses)))
lambda acc: acc.accuracy and acc.accuracy.acc_color.passed(acc_threshold), self.responses)))
)
4 changes: 3 additions & 1 deletion src/main/askai/core/router/task_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _create_lc_agent(self, temperature: Temperature = Temperature.COLDEST) -> Ru
tools=tools,
max_iterations=configs.max_agent_retries,
memory=chat_memory,
handle_parsing_errors=True,
handle_parsing_errors="Generate a JSON blob that is fully parseable using the Python `json` module.",
max_execution_time=configs.max_agent_execution_time_seconds,
verbose=configs.is_debug,
)
Expand All @@ -105,6 +105,8 @@ def _exec_task(self, task: AnyStr) -> Optional[Output]:
return lc_agent.invoke({"input": task})
except openai.APIError as err:
log.error(str(err))
except ValueError as err:
log.error(str(err))

return None

Expand Down
2 changes: 2 additions & 0 deletions src/main/askai/core/support/shared_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@

import os

LOGGER_NAME: str = 'Askai-Taius'


class SharedInstances(metaclass=Singleton):
"""Provides access to shared instances."""
Expand Down

0 comments on commit 8140951

Please sign in to comment.