Skip to content

Commit

Permalink
Replace actual processor by the pipeline - part2
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Oct 23, 2024
1 parent 5981283 commit 3e0bcce
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 37 deletions.
3 changes: 3 additions & 0 deletions src/main/askai/core/askai_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,5 +257,8 @@ def quote_exceeded(self) -> str:
def interruption_requested(self, reason: str) -> str:
return f"AI has interrupted the execution => {reason}"

def terminate_requested(self, reason: str) -> str:
return f"AI has terminated the execution => {reason}"


assert (msg := AskAiMessages().INSTANCE) is not None
14 changes: 5 additions & 9 deletions src/main/askai/core/processors/splitter/splitter_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from textwrap import indent
from threading import Thread

from hspylib.core.decorator.decorators import profiled
from rich.console import Console

from askai.core.askai_configs import configs
from askai.core.askai_messages import msg
from askai.core.enums.acc_color import AccColor
from askai.core.processors.splitter.splitter_pipeline import SplitterPipeline
from askai.core.processors.splitter.splitter_states import States
from askai.core.support.shared_instances import shared


class SplitterExecutor(Thread):
Expand All @@ -43,12 +43,11 @@ def display(self, message: str) -> None:
if configs.is_debug:
self._console.print(message)

@profiled
def run(self):
with self._console.status(msg.wait(), spinner="dots") as spinner:
while not self.pipeline.state == States.COMPLETE:
self.pipeline.track_previous()
spinner.update(f"[green]{self.pipeline.state.value}[/green]")
spinner.update(f"{shared.nickname_spinner}[green]{self.pipeline.state.value}[/green]")
if 0 < configs.max_router_retries < self.pipeline.failures[self.pipeline.state.value]:
self.display(f"\n[red] Max retries exceeded: {configs.max_router_retries}[/red]\n")
break
Expand All @@ -64,8 +63,8 @@ def run(self):
self.pipeline.ev_model_selected()
case States.TASK_SPLIT:
if self.pipeline.st_task_split():
if self.pipeline.is_direct:
spinner.update("[yellow] AI decided to respond directly[/yellow]")
if self.pipeline.is_direct():
self.display("[yellow] AI decided to respond directly[/yellow]")
self.pipeline.ev_direct_answer()
else:
spinner.update("[green] Executing action plan[/green]")
Expand All @@ -76,7 +75,7 @@ def run(self):
case States.ACCURACY_CHECK:
acc_color: AccColor = self.pipeline.st_accuracy_check()
c_name: str = acc_color.color.casefold()
spinner.update(f"[green] Accuracy check: [{c_name}]{c_name.upper()}[/{c_name}][/green]")
self.display(f"[green] Accuracy check: [{c_name}]{c_name.upper()}[/{c_name}][/green]")
if acc_color.passed(AccColor.GOOD):
self.pipeline.ev_accuracy_passed()
elif acc_color.passed(AccColor.MODERATE):
Expand Down Expand Up @@ -116,6 +115,3 @@ def run(self):
if final_state != States.COMPLETE:
retries: int = self.pipeline.failures[self.pipeline.state.value]
self.display(f"Failed to generate a response after {retries} retries")

