Skip to content

Commit

Permalink
Improved splitter pipeline reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Oct 21, 2024
1 parent 4009b97 commit 9b2caaf
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 30 deletions.
43 changes: 28 additions & 15 deletions src/main/askai/core/processors/splitter/splitter_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
Copyright (c) 2024, HomeSetup
"""
import os
from textwrap import indent
from threading import Thread

from hspylib.core.decorator.decorators import profiled
Expand Down Expand Up @@ -39,19 +41,18 @@ def pipeline(self) -> SplitterPipeline:
def run(self):
with self._console.status("Processing query...", spinner="dots") as spinner:
max_retries: int = configs.max_router_retries
max_interactions: int = configs.max_iteractions
while not self.pipeline.state == States.COMPLETE:
self.pipeline.track_previous()
spinner.update(f"[green]{self.pipeline.state.value}[/green]")
if 0 < max_retries < self.pipeline.failures[self.pipeline.state.value]:
if (0 < max_retries < self.pipeline.failures[self.pipeline.state.value]) \
and (0 < max_interactions < self.pipeline.iteractions):
spinner.update(f"\nMax state retries reached: {max_retries}")
break
match self.pipeline.state:
case States.STARTUP:
if self.pipeline.st_startup():
self.pipeline.ev_pipeline_started()
case States.QUERY_QUEUED:
if self.pipeline.st_query_queued():
self.pipeline.ev_query_queued()
case States.MODEL_SELECT:
if self.pipeline.st_model_select():
self.pipeline.ev_model_selected()
Expand All @@ -63,16 +64,16 @@ def run(self):
else:
self.pipeline.ev_plan_created()
case States.EXECUTE_TASK:
color, has_next = self.pipeline.st_execute_next()
if color.passed:
if has_next:
self.pipeline.st_execute_next()
else:
self.pipeline.ev_task_executed()
if self.pipeline.st_execute_next():
self.pipeline.ev_task_executed()
case States.ACCURACY_CHECK:
color: AccColor = self.pipeline.st_accuracy_check()
if color.passed:
acc_color: AccColor = self.pipeline.st_accuracy_check()
c_name: str = acc_color.color.casefold()
self._console.print(f"[green] Accuracy check: [{c_name}]{c_name.upper()}[/{c_name}][/green]")
if acc_color.passed(AccColor.GOOD):
self.pipeline.ev_accuracy_passed()
elif acc_color.passed(AccColor.MODERATE):
self.pipeline.ev_refine_required()
else:
self.pipeline.ev_accuracy_failed()
case States.REFINE_ANSWER:
Expand All @@ -88,13 +89,25 @@ def run(self):
)
self.pipeline.failures[self.pipeline.state.value] += 1 if not execution_status else 0
self._console.print(f"[green]{execution_status_str}[/green]")
self.pipeline.iteractions += 1

final_state: States = self.pipeline.state
final_state_str: str = '[green]√ Succeeded[/green] ' if final_state == States.COMPLETE else '[red]X Failed [/red]'
self._console.print(f"[cyan]Pipeline Execution {final_state_str} [cyan][{final_state}][/cyan]")
final_state_str: str = '[green]√ Succeeded[/green] ' \
if final_state == States.COMPLETE \
else '[red]X Failed [/red]'
self._console.print(
f"\n[cyan]Pipeline execution {final_state_str} [cyan][{final_state}][/cyan] "
f"after [yellow]{self.pipeline.iteractions}[/yellow] iteractions\n"
)
all_failures: str = indent(
os.linesep.join([f"- {k}: {c}" for k, c in self.pipeline.failures.items()]),
' '
)
self._console.print(f"Failures:\n{all_failures}")

if final_state != States.COMPLETE:
self._console.print(f"Failed to generate a response")
retries: int = self.pipeline.failures[self.pipeline.state.value]
self._console.print(f"Failed to generate a response after {retries} retries")


if __name__ == '__main__':
Expand Down
26 changes: 15 additions & 11 deletions src/main/askai/core/processors/splitter/splitter_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SplitterPipeline:

state: States

FAKE_SLEEP: int = 1
FAKE_SLEEP: float = 0.0

def __init__(self, query: AnyStr):
self._transitions: list[Transition] = [t for t in TRANSITIONS]
Expand All @@ -52,6 +52,14 @@ def __init__(self, query: AnyStr):
def failures(self) -> dict[str, int]:
return self._failures

@property
def iteractions(self) -> int:
return self._iteractions

@iteractions.setter
def iteractions(self, value: int):
self._iteractions = value

@property
def plan(self) -> ActionPlan:
return self._plan
Expand All @@ -65,22 +73,19 @@ def track_previous(self) -> None:

def has_next(self) -> bool:
"""TODO"""
return len(self.plan.tasks) > 0 if self.plan else False
# return len(self.plan.tasks) > 0 if self.plan else False
return random.choice([True, False])

def is_direct(self) -> bool:
"""TODO"""
return self.plan.is_direct if self.plan else True
# return self.plan.is_direct if self.plan else True
return random.choice([True, False])

def st_startup(self) -> bool:
result = random.choice([True, False])
pause.seconds(self.FAKE_SLEEP)
return result

def st_query_queued(self) -> bool:
result = random.choice([True, False])
pause.seconds(self.FAKE_SLEEP)
return result

def st_model_select(self) -> bool:
result = random.choice([True, False])
pause.seconds(self.FAKE_SLEEP)
Expand All @@ -91,11 +96,10 @@ def st_task_split(self) -> tuple[bool, bool]:
pause.seconds(self.FAKE_SLEEP)
return result1, result2

def st_execute_next(self) -> tuple[AccColor, bool]:
color = AccColor.value_of(random.choice(AccColor.names()))
def st_execute_next(self) -> bool:
result = random.choice([True, False])
pause.seconds(self.FAKE_SLEEP)
return color, result
return result

def st_accuracy_check(self) -> AccColor:
color = AccColor.value_of(random.choice(AccColor.names()))
Expand Down
1 change: 0 additions & 1 deletion src/main/askai/core/processors/splitter/splitter_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class States(Enumeration):
"""Enumeration of possible task splitter states."""
# fmt: off
STARTUP = 'Processing query'
QUERY_QUEUED = 'Queuing Query'
MODEL_SELECT = 'Selecting Model'
TASK_SPLIT = 'Splitting Task'
ACCURACY_CHECK = 'Checking Accuracy'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@

# Define transitions from the workflow
TRANSITIONS = [
{'trigger': 'ev_pipeline_started', 'source': States.STARTUP, 'dest': States.QUERY_QUEUED},
{'trigger': 'ev_query_queued', 'source': States.QUERY_QUEUED, 'dest': States.MODEL_SELECT},
{'trigger': 'ev_pipeline_started', 'source': States.STARTUP, 'dest': States.MODEL_SELECT},
{'trigger': 'ev_model_selected', 'source': States.MODEL_SELECT, 'dest': States.TASK_SPLIT},

{'trigger': 'ev_direct_answer', 'source': States.TASK_SPLIT, 'dest': States.ACCURACY_CHECK},
Expand All @@ -35,7 +34,7 @@
{'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.EXECUTE_TASK, 'conditions': ['has_next']},
{'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.COMPLETE, 'unless': ['has_next']},
{'trigger': 'ev_accuracy_failed', 'source': States.ACCURACY_CHECK, 'dest': States.EXECUTE_TASK, 'conditions': ['has_next']},
{'trigger': 'ev_refine_required', 'source': States.ACCURACY_CHECK, 'dest': States.REFINE_ANSWER, 'conditions': ['is_direct']},
{'trigger': 'ev_refine_required', 'source': States.ACCURACY_CHECK, 'dest': States.REFINE_ANSWER, 'unless': ['has_next']},

{'trigger': 'ev_answer_refined', 'source': States.REFINE_ANSWER, 'dest': States.COMPLETE},

Expand Down

0 comments on commit 9b2caaf

Please sign in to comment.