diff --git a/src/main/askai/core/processors/splitter/splitter_executor.py b/src/main/askai/core/processors/splitter/splitter_executor.py index b7a05440..d7f35d3f 100644 --- a/src/main/askai/core/processors/splitter/splitter_executor.py +++ b/src/main/askai/core/processors/splitter/splitter_executor.py @@ -12,6 +12,8 @@ Copyright (c) 2024, HomeSetup """ +import os +from textwrap import indent from threading import Thread from hspylib.core.decorator.decorators import profiled @@ -39,19 +41,18 @@ def pipeline(self) -> SplitterPipeline: def run(self): with self._console.status("Processing query...", spinner="dots") as spinner: max_retries: int = configs.max_router_retries + max_interactions: int = configs.max_iteractions while not self.pipeline.state == States.COMPLETE: self.pipeline.track_previous() spinner.update(f"[green]{self.pipeline.state.value}[/green]") - if 0 < max_retries < self.pipeline.failures[self.pipeline.state.value]: + if (0 < max_retries < self.pipeline.failures[self.pipeline.state.value]) \ + and (0 < max_interactions < self.pipeline.iteractions): spinner.update(f"\nMax state retries reached: {max_retries}") break match self.pipeline.state: case States.STARTUP: if self.pipeline.st_startup(): self.pipeline.ev_pipeline_started() - case States.QUERY_QUEUED: - if self.pipeline.st_query_queued(): - self.pipeline.ev_query_queued() case States.MODEL_SELECT: if self.pipeline.st_model_select(): self.pipeline.ev_model_selected() @@ -63,16 +64,16 @@ def run(self): else: self.pipeline.ev_plan_created() case States.EXECUTE_TASK: - color, has_next = self.pipeline.st_execute_next() - if color.passed: - if has_next: - self.pipeline.st_execute_next() - else: - self.pipeline.ev_task_executed() + if self.pipeline.st_execute_next(): + self.pipeline.ev_task_executed() case States.ACCURACY_CHECK: - color: AccColor = self.pipeline.st_accuracy_check() - if color.passed: + acc_color: AccColor = self.pipeline.st_accuracy_check() + c_name: str = acc_color.color.casefold() + self._console.print(f"[green] Accuracy check: [{c_name}]{c_name.upper()}[/{c_name}][/green]") + if acc_color.passed(AccColor.GOOD): self.pipeline.ev_accuracy_passed() + elif acc_color.passed(AccColor.MODERATE): + self.pipeline.ev_refine_required() else: self.pipeline.ev_accuracy_failed() case States.REFINE_ANSWER: @@ -88,13 +89,25 @@ def run(self): ) self.pipeline.failures[self.pipeline.state.value] += 1 if not execution_status else 0 self._console.print(f"[green]{execution_status_str}[/green]") + self.pipeline.iteractions += 1 final_state: States = self.pipeline.state - final_state_str: str = '[green]√ Succeeded[/green] ' if final_state == States.COMPLETE else '[red]X Failed [/red]' - self._console.print(f"[cyan]Pipeline Execution {final_state_str} [cyan][{final_state}][/cyan]") + final_state_str: str = '[green]√ Succeeded[/green] ' \ + if final_state == States.COMPLETE \ + else '[red]X Failed [/red]' + self._console.print( + f"\n[cyan]Pipeline execution {final_state_str} [cyan][{final_state}][/cyan] " + f"after [yellow]{self.pipeline.iteractions}[/yellow] iteractions\n" + ) + all_failures: str = indent( + os.linesep.join([f"- {k}: {c}" for k, c in self.pipeline.failures.items()]), + ' ' + ) + self._console.print(f"Failures:\n{all_failures}") if final_state != States.COMPLETE: - self._console.print(f"Failed to generate a response") + retries: int = self.pipeline.failures[self.pipeline.state.value] + self._console.print(f"Failed to generate a response after {retries} retries") if __name__ == '__main__': diff --git a/src/main/askai/core/processors/splitter/splitter_pipeline.py b/src/main/askai/core/processors/splitter/splitter_pipeline.py index 177be609..f8a8a09e 100644 --- a/src/main/askai/core/processors/splitter/splitter_pipeline.py +++ b/src/main/askai/core/processors/splitter/splitter_pipeline.py @@ -30,7 +30,7 @@ class SplitterPipeline: state: States - FAKE_SLEEP: int = 1 + FAKE_SLEEP: float = 0.0 def __init__(self, query: AnyStr): self._transitions: list[Transition] = [t for t in TRANSITIONS] @@ -52,6 +52,14 @@ def __init__(self, query: AnyStr): def failures(self) -> dict[str, int]: return self._failures + @property + def iteractions(self) -> int: + return self._iteractions + + @iteractions.setter + def iteractions(self, value: int): + self._iteractions = value + @property def plan(self) -> ActionPlan: return self._plan @@ -65,22 +73,19 @@ def track_previous(self) -> None: def has_next(self) -> bool: """TODO""" - return len(self.plan.tasks) > 0 if self.plan else False + # return len(self.plan.tasks) > 0 if self.plan else False + return random.choice([True, False]) def is_direct(self) -> bool: """TODO""" - return self.plan.is_direct if self.plan else True + # return self.plan.is_direct if self.plan else True + return random.choice([True, False]) def st_startup(self) -> bool: result = random.choice([True, False]) pause.seconds(self.FAKE_SLEEP) return result - def st_query_queued(self) -> bool: - result = random.choice([True, False]) - pause.seconds(self.FAKE_SLEEP) - return result - def st_model_select(self) -> bool: result = random.choice([True, False]) pause.seconds(self.FAKE_SLEEP) @@ -91,11 +96,10 @@ def st_task_split(self) -> tuple[bool, bool]: pause.seconds(self.FAKE_SLEEP) return result1, result2 - def st_execute_next(self) -> tuple[AccColor, bool]: - color = AccColor.value_of(random.choice(AccColor.names())) + def st_execute_next(self) -> bool: result = random.choice([True, False]) pause.seconds(self.FAKE_SLEEP) - return color, result + return result def st_accuracy_check(self) -> AccColor: color = AccColor.value_of(random.choice(AccColor.names())) diff --git a/src/main/askai/core/processors/splitter/splitter_states.py b/src/main/askai/core/processors/splitter/splitter_states.py index babf7f76..aaded8b3 100644 --- a/src/main/askai/core/processors/splitter/splitter_states.py +++ b/src/main/askai/core/processors/splitter/splitter_states.py @@ -19,7 +19,6 @@ class States(Enumeration): """Enumeration of possible task splitter states.""" # fmt: off STARTUP = 'Processing query' - QUERY_QUEUED = 'Queuing Query' MODEL_SELECT = 'Selecting Model' TASK_SPLIT = 'Splitting Task' ACCURACY_CHECK = 'Checking Accuracy' diff --git a/src/main/askai/core/processors/splitter/splitter_transitions.py b/src/main/askai/core/processors/splitter/splitter_transitions.py index d38a80b4..05cb71bf 100644 --- a/src/main/askai/core/processors/splitter/splitter_transitions.py +++ b/src/main/askai/core/processors/splitter/splitter_transitions.py @@ -21,8 +21,7 @@ # Define transitions from the workflow TRANSITIONS = [ - {'trigger': 'ev_pipeline_started', 'source': States.STARTUP, 'dest': States.QUERY_QUEUED}, - {'trigger': 'ev_query_queued', 'source': States.QUERY_QUEUED, 'dest': States.MODEL_SELECT}, + {'trigger': 'ev_pipeline_started', 'source': States.STARTUP, 'dest': States.MODEL_SELECT}, {'trigger': 'ev_model_selected', 'source': States.MODEL_SELECT, 'dest': States.TASK_SPLIT}, {'trigger': 'ev_direct_answer', 'source': States.TASK_SPLIT, 'dest': States.ACCURACY_CHECK}, @@ -35,7 +34,7 @@ {'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.EXECUTE_TASK, 'conditions': ['has_next']}, {'trigger': 'ev_accuracy_passed', 'source': States.ACCURACY_CHECK, 'dest': States.COMPLETE, 'unless': ['has_next']}, {'trigger': 'ev_accuracy_failed', 'source': States.ACCURACY_CHECK, 'dest': States.EXECUTE_TASK, 'conditions': ['has_next']}, - {'trigger': 'ev_refine_required', 'source': States.ACCURACY_CHECK, 'dest': States.REFINE_ANSWER, 'conditions': ['is_direct']}, + {'trigger': 'ev_refine_required', 'source': States.ACCURACY_CHECK, 'dest': States.REFINE_ANSWER, 'unless': ['has_next']}, {'trigger': 'ev_answer_refined', 'source': States.REFINE_ANSWER, 'dest': States.COMPLETE},