From 3e0bccea0f5dd732c4e56b0f84175f892076d24f Mon Sep 17 00:00:00 2001 From: Hugo Saporetti Junior Date: Tue, 22 Oct 2024 21:17:45 -0300 Subject: [PATCH] Replace actual processor by the pipeline - part2 --- src/main/askai/core/askai_messages.py | 3 + .../processors/splitter/splitter_executor.py | 14 ++- .../processors/splitter/splitter_pipeline.py | 87 +++++++++++++------ .../splitter/splitter_transitions.py | 2 +- .../askai/core/support/shared_instances.py | 4 + 5 files changed, 73 insertions(+), 37 deletions(-) diff --git a/src/main/askai/core/askai_messages.py b/src/main/askai/core/askai_messages.py index 5862c608..37b81e8f 100644 --- a/src/main/askai/core/askai_messages.py +++ b/src/main/askai/core/askai_messages.py @@ -257,5 +257,8 @@ def quote_exceeded(self) -> str: def interruption_requested(self, reason: str) -> str: return f"AI has interrupted the execution => {reason}" + def terminate_requested(self, reason: str) -> str: + return f"AI has terminated the execution => {reason}" + assert (msg := AskAiMessages().INSTANCE) is not None diff --git a/src/main/askai/core/processors/splitter/splitter_executor.py b/src/main/askai/core/processors/splitter/splitter_executor.py index ec6baebc..727d0ed6 100644 --- a/src/main/askai/core/processors/splitter/splitter_executor.py +++ b/src/main/askai/core/processors/splitter/splitter_executor.py @@ -16,7 +16,6 @@ from textwrap import indent from threading import Thread -from hspylib.core.decorator.decorators import profiled from rich.console import Console from askai.core.askai_configs import configs @@ -24,6 +23,7 @@ from askai.core.enums.acc_color import AccColor from askai.core.processors.splitter.splitter_pipeline import SplitterPipeline from askai.core.processors.splitter.splitter_states import States +from askai.core.support.shared_instances import shared class SplitterExecutor(Thread): @@ -43,12 +43,11 @@ def display(self, message: str) -> None: if configs.is_debug: self._console.print(message) - @profiled def run(self): with self._console.status(msg.wait(), spinner="dots") as spinner: while not self.pipeline.state == States.COMPLETE: self.pipeline.track_previous() - spinner.update(f"[green]{self.pipeline.state.value}[/green]") + spinner.update(f"{shared.nickname_spinner}[green]{self.pipeline.state.value}…[/green]") if 0 < configs.max_router_retries < self.pipeline.failures[self.pipeline.state.value]: self.display(f"\n[red] Max retries exceeded: {configs.max_router_retries}[/red]\n") break @@ -64,8 +63,8 @@ def run(self): self.pipeline.ev_model_selected() case States.TASK_SPLIT: if self.pipeline.st_task_split(): - if self.pipeline.is_direct: - spinner.update("[yellow] AI decided to respond directly[/yellow]") + if self.pipeline.is_direct(): + self.display("[yellow] AI decided to respond directly[/yellow]") self.pipeline.ev_direct_answer() else: spinner.update("[green] Executing action plan[/green]") @@ -76,7 +75,7 @@ def run(self): case States.ACCURACY_CHECK: acc_color: AccColor = self.pipeline.st_accuracy_check() c_name: str = acc_color.color.casefold() - spinner.update(f"[green] Accuracy check: [{c_name}]{c_name.upper()}[/{c_name}][/green]") + self.display(f"[green] Accuracy check: [{c_name}]{c_name.upper()}[/{c_name}][/green]") if acc_color.passed(AccColor.GOOD): self.pipeline.ev_accuracy_passed() elif acc_color.passed(AccColor.MODERATE): @@ -116,6 +115,3 @@ def run(self): if final_state != States.COMPLETE: retries: int = self.pipeline.failures[self.pipeline.state.value] self.display(f"Failed to generate a response after {retries} retries") - - if self.pipeline.state == States.COMPLETE and self.pipeline.final_answer: - print(self.pipeline.final_answer) diff --git a/src/main/askai/core/processors/splitter/splitter_pipeline.py b/src/main/askai/core/processors/splitter/splitter_pipeline.py index a9c551ec..6ef9e02f 100644 --- a/src/main/askai/core/processors/splitter/splitter_pipeline.py +++ b/src/main/askai/core/processors/splitter/splitter_pipeline.py @@ -14,11 +14,9 @@ """ import logging as log import os -import random from collections import defaultdict from typing import AnyStr, Optional -import pause from langchain_core.prompts import PromptTemplate from transitions import Machine @@ -33,7 +31,6 @@ from askai.core.processors.splitter.splitter_transitions import Transition, TRANSITIONS from askai.core.router.evaluation import assert_accuracy, EVALUATION_GUIDE from askai.core.support.shared_instances import shared -from askai.exception.exceptions import InterruptionRequest, TerminatingQuery class SplitterPipeline: @@ -58,9 +55,20 @@ def __init__(self, query: AnyStr): self._iteractions: int = 0 self._query: str = query self._plan: ActionPlan | None = None - self._final_answer: Optional[str] = None + self._direct_answer: Optional[str] = None self._model: ModelResult | None = None self._resp_history: list[str] = list() + self._last_acc_response: AccResponse | None = None + self._last_task: str | None = None + + def _invalidate(self) -> None: + """TODO""" + self._plan = None + self._direct_answer = None + self._model = None + self._resp_history = list() + self._last_acc_response = None + self._last_task = None @property def iteractions(self) -> int: @@ -70,6 +78,22 @@ def iteractions(self) -> int: def iteractions(self, value: int): self._iteractions = value + @property + def last_acc_response(self) -> AccResponse: + return self._last_acc_response + + @last_acc_response.setter + def last_acc_response(self, value: AccResponse) -> None: + self._last_acc_response = value + + @property + def last_task(self) -> str: + return self._last_task + + @last_task.setter + def last_task(self, value: str) -> None: + self._last_task = value + @property def failures(self) -> dict[str, int]: return self._failures @@ -88,11 +112,19 @@ def previous(self) -> States: @property def query(self) -> str: - return self._query + if self.last_task is not None: + question: str = self.last_task + else: + question: str = self._query + return question @property def final_answer(self) -> Optional[str]: - return self._final_answer + if self.is_direct(): + ai_response: str = self._direct_answer + else: + ai_response: str = os.linesep.join(self._resp_history) + return ai_response @property def resp_history(self) -> list[str]: @@ -104,14 +136,15 @@ def track_previous(self) -> None: def has_next(self) -> bool: """TODO""" - return len(self.plan.tasks) > 0 if self.plan and self.plan.tasks else False + return len(self.plan.tasks) > 0 if self.plan is not None and self.plan.tasks else False def is_direct(self) -> bool: """TODO""" - return self.plan.is_direct if self.plan else True + return self.plan.is_direct if self.plan is not None else True def st_startup(self) -> bool: log.info("Task Splitter pipeline has started!") + self._invalidate() return True def st_model_select(self) -> bool: @@ -123,15 +156,15 @@ def st_task_split(self) -> bool: log.info("Splitting tasks...") self._plan = actions.split(self.query, self.model) if self._plan.is_direct: - self._final_answer = self._plan.speak or msg.no_output("TaskSplitter") + self._direct_answer = self._plan.speak or msg.no_output("TaskSplitter") return True def st_execute_next(self) -> bool: _iter_ = self.plan.tasks.copy().__iter__() if action := next(_iter_, None): if agent_output := actions.process_action(action): - self.resp_history.append(agent_output) - self.plan.tasks.pop(0) + self.last_task = self.plan.tasks.pop(0).task if len(self.plan.tasks) > 0 else None + return self.last_task is not None return False def st_accuracy_check(self) -> AccColor: @@ -139,21 +172,18 @@ def st_accuracy_check(self) -> AccColor: # FIXME Hardcoded for now pass_threshold: AccColor = AccColor.GOOD - if self.is_direct: - ai_response: str = self.final_answer - else: - ai_response: str = os.linesep.join(self._resp_history) - - acc: AccResponse = assert_accuracy(self.query, ai_response, pass_threshold) + acc: AccResponse = assert_accuracy(self.query, self.final_answer, pass_threshold) if acc.is_interrupt: # AI flags that it can't continue interacting. - log.warning(msg.interruption_requested(ai_response)) - raise InterruptionRequest(ai_response) + log.warning(msg.interruption_requested(self.final_answer)) elif acc.is_terminate: # AI flags that the user wants to end the session. - raise TerminatingQuery(ai_response) + log.warning(msg.terminate_requested(self.final_answer)) elif acc.is_pass(pass_threshold): + # AI provided a good answer. + log.warning(f"AI provided a final answer: {self.final_answer}") + self.resp_history.append(self.final_answer) shared.memory.save_context({"input": self.query}, {"output": self.final_answer}) else: acc_template = PromptTemplate(input_variables=["problems"], template=prompt.read_prompt("acc-report")) @@ -162,15 +192,18 @@ def st_accuracy_check(self) -> AccColor: shared.context.push("EVALUATION", EVALUATION_GUIDE) shared.context.push("EVALUATION", acc_template.format(problems=acc.details)) + self.last_acc_response = acc + return acc.acc_color def st_refine_answer(self) -> bool: - result = random.choice([True, False]) - pause.seconds(self.FAKE_SLEEP) - return result + if self.is_direct: + ai_response: str = self.final_answer + else: + ai_response: str = os.linesep.join(self._resp_history) + + return actions.refine_answer(self.query, ai_response, self.last_acc_response) def st_final_answer(self) -> bool: - self._final_answer = "This is the final answer" - result = random.choice([True, False]) - pause.seconds(self.FAKE_SLEEP) - return result + + return actions.wrap_answer(self.query, self.final_answer, self.model) diff --git a/src/main/askai/core/processors/splitter/splitter_transitions.py b/src/main/askai/core/processors/splitter/splitter_transitions.py index 1d01f8ca..c743ade8 100644 --- a/src/main/askai/core/processors/splitter/splitter_transitions.py +++ b/src/main/askai/core/processors/splitter/splitter_transitions.py @@ -32,7 +32,7 @@ {'trigger': 'ev_task_executed', 'source': States.EXECUTE_TASK, 'dest': States.ACCURACY_CHECK}, {'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.EXECUTE_TASK, 'conditions': ['has_next']}, - {'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.COMPLETE, 'unless': ['has_next']}, + {'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.WRAP_ANSWER, 'unless': ['has_next']}, {'trigger': 'ev_accuracy_failed', 'source': States.ACCURACY_CHECK, 'dest': States.EXECUTE_TASK}, {'trigger': 'ev_refine_required', 'source': States.ACCURACY_CHECK, 'dest': States.REFINE_ANSWER, 'unless': ['has_next']}, diff --git a/src/main/askai/core/support/shared_instances.py b/src/main/askai/core/support/shared_instances.py index 1edb03f5..35f7867f 100644 --- a/src/main/askai/core/support/shared_instances.py +++ b/src/main/askai/core/support/shared_instances.py @@ -100,6 +100,10 @@ def nickname_md(self) -> str: def username_md(self) -> str: return f"** {prompt.user.title()}:** " + @property + def nickname_spinner(self) -> str: + return f"[green]{self.mode.icon}[bold] Taius:[/bold][/green] " + @property def idiom(self) -> str: return self._idiom