From 8cc8e5732953a08aaeb8243a3f9dcb88d779d1b8 Mon Sep 17 00:00:00 2001 From: Alejandro Herrera <149527975+sfc-gh-alherrera@users.noreply.github.com> Date: Fri, 6 Dec 2024 09:08:33 -0500 Subject: [PATCH] fix: refactoring parser to address long response bugs and debug log cleanups (#79) * fix: refactoring parser to address long response bugs and adding debug logs cleanup * refactor: adding error handling for empty parser output * chore: clearing redundant error logging * Update agent_gateway/gateway/gateway.py Co-authored-by: Tyler White --------- Co-authored-by: Tyler White --- agent_gateway/gateway/gateway.py | 76 ++++++++++++------------- agent_gateway/gateway/planner.py | 3 +- agent_gateway/gateway/task_processor.py | 1 + agent_gateway/tools/snowflake_tools.py | 1 + 4 files changed, 41 insertions(+), 40 deletions(-) diff --git a/agent_gateway/gateway/gateway.py b/agent_gateway/gateway/gateway.py index d19fec3..5d49d83 100644 --- a/agent_gateway/gateway/gateway.py +++ b/agent_gateway/gateway/gateway.py @@ -37,6 +37,7 @@ class AgentGatewayError(Exception): def __init__(self, message): self.message = message + gateway_logger.log(logging.ERROR, self.message) super().__init__(self.message) @@ -50,20 +51,19 @@ def __init__(self, session, llm) -> None: async def arun(self, prompt: str) -> str: """Run the LLM.""" headers, url, data = self._prepare_llm_request(prompt=prompt) - gateway_logger.log(logging.DEBUG, "Cortex Request URL\n", url, block=True) - gateway_logger.log(logging.DEBUG, "Cortex Request Data\n", data, block=True) - - response_text = await post_cortex_request(url=url, headers=headers, data=data) - gateway_logger.log( - logging.DEBUG, - "Cortex Request Response\n", - response_text, - block=True, - ) + + try: + response_text = await post_cortex_request( + url=url, headers=headers, data=data + ) + except Exception as e: + raise AgentGatewayError( + message=f"Failed Cortex LLM Request. See details:{str(e)}" + ) from e if "choices" not in response_text: raise AgentGatewayError( - message=f"Failed Cortex LLM Request. Missing choices in response. See details:{response_text}" + message=f"Invalid Cortex LLM Response. See details:{response_text}" ) try: @@ -135,8 +135,8 @@ def __init__( snowflake_connection: Union[Session, SnowflakeConnection], tools: list[Union[Tool, StructuredTool]], max_retries: int = 2, - planner_llm: str = "mistral-large2", # replace basellm - agent_llm: str = "mistral-large2", # replace basellm + planner_llm: str = "mistral-large2", + agent_llm: str = "mistral-large2", planner_example_prompt: str = SNOWFLAKE_PLANNER_PROMPT, planner_example_prompt_replan: Optional[str] = None, planner_stop: Optional[list[str]] = [END_OF_PLAN], @@ -227,37 +227,35 @@ def _parse_fusion_output(self, raw_answer: str) -> str: if is_replan: answer = "We couldn't find the information you're looking for. You can try rephrasing your request or validate that the provided tools contain sufficient information." + if answer is None: + raise AgentGatewayError( + message="Unable to parse final answer. Raw answer is:{raw_answer}" + ) + return thought, answer, is_replan def _extract_answer(self, raw_answer): - start_index = raw_answer.find("Action: Finish(") - replan_index = raw_answer.find("Replan") + start_marker = "Action: Finish(" + end_marker = "" + end_parens = raw_answer.rfind(")") + + start_index = raw_answer.find(start_marker) if start_index != -1: - start_index += len("Action: Finish(") - parentheses_count = 1 - for i, char in enumerate(raw_answer[start_index:], start_index): - if char == "(": - parentheses_count += 1 - elif char == ")": - parentheses_count -= 1 - if parentheses_count == 0: - end_index = i - break - else: - # If no corresponding closing parenthesis is found - return None - answer = raw_answer[start_index:end_index] - return answer - else: - if replan_index != 1: - gateway_logger.log( - logging.INFO, - "Unable to answer the request. Replanning....", - block=True, - ) - return "Replan required. Consider rephrasing your question." + start_index += len(start_marker) + end_index = raw_answer.find(end_marker, start_index) + + if end_index != -1: + return raw_answer[start_index:end_index].strip() + elif end_parens > start_index: + return raw_answer[start_index:end_parens].strip() else: - return None + return raw_answer[start_index:].strip() + + # Handle "Replan" case + if "Replan" in raw_answer: + return "Replan required. Consider rephrasing your question." + + return None def _generate_context_for_replanner( self, tasks: Mapping[int, Task], fusion_thought: str diff --git a/agent_gateway/gateway/planner.py b/agent_gateway/gateway/planner.py index 8183cc7..529b2f1 100644 --- a/agent_gateway/gateway/planner.py +++ b/agent_gateway/gateway/planner.py @@ -39,6 +39,7 @@ class AgentGatewayError(Exception): def __init__(self, message): self.message = message + gateway_logger.log(logging.ERROR, message) super().__init__(self.message) @@ -296,7 +297,7 @@ def _parse_snowflake_response(self, data_str): if "content" in choices["delta"].keys(): completion += choices["delta"]["content"] - gateway_logger.log(logging.DEBUG, f"Planner response:{completion}") + gateway_logger.log(logging.DEBUG, f"LLM Generated Plan:\n{completion}") return completion async def plan(self, inputs: dict, is_replan: bool, **kwargs: Any): diff --git a/agent_gateway/gateway/task_processor.py b/agent_gateway/gateway/task_processor.py index 42dd168..2941b45 100644 --- a/agent_gateway/gateway/task_processor.py +++ b/agent_gateway/gateway/task_processor.py @@ -27,6 +27,7 @@ class AgentGatewayError(Exception): def __init__(self, message): self.message = message + gateway_logger.log(logging.ERROR, self.message) super().__init__(self.message) diff --git a/agent_gateway/tools/snowflake_tools.py b/agent_gateway/tools/snowflake_tools.py index fe9d988..4bf19ed 100644 --- a/agent_gateway/tools/snowflake_tools.py +++ b/agent_gateway/tools/snowflake_tools.py @@ -34,6 +34,7 @@ class SnowflakeError(Exception): def __init__(self, message): self.message = message + gateway_logger.log(logging.ERROR, message) super().__init__(self.message)