diff --git a/src/core/orchestrator.py b/src/core/orchestrator.py index de00367..d0c1d00 100644 --- a/src/core/orchestrator.py +++ b/src/core/orchestrator.py @@ -11,7 +11,7 @@ from typing import Any, Optional import importlib from config.agent_prompts.base_agent_prompt import BaseAgentPrompt - +from collections import defaultdict from omegaconf import DictConfig @@ -124,6 +124,45 @@ def __init__( ): self.sub_agent_llm_client.task_log = task_log + # Record used subtask / q / Query + self.max_repeat_queries = 3 + self.used_queries = {} + + def _get_query_from_tool_call( + self, tool_name: str, arguments: dict + ) -> Optional[str]: + """ + Extracts the query from tool call arguments based on tool_name. + Supports google_search, wiki_get_page_content, search_wiki_revision, search_archived_webpage, and scrape_website. + """ + if tool_name == "google_search": + return "q:" + arguments.get("q") + elif tool_name == "wiki_get_page_content": + return "entity:" + arguments.get("entity") + elif tool_name == "search_wiki_revision": + return ( + "entity:" + + arguments.get("entity") + + "_year:" + + str(arguments.get("year")) + + "_month:" + + str(arguments.get("month")) + ) + elif tool_name == "search_archived_webpage": + return ( + "url:" + + arguments.get("url") + + "_year:" + + str(arguments.get("year")) + + "_month:" + + str(arguments.get("month")) + + "_day:" + + str(arguments.get("day")) + ) + elif tool_name == "scrape_website": + return "url:" + arguments.get("url") + return None + async def _handle_llm_call_with_logging( self, system_prompt, @@ -430,6 +469,7 @@ async def run_sub_agent( turn_count = 0 all_tool_results_content_with_id = [] task_failed = False # Track whether task failed + should_hard_stop = False while turn_count < max_turns: turn_count += 1 @@ -521,9 +561,36 @@ async def run_sub_agent( call_start_time = time.time() try: - tool_result = await self.sub_agent_tool_managers[ - sub_agent_name - ].execute_tool_call(server_name, tool_name, arguments) + query_str = self._get_query_from_tool_call(tool_name, arguments) + if query_str: + cache_name = sub_agent_name + "_" + tool_name + self.used_queries.setdefault( + cache_name, defaultdict(lambda: [0, ""]) + ) + count = self.used_queries[cache_name][query_str][0] + cache_result = self.used_queries[cache_name][query_str][1] + if count > 0: + tool_result = { + "server_name": server_name, + "tool_name": tool_name, + "result": f"{cache_result}.\nNotice: This query has already been used in previous {tool_name}. Please try a different query or keyword.", + } + if count >= self.max_repeat_queries: + should_hard_stop = True + self.used_queries[cache_name][query_str][0] += 1 + else: + tool_result = await self.sub_agent_tool_managers[ + sub_agent_name + ].execute_tool_call(server_name, tool_name, arguments) + if "error" not in tool_result: + self.used_queries[cache_name][query_str][1] = ( + tool_result["result"] + ) + self.used_queries[cache_name][query_str][0] += 1 + else: + tool_result = await self.sub_agent_tool_managers[ + sub_agent_name + ].execute_tool_call(server_name, tool_name, arguments) call_end_time = time.time() call_duration_ms = int((call_end_time - call_start_time) * 1000) @@ -603,6 +670,15 @@ async def run_sub_agent( message_history, all_tool_results_content_with_id, tool_calls_exceeded ) + if should_hard_stop: + task_failed = True + self.task_log.log_step( + "too_many_repeated_queries_in_sub_agent", + f"{self.max_repeat_queries} repeated queries in sub agent {sub_agent_name}, stopping the task", + "warning", + ) + break + # Continue execution logger.debug( f"\n=== Sub Agent {sub_agent_name} Completed ({turn_count} turns) ===" @@ -793,6 +869,7 @@ async def run_main_agent( max_turns = sys.maxsize turn_count = 0 task_failed = False # Track whether task failed + should_hard_stop = False while turn_count < max_turns: turn_count += 1 logger.debug(f"\n--- Main Agent Turn {turn_count} ---") @@ -877,13 +954,42 @@ async def run_main_agent( "result": sub_agent_result, } else: - tool_result = ( - await self.main_agent_tool_manager.execute_tool_call( - server_name=server_name, - tool_name=tool_name, - arguments=arguments, + query_str = self._get_query_from_tool_call(tool_name, arguments) + if query_str: + cache_name = "main_" + tool_name + self.used_queries.setdefault( + cache_name, defaultdict(lambda: [0, ""]) + ) + count = self.used_queries[cache_name][query_str][0] + cache_result = self.used_queries[cache_name][query_str][1] + if count > 0: + tool_result = { + "server_name": server_name, + "tool_name": tool_name, + "result": f"{cache_result}.\nNotice: This query has already been used in previous {tool_name}. Please try a different query or keyword.", + } + if count >= self.max_repeat_queries: + should_hard_stop = True + self.used_queries[cache_name][query_str][0] += 1 + else: + tool_result = await self.main_agent_tool_manager.execute_tool_call( + server_name=server_name, + tool_name=tool_name, + arguments=arguments, + ) + if "error" not in tool_result: + self.used_queries[cache_name][query_str][1] = ( + tool_result["result"] + ) + self.used_queries[cache_name][query_str][0] += 1 + else: + tool_result = ( + await self.main_agent_tool_manager.execute_tool_call( + server_name=server_name, + tool_name=tool_name, + arguments=arguments, + ) ) - ) call_end_time = time.time() call_duration_ms = int((call_end_time - call_start_time) * 1000) @@ -959,6 +1065,15 @@ async def run_main_agent( message_history, all_tool_results_content_with_id, tool_calls_exceeded ) + if should_hard_stop: + task_failed = True + self.task_log.log_step( + "too_many_repeated_queries", + f"{self.max_repeat_queries} repeated queries, stopping the task", + "warning", + ) + break + # Record main loop end if turn_count >= max_turns: if ( diff --git a/src/llm/providers/mirothinker_sglang_client.py b/src/llm/providers/mirothinker_sglang_client.py index a4dcca0..c18ec18 100644 --- a/src/llm/providers/mirothinker_sglang_client.py +++ b/src/llm/providers/mirothinker_sglang_client.py @@ -142,6 +142,22 @@ async def _create_message( ) raise Exception("LLM finish_reason is 'stop', but content is empty") + # identify repeated messages and retry + # Check if the last 100 characters of the response appear more than 5 times in the response content. + # If so, treat it as a severe repeat and trigger a retry. + resp_content = response.choices[0].message.content or "" + + if resp_content and len(resp_content) >= 50: + tail_50 = resp_content[-50:] + repeat_count = resp_content.count(tail_50) + if repeat_count > 5: + self.task_log.log_step( + "warning", + "LLM | Repeat Detected", + "Severe repeat: the last 50 chars appeared over 5 times, retrying...", + ) + raise Exception("Severe repeat detected in response, please retry.") + logger.debug( f"LLM call finish_reason: {getattr(response.choices[0], 'finish_reason', 'N/A')}" )