if self.pipeline.state == States.COMPLETE and self.pipeline.final_answer:
print(self.pipeline.final_answer)
87 changes: 60 additions & 27 deletions src/main/askai/core/processors/splitter/splitter_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
"""
import logging as log
import os
import random
from collections import defaultdict
from typing import AnyStr, Optional

import pause
from langchain_core.prompts import PromptTemplate
from transitions import Machine

Expand All @@ -33,7 +31,6 @@
from askai.core.processors.splitter.splitter_transitions import Transition, TRANSITIONS
from askai.core.router.evaluation import assert_accuracy, EVALUATION_GUIDE
from askai.core.support.shared_instances import shared
from askai.exception.exceptions import InterruptionRequest, TerminatingQuery


class SplitterPipeline:
Expand All @@ -58,9 +55,20 @@ def __init__(self, query: AnyStr):
self._iteractions: int = 0
self._query: str = query
self._plan: ActionPlan | None = None
self._final_answer: Optional[str] = None
self._direct_answer: Optional[str] = None
self._model: ModelResult | None = None
self._resp_history: list[str] = list()
self._last_acc_response: AccResponse | None = None
self._last_task: str | None = None

def _invalidate(self) -> None:
"""TODO"""
self._plan = None
self._direct_answer = None
self._model = None
self._resp_history = list()
self._last_acc_response = None
self._last_task = None

@property
def iteractions(self) -> int:
Expand All @@ -70,6 +78,22 @@ def iteractions(self) -> int:
def iteractions(self, value: int):
self._iteractions = value

@property
def last_acc_response(self) -> AccResponse:
return self._last_acc_response

@last_acc_response.setter
def last_acc_response(self, value: AccResponse) -> None:
self._last_acc_response = value

@property
def last_task(self) -> str:
return self._last_task

@last_task.setter
def last_task(self, value: str) -> None:
self._last_task = value

@property
def failures(self) -> dict[str, int]:
return self._failures
Expand All @@ -88,11 +112,19 @@ def previous(self) -> States:

@property
def query(self) -> str:
return self._query
if self.last_task is not None:
question: str = self.last_task
else:
question: str = self._query
return question

@property
def final_answer(self) -> Optional[str]:
return self._final_answer
if self.is_direct():
ai_response: str = self._direct_answer
else:
ai_response: str = os.linesep.join(self._resp_history)
return ai_response

@property
def resp_history(self) -> list[str]:
Expand All @@ -104,14 +136,15 @@ def track_previous(self) -> None:

def has_next(self) -> bool:
"""TODO"""
return len(self.plan.tasks) > 0 if self.plan and self.plan.tasks else False
return len(self.plan.tasks) > 0 if self.plan is not None and self.plan.tasks else False

def is_direct(self) -> bool:
"""TODO"""
return self.plan.is_direct if self.plan else True
return self.plan.is_direct if self.plan is not None else True

def st_startup(self) -> bool:
log.info("Task Splitter pipeline has started!")
self._invalidate()
return True

def st_model_select(self) -> bool:
Expand All @@ -123,37 +156,34 @@ def st_task_split(self) -> bool:
log.info("Splitting tasks...")
self._plan = actions.split(self.query, self.model)
if self._plan.is_direct:
self._final_answer = self._plan.speak or msg.no_output("TaskSplitter")
self._direct_answer = self._plan.speak or msg.no_output("TaskSplitter")
return True

def st_execute_next(self) -> bool:
_iter_ = self.plan.tasks.copy().__iter__()
if action := next(_iter_, None):
if agent_output := actions.process_action(action):
self.resp_history.append(agent_output)
self.plan.tasks.pop(0)
self.last_task = self.plan.tasks.pop(0).task if len(self.plan.tasks) > 0 else None
return self.last_task is not None
return False

def st_accuracy_check(self) -> AccColor:

# FIXME Hardcoded for now
pass_threshold: AccColor = AccColor.GOOD

if self.is_direct:
ai_response: str = self.final_answer
else:
ai_response: str = os.linesep.join(self._resp_history)

acc: AccResponse = assert_accuracy(self.query, ai_response, pass_threshold)
acc: AccResponse = assert_accuracy(self.query, self.final_answer, pass_threshold)

if acc.is_interrupt:
# AI flags that it can't continue interacting.
log.warning(msg.interruption_requested(ai_response))
raise InterruptionRequest(ai_response)
log.warning(msg.interruption_requested(self.final_answer))
elif acc.is_terminate:
# AI flags that the user wants to end the session.
raise TerminatingQuery(ai_response)
log.warning(msg.terminate_requested(self.final_answer))
elif acc.is_pass(pass_threshold):
# AI provided a good answer.
log.warning(f"AI provided a final answer: {self.final_answer}")
self.resp_history.append(self.final_answer)
shared.memory.save_context({"input": self.query}, {"output": self.final_answer})
else:
acc_template = PromptTemplate(input_variables=["problems"], template=prompt.read_prompt("acc-report"))
Expand All @@ -162,15 +192,18 @@ def st_accuracy_check(self) -> AccColor:
shared.context.push("EVALUATION", EVALUATION_GUIDE)
shared.context.push("EVALUATION", acc_template.format(problems=acc.details))

self.last_acc_response = acc

return acc.acc_color

def st_refine_answer(self) -> bool:
result = random.choice([True, False])
pause.seconds(self.FAKE_SLEEP)
return result
if self.is_direct:
ai_response: str = self.final_answer
else:
ai_response: str = os.linesep.join(self._resp_history)

return actions.refine_answer(self.query, ai_response, self.last_acc_response)

def st_final_answer(self) -> bool:
self._final_answer = "This is the final answer"
result = random.choice([True, False])
pause.seconds(self.FAKE_SLEEP)
return result

return actions.wrap_answer(self.query, self.final_answer, self.model)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
{'trigger': 'ev_task_executed', 'source': States.EXECUTE_TASK, 'dest': States.ACCURACY_CHECK},

{'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.EXECUTE_TASK, 'conditions': ['has_next']},
{'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.COMPLETE, 'unless': ['has_next']},
{'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.WRAP_ANSWER, 'unless': ['has_next']},
{'trigger': 'ev_accuracy_failed', 'source': States.ACCURACY_CHECK, 'dest': States.EXECUTE_TASK},
{'trigger': 'ev_refine_required', 'source': States.ACCURACY_CHECK, 'dest': States.REFINE_ANSWER, 'unless': ['has_next']},

Expand Down
4 changes: 4 additions & 0 deletions src/main/askai/core/support/shared_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def nickname_md(self) -> str:
def username_md(self) -> str:
return f"** {prompt.user.title()}:** "

@property
def nickname_spinner(self) -> str:
return f"[green]{self.mode.icon}[bold] Taius:[/bold][/green] "

@property
def idiom(self) -> str:
return self._idiom
Expand Down

0 comments on commit 3e0bcce

Please sign in to comment.