diff --git a/python/valuecell/core/agent/decorator.py b/python/valuecell/core/agent/decorator.py index c3eb91e56..613ba73db 100644 --- a/python/valuecell/core/agent/decorator.py +++ b/python/valuecell/core/agent/decorator.py @@ -30,7 +30,6 @@ from valuecell.core.types import ( BaseAgent, NotifyResponse, - NotifyResponseEvent, StreamResponse, StreamResponseEvent, ) @@ -39,6 +38,7 @@ get_next_available_port, parse_host_port, ) +from .responses import EventPredicates logger = logging.getLogger(__name__) @@ -184,9 +184,14 @@ async def _add_chunk( if not response.content: return - response_event = response.event parts = [Part(root=TextPart(text=response.content))] - metadata = {"response_event": response_event.value} + response_event = response.event + metadata = { + "response_event": response_event.value, + "subtask_id": response.subtask_id, + } + if response_event == StreamResponseEvent.COMPONENT_GENERATOR: + metadata["component_type"] = response.metadata.get("component_type") await updater.add_artifact( parts=parts, artifact_id=artifact_id, @@ -213,21 +218,32 @@ async def _add_chunk( ) response_event = response.event - if is_task_failed(response_event): + if EventPredicates.is_task_failed(response_event): raise RuntimeError( f"Agent {agent_name} reported failure: {response.content}" ) - is_complete = is_task_completed(response_event) - if is_tool_call(response_event): + is_complete = EventPredicates.is_task_completed(response_event) + if EventPredicates.is_tool_call(response_event): await updater.update_status( TaskState.working, - message=message, + message=new_agent_text_message(response.content or ""), metadata={ "event": response_event.value, "tool_call_id": response.metadata.get("tool_call_id"), "tool_name": response.metadata.get("tool_name"), "tool_result": response.metadata.get("content"), + "subtask_id": response.subtask_id, + }, + ) + continue + if EventPredicates.is_reasoning(response_event): + await updater.update_status( + TaskState.working, + message=new_agent_text_message(response.content or ""), + metadata={ + "event": response_event.value, + "subtask_id": response.subtask_id, }, ) continue @@ -251,29 +267,6 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None raise ServerError(error=UnsupportedOperationError()) -def is_task_completed(response_type: str) -> bool: - return response_type in { - StreamResponseEvent.TASK_DONE, - StreamResponseEvent.TASK_FAILED, - NotifyResponseEvent.TASK_DONE, - NotifyResponseEvent.TASK_FAILED, - } - - -def is_task_failed(response_type: str) -> bool: - return response_type in { - StreamResponseEvent.TASK_FAILED, - NotifyResponseEvent.TASK_FAILED, - } - - -def is_tool_call(response_type: str) -> bool: - return response_type in { - StreamResponseEvent.TOOL_CALL_STARTED, - StreamResponseEvent.TOOL_CALL_COMPLETED, - } - - def _create_agent_executor(agent_instance): return GenericAgentExecutor(agent_instance) diff --git a/python/valuecell/core/agent/responses.py b/python/valuecell/core/agent/responses.py index c44e94992..74bc6ed3c 100644 --- a/python/valuecell/core/agent/responses.py +++ b/python/valuecell/core/agent/responses.py @@ -1,24 +1,3 @@ -"""User-facing response constructors under valuecell.core.agent. - -Prefer importing from here if you're already working inside the core.agent -namespace. For a stable top-level import, you can also use -`valuecell.responses` which provides the same API. - -Example: - from valuecell.core.agent.responses import stream, notify - # Or explicit aliases for clarity: - from valuecell.core.agent.responses import streaming, notification - - yield stream.message_chunk("Thinking…") - yield stream.reasoning("Plan: 1) fetch 2) analyze") - yield stream.tool_call_start("call_1", "search") - yield stream.tool_call_result('{"items": 12}', "call_1", "search") - yield stream.done() - - send(notify.message("Task submitted")) - send(notify.done("OK")) -""" - from __future__ import annotations from typing import Optional @@ -28,50 +7,92 @@ NotifyResponseEvent, StreamResponse, StreamResponseEvent, - ToolCallContent, + SystemResponseEvent, + ToolCallPayload, + _TaskResponseEvent, ) class _StreamResponseNamespace: """Factory methods for streaming responses.""" - def message_chunk(self, content: str) -> StreamResponse: - return StreamResponse(event=StreamResponseEvent.MESSAGE_CHUNK, content=content) + def message_chunk( + self, content: str, subtask_id: str | None = None + ) -> StreamResponse: + return StreamResponse( + event=StreamResponseEvent.MESSAGE_CHUNK, + content=content, + subtask_id=subtask_id, + ) - def tool_call_started(self, tool_call_id: str, tool_name: str) -> StreamResponse: + def tool_call_started( + self, tool_call_id: str, tool_name: str, subtask_id: str | None = None + ) -> StreamResponse: return StreamResponse( event=StreamResponseEvent.TOOL_CALL_STARTED, - metadata=ToolCallContent( - tool_call_id=tool_call_id, tool_name=tool_name + metadata=ToolCallPayload( + tool_call_id=tool_call_id, + tool_name=tool_name, ).model_dump(), + subtask_id=subtask_id, ) def tool_call_completed( - self, tool_result: str, tool_call_id: str, tool_name: str + self, + tool_result: str, + tool_call_id: str, + tool_name: str, + subtask_id: str | None = None, ) -> StreamResponse: return StreamResponse( event=StreamResponseEvent.TOOL_CALL_COMPLETED, - metadata=ToolCallContent( - tool_call_id=tool_call_id, tool_name=tool_name, tool_result=tool_result + metadata=ToolCallPayload( + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_result=tool_result, ).model_dump(), + subtask_id=subtask_id, ) - def reasoning(self, content: str) -> StreamResponse: + def reasoning_started(self, subtask_id: str | None = None) -> StreamResponse: + return StreamResponse( + event=StreamResponseEvent.REASONING_STARTED, + subtask_id=subtask_id, + ) + + def reasoning(self, content: str, subtask_id: str | None = None) -> StreamResponse: return StreamResponse( event=StreamResponseEvent.REASONING, content=content, + subtask_id=subtask_id, + ) + + def reasoning_completed(self, subtask_id: str | None = None) -> StreamResponse: + return StreamResponse( + event=StreamResponseEvent.REASONING_COMPLETED, + subtask_id=subtask_id, + ) + + def component_generator( + self, content: str, component_type: str, subtask_id: str | None = None + ) -> StreamResponse: + return StreamResponse( + event=StreamResponseEvent.COMPONENT_GENERATOR, + content=content, + metadata={"component_type": component_type}, + subtask_id=subtask_id, ) def done(self, content: Optional[str] = None) -> StreamResponse: return StreamResponse( content=content, - event=StreamResponseEvent.TASK_DONE, + event=_TaskResponseEvent.TASK_COMPLETED, ) def failed(self, content: Optional[str] = None) -> StreamResponse: return StreamResponse( content=content, - event=StreamResponseEvent.TASK_FAILED, + event=SystemResponseEvent.TASK_FAILED, ) @@ -90,22 +111,56 @@ def message(self, content: str) -> NotifyResponse: def done(self, content: Optional[str] = None) -> NotifyResponse: return NotifyResponse( content=content, - event=NotifyResponseEvent.TASK_DONE, + event=_TaskResponseEvent.TASK_COMPLETED, ) def failed(self, content: Optional[str] = None) -> NotifyResponse: return NotifyResponse( content=content, - event=NotifyResponse.TASK_FAILED, + event=SystemResponseEvent.TASK_FAILED, ) notification = _NotifyResponseNamespace() +class EventPredicates: + """Utilities to classify response event types. + + These mirror the helper predicates previously defined in decorator.py + and centralize them next to response event definitions. + """ + + @staticmethod + def is_task_completed(response_type) -> bool: + return response_type in { + _TaskResponseEvent.TASK_COMPLETED, + } + + @staticmethod + def is_task_failed(response_type) -> bool: + return response_type in { + SystemResponseEvent.TASK_FAILED, + } + + @staticmethod + def is_tool_call(response_type) -> bool: + return response_type in { + StreamResponseEvent.TOOL_CALL_STARTED, + StreamResponseEvent.TOOL_CALL_COMPLETED, + } + + @staticmethod + def is_reasoning(response_type) -> bool: + return response_type in { + StreamResponseEvent.REASONING_STARTED, + StreamResponseEvent.REASONING, + StreamResponseEvent.REASONING_COMPLETED, + } + + __all__ = [ "streaming", "notification", - "StreamResponse", - "NotifyResponse", + "EventPredicates", ] diff --git a/python/valuecell/core/coordinate/orchestrator.py b/python/valuecell/core/coordinate/orchestrator.py index a711a0e9c..85e218d79 100644 --- a/python/valuecell/core/coordinate/orchestrator.py +++ b/python/valuecell/core/coordinate/orchestrator.py @@ -4,21 +4,25 @@ from typing import AsyncGenerator, Dict, Optional from a2a.types import TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent -from a2a.utils import get_message_text from valuecell.core.agent.connect import get_default_remote_connections -from valuecell.core.agent.decorator import is_tool_call, is_task_completed +from valuecell.core.agent.responses import EventPredicates +from valuecell.core.coordinate.response import ResponseFactory +from valuecell.core.coordinate.response_router import ( + RouteResult, + SideEffectKind, + handle_artifact_update, + handle_status_update, +) from valuecell.core.session import Role, SessionStatus, get_default_session_manager from valuecell.core.task import Task, get_default_task_manager from valuecell.core.task.models import TaskPattern from valuecell.core.types import ( + BaseResponse, NotifyResponseEvent, - ProcessMessage, - ProcessMessageData, StreamResponseEvent, - ToolCallContent, UserInput, ) -from valuecell.utils.uuid import generate_message_id +from valuecell.utils.uuid import generate_thread_id from .callback import store_task_in_session from .models import ExecutionPlan @@ -34,9 +38,10 @@ class ExecutionContext: """Manages the state of an interrupted execution for resumption""" - def __init__(self, stage: str, session_id: str, user_id: str): + def __init__(self, stage: str, session_id: str, thread_id: str, user_id: str): self.stage = stage self.session_id = session_id + self.thread_id = thread_id self.user_id = user_id self.created_at = asyncio.get_event_loop().time() self.metadata: Dict = {} @@ -120,11 +125,13 @@ def __init__(self): # Initialize planner self.planner = ExecutionPlanner(self.agent_connections) + self._response_factory = ResponseFactory() + # ==================== Public API Methods ==================== async def process_user_input( self, user_input: UserInput - ) -> AsyncGenerator[ProcessMessage, None]: + ) -> AsyncGenerator[BaseResponse, None]: """ Main entry point for processing user requests with Human-in-the-Loop support. @@ -144,20 +151,29 @@ async def process_user_input( try: # Ensure session exists - session = await self._ensure_session_exists(session_id, user_id) + session = await self.session_manager.get_session(session_id) + if not session: + await self.session_manager.create_session( + user_id, session_id=session_id + ) + session = await self.session_manager.get_session(session_id) + yield self._response_factory.conversation_started( + conversation_id=session_id + ) # Handle session continuation vs new request if session.status == SessionStatus.REQUIRE_USER_INPUT: - async for message in self._handle_session_continuation(user_input): - yield message + async for response in self._handle_session_continuation(user_input): + yield response else: - async for message in self._handle_new_request(user_input): - yield message + async for response in self._handle_new_request(user_input): + yield response except Exception as e: logger.exception(f"Error processing user input for session {session_id}") - yield self._create_error_message( - f"Error processing request: {str(e)}", session_id + yield self._response_factory.system_failed( + session_id, + f"(Error) Error processing request: {str(e)}", ) async def provide_user_input(self, session_id: str, response: str): @@ -218,8 +234,6 @@ async def cleanup(self): # ==================== Private Helper Methods ==================== - # ==================== Private Helper Methods ==================== - async def _handle_user_input_request(self, request: UserInputRequest): """Handle user input request from planner""" # Extract session_id from request context @@ -227,26 +241,18 @@ async def _handle_user_input_request(self, request: UserInputRequest): if session_id: self.user_input_manager.add_request(session_id, request) - async def _ensure_session_exists(self, session_id: str, user_id: str): - """Ensure a session exists, creating it if necessary""" - session = await self.session_manager.get_session(session_id) - if not session: - await self.session_manager.create_session(user_id, session_id=session_id) - session = await self.session_manager.get_session(session_id) - return session - async def _handle_session_continuation( self, user_input: UserInput - ) -> AsyncGenerator[ProcessMessage, None]: + ) -> AsyncGenerator[BaseResponse, None]: """Handle continuation of an interrupted session""" session_id = user_input.meta.session_id user_id = user_input.meta.user_id # Validate execution context exists if session_id not in self._execution_contexts: - yield self._create_error_message( - "No execution context found for this session. The session may have expired.", + yield self._response_factory.system_failed( session_id, + "No execution context found for this session. The session may have expired.", ) return @@ -254,9 +260,9 @@ async def _handle_session_continuation( # Validate context integrity and user consistency if not self._validate_execution_context(context, user_id): - yield self._create_error_message( - "Invalid execution context or user mismatch.", + yield self._response_factory.system_failed( session_id, + "Invalid execution context or user mismatch.", ) await self._cancel_execution(session_id) return @@ -270,18 +276,19 @@ async def _handle_session_continuation( if context.stage == "planning": async for chunk in self._continue_planning(session_id, context): yield chunk - # TODO: Add support for resuming execution stage if needed + # Resuming execution stage is not yet supported else: - yield self._create_error_message( - "Resuming execution stage is not yet supported.", + yield self._response_factory.system_failed( session_id, + "Resuming execution stage is not yet supported.", ) async def _handle_new_request( self, user_input: UserInput - ) -> AsyncGenerator[ProcessMessage, None]: + ) -> AsyncGenerator[BaseResponse, None]: """Handle a new user request""" session_id = user_input.meta.session_id + thread_id = generate_thread_id() # Add user message to session await self.session_manager.add_message( @@ -297,7 +304,7 @@ async def _handle_new_request( # Monitor planning progress async for chunk in self._monitor_planning_task( - planning_task, user_input, context_aware_callback + planning_task, thread_id, user_input, context_aware_callback ): yield chunk @@ -311,8 +318,12 @@ async def context_aware_handle(request): return context_aware_handle async def _monitor_planning_task( - self, planning_task, user_input: UserInput, callback - ) -> AsyncGenerator[ProcessMessage, None]: + self, + planning_task: asyncio.Task, + thread_id: str, + user_input: UserInput, + callback, + ) -> AsyncGenerator[BaseResponse, None]: """Monitor planning task and handle user input interruptions""" session_id = user_input.meta.session_id user_id = user_input.meta.user_id @@ -321,7 +332,7 @@ async def _monitor_planning_task( while not planning_task.done(): if self.has_pending_user_input(session_id): # Save planning context - context = ExecutionContext("planning", session_id, user_id) + context = ExecutionContext("planning", session_id, thread_id, user_id) context.add_metadata( original_user_input=user_input, planning_task=planning_task, @@ -331,8 +342,8 @@ async def _monitor_planning_task( # Update session status and send user input request await self._request_user_input(session_id) - yield self._create_user_input_request( - self.get_user_input_prompt(session_id), session_id + yield self._response_factory.plan_require_user_input( + session_id, thread_id, self.get_user_input_prompt(session_id) ) return @@ -340,7 +351,7 @@ async def _monitor_planning_task( # Planning completed, execute plan plan = await planning_task - async for chunk in self._execute_plan_with_input_support(plan): + async for chunk in self._execute_plan_with_input_support(plan, thread_id): yield chunk async def _request_user_input(self, session_id: str): @@ -365,78 +376,20 @@ def _validate_execution_context( return True - def _create_message( - self, - content: str, - conversation_id: str, - event: ( - StreamResponseEvent | NotifyResponseEvent - ) = StreamResponseEvent.MESSAGE_CHUNK, - message_id: Optional[str] = None, - ) -> ProcessMessage: - """Create a ProcessMessage for plain text content using the new schema.""" - return ProcessMessage( - event=event, - data=ProcessMessageData( - conversation_id=conversation_id, - message_id=message_id or generate_message_id(), - content=content, - ), - ) - - def _create_tool_message( - self, - event: StreamResponseEvent | NotifyResponseEvent, - conversation_id: str, - tool_call_id: str, - tool_name: str, - tool_result: Optional[str] = None, - ) -> ProcessMessage: - """Create a ProcessMessage for tool call events with ToolCallContent.""" - return ProcessMessage( - event=event, - data=ProcessMessageData( - conversation_id=conversation_id, - message_id=generate_message_id(), - content=ToolCallContent( - tool_call_id=tool_call_id, - tool_name=tool_name, - tool_result=tool_result, - ), - ), - ) - - def _create_error_message(self, error_msg: str, session_id: str) -> ProcessMessage: - """Create an error ProcessMessage with standardized format (TASK_FAILED).""" - return self._create_message( - content=f"(Error): {error_msg}", - conversation_id=session_id, - event=StreamResponseEvent.TASK_FAILED, - ) - - def _create_user_input_request( - self, - prompt: str, - session_id: str, - ) -> ProcessMessage: - """Create a user input request ProcessMessage. The consumer should parse the prefix.""" - return self._create_message( - content=f"USER_INPUT_REQUIRED:{prompt}", - conversation_id=session_id, - event=StreamResponseEvent.MESSAGE_CHUNK, - ) - async def _continue_planning( self, session_id: str, context: ExecutionContext - ) -> AsyncGenerator[ProcessMessage, None]: + ) -> AsyncGenerator[BaseResponse, None]: """Resume planning stage execution""" planning_task = context.get_metadata("planning_task") original_user_input = context.get_metadata("original_user_input") + thread_id = generate_thread_id() + context.thread_id = thread_id if not all([planning_task, original_user_input]): - yield self._create_error_message( - "Invalid planning context - missing required data", + yield self._response_factory.plan_failed( session_id, + thread_id, + "Invalid planning context - missing required data", ) await self._cancel_execution(session_id) return @@ -448,7 +401,9 @@ async def _continue_planning( prompt = self.get_user_input_prompt(session_id) # Ensure session is set to require user input again for repeated prompts await self._request_user_input(session_id) - yield self._create_user_input_request(prompt, session_id) + yield self._response_factory.plan_require_user_input( + session_id, thread_id, prompt + ) return await asyncio.sleep(ASYNC_SLEEP_INTERVAL) @@ -457,7 +412,7 @@ async def _continue_planning( plan = await planning_task del self._execution_contexts[session_id] - async for message in self._execute_plan_with_input_support(plan): + async for message in self._execute_plan_with_input_support(plan, thread_id): yield message async def _cancel_execution(self, session_id: str): @@ -501,8 +456,8 @@ async def _cleanup_expired_contexts( # ==================== Plan and Task Execution Methods ==================== async def _execute_plan_with_input_support( - self, plan: ExecutionPlan, metadata: Optional[dict] = None - ) -> AsyncGenerator[ProcessMessage, None]: + self, plan: ExecutionPlan, thread_id: str, metadata: Optional[dict] = None + ) -> AsyncGenerator[BaseResponse, None]: """ Execute an execution plan with Human-in-the-Loop support. @@ -516,8 +471,8 @@ async def _execute_plan_with_input_support( session_id = plan.session_id if not plan.tasks: - yield self._create_error_message( - "No tasks found for this request.", session_id + yield self._response_factory.plan_failed( + session_id, thread_id, "No tasks found for this request." ) return @@ -530,20 +485,22 @@ async def _execute_plan_with_input_support( await self.task_manager.store.save_task(task) # Execute task with input support - async for message in self._execute_task_with_input_support( - task, metadata + async for response in self._execute_task_with_input_support( + task, thread_id, metadata ): # Accumulate based on event - if message.event in { + if response.event in { StreamResponseEvent.MESSAGE_CHUNK, StreamResponseEvent.REASONING, NotifyResponseEvent.MESSAGE, - } and isinstance(message.data.content, str): - agent_responses[task.agent_name] += message.data.content - yield message + } and isinstance(response.data.payload.content, str): + agent_responses[task.agent_name] += ( + response.data.payload.content + ) + yield response if ( - is_task_completed(message.event) + EventPredicates.is_task_completed(response.event) or task.pattern == TaskPattern.RECURRING ): if agent_responses[task.agent_name].strip(): @@ -556,16 +513,23 @@ async def _execute_plan_with_input_support( agent_responses[task.agent_name] = "" except Exception as e: - error_msg = f"Error executing {task.agent_name}: {str(e)}" + error_msg = f"(Error) Error executing {task.agent_name}: {str(e)}" logger.exception(f"Task execution failed: {error_msg}") - yield self._create_error_message(error_msg, session_id) + yield self._response_factory.task_failed( + session_id, + thread_id, + task.task_id, + _generate_task_default_subtask_id(task.task_id), + error_msg, + ) # Save any remaining agent responses await self._save_remaining_responses(session_id, agent_responses) + yield self._response_factory.done(session_id, thread_id) async def _execute_task_with_input_support( - self, task: Task, metadata: Optional[dict] = None - ) -> AsyncGenerator[ProcessMessage, None]: + self, task: Task, thread_id: str, metadata: Optional[dict] = None + ) -> AsyncGenerator[BaseResponse, None]: """ Execute a single task with user input interruption support. @@ -609,44 +573,36 @@ async def _execute_task_with_input_support( continue if isinstance(event, TaskStatusUpdateEvent): - state = event.status.state - logger.info(f"Task {task.task_id} status update: {state}") - if state in {TaskState.submitted, TaskState.completed}: - continue - # Handle task failure - if state == TaskState.failed: - err_msg = get_message_text(event.status.message) - await self.task_manager.fail_task(task.task_id, err_msg) - yield self._create_error_message(err_msg, task.session_id) + result: RouteResult = await handle_status_update( + self._response_factory, task, thread_id, event + ) + for r in result.responses: + yield r + # Apply side effects + for eff in result.side_effects: + if eff.kind == SideEffectKind.FAIL_TASK: + await self.task_manager.fail_task( + task.task_id, eff.reason or "" + ) + if result.done: return - # if state == TaskState.input_required: - # Handle tool call start - if not event.metadata: - continue - response_event = event.metadata.get("response_event") - if state == TaskState.working and is_tool_call(response_event): - yield self._create_tool_message( - response_event, - task.session_id, - tool_call_id=event.metadata.get("tool_call_id", ""), - tool_name=event.metadata.get("tool_name", ""), - tool_result=event.metadata.get("tool_result"), - ) - continue + continue - elif isinstance(event, TaskArtifactUpdateEvent): - yield self._create_message( - get_message_text(event.artifact, ""), - task.session_id, - event=StreamResponseEvent.MESSAGE_CHUNK, + if isinstance(event, TaskArtifactUpdateEvent): + responses = await handle_artifact_update( + self._response_factory, task, thread_id, event ) + for r in responses: + yield r + continue # Complete task successfully await self.task_manager.complete_task(task.task_id) - yield self._create_message( - "", - task.session_id, - event=StreamResponseEvent.TASK_DONE, + yield self._response_factory.task_completed( + conversation_id=task.session_id, + thread_id=thread_id, + task_id=task.task_id, + subtask_id=_generate_task_default_subtask_id(task.task_id), ) except Exception as e: @@ -662,6 +618,11 @@ async def _save_remaining_responses(self, session_id: str, agent_responses: dict ) +def _generate_task_default_subtask_id(task_id: str) -> str: + """Generate a default subtask ID based on the main task ID""" + return f"{task_id}-default_subtask" + + # ==================== Module-level Factory Function ==================== _orchestrator = AgentOrchestrator() diff --git a/python/valuecell/core/coordinate/planner.py b/python/valuecell/core/coordinate/planner.py index 21cb98ffc..33a65bb1e 100644 --- a/python/valuecell/core/coordinate/planner.py +++ b/python/valuecell/core/coordinate/planner.py @@ -131,6 +131,7 @@ async def _analyze_input_and_create_tasks( for field in input_schema: if user_input_callback: # Use callback for async user input + # TODO: prompt options if available request = UserInputRequest(field.description) await user_input_callback(request) user_value = await request.wait_for_response() diff --git a/python/valuecell/core/coordinate/response.py b/python/valuecell/core/coordinate/response.py new file mode 100644 index 000000000..399af08c2 --- /dev/null +++ b/python/valuecell/core/coordinate/response.py @@ -0,0 +1,196 @@ +from typing import Optional + +from typing_extensions import Literal +from valuecell.core.types import ( + BaseResponseDataPayload, + ComponentGeneratorResponse, + ComponentGeneratorResponseDataPayload, + ConversationStartedResponse, + DoneResponse, + MessageResponse, + NotifyResponseEvent, + PlanFailedResponse, + PlanRequireUserInputResponse, + ReasoningResponse, + StreamResponseEvent, + SystemFailedResponse, + TaskCompletedResponse, + TaskFailedResponse, + ToolCallPayload, + ToolCallResponse, + UnifiedResponseData, +) + + +class ResponseFactory: + def conversation_started(self, conversation_id: str) -> ConversationStartedResponse: + return ConversationStartedResponse( + data=UnifiedResponseData(conversation_id=conversation_id) + ) + + def system_failed(self, conversation_id: str, content: str) -> SystemFailedResponse: + return SystemFailedResponse( + data=UnifiedResponseData( + conversation_id=conversation_id, + payload=BaseResponseDataPayload(content=content), + ) + ) + + def done(self, conversation_id: str, thread_id: str) -> DoneResponse: + return DoneResponse( + data=UnifiedResponseData( + conversation_id=conversation_id, + thread_id=thread_id, + ) + ) + + def plan_require_user_input( + self, conversation_id: str, thread_id: str, content: str + ) -> PlanRequireUserInputResponse: + return PlanRequireUserInputResponse( + data=UnifiedResponseData( + conversation_id=conversation_id, + thread_id=thread_id, + payload=BaseResponseDataPayload(content=content), + ) + ) + + def plan_failed( + self, conversation_id: str, thread_id: str, content: str + ) -> PlanFailedResponse: + return PlanFailedResponse( + data=UnifiedResponseData( + conversation_id=conversation_id, + thread_id=thread_id, + payload=BaseResponseDataPayload(content=content), + ) + ) + + def task_failed( + self, + conversation_id: str, + thread_id: str, + task_id: str, + subtask_id: str | None, + content: str, + ) -> TaskFailedResponse: + return TaskFailedResponse( + data=UnifiedResponseData( + conversation_id=conversation_id, + thread_id=thread_id, + task_id=task_id, + subtask_id=subtask_id, + payload=BaseResponseDataPayload(content=content), + ) + ) + + def task_completed( + self, + conversation_id: str, + thread_id: str, + task_id: str, + subtask_id: str | None, + ) -> TaskCompletedResponse: + return TaskCompletedResponse( + data=UnifiedResponseData( + conversation_id=conversation_id, + thread_id=thread_id, + task_id=task_id, + subtask_id=subtask_id, + ), + ) + + def tool_call( + self, + conversation_id: str, + thread_id: str, + task_id: str, + subtask_id: str, + event: Literal[ + StreamResponseEvent.TOOL_CALL_STARTED, + StreamResponseEvent.TOOL_CALL_COMPLETED, + ], + tool_call_id: str, + tool_name: str, + tool_result: Optional[str] = None, + ) -> ToolCallResponse: + return ToolCallResponse( + event=event, + data=UnifiedResponseData( + conversation_id=conversation_id, + thread_id=thread_id, + task_id=task_id, + subtask_id=subtask_id, + payload=ToolCallPayload( + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_result=tool_result, + ), + ), + ) + + def message_response_general( + self, + event: Literal[StreamResponseEvent.MESSAGE_CHUNK, NotifyResponseEvent.MESSAGE], + conversation_id: str, + thread_id: str, + task_id: str, + subtask_id: str, + content: str, + ) -> MessageResponse: + return MessageResponse( + event=event, + data=UnifiedResponseData( + conversation_id=conversation_id, + thread_id=thread_id, + task_id=task_id, + subtask_id=subtask_id, + payload=BaseResponseDataPayload(content=content), + ), + ) + + def reasoning( + self, + conversation_id: str, + thread_id: str, + task_id: str, + subtask_id: str, + event: Literal[ + StreamResponseEvent.REASONING, + StreamResponseEvent.REASONING_STARTED, + StreamResponseEvent.REASONING_COMPLETED, + ], + content: Optional[str] = None, + ) -> ReasoningResponse: + return ReasoningResponse( + event=event, + data=UnifiedResponseData( + conversation_id=conversation_id, + thread_id=thread_id, + task_id=task_id, + subtask_id=subtask_id, + payload=BaseResponseDataPayload(content=content) if content else None, + ), + ) + + def component_generator( + self, + conversation_id: str, + thread_id: str, + task_id: str, + subtask_id: str, + content: str, + component_type: str, + ) -> ComponentGeneratorResponse: + return ComponentGeneratorResponse( + data=UnifiedResponseData( + conversation_id=conversation_id, + thread_id=thread_id, + task_id=task_id, + subtask_id=subtask_id, + payload=ComponentGeneratorResponseDataPayload( + content=content, + component_type=component_type, + ), + ), + ) diff --git a/python/valuecell/core/coordinate/response_router.py b/python/valuecell/core/coordinate/response_router.py new file mode 100644 index 000000000..7b97b13ea --- /dev/null +++ b/python/valuecell/core/coordinate/response_router.py @@ -0,0 +1,156 @@ +import logging +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + +from a2a.types import TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent +from a2a.utils import get_message_text +from valuecell.core.agent.responses import EventPredicates +from valuecell.core.coordinate.response import ResponseFactory +from valuecell.core.task import Task +from valuecell.core.types import BaseResponse, StreamResponseEvent + +logger = logging.getLogger(__name__) + + +class SideEffectKind(Enum): + FAIL_TASK = "fail_task" + + +@dataclass +class SideEffect: + kind: SideEffectKind + reason: Optional[str] = None + + +@dataclass +class RouteResult: + responses: List[BaseResponse] + done: bool = False + side_effects: List[SideEffect] = None + + def __post_init__(self): + if self.side_effects is None: + self.side_effects = [] + + +def _default_subtask_id(task_id: str) -> str: + return f"{task_id}_default-subtask" + + +async def handle_status_update( + response_factory: ResponseFactory, + task: Task, + thread_id: str, + event: TaskStatusUpdateEvent, +) -> RouteResult: + responses: List[BaseResponse] = [] + state = event.status.state + logger.info(f"Task {task.task_id} status update: {state}") + + if state in {TaskState.submitted, TaskState.completed}: + return RouteResult(responses) + + if state == TaskState.failed: + err_msg = get_message_text(event.status.message) + responses.append( + response_factory.task_failed( + conversation_id=task.session_id, + thread_id=thread_id, + task_id=task.task_id, + subtask_id=_default_subtask_id(task.task_id), + content=err_msg, + ) + ) + return RouteResult( + responses=responses, + done=True, + side_effects=[SideEffect(kind=SideEffectKind.FAIL_TASK, reason=err_msg)], + ) + + if not event.metadata: + return RouteResult(responses) + + response_event = event.metadata.get("response_event") + subtask_id = event.metadata.get("subtask_id") + if not subtask_id: + subtask_id = _default_subtask_id(task.task_id) + + # Tool call events + if state == TaskState.working and EventPredicates.is_tool_call(response_event): + tool_call_id = event.metadata.get("tool_call_id", "unknown_tool_call_id") + tool_name = event.metadata.get("tool_name", "unknown_tool_name") + + tool_result = None + if "tool_result" in event.metadata: + tool_result = get_message_text(event.metadata.get("tool_result", "")) + responses.append( + response_factory.tool_call( + conversation_id=task.session_id, + thread_id=thread_id, + task_id=task.task_id, + subtask_id=subtask_id, + event=response_event, + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_result=tool_result, + ) + ) + return RouteResult(responses) + + # Reasoning messages + if state == TaskState.working and EventPredicates.is_reasoning(response_event): + responses.append( + response_factory.reasoning( + conversation_id=task.session_id, + thread_id=thread_id, + task_id=task.task_id, + subtask_id=subtask_id, + event=response_event, + content=get_message_text(event.status.message, ""), + ) + ) + return RouteResult(responses) + + return RouteResult(responses) + + +async def handle_artifact_update( + response_factory: ResponseFactory, + task: Task, + thread_id: str, + event: TaskArtifactUpdateEvent, +) -> List[BaseResponse]: + responses: List[BaseResponse] = [] + artifact = event.artifact + subtask_id = artifact.metadata.get("subtask_id") if artifact.metadata else None + if not subtask_id: + subtask_id = _default_subtask_id(task.task_id) + response_event = artifact.metadata.get("response_event") + content = get_message_text(artifact, "") + + if response_event == StreamResponseEvent.COMPONENT_GENERATOR: + component_type = artifact.metadata.get("component_type", "unknown") + responses.append( + response_factory.component_generator( + conversation_id=task.session_id, + thread_id=thread_id, + task_id=task.task_id, + subtask_id=subtask_id, + content=content, + component_type=component_type, + ) + ) + return responses + + responses.append( + response_factory.message_response_general( + event=response_event, + conversation_id=task.session_id, + thread_id=thread_id, + task_id=task.task_id, + subtask_id=subtask_id, + content=content, + ) + ) + return responses diff --git a/python/valuecell/core/coordinate/tests/test_orchestrator.py b/python/valuecell/core/coordinate/tests/test_orchestrator.py index e78370a54..30caa4bbd 100644 --- a/python/valuecell/core/coordinate/tests/test_orchestrator.py +++ b/python/valuecell/core/coordinate/tests/test_orchestrator.py @@ -309,8 +309,8 @@ async def test_planner_error( out.append(chunk) assert len(out) == 1 - assert "(Error)" in out[0].data.content - assert "Planning failed" in out[0].data.content + assert "(Error)" in out[0].data.payload.content + assert "Planning failed" in out[0].data.payload.content @pytest.mark.asyncio @@ -328,7 +328,7 @@ async def test_agent_connection_error( async for chunk in orchestrator.process_user_input(sample_user_input): out.append(chunk) - assert any("(Error)" in c.data.content for c in out) + assert any("(Error)" in c.data.payload.content for c in out) @pytest.mark.asyncio diff --git a/python/valuecell/core/types.py b/python/valuecell/core/types.py index f5463637d..89a640adc 100644 --- a/python/valuecell/core/types.py +++ b/python/valuecell/core/types.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import AsyncGenerator, Callable, Optional +from typing import AsyncGenerator, Callable, Literal, Optional, Union from a2a.types import Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent from pydantic import BaseModel, Field @@ -47,19 +47,33 @@ def clear_desired_agent(self) -> None: self.desired_agent_name = None +class SystemResponseEvent(str, Enum): + CONVERSATION_STARTED = "conversation_started" + PLAN_REQUIRE_USER_INPUT = "plan_require_user_input" + PLAN_FAILED = "plan_failed" + TASK_FAILED = "task_failed" + SYSTEM_FAILED = "system_failed" + DONE = "done" + + +class _TaskResponseEvent(str, Enum): + TASK_STARTED = "task_started" + TASK_COMPLETED = "task_completed" + TASK_CANCELLED = "task_cancelled" + + class StreamResponseEvent(str, Enum): MESSAGE_CHUNK = "message_chunk" + COMPONENT_GENERATOR = "component_generator" TOOL_CALL_STARTED = "tool_call_started" TOOL_CALL_COMPLETED = "tool_call_completed" + REASONING_STARTED = "reasoning_started" REASONING = "reasoning" - TASK_DONE = "task_done" - TASK_FAILED = "task_failed" + REASONING_COMPLETED = "reasoning_completed" class NotifyResponseEvent(str, Enum): MESSAGE = "message" - TASK_DONE = "task_done" - TASK_FAILED = "task_failed" class StreamResponse(BaseModel): @@ -69,7 +83,7 @@ class StreamResponse(BaseModel): None, description="The content of the stream response, typically a chunk of data or message.", ) - event: StreamResponseEvent = Field( + event: StreamResponseEvent | _TaskResponseEvent = Field( ..., description="The type of stream response, indicating its purpose or content nature.", ) @@ -77,6 +91,10 @@ class StreamResponse(BaseModel): None, description="Optional metadata providing additional context about the response", ) + subtask_id: Optional[str] = Field( + None, + description="Optional subtask ID if the response is related to a specific subtask", + ) class NotifyResponse(BaseModel): @@ -86,13 +104,13 @@ class NotifyResponse(BaseModel): ..., description="The content of the notification response", ) - event: NotifyResponseEvent = Field( + event: NotifyResponseEvent | _TaskResponseEvent = Field( ..., description="The type of notification response", ) -class ToolCallContent(BaseModel): +class ToolCallPayload(BaseModel): tool_call_id: str = Field(..., description="Unique ID for the tool call") tool_name: str = Field(..., description="Name of the tool being called") tool_result: Optional[str] = Field( @@ -101,24 +119,136 @@ class ToolCallContent(BaseModel): ) -class ProcessMessageData(BaseModel): - conversation_id: str = Field(..., description="Conversation ID for this request") - message_id: str = Field(..., description="Message ID for this request") - content: str | ToolCallContent = Field( - ..., description="Content of the message chunk" +class BaseResponseDataPayload(BaseModel, ABC): + content: Optional[str] = Field(None, description="The message content") + + +class ComponentGeneratorResponseDataPayload(BaseResponseDataPayload): + component_type: str = Field(..., description="The component type") + + +ResponsePayload = Union[ + BaseResponseDataPayload, + ComponentGeneratorResponseDataPayload, + ToolCallPayload, +] + + +class UnifiedResponseData(BaseModel): + """Unified response data structure with optional hierarchy fields. + + Field names are preserved to maintain JSON compatibility when using + model_dump(exclude_none=True). + """ + + conversation_id: str = Field(..., description="Unique ID for the conversation") + thread_id: Optional[str] = Field( + None, description="Unique ID for the message thread" + ) + task_id: Optional[str] = Field(None, description="Unique ID for the task") + subtask_id: Optional[str] = Field( + None, description="Unique ID for the subtask, if any" + ) + payload: Optional[ResponsePayload] = Field( + None, description="The message data payload" + ) + + +class BaseResponse(BaseModel, ABC): + """Top-level response envelope used for all events.""" + + event: StreamResponseEvent | NotifyResponseEvent | SystemResponseEvent = Field( + ..., description="The event type of the response" + ) + data: UnifiedResponseData = Field( + ..., description="The data payload of the response" + ) + + +class ConversationStartedResponse(BaseResponse): + event: Literal[SystemResponseEvent.CONVERSATION_STARTED] = Field( + SystemResponseEvent.CONVERSATION_STARTED, + description="The event type of the response", + ) + + +class PlanRequireUserInputResponse(BaseResponse): + event: Literal[SystemResponseEvent.PLAN_REQUIRE_USER_INPUT] = Field( + SystemResponseEvent.PLAN_REQUIRE_USER_INPUT, + description="The event type of the response", + ) + data: UnifiedResponseData = Field(..., description="The plan data payload") + + +class MessageResponse(BaseResponse): + event: Literal[ + StreamResponseEvent.MESSAGE_CHUNK, + NotifyResponseEvent.MESSAGE, + ] = Field(..., description="The event type of the response") + data: UnifiedResponseData = Field(..., description="The complete message content") + + +class ComponentGeneratorResponse(BaseResponse): + event: Literal[StreamResponseEvent.COMPONENT_GENERATOR] = Field( + StreamResponseEvent.COMPONENT_GENERATOR + ) + data: UnifiedResponseData = Field(..., description="The component generator data") + + +class ToolCallResponse(BaseResponse): + event: Literal[ + StreamResponseEvent.TOOL_CALL_STARTED, StreamResponseEvent.TOOL_CALL_COMPLETED + ] = Field( + ..., + description="The event type of the response", ) + data: UnifiedResponseData = Field(..., description="The task data payload") -class ProcessMessage(BaseModel): - """Chunk of a message, useful for streaming responses""" +class ReasoningResponse(BaseResponse): + event: Literal[ + StreamResponseEvent.REASONING_STARTED, + StreamResponseEvent.REASONING, + StreamResponseEvent.REASONING_COMPLETED, + ] = Field(..., description="The event type of the response") + data: UnifiedResponseData = Field(..., description="The reasoning message content") + + +class DoneResponse(BaseResponse): + event: Literal[SystemResponseEvent.DONE] = Field( + SystemResponseEvent.DONE, description="The event type of the response" + ) + data: UnifiedResponseData = Field(..., description="The thread data payload") + + +class PlanFailedResponse(BaseResponse): + event: Literal[SystemResponseEvent.PLAN_FAILED] = Field( + SystemResponseEvent.PLAN_FAILED, description="The event type of the response" + ) + data: UnifiedResponseData = Field(..., description="The plan data payload") + + +class TaskFailedResponse(BaseResponse): + event: Literal[SystemResponseEvent.TASK_FAILED] = Field( + SystemResponseEvent.TASK_FAILED, description="The event type of the response" + ) + data: UnifiedResponseData = Field(..., description="The task data payload") + + +class TaskCompletedResponse(BaseResponse): + event: Literal[_TaskResponseEvent.TASK_COMPLETED] = Field( + _TaskResponseEvent.TASK_COMPLETED, description="The event type of the response" + ) + data: UnifiedResponseData = Field(..., description="The task data payload") + - event: StreamResponseEvent | NotifyResponseEvent = Field( - ..., description="The event type of the message chunk" +class SystemFailedResponse(BaseResponse): + event: Literal[SystemResponseEvent.SYSTEM_FAILED] = Field( + SystemResponseEvent.SYSTEM_FAILED, description="The event type of the response" ) - data: ProcessMessageData = Field(..., description="Content of the message chunk") + data: UnifiedResponseData = Field(..., description="The conversation data payload") -# TODO: keep only essential parameters class BaseAgent(ABC): """ Abstract base class for all agents. diff --git a/python/valuecell/server/services/agent_stream_service.py b/python/valuecell/server/services/agent_stream_service.py index 2c67efbec..209f5968d 100644 --- a/python/valuecell/server/services/agent_stream_service.py +++ b/python/valuecell/server/services/agent_stream_service.py @@ -48,7 +48,7 @@ async def stream_query_agent( async for response_chunk in self.orchestrator.process_user_input( user_input ): - yield response_chunk.model_dump() + yield response_chunk.model_dump(exclude_none=True) except Exception as e: logger.error(f"Error in stream_query_agent: {str(e)}") diff --git a/python/valuecell/utils/uuid.py b/python/valuecell/utils/uuid.py index 501c95f8b..4e29a9306 100644 --- a/python/valuecell/utils/uuid.py +++ b/python/valuecell/utils/uuid.py @@ -10,3 +10,7 @@ def generate_uuid(prefix: str = None) -> str: def generate_message_id() -> str: return generate_uuid("msg") + + +def generate_thread_id() -> str: + return generate_uuid("th")