Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 125 additions & 10 deletions src/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, Optional
import importlib
from config.agent_prompts.base_agent_prompt import BaseAgentPrompt

from collections import defaultdict
from omegaconf import DictConfig


Expand Down Expand Up @@ -124,6 +124,45 @@ def __init__(
):
self.sub_agent_llm_client.task_log = task_log

# Record used subtask / q / Query
self.max_repeat_queries = 3
self.used_queries = {}

def _get_query_from_tool_call(
self, tool_name: str, arguments: dict
) -> Optional[str]:
"""
Extracts the query from tool call arguments based on tool_name.
Supports google_search, wiki_get_page_content, search_wiki_revision, search_archived_webpage, and scrape_website.
"""
if tool_name == "google_search":
return "q:" + arguments.get("q")
elif tool_name == "wiki_get_page_content":
return "entity:" + arguments.get("entity")
elif tool_name == "search_wiki_revision":
return (
"entity:"
+ arguments.get("entity")
+ "_year:"
+ str(arguments.get("year"))
+ "_month:"
+ str(arguments.get("month"))
)
elif tool_name == "search_archived_webpage":
return (
"url:"
+ arguments.get("url")
+ "_year:"
+ str(arguments.get("year"))
+ "_month:"
+ str(arguments.get("month"))
+ "_day:"
+ str(arguments.get("day"))
)
elif tool_name == "scrape_website":
return "url:" + arguments.get("url")
return None

async def _handle_llm_call_with_logging(
self,
system_prompt,
Expand Down Expand Up @@ -430,6 +469,7 @@ async def run_sub_agent(
turn_count = 0
all_tool_results_content_with_id = []
task_failed = False # Track whether task failed
should_hard_stop = False

while turn_count < max_turns:
turn_count += 1
Expand Down Expand Up @@ -521,9 +561,36 @@ async def run_sub_agent(

call_start_time = time.time()
try:
tool_result = await self.sub_agent_tool_managers[
sub_agent_name
].execute_tool_call(server_name, tool_name, arguments)
query_str = self._get_query_from_tool_call(tool_name, arguments)
if query_str:
cache_name = sub_agent_name + "_" + tool_name
self.used_queries.setdefault(
cache_name, defaultdict(lambda: [0, ""])
)
count = self.used_queries[cache_name][query_str][0]
cache_result = self.used_queries[cache_name][query_str][1]
if count > 0:
tool_result = {
"server_name": server_name,
"tool_name": tool_name,
"result": f"{cache_result}.\nNotice: This query has already been used in previous {tool_name}. Please try a different query or keyword.",
}
if count >= self.max_repeat_queries:
should_hard_stop = True
self.used_queries[cache_name][query_str][0] += 1
else:
tool_result = await self.sub_agent_tool_managers[
sub_agent_name
].execute_tool_call(server_name, tool_name, arguments)
if "error" not in tool_result:
self.used_queries[cache_name][query_str][1] = (
tool_result["result"]
)
self.used_queries[cache_name][query_str][0] += 1
else:
tool_result = await self.sub_agent_tool_managers[
sub_agent_name
].execute_tool_call(server_name, tool_name, arguments)

call_end_time = time.time()
call_duration_ms = int((call_end_time - call_start_time) * 1000)
Expand Down Expand Up @@ -603,6 +670,15 @@ async def run_sub_agent(
message_history, all_tool_results_content_with_id, tool_calls_exceeded
)

if should_hard_stop:
task_failed = True
self.task_log.log_step(
"too_many_repeated_queries_in_sub_agent",
f"{self.max_repeat_queries} repeated queries in sub agent {sub_agent_name}, stopping the task",
"warning",
)
break

# Continue execution
logger.debug(
f"\n=== Sub Agent {sub_agent_name} Completed ({turn_count} turns) ==="
Expand Down Expand Up @@ -793,6 +869,7 @@ async def run_main_agent(
max_turns = sys.maxsize
turn_count = 0
task_failed = False # Track whether task failed
should_hard_stop = False
while turn_count < max_turns:
turn_count += 1
logger.debug(f"\n--- Main Agent Turn {turn_count} ---")
Expand Down Expand Up @@ -877,13 +954,42 @@ async def run_main_agent(
"result": sub_agent_result,
}
else:
tool_result = (
await self.main_agent_tool_manager.execute_tool_call(
server_name=server_name,
tool_name=tool_name,
arguments=arguments,
query_str = self._get_query_from_tool_call(tool_name, arguments)
if query_str:
cache_name = "main_" + tool_name
self.used_queries.setdefault(
cache_name, defaultdict(lambda: [0, ""])
)
count = self.used_queries[cache_name][query_str][0]
cache_result = self.used_queries[cache_name][query_str][1]
if count > 0:
tool_result = {
"server_name": server_name,
"tool_name": tool_name,
"result": f"{cache_result}.\nNotice: This query has already been used in previous {tool_name}. Please try a different query or keyword.",
}
if count >= self.max_repeat_queries:
should_hard_stop = True
self.used_queries[cache_name][query_str][0] += 1
else:
tool_result = await self.main_agent_tool_manager.execute_tool_call(
server_name=server_name,
tool_name=tool_name,
arguments=arguments,
)
if "error" not in tool_result:
self.used_queries[cache_name][query_str][1] = (
tool_result["result"]
)
self.used_queries[cache_name][query_str][0] += 1
else:
tool_result = (
await self.main_agent_tool_manager.execute_tool_call(
server_name=server_name,
tool_name=tool_name,
arguments=arguments,
)
)
)

call_end_time = time.time()
call_duration_ms = int((call_end_time - call_start_time) * 1000)
Expand Down Expand Up @@ -959,6 +1065,15 @@ async def run_main_agent(
message_history, all_tool_results_content_with_id, tool_calls_exceeded
)

if should_hard_stop:
task_failed = True
self.task_log.log_step(
"too_many_repeated_queries",
f"{self.max_repeat_queries} repeated queries, stopping the task",
"warning",
)
break

# Record main loop end
if turn_count >= max_turns:
if (
Expand Down
16 changes: 16 additions & 0 deletions src/llm/providers/mirothinker_sglang_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,22 @@ async def _create_message(
)
raise Exception("LLM finish_reason is 'stop', but content is empty")

# identify repeated messages and retry
# Check if the last 100 characters of the response appear more than 5 times in the response content.
# If so, treat it as a severe repeat and trigger a retry.
resp_content = response.choices[0].message.content or ""

if resp_content and len(resp_content) >= 50:
tail_50 = resp_content[-50:]
repeat_count = resp_content.count(tail_50)
if repeat_count > 5:
self.task_log.log_step(
"warning",
"LLM | Repeat Detected",
"Severe repeat: the last 50 chars appeared over 5 times, retrying...",
)
raise Exception("Severe repeat detected in response, please retry.")

logger.debug(
f"LLM call finish_reason: {getattr(response.choices[0], 'finish_reason', 'N/A')}"
)
Expand Down
Loading