diff --git a/python/valuecell/agents/sec_agent.py b/python/valuecell/agents/sec_agent.py index 485d1088f..1c13aaaec 100644 --- a/python/valuecell/agents/sec_agent.py +++ b/python/valuecell/agents/sec_agent.py @@ -4,16 +4,15 @@ import os from datetime import datetime from enum import Enum -from typing import Dict, Iterator, AsyncGenerator +from typing import AsyncGenerator, Dict, Iterator -from agno.agent import Agent, RunResponse, RunResponseEvent # noqa +from agno.agent import Agent, RunResponseEvent from agno.models.openrouter import OpenRouter from edgar import Company, set_identity from pydantic import BaseModel, Field, field_validator - -from valuecell.core.agent.responses import streaming, notification -from valuecell.core.types import BaseAgent, StreamResponse from valuecell.core.agent.decorator import create_wrapped_agent +from valuecell.core.agent.responses import notification, streaming +from valuecell.core.types import BaseAgent, StreamResponse # Configure logging logging.basicConfig(level=logging.INFO) @@ -358,8 +357,6 @@ async def _process_financial_data_query( yield streaming.tool_call_completed( event.tool.result, event.tool.tool_call_id, event.tool.tool_name ) - elif event.event == "ReasoningStep": - yield streaming.reasoning(event.reasoning_content) logger.info("Financial data analysis completed") yield streaming.done() @@ -454,8 +451,6 @@ async def _process_fund_holdings_query( yield streaming.tool_call_completed( event.tool.result, event.tool.tool_call_id, event.tool.tool_name ) - elif event.event == "ReasoningStep": - yield streaming.reasoning(event.reasoning_content) logger.info("Financial data analysis completed") streaming.done() diff --git a/python/valuecell/core/agent/responses.py b/python/valuecell/core/agent/responses.py index 8ef129d2f..4a515818b 100644 --- a/python/valuecell/core/agent/responses.py +++ b/python/valuecell/core/agent/responses.py @@ -8,7 +8,6 @@ NotifyResponseEvent, StreamResponse, StreamResponseEvent, - SystemResponseEvent, TaskStatusEvent, ToolCallPayload, ) @@ -63,7 +62,7 @@ def done(self, content: Optional[str] = None) -> StreamResponse: def failed(self, content: Optional[str] = None) -> StreamResponse: return StreamResponse( content=content, - event=SystemResponseEvent.TASK_FAILED, + event=TaskStatusEvent.TASK_FAILED, ) @@ -95,7 +94,7 @@ def done(self, content: Optional[str] = None) -> NotifyResponse: def failed(self, content: Optional[str] = None) -> NotifyResponse: return NotifyResponse( content=content, - event=SystemResponseEvent.TASK_FAILED, + event=TaskStatusEvent.TASK_FAILED, ) @@ -118,7 +117,7 @@ def is_task_completed(response_type) -> bool: @staticmethod def is_task_failed(response_type) -> bool: return response_type in { - SystemResponseEvent.TASK_FAILED, + TaskStatusEvent.TASK_FAILED, } @staticmethod diff --git a/python/valuecell/core/coordinate/orchestrator.py b/python/valuecell/core/coordinate/orchestrator.py index 7e4a61510..ec387cb97 100644 --- a/python/valuecell/core/coordinate/orchestrator.py +++ b/python/valuecell/core/coordinate/orchestrator.py @@ -172,6 +172,8 @@ async def process_user_input( session_id, f"(Error) Error processing request: {str(e)}", ) + finally: + yield self._response_factory.done(session_id) async def provide_user_input(self, session_id: str, response: str): """ @@ -497,8 +499,6 @@ async def _execute_plan_with_input_support( error_msg, ) - yield self._response_factory.done(session_id, thread_id) - async def _execute_task_with_input_support( self, task: Task, thread_id: str, metadata: Optional[dict] = None ) -> AsyncGenerator[BaseResponse, None]: @@ -512,7 +512,9 @@ async def _execute_task_with_input_support( """ try: # Start task execution - await self.task_manager.start_task(task.task_id) + task_id = task.task_id + conversation_id = task.session_id + await self.task_manager.start_task(task_id) # Get agent connection agent_name = task.agent_name @@ -532,7 +534,7 @@ async def _execute_task_with_input_support( # Send message to agent remote_response = await client.send_message( task.query, - session_id=task.session_id, + session_id=conversation_id, metadata=metadata, streaming=agent_card.capabilities.streaming, ) @@ -541,6 +543,11 @@ async def _execute_task_with_input_support( async for remote_task, event in remote_response: if event is None and remote_task.status.state == TaskState.submitted: task.remote_task_ids.append(remote_task.id) + yield self._response_factory.task_completed( + conversation_id=conversation_id, + thread_id=thread_id, + task_id=task_id, + ) continue if isinstance(event, TaskStatusUpdateEvent): @@ -553,43 +560,41 @@ async def _execute_task_with_input_support( # Apply side effects for eff in result.side_effects: if eff.kind == SideEffectKind.FAIL_TASK: - await self.task_manager.fail_task( - task.task_id, eff.reason or "" - ) + await self.task_manager.fail_task(task_id, eff.reason or "") if result.done: return continue if isinstance(event, TaskArtifactUpdateEvent): logger.info( - f"Received unexpected artifact update for task {task.task_id}: {event}" + f"Received unexpected artifact update for task {task_id}: {event}" ) continue # Complete task successfully - await self.task_manager.complete_task(task.task_id) + await self.task_manager.complete_task(task_id) yield self._response_factory.task_completed( - conversation_id=task.session_id, + conversation_id=conversation_id, thread_id=thread_id, - task_id=task.task_id, + task_id=task_id, ) # Finalize buffered aggregates for this task (explicit flush at task end) items = self._response_buffer.flush_task( - conversation_id=task.session_id, + conversation_id=conversation_id, thread_id=thread_id, - task_id=task.task_id, + task_id=task_id, ) await self._persist_items(items) except Exception as e: # On failure, finalize any buffered aggregates for this task items = self._response_buffer.flush_task( - conversation_id=task.session_id, + conversation_id=conversation_id, thread_id=thread_id, - task_id=task.task_id, + task_id=task_id, ) await self._persist_items(items) - await self.task_manager.fail_task(task.task_id, str(e)) + await self.task_manager.fail_task(task_id, str(e)) raise e async def _persist_from_buffer(self, response: BaseResponse): diff --git a/python/valuecell/core/coordinate/response.py b/python/valuecell/core/coordinate/response.py index 285a1adce..2ad39a86c 100644 --- a/python/valuecell/core/coordinate/response.py +++ b/python/valuecell/core/coordinate/response.py @@ -20,6 +20,7 @@ SystemResponseEvent, TaskCompletedResponse, TaskFailedResponse, + TaskStartedResponse, TaskStatusEvent, ThreadStartedResponse, ToolCallPayload, @@ -159,7 +160,9 @@ def system_failed(self, conversation_id: str, content: str) -> SystemFailedRespo ) ) - def done(self, conversation_id: str, thread_id: str) -> DoneResponse: + def done( + self, conversation_id: str, thread_id: Optional[str] = None + ) -> DoneResponse: return DoneResponse( data=UnifiedResponseData( conversation_id=conversation_id, @@ -209,6 +212,21 @@ def task_failed( ) ) + def task_started( + self, + conversation_id: str, + thread_id: str, + task_id: str, + ) -> TaskStartedResponse: + return TaskStartedResponse( + data=UnifiedResponseData( + conversation_id=conversation_id, + thread_id=thread_id, + task_id=task_id, + role=Role.AGENT, + ), + ) + def task_completed( self, conversation_id: str, diff --git a/python/valuecell/core/coordinate/tests/test_orchestrator.py b/python/valuecell/core/coordinate/tests/test_orchestrator.py index 434d06839..f8473a10b 100644 --- a/python/valuecell/core/coordinate/tests/test_orchestrator.py +++ b/python/valuecell/core/coordinate/tests/test_orchestrator.py @@ -308,7 +308,7 @@ async def test_planner_error( async for chunk in orchestrator.process_user_input(sample_user_input): out.append(chunk) - assert len(out) == 2 + assert len(out) == 3 assert "(Error)" in out[1].data.payload.content assert "Planning failed" in out[1].data.payload.content diff --git a/python/valuecell/core/types.py b/python/valuecell/core/types.py index b5149f8f7..62c562f65 100644 --- a/python/valuecell/core/types.py +++ b/python/valuecell/core/types.py @@ -37,7 +37,6 @@ class SystemResponseEvent(str, Enum): THREAD_STARTED = "thread_started" PLAN_REQUIRE_USER_INPUT = "plan_require_user_input" PLAN_FAILED = "plan_failed" - TASK_FAILED = "task_failed" SYSTEM_FAILED = "system_failed" DONE = "done" @@ -45,6 +44,7 @@ class SystemResponseEvent(str, Enum): class TaskStatusEvent(str, Enum): TASK_STARTED = "task_started" TASK_COMPLETED = "task_completed" + TASK_FAILED = "task_failed" TASK_CANCELLED = "task_cancelled" @@ -253,9 +253,16 @@ class PlanFailedResponse(BaseResponse): data: UnifiedResponseData = Field(..., description="The plan data payload") +class TaskStartedResponse(BaseResponse): + event: Literal[TaskStatusEvent.TASK_STARTED] = Field( + TaskStatusEvent.TASK_STARTED, description="The event type of the response" + ) + data: UnifiedResponseData = Field(..., description="The task data payload") + + class TaskFailedResponse(BaseResponse): - event: Literal[SystemResponseEvent.TASK_FAILED] = Field( - SystemResponseEvent.TASK_FAILED, description="The event type of the response" + event: Literal[TaskStatusEvent.TASK_FAILED] = Field( + TaskStatusEvent.TASK_FAILED, description="The event type of the response" ) data: UnifiedResponseData = Field(..., description="The task data payload")