Skip to content

Commit

Permalink
Refactorings, bugfixes and parse improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Sep 26, 2024
1 parent 31a16ba commit cb7d7b2
Show file tree
Hide file tree
Showing 20 changed files with 550 additions and 242 deletions.
2 changes: 1 addition & 1 deletion src/demo/features/rag/x_refs_demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from askai.core.features.router.task_accuracy import resolve_x_refs
from askai.core.features.router.evaluation import resolve_x_refs
from askai.core.support.shared_instances import shared
from askai.core.support.utilities import display_text
from utils import get_resource, init_context
Expand Down
106 changes: 106 additions & 0 deletions src/main/askai/core/enums/acc_color.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
@project: HsPyLib-AskAI
@package: askai.core.enums.acc_response
@file: acc_color.py
@created: thu, 26 Sep 2024
@author: <B>H</B>ugo <B>S</B>aporetti <B>J</B>unior
@site: https://github.com/yorevs/askai
@license: MIT - Please refer to <https://opensource.org/licenses/MIT>
Copyright (c) 2024, HomeSetup
"""
from typing import Literal, TypeAlias

from hspylib.core.enums.enumeration import Enumeration

AccuracyColors: TypeAlias = Literal["Blue", "Green", "Yellow", "Orange", "Red"]


class AccColor(Enumeration):
"""TODO"""

# fmt: off

INTERRUPT = 'Black', -1

EXCELLENT = 'Blue', 0

GOOD = 'Green', 1

MODERATE = 'Yellow', 2

INCOMPLETE = 'Orange', 3

BAD = 'Red', 4

def __init__(self, color: AccuracyColors, weight: int):
self._color: AccuracyColors = color
self._weight: int = weight

def __eq__(self, other: "AccColor") -> bool:
return self.val == other.val

def __lt__(self, other) -> bool:
return self.val < other.val

def __le__(self, other) -> bool:
return self.val <= other.val

def __gt__(self, other) -> bool:
return self.val > other.val

def __ge__(self, other) -> bool:
return self.val >= other.val

def __str__(self) -> str:
return self.color

@classmethod
def of_color(cls, color_str: AccuracyColors) -> 'AccColor':
"""Create an AccResponse instance based on status and optional reasoning.
:param color_str: The color as a string.
:return: An instance of AccColor with the given color.
"""
acc_color: tuple[str, int] = next((c for c in cls.values() if c[0] == color_str.title()), None)
if acc_color and isinstance(acc_color, tuple):
return cls.of_value(acc_color)
raise ValueError(f"'{color_str}'is not a valid AccColor")

@property
def color(self) -> AccuracyColors:
return self.value[0]

@property
def val(self) -> int:
"""Gets the integer value of the verbosity level.
:return: The integer representation of the verbosity level.
"""
return int(self.value[1])

@property
def is_bad(self) -> bool:
return self in [self.BAD, self.INCOMPLETE]

@property
def is_moderate(self) -> bool:
return self == self.MODERATE

@property
def is_good(self) -> bool:
return self in [self.GOOD, self.EXCELLENT]

@property
def is_interrupt(self) -> bool:
return self == self.INTERRUPT

def passed(self, threshold: "AccColor") -> bool:
"""Determine whether the response matches a 'PASS' classification.
:param threshold: The threshold or criteria used to determine a 'PASS' classification.
:return: True if the response meets or exceeds the 'PASS' threshold, otherwise False.
"""
if isinstance(threshold, AccColor):
return self.val <= threshold.val
return False
117 changes: 36 additions & 81 deletions src/main/askai/core/enums/acc_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,107 +12,62 @@
Copyright (c) 2024, HomeSetup
"""
import os
from dataclasses import dataclass

import re
from typing import Literal
from askai.core.enums.acc_color import AccColor, AccuracyColors
from askai.core.support.utilities import parse_field
from hspylib.core.tools.text_tools import ensure_endswith

from hspylib.core.enums.enumeration import Enumeration


class AccResponse(Enumeration):
@dataclass(frozen=True)
class AccResponse:
"""Track and classify accuracy responses based on color classifications. This class provides an enumeration of
possible accuracy responses, which are typically represented by different colors.
"""

# fmt: off

EXCELLENT = 'Blue'

GOOD = 'Green'

MODERATE = 'Yellow'

INCOMPLETE = 'Orange'

BAD = 'Red'

INTERRUPT = 'Black'

# fmt: on
acc_color: AccColor
accuracy: float
reasoning: str
tips: str

@classmethod
def matches(cls, output: str) -> re.Match:
"""Find a match in the given output string.
:param output: The string to search for a match.
:return: A match object if a match is found.
:raises: re.error if an error occurs during the matching process.
"""
flags: int = re.IGNORECASE | re.MULTILINE | re.DOTALL
return re.search(cls._re(), output.replace("\n", " "), flags=flags)

@classmethod
def _re(cls) -> str:
def parse_response(cls, response: str) -> "AccResponse":
"""TODO"""
return rf"^\$?({'|'.join(cls.values())})[:,-]\s*[0-9]+%\s+(.+)"

