diff --git a/python/valuecell/core/coordinate/orchestrator.py b/python/valuecell/core/coordinate/orchestrator.py index 85e218d79..2c2b78635 100644 --- a/python/valuecell/core/coordinate/orchestrator.py +++ b/python/valuecell/core/coordinate/orchestrator.py @@ -216,7 +216,6 @@ async def close_session(self, session_id: str): session_id, Role.SYSTEM, f"Session closed. {cancelled_count} tasks were cancelled.", - agent_name="orchestrator", ) async def get_session_history(self, session_id: str): @@ -289,6 +288,9 @@ async def _handle_new_request( """Handle a new user request""" session_id = user_input.meta.session_id thread_id = generate_thread_id() + yield self._response_factory.thread_started( + conversation_id=session_id, thread_id=thread_id + ) # Add user message to session await self.session_manager.add_message( @@ -384,6 +386,9 @@ async def _continue_planning( original_user_input = context.get_metadata("original_user_input") thread_id = generate_thread_id() context.thread_id = thread_id + yield self._response_factory.thread_started( + conversation_id=session_id, thread_id=thread_id + ) if not all([planning_task, original_user_input]): yield self._response_factory.plan_failed( @@ -508,7 +513,6 @@ async def _execute_plan_with_input_support( session_id, Role.AGENT, agent_responses[task.agent_name], - agent_name=task.agent_name, ) agent_responses[task.agent_name] = "" @@ -614,7 +618,7 @@ async def _save_remaining_responses(self, session_id: str, agent_responses: dict for agent_name, full_response in agent_responses.items(): if full_response.strip(): await self.session_manager.add_message( - session_id, Role.AGENT, full_response, agent_name=agent_name + session_id, Role.AGENT, full_response ) diff --git a/python/valuecell/core/coordinate/response.py b/python/valuecell/core/coordinate/response.py index 399af08c2..ab9f73190 100644 --- a/python/valuecell/core/coordinate/response.py +++ b/python/valuecell/core/coordinate/response.py @@ -16,6 +16,7 @@ SystemFailedResponse, TaskCompletedResponse, TaskFailedResponse, + ThreadStartedResponse, ToolCallPayload, ToolCallResponse, UnifiedResponseData, @@ -28,6 +29,15 @@ def conversation_started(self, conversation_id: str) -> ConversationStartedRespo data=UnifiedResponseData(conversation_id=conversation_id) ) + def thread_started( + self, conversation_id: str, thread_id: str + ) -> ThreadStartedResponse: + return ThreadStartedResponse( + data=UnifiedResponseData( + conversation_id=conversation_id, thread_id=thread_id + ) + ) + def system_failed(self, conversation_id: str, content: str) -> SystemFailedResponse: return SystemFailedResponse( data=UnifiedResponseData( diff --git a/python/valuecell/core/coordinate/tests/test_orchestrator.py b/python/valuecell/core/coordinate/tests/test_orchestrator.py index 30caa4bbd..ff5e711ce 100644 --- a/python/valuecell/core/coordinate/tests/test_orchestrator.py +++ b/python/valuecell/core/coordinate/tests/test_orchestrator.py @@ -308,9 +308,9 @@ async def test_planner_error( async for chunk in orchestrator.process_user_input(sample_user_input): out.append(chunk) - assert len(out) == 1 - assert "(Error)" in out[0].data.payload.content - assert "Planning failed" in out[0].data.payload.content + assert len(out) == 2 + assert "(Error)" in out[1].data.payload.content + assert "Planning failed" in out[1].data.payload.content @pytest.mark.asyncio @@ -328,7 +328,7 @@ async def test_agent_connection_error( async for chunk in orchestrator.process_user_input(sample_user_input): out.append(chunk) - assert any("(Error)" in c.data.payload.content for c in out) + assert any("(Error)" in c.data.payload.content for c in out if c.data.payload) @pytest.mark.asyncio diff --git a/python/valuecell/core/types.py b/python/valuecell/core/types.py index 89a640adc..ab06e2dc5 100644 --- a/python/valuecell/core/types.py +++ b/python/valuecell/core/types.py @@ -49,6 +49,7 @@ def clear_desired_agent(self) -> None: class SystemResponseEvent(str, Enum): CONVERSATION_STARTED = "conversation_started" + THREAD_STARTED = "thread_started" PLAN_REQUIRE_USER_INPUT = "plan_require_user_input" PLAN_FAILED = "plan_failed" TASK_FAILED = "task_failed" @@ -172,6 +173,13 @@ class ConversationStartedResponse(BaseResponse): ) +class ThreadStartedResponse(BaseResponse): + event: Literal[SystemResponseEvent.THREAD_STARTED] = Field( + SystemResponseEvent.THREAD_STARTED, + description="The event type of the response", + ) + + class PlanRequireUserInputResponse(BaseResponse): event: Literal[SystemResponseEvent.PLAN_REQUIRE_USER_INPUT] = Field( SystemResponseEvent.PLAN_REQUIRE_USER_INPUT,