diff --git a/python/valuecell/core/agent/card.py b/python/valuecell/core/agent/card.py index 4bdf79efd..ab36ff5f6 100644 --- a/python/valuecell/core/agent/card.py +++ b/python/valuecell/core/agent/card.py @@ -21,9 +21,10 @@ def parse_local_agent_card_dict(agent_card_dict: dict) -> Optional[AgentCard]: if not isinstance(agent_card_dict, dict): return None # Defined by us, remove fields that are not part of AgentCard - for field in FIELDS_UNDEFINED_IN_AGENT_CARD_MODEL: - if field in agent_card_dict: - del agent_card_dict[field] + for field in FIELDS_UNDEFINED_IN_AGENT_CARD_MODEL.intersection( + agent_card_dict.keys() + ): + del agent_card_dict[field] # Requested fields as per AgentCard model if "description" not in agent_card_dict: @@ -73,14 +74,12 @@ def find_local_agent_card_by_agent_name( with open(json_file, "r", encoding="utf-8") as f: agent_card_dict = json.load(f) - # Check if this agent config has the matching name - if not isinstance(agent_card_dict, dict): - continue - if agent_card_dict.get("name") != agent_name: - continue - if not agent_card_dict.get("enabled", True): - continue - return parse_local_agent_card_dict(agent_card_dict) + if ( + isinstance(agent_card_dict, dict) + and agent_card_dict.get("name") == agent_name + and agent_card_dict.get("enabled", True) + ): + return parse_local_agent_card_dict(agent_card_dict) except (json.JSONDecodeError, IOError): # Skip files that can't be read or parsed diff --git a/python/valuecell/core/plan/planner.py b/python/valuecell/core/plan/planner.py index 081a1f7dc..26e9a8ef3 100644 --- a/python/valuecell/core/plan/planner.py +++ b/python/valuecell/core/plan/planner.py @@ -86,10 +86,11 @@ class ExecutionPlanner: def __init__( self, agent_connections: RemoteConnections, + agent_name: str = "super_agent", ): self.agent_connections = agent_connections # Fetch model via utils module reference so tests can monkeypatch it reliably - model = model_utils_mod.get_model_for_agent("super_agent") + model = model_utils_mod.get_model_for_agent(agent_name) self.agent = Agent( model=model, tools=[ diff --git a/python/valuecell/core/plan/service.py b/python/valuecell/core/plan/service.py index 48b0729b2..0eac5576c 100644 --- a/python/valuecell/core/plan/service.py +++ b/python/valuecell/core/plan/service.py @@ -48,8 +48,11 @@ def __init__( agent_connections: RemoteConnections, execution_planner: ExecutionPlanner | None = None, user_input_registry: UserInputRegistry | None = None, + agent_name: str = "super_agent", ) -> None: - self._planner = execution_planner or ExecutionPlanner(agent_connections) + self._planner = execution_planner or ExecutionPlanner( + agent_connections, agent_name=agent_name + ) self._input_registry = user_input_registry or UserInputRegistry() @property diff --git a/python/valuecell/core/plan/tests/test_planner.py b/python/valuecell/core/plan/tests/test_planner.py index 01bce7a2e..e6a433138 100644 --- a/python/valuecell/core/plan/tests/test_planner.py +++ b/python/valuecell/core/plan/tests/test_planner.py @@ -239,3 +239,46 @@ def __init__(self): # Not found branch missing = planner.tool_get_agent_description("MissingAgent") assert "could not be found" in missing + + +def test_execution_planner_uses_custom_agent_name(monkeypatch: pytest.MonkeyPatch): + """ExecutionPlanner uses the agent_name passed to its constructor.""" + + agent_name = "custom_agent" + + class FakeAgent: + def __init__(self, *args, **kwargs): + self.model = SimpleNamespace(id="fake-model", provider="fake-provider") + + def run(self, *args, **kwargs): + return SimpleNamespace( + is_paused=False, + tools_requiring_user_input=[], + tools=[], + content=PlannerResponse.model_validate( + { + "adequate": True, + "reason": "ok", + "tasks": [], + "guidance_message": None, + } + ), + ) + + monkeypatch.setattr(planner_mod, "Agent", FakeAgent) + + called_with_agent_name = None + + def fake_get_model_for_agent(name): + nonlocal called_with_agent_name + called_with_agent_name = name + return "stub-model" + + monkeypatch.setattr( + model_utils_mod, "get_model_for_agent", fake_get_model_for_agent + ) + monkeypatch.setattr(planner_mod, "agent_debug_mode_enabled", lambda: False) + + ExecutionPlanner(StubConnections(), agent_name=agent_name) + + assert called_with_agent_name == agent_name