@classmethod
def strip_code(cls, message: str) -> str:
"""Strip the color code from the message.
:param message: The message from which to strip color codes.
:return: The message with color codes removed.
"""
mat = cls.matches(message)
return str(mat.group(2)).strip() if mat else message.strip()
# FIXME: Remove log the response
with open("/Users/hjunior/Desktop/acc-response-resp.txt", "w") as f_bosta:
f_bosta.write(response + os.linesep)
f_bosta.flush()

@classmethod
def of_status(cls, status: str, reasoning: str | None) -> "AccResponse":
"""Create an AccResponse instance based on status and optional reasoning.
:param status: The status as a string.
:param reasoning: Optional reasoning for the status, formatted as '<percentage>% <details>'.
:return: An instance of AccResponse with the given status and reasoning.
"""
resp = cls.of_value(status.title())
if reasoning and (mat := re.match(r"(^[0-9]{1,3})%\s+(.*)", reasoning)):
resp.rate = float(mat.group(1))
resp.reasoning = mat.group(2)
return resp

def __init__(self, color: Literal["Blue", "Green", "Yellow", "Orange", "Red"]):
self.color = color
self.reasoning: str | None = None
self.rate: float | None = None
# Parse fields
acc_color: AccColor = AccColor.of_color(parse_field("@color", response))
accuracy: float = float(parse_field("@accuracy", response).strip("%"))
reasoning: str = parse_field("@reasoning", response)
tips: str = parse_field("@tips", response)

return AccResponse(acc_color, accuracy, reasoning, tips)

def __str__(self):
details: str = f"{' -> ' + str(self.rate) + '% ' + self.reasoning if self.reasoning else ''}"
return f"{self.name}{details}"
return f"{self.status} -> {self.details}"

@property
def is_bad(self) -> bool:
return self in [self.BAD, self.INCOMPLETE]
def color(self) -> AccuracyColors:
return self.acc_color.color

@property
def is_moderate(self) -> bool:
return self == self.MODERATE
def status(self) -> str:
return f"{self.color}, {str(self.accuracy)}%"

@property
def is_good(self) -> bool:
return self in [self.GOOD, self.EXCELLENT]
def details(self) -> str:
return f"{ensure_endswith(self.reasoning, '.')} {'**' + self.tips + '**' if self.tips else ''}"

@property
def is_interrupt(self) -> bool:
return self == self.INTERRUPT

def passed(self, threshold: "AccResponse") -> bool:
"""Determine whether the response matches a 'PASS' classification.
:param threshold: The threshold or criteria used to determine a 'PASS' classification.
:return: True if the response meets or exceeds the 'PASS' threshold, otherwise False.
"""
if isinstance(threshold, AccResponse):
idx_self, idx_threshold = None, None
for i, v in enumerate(AccResponse.values()):
if v == self.value:
idx_self = i
if v == threshold.value:
idx_threshold = i
return idx_self is not None and idx_threshold is not None and idx_self <= idx_threshold
return False
"""TODO"""
return self.acc_color.is_interrupt

def is_pass(self, threshold: AccColor) -> bool:
"""TODO"""
return self.acc_color.passed(threshold)
8 changes: 4 additions & 4 deletions src/main/askai/core/features/processors/task_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
from askai.core.askai_prompt import prompt
from askai.core.component.geo_location import geo_location
from askai.core.engine.openai.temperature import Temperature
from askai.core.enums.acc_color import AccColor
from askai.core.enums.acc_response import AccResponse
from askai.core.enums.routing_model import RoutingModel
from askai.core.features.router.agent_tools import features
from askai.core.features.router.task_accuracy import assert_accuracy
from askai.core.features.router.evaluation import assert_accuracy
from askai.core.features.router.task_agent import agent
from askai.core.features.tools.general import final_answer
from askai.core.model.action_plan import ActionPlan
Expand Down Expand Up @@ -165,8 +166,7 @@ def _splitter_wrapper_() -> Optional[str]:
if response := runnable.invoke({"input": question}, config={"configurable": {"session_id": "HISTORY"}}):
log.info("Router::[RESPONSE] Received from AI: \n%s.", str(response.content))
plan = ActionPlan.create(question, response, model)
task_list = plan.tasks
if task_list:
if task_list := plan.tasks:
events.reply.emit(reply=AIReply.debug(msg.action_plan(str(plan))))
if plan.speak:
events.reply.emit(reply=AIReply.info(plan.speak))
Expand All @@ -183,7 +183,7 @@ def _splitter_wrapper_() -> Optional[str]:

try:
wrapper_output = self._process_tasks(task_list)
assert_accuracy(question, wrapper_output, AccResponse.MODERATE)
assert_accuracy(question, wrapper_output, AccColor.MODERATE)
except (InterruptionRequest, TerminatingQuery) as err:
return str(err)
except self.RETRIABLE_ERRORS:
Expand Down
2 changes: 1 addition & 1 deletion src/main/askai/core/features/router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
# Package: main.askai.core.features.router
"""Package initialization."""

