Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 37 additions & 39 deletions agent_gateway/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
class AgentGatewayError(Exception):
def __init__(self, message):
self.message = message
gateway_logger.log(logging.ERROR, self.message)
super().__init__(self.message)


Expand All @@ -50,20 +51,19 @@ def __init__(self, session, llm) -> None:
async def arun(self, prompt: str) -> str:
"""Run the LLM."""
headers, url, data = self._prepare_llm_request(prompt=prompt)
gateway_logger.log(logging.DEBUG, "Cortex Request URL\n", url, block=True)
gateway_logger.log(logging.DEBUG, "Cortex Request Data\n", data, block=True)

response_text = await post_cortex_request(url=url, headers=headers, data=data)
gateway_logger.log(
logging.DEBUG,
"Cortex Request Response\n",
response_text,
block=True,
)

try:
response_text = await post_cortex_request(
url=url, headers=headers, data=data
)
except Exception as e:
raise AgentGatewayError(
message=f"Failed Cortex LLM Request. See details:{str(e)}"
) from e

if "choices" not in response_text:
raise AgentGatewayError(
message=f"Failed Cortex LLM Request. Missing choices in response. See details:{response_text}"
message=f"Invalid Cortex LLM Response. See details:{response_text}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the possibilities here?

Would it only hit this block in the case of Analyst not returning either the text response with suggestions or an SQL expression?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for the complete calls, not analyst. there can be cases where complete is not configured properly (invalid model provided by end user..etc) or the response comes back empty

)

try:
Expand Down Expand Up @@ -135,8 +135,8 @@ def __init__(
snowflake_connection: Union[Session, SnowflakeConnection],
tools: list[Union[Tool, StructuredTool]],
max_retries: int = 2,
planner_llm: str = "mistral-large2", # replace basellm
agent_llm: str = "mistral-large2", # replace basellm
planner_llm: str = "mistral-large2",
agent_llm: str = "mistral-large2",
planner_example_prompt: str = SNOWFLAKE_PLANNER_PROMPT,
planner_example_prompt_replan: Optional[str] = None,
planner_stop: Optional[list[str]] = [END_OF_PLAN],
Expand Down Expand Up @@ -227,37 +227,35 @@ def _parse_fusion_output(self, raw_answer: str) -> str:
if is_replan:
answer = "We couldn't find the information you're looking for. You can try rephrasing your request or validate that the provided tools contain sufficient information."

if answer is None:
raise AgentGatewayError(
message="Unable to parse final answer. Raw answer is:{raw_answer}"
)

return thought, answer, is_replan

def _extract_answer(self, raw_answer):
start_index = raw_answer.find("Action: Finish(")
replan_index = raw_answer.find("Replan")
start_marker = "Action: Finish("
end_marker = "<END_OF_RESPONSE>"
end_parens = raw_answer.rfind(")")

start_index = raw_answer.find(start_marker)
if start_index != -1:
start_index += len("Action: Finish(")
parentheses_count = 1
for i, char in enumerate(raw_answer[start_index:], start_index):
if char == "(":
parentheses_count += 1
elif char == ")":
parentheses_count -= 1
if parentheses_count == 0:
end_index = i
break
else:
# If no corresponding closing parenthesis is found
return None
answer = raw_answer[start_index:end_index]
return answer
else:
if replan_index != 1:
gateway_logger.log(
logging.INFO,
"Unable to answer the request. Replanning....",
block=True,
)
return "Replan required. Consider rephrasing your question."
start_index += len(start_marker)
end_index = raw_answer.find(end_marker, start_index)

if end_index != -1:
return raw_answer[start_index:end_index].strip()
elif end_parens > start_index:
return raw_answer[start_index:end_parens].strip()
else:
return None
return raw_answer[start_index:].strip()

# Handle "Replan" case
if "Replan" in raw_answer:
return "Replan required. Consider rephrasing your question."

return None

def _generate_context_for_replanner(
self, tasks: Mapping[int, Task], fusion_thought: str
Expand Down
3 changes: 2 additions & 1 deletion agent_gateway/gateway/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
class AgentGatewayError(Exception):
def __init__(self, message):
self.message = message
gateway_logger.log(logging.ERROR, message)
super().__init__(self.message)


Expand Down Expand Up @@ -296,7 +297,7 @@ def _parse_snowflake_response(self, data_str):
if "content" in choices["delta"].keys():
completion += choices["delta"]["content"]

gateway_logger.log(logging.DEBUG, f"Planner response:{completion}")
gateway_logger.log(logging.DEBUG, f"LLM Generated Plan:\n{completion}")
return completion

async def plan(self, inputs: dict, is_replan: bool, **kwargs: Any):
Expand Down
1 change: 1 addition & 0 deletions agent_gateway/gateway/task_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
class AgentGatewayError(Exception):
def __init__(self, message):
self.message = message
gateway_logger.log(logging.ERROR, self.message)
super().__init__(self.message)


Expand Down
1 change: 1 addition & 0 deletions agent_gateway/tools/snowflake_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
class SnowflakeError(Exception):
def __init__(self, message):
self.message = message
gateway_logger.log(logging.ERROR, message)
super().__init__(self.message)


Expand Down