Skip to content

Commit

Permalink
Replace actual processor by the pipeline - part3
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Oct 23, 2024
1 parent 3e0bcce commit 0832160
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 121 deletions.
2 changes: 1 addition & 1 deletion dependencies.hspd
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ package: html2text, version: 2024.2.26, mode: ge

/* CLI/TUI */
package: rich, version: 13.8.1, mode: ge
package: textual, version: 0.80.1, mode: ge
package: textual, version: 0.80.1, mode: eq

/* Audio */
package: soundfile, version: 0.12.1, mode: ge
Expand Down
2 changes: 2 additions & 0 deletions src/main/askai/__classpath__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

if not is_a_tty():
log.getLogger().setLevel(log.ERROR)
else:
log.getLogger().setLevel(log.INFO)

if not os.environ.get("USER_AGENT"):
# The AskAI User Agent, required by the langchain framework
Expand Down
2 changes: 1 addition & 1 deletion src/main/askai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _main(self, *params, **kwargs) -> ExitStatus:
os.environ["ASKAI_APP"] = RunModes.ASKAI_CMD.value
return self._execute_command(query_string)

log.info(
log.debug(
dedent(
f"""
{os.environ.get("ASKAI_APP")} v{self._app_version}
Expand Down
8 changes: 5 additions & 3 deletions src/main/askai/core/processors/splitter/splitter_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ def run(self):
self.display("[yellow] AI decided to respond directly[/yellow]")
self.pipeline.ev_direct_answer()
else:
spinner.update("[green] Executing action plan[/green]")
self.display(f"[green] Executing action plan[/green]")
self.pipeline.ev_plan_created()
self.display(f"[green] Action plan created: {self.pipeline.plan.tasks}[/green]")
case States.EXECUTE_TASK:
if self.pipeline.st_execute_next():
if self.pipeline.st_execute_task():
self.display(f"[green] Task executed: '{self.pipeline.last_answer}'[/green]")
self.pipeline.ev_task_executed()
case States.ACCURACY_CHECK:
case States.ACC_CHECK:
acc_color: AccColor = self.pipeline.st_accuracy_check()
c_name: str = acc_color.color.casefold()
self.display(f"[green] Accuracy check: [{c_name}]{c_name.upper()}[/{c_name}][/green]")
Expand Down
180 changes: 90 additions & 90 deletions src/main/askai/core/processors/splitter/splitter_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
Copyright (c) 2024, HomeSetup
"""
import logging as log
import os
from collections import defaultdict
from typing import AnyStr, Optional
from typing import AnyStr

from hspylib.core.preconditions import check_state
from langchain_core.prompts import PromptTemplate
from transitions import Machine

Expand All @@ -27,9 +27,10 @@
from askai.core.model.action_plan import ActionPlan
from askai.core.model.model_result import ModelResult
from askai.core.processors.splitter.splitter_actions import actions
from askai.core.processors.splitter.splitter_result import SplitterResult, PipelineResponse
from askai.core.processors.splitter.splitter_states import States
from askai.core.processors.splitter.splitter_transitions import Transition, TRANSITIONS
from askai.core.router.evaluation import assert_accuracy, EVALUATION_GUIDE
from askai.core.router.evaluation import eval_response, EVALUATION_GUIDE
from askai.core.support.shared_instances import shared


Expand All @@ -40,7 +41,7 @@ class SplitterPipeline:

FAKE_SLEEP: float = 0.3

def __init__(self, query: AnyStr):
def __init__(self, question: AnyStr):
self._transitions: list[Transition] = [t for t in TRANSITIONS]
self._machine: Machine = Machine(
name="Taius-Coder",
Expand All @@ -50,25 +51,14 @@ def __init__(self, query: AnyStr):
transitions=self._transitions,
auto_transitions=False
)
self._previous: States | None = None
self._failures: dict[str, int] = defaultdict(int)
self._previous: States = States.NOT_STARTED
self._iteractions: int = 0
self._query: str = query
self._plan: ActionPlan | None = None
self._direct_answer: Optional[str] = None
self._model: ModelResult | None = None
self._resp_history: list[str] = list()
self._last_acc_response: AccResponse | None = None
self._last_task: str | None = None

def _invalidate(self) -> None:
"""TODO"""
self._plan = None
self._direct_answer = None
self._model = None
self._resp_history = list()
self._last_acc_response = None
self._last_task = None
self._failures: dict[str, int] = defaultdict(int)
self._result: SplitterResult = SplitterResult(question)

@property
def previous(self) -> States:
return self._previous

@property
def iteractions(self) -> int:
Expand All @@ -79,56 +69,64 @@ def iteractions(self, value: int):
self._iteractions = value

@property
def last_acc_response(self) -> AccResponse:
return self._last_acc_response

@last_acc_response.setter
def last_acc_response(self, value: AccResponse) -> None:
self._last_acc_response = value
def failures(self) -> dict[str, int]:
return self._failures

@property
def last_task(self) -> str:
return self._last_task
def result(self) -> SplitterResult:
return self._result

@last_task.setter
def last_task(self, value: str) -> None:
self._last_task = value
@property
def responses(self) -> list[PipelineResponse]:
return self._result.responses

@property
def failures(self) -> dict[str, int]:
return self._failures
def question(self) -> str:
return self.result.question

@property
def plan(self) -> ActionPlan:
return self._plan
def last_query(self) -> str:
return self.responses[-1].query

@last_query.setter
def last_query(self, value: str) -> None:
self.responses[-1].query = value

@property
def model(self) -> ModelResult:
return self._model
def last_answer(self) -> str:
return self.responses[-1].answer

@last_answer.setter
def last_answer(self, value: str) -> None:
self.responses[-1].answer = value

@property
def previous(self) -> States:
return self._previous
def last_accuracy(self) -> AccResponse:
return self.responses[-1].accuracy

@last_accuracy.setter
def last_accuracy(self, value: AccResponse) -> None:
self.responses[-1].accuracy = value

@property
def query(self) -> str:
if self.last_task is not None:
question: str = self.last_task
else:
question: str = self._query
return question
def plan(self) -> ActionPlan:
return self.result.plan

@plan.setter
def plan(self, value: ActionPlan):
self.result.plan = value

@property
def final_answer(self) -> Optional[str]:
if self.is_direct():
ai_response: str = self._direct_answer
else:
ai_response: str = os.linesep.join(self._resp_history)
return ai_response
def model(self) -> ModelResult:
return self.result.model

@model.setter
def model(self, value: ModelResult):
self.result.model = value

@property
def resp_history(self) -> list[str]:
return self._resp_history
def final_answer(self) -> str:
return self.result.final_response()

def track_previous(self) -> None:
"""TODO"""
Expand All @@ -143,67 +141,69 @@ def is_direct(self) -> bool:
return self.plan.is_direct if self.plan is not None else True

def st_startup(self) -> bool:
"""TODO"""
log.info("Task Splitter pipeline has started!")
self._invalidate()
return True

def st_model_select(self) -> bool:
"""TODO"""
log.info("Selecting response model...")
self._model = ModelResult.default()
# FIXME: Model select is default for now
self.model = ModelResult.default()
return True

def st_task_split(self) -> bool:
"""TODO"""
log.info("Splitting tasks...")
self._plan = actions.split(self.query, self.model)
if self._plan.is_direct:
self._direct_answer = self._plan.speak or msg.no_output("TaskSplitter")
return True
if (plan := actions.split(self.question, self.model)) is not None:
if plan.is_direct:
self.responses.append(PipelineResponse(self.question, plan.speak or msg.no_output("TaskSplitter")))
self.plan = plan
return True
return False

def st_execute_next(self) -> bool:
def st_execute_task(self) -> bool:
"""TODO"""
check_state(self.plan.tasks is not None and len(self.plan.tasks) > 0)
_iter_ = self.plan.tasks.copy().__iter__()
if action := next(_iter_, None):
log.info(f"Executing task '{action}'...")
if agent_output := actions.process_action(action):
self.last_task = self.plan.tasks.pop(0).task if len(self.plan.tasks) > 0 else None
return self.last_task is not None
self.responses.append(PipelineResponse(action.task, agent_output))
return True
return False

def st_accuracy_check(self) -> AccColor:
"""TODO"""

if self.last_query is None or self.last_answer is None:
return AccColor.BAD

# FIXME Hardcoded for now
pass_threshold: AccColor = AccColor.GOOD

acc: AccResponse = assert_accuracy(self.query, self.final_answer, pass_threshold)

if acc.is_interrupt:
# AI flags that it can't continue interacting.
log.warning(msg.interruption_requested(self.final_answer))
elif acc.is_terminate:
# AI flags that the user wants to end the session.
log.warning(msg.terminate_requested(self.final_answer))
elif acc.is_pass(pass_threshold):
# AI provided a good answer.
log.warning(f"AI provided a final answer: {self.final_answer}")
self.resp_history.append(self.final_answer)
shared.memory.save_context({"input": self.query}, {"output": self.final_answer})
pass_threshold: AccColor = AccColor.MODERATE
acc: AccResponse = eval_response(self.last_query, self.last_answer)

if acc.is_interrupt: # AI flags that it can't continue interacting.
log.warning(msg.interruption_requested(self.last_answer))
elif acc.is_terminate: # AI flags that the user wants to end the session.
log.warning(msg.terminate_requested(self.last_answer))
elif acc.is_pass(pass_threshold): # AI provided a good answer.
log.info(f"AI provided a good answer: {self.last_answer}")
if len(self.plan.tasks) > 0:
self.plan.tasks.pop(0)
shared.memory.save_context({"input": self.last_query}, {"output": self.last_answer})
else:
acc_template = PromptTemplate(input_variables=["problems"], template=prompt.read_prompt("acc-report"))
# Include the guidelines for the first mistake.
if not shared.context.get("EVALUATION"):
if not shared.context.get("EVALUATION"): # Include the guidelines for the first mistake.
shared.context.push("EVALUATION", EVALUATION_GUIDE)
shared.context.push("EVALUATION", acc_template.format(problems=acc.details))

self.last_acc_response = acc
self.last_accuracy = acc

return acc.acc_color

def st_refine_answer(self) -> bool:
if self.is_direct:
ai_response: str = self.final_answer
else:
ai_response: str = os.linesep.join(self._resp_history)

return actions.refine_answer(self.query, ai_response, self.last_acc_response)
return actions.refine_answer(self.question, self.final_answer, self.last_accuracy)

def st_final_answer(self) -> bool:

return actions.wrap_answer(self.query, self.final_answer, self.model)
return actions.wrap_answer(self.question, self.final_answer, self.model)
31 changes: 31 additions & 0 deletions src/main/askai/core/processors/splitter/splitter_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
from dataclasses import dataclass, field

from askai.core.enums.acc_color import AccColor
from askai.core.model.acc_response import AccResponse
from askai.core.model.action_plan import ActionPlan
from askai.core.model.model_result import ModelResult


@dataclass
class PipelineResponse:
"""TODO"""
query: str
answer: str | None = None
accuracy: AccResponse | None = None


@dataclass
class SplitterResult:
"""TODO"""
question: str
responses: list[PipelineResponse] = field(default_factory=list)
plan: ActionPlan | None = None
model: ModelResult | None = None

def final_response(self) -> str:
"""TODO"""
return os.linesep.join(
list(map(lambda r: r.answer, filter(
lambda acc: acc.accuracy and acc.accuracy.acc_color.passed(AccColor.MODERATE), self.responses)))
)
17 changes: 9 additions & 8 deletions src/main/askai/core/processors/splitter/splitter_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
class States(Enumeration):
"""Enumeration of possible task splitter states."""
# fmt: off
STARTUP = ' Processing query'
MODEL_SELECT = ' Selecting Model'
TASK_SPLIT = ' Splitting Tasks'
ACCURACY_CHECK = ' Checking Accuracy'
EXECUTE_TASK = ' Executing Task'
REFINE_ANSWER = ' Refining Answer'
WRAP_ANSWER = ' Wrapping final answer'
COMPLETE = 'ﲏ Completed'
NOT_STARTED = 'Not started'
STARTUP = 'Processing query'
MODEL_SELECT = 'Selecting Model'
TASK_SPLIT = 'Splitting Tasks'
ACC_CHECK = 'Checking Accuracy'
EXECUTE_TASK = 'Executing Task'
REFINE_ANSWER = 'Refining Answer'
WRAP_ANSWER = 'Wrapping final answer'
COMPLETE = 'Completed'
# fmt: on
16 changes: 9 additions & 7 deletions src/main/askai/core/processors/splitter/splitter_transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@
{'trigger': 'ev_pipeline_started', 'source': States.STARTUP, 'dest': States.MODEL_SELECT},
{'trigger': 'ev_model_selected', 'source': States.MODEL_SELECT, 'dest': States.TASK_SPLIT},

{'trigger': 'ev_direct_answer', 'source': States.TASK_SPLIT, 'dest': States.ACCURACY_CHECK},
{'trigger': 'ev_direct_answer', 'source': States.TASK_SPLIT, 'dest': States.ACC_CHECK},
{'trigger': 'ev_plan_created', 'source': States.TASK_SPLIT, 'dest': States.EXECUTE_TASK},

{'trigger': 'ev_accuracy_check', 'source': States.ACCURACY_CHECK, 'dest': States.EXECUTE_TASK},
{'trigger': 'ev_accuracy_check', 'source': States.ACC_CHECK, 'dest': States.EXECUTE_TASK},

{'trigger': 'ev_task_executed', 'source': States.EXECUTE_TASK, 'dest': States.ACCURACY_CHECK},
{'trigger': 'ev_task_executed', 'source': States.EXECUTE_TASK, 'dest': States.ACC_CHECK},

{'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.EXECUTE_TASK, 'conditions': ['has_next']},
{'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.WRAP_ANSWER, 'unless': ['has_next']},
{'trigger': 'ev_accuracy_failed', 'source': States.ACCURACY_CHECK, 'dest': States.EXECUTE_TASK},
{'trigger': 'ev_refine_required', 'source': States.ACCURACY_CHECK, 'dest': States.REFINE_ANSWER, 'unless': ['has_next']},
{'trigger': 'ev_accuracy_passed', 'source': States.ACC_CHECK, 'dest': States.EXECUTE_TASK,
'conditions': ['has_next']},
{'trigger': 'ev_accuracy_passed', 'source': States.ACC_CHECK, 'dest': States.WRAP_ANSWER, 'unless': ['has_next']},
{'trigger': 'ev_accuracy_failed', 'source': States.ACC_CHECK, 'dest': States.EXECUTE_TASK},
{'trigger': 'ev_accuracy_failed', 'source': States.ACC_CHECK, 'dest': States.TASK_SPLIT, 'unless': ['has_next']},
{'trigger': 'ev_refine_required', 'source': States.ACC_CHECK, 'dest': States.REFINE_ANSWER, 'unless': ['has_next']},

{'trigger': 'ev_answer_refined', 'source': States.REFINE_ANSWER, 'dest': States.WRAP_ANSWER},
{'trigger': 'ev_final_answer', 'source': States.WRAP_ANSWER, 'dest': States.COMPLETE},
Expand Down
Loading

0 comments on commit 0832160

Please sign in to comment.