__all__ = ["model_selector", "task_accuracy", "task_agent", "agent_tools.py"]
__all__ = ["model_selector", "evaluation.py", "task_agent", "agent_tools.py"]
__version__ = "1.0.13"
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.engine.openai.temperature import Temperature
from askai.core.enums.acc_color import AccColor
from askai.core.enums.acc_response import AccResponse
from askai.core.model.ai_reply import AIReply
from askai.core.support.langchain_support import lc_llm
Expand All @@ -30,21 +31,21 @@

import logging as log

EVALUATION_GUIDE: str = dedent(
"""
# fmt: off
EVALUATION_GUIDE: str = dedent("""
**Accuracy Evaluation Guidelines:**
1. Analyze past responses to ensure accuracy.
2. Regularly self-critique overall responses.
3. Reflect on past strategies to refine your approach.
4. Experiment with different methods or solutions.
"""
).strip()
""").strip()
# fmt: on

RAG: RAGProvider = RAGProvider("accuracy.csv")


def assert_accuracy(question: str, ai_response: str, pass_threshold: AccResponse = AccResponse.MODERATE) -> AccResponse:
def assert_accuracy(question: str, ai_response: str, pass_threshold: AccColor = AccColor.MODERATE) -> AccResponse:
"""Assert that the AI's response to the question meets the required accuracy threshold.
:param question: The user's question.
:param ai_response: The AI's response to be analyzed for accuracy.
Expand All @@ -53,32 +54,31 @@ def assert_accuracy(question: str, ai_response: str, pass_threshold: AccResponse
:return: The accuracy classification of the AI's response as an AccResponse enum value.
"""
if ai_response and ai_response not in msg.accurate_responses:
issues_prompt = PromptTemplate(input_variables=["problems"], template=prompt.read_prompt("evaluation"))
assert_template = PromptTemplate(
input_variables=["rag", "input", "response"], template=prompt.read_prompt("accuracy")
acc_template = PromptTemplate(input_variables=["problems"], template=prompt.read_prompt("acc-report"))
eval_template = PromptTemplate(
input_variables=["rag", "input", "response"], template=prompt.read_prompt("evaluation")
)
final_prompt = assert_template.format(rag=RAG.get_rag_examples(question), input=question, response=ai_response)
final_prompt = eval_template.format(rag=RAG.get_rag_examples(question), input=question, response=ai_response)
log.info("Assert::[QUESTION] '%s' context: '%s'", question, ai_response)
llm = lc_llm.create_chat_model(Temperature.COLDEST.temp)
response: AIMessage = llm.invoke(final_prompt)

if response and (output := response.content):
if mat := AccResponse.matches(output):
status, details = mat.group(1), mat.group(2)
log.info("Accuracy check -> status: '%s' reason: '%s'", status, details)
events.reply.emit(reply=AIReply.debug(msg.assert_acc(status, details)))
if (rag_resp := AccResponse.of_status(status, details)).is_interrupt:
if acc := AccResponse.parse_response(output):
log.info("Accuracy check -> status: '%s' details: '%s'", acc.status, acc.details)
events.reply.emit(reply=AIReply.debug(msg.assert_acc(acc.status, acc.details)))
if acc.is_interrupt:
# AI flags that it can't continue interacting.
log.warning(msg.interruption_requested(output))
raise InterruptionRequest(ai_response)
elif not rag_resp.passed(pass_threshold):
elif not acc.is_pass(pass_threshold):
# Include the guidelines for the first mistake.
if not shared.context.get("EVALUATION"):
shared.context.push("EVALUATION", EVALUATION_GUIDE)
shared.context.push("EVALUATION", issues_prompt.format(problems=AccResponse.strip_code(output)))
shared.context.push("EVALUATION", acc_template.format(problems=acc.details))
raise InaccurateResponse(f"AI Assistant failed to respond => '{response.content}'")
return rag_resp
# At this point, the response was not Good.
return acc
# At this point, the response was inaccurate.
raise InaccurateResponse(f"AI Assistant didn't respond accurately. Response: '{response}'")


Expand Down
6 changes: 3 additions & 3 deletions src/main/askai/core/features/router/task_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.engine.openai.temperature import Temperature
from askai.core.enums.acc_response import AccResponse
from askai.core.enums.acc_color import AccColor
from askai.core.features.router.agent_tools import features
from askai.core.features.router.task_accuracy import assert_accuracy
from askai.core.features.router.evaluation import assert_accuracy
from askai.core.model.ai_reply import AIReply
from askai.core.support.langchain_support import lc_llm
from askai.core.support.shared_instances import shared
Expand Down Expand Up @@ -60,7 +60,7 @@ def invoke(self, task: str) -> str:
shared.context.push("HISTORY", task, "assistant")
shared.context.push("HISTORY", output, "assistant")
shared.memory.save_context({"input": task}, {"output": output})
assert_accuracy(task, output, AccResponse.MODERATE)
assert_accuracy(task, output, AccColor.MODERATE)
else:
output = msg.no_output("AI")

Expand Down
Loading

0 comments on commit cb7d7b2

Please sign in to comment.