diff --git a/python/pyproject.toml b/python/pyproject.toml index 49341043e..c96f87b94 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "agno[openai]>=1.8.2,<2.0", "edgartools>=4.12.2", "sqlalchemy>=2.0.43", + "aiosqlite>=0.19.0", ] [project.optional-dependencies] diff --git a/python/third_party/ai-hedge-fund/adapter/__main__.py b/python/third_party/ai-hedge-fund/adapter/__main__.py index 91663ad92..28b3f26d2 100644 --- a/python/third_party/ai-hedge-fund/adapter/__main__.py +++ b/python/third_party/ai-hedge-fund/adapter/__main__.py @@ -62,7 +62,7 @@ async def stream( logger.info( f"Parsing query: {query}. Task ID: {task_id}, Session ID: {session_id}" ) - run_response = self.agno_agent.run( + run_response = await self.agno_agent.arun( f"Parse the following hedge fund analysis request and extract the parameters: {query}" ) hedge_fund_request = run_response.content @@ -101,7 +101,7 @@ async def stream( } logger.info(f"Start analyzing. Task ID: {task_id}, Session ID: {session_id}") - for _, chunk in run_hedge_fund_stream( + async for _, chunk in run_hedge_fund_stream( tickers=hedge_fund_request.tickers, start_date=start_date, end_date=end_date, @@ -116,7 +116,7 @@ async def stream( yield streaming.done() -def run_hedge_fund_stream( +async def run_hedge_fund_stream( tickers: list[str], start_date: str, end_date: str, @@ -153,7 +153,8 @@ def run_hedge_fund_stream( "model_provider": model_provider, }, } - yield from _agent.stream(inputs, stream_mode=["custom", "messages"]) + async for res in _agent.astream(inputs, stream_mode=["custom", "messages"]): + yield res finally: # Stop progress tracking progress.stop() diff --git a/python/third_party/ai-hedge-fund/uv.lock b/python/third_party/ai-hedge-fund/uv.lock index 4623ce045..3d8de0abe 100644 --- a/python/third_party/ai-hedge-fund/uv.lock +++ b/python/third_party/ai-hedge-fund/uv.lock @@ -215,6 +215,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "aiosqlite" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, +] + [[package]] name = "akracer" version = "0.0.14" @@ -3083,6 +3095,7 @@ source = { editable = "../../" } dependencies = [ { name = "a2a-sdk", extra = ["http-server"] }, { name = "agno", extra = ["openai"] }, + { name = "aiosqlite" }, { name = "akshare" }, { name = "edgartools" }, { name = "fastapi" }, @@ -3100,6 +3113,7 @@ dependencies = [ requires-dist = [ { name = "a2a-sdk", extras = ["http-server"], specifier = ">=0.3.4" }, { name = "agno", extras = ["openai"], specifier = ">=1.8.2,<2.0" }, + { name = "aiosqlite", specifier = ">=0.19.0" }, { name = "akshare", specifier = ">=1.17.44" }, { name = "edgartools", specifier = ">=4.12.2" }, { name = "fastapi", specifier = ">=0.104.0" }, diff --git a/python/uv.lock b/python/uv.lock index 810761da9..95e6de56f 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -140,6 +140,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "aiosqlite" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, +] + [[package]] name = "akracer" version = "0.0.13" @@ -2094,6 +2106,7 @@ source = { editable = "." } dependencies = [ { name = "a2a-sdk", extra = ["http-server"] }, { name = "agno", extra = ["openai"] }, + { name = "aiosqlite" }, { name = "akshare" }, { name = "edgartools" }, { name = "fastapi" }, @@ -2135,6 +2148,7 @@ test = [ requires-dist = [ { name = "a2a-sdk", extras = ["http-server"], specifier = ">=0.3.4" }, { name = "agno", extras = ["openai"], specifier = ">=1.8.2,<2.0" }, + { name = "aiosqlite", specifier = ">=0.19.0" }, { name = "akshare", specifier = ">=1.17.44" }, { name = "edgartools", specifier = ">=4.12.2" }, { name = "fastapi", specifier = ">=0.104.0" }, diff --git a/python/valuecell/core/__init__.py b/python/valuecell/core/__init__.py index 5a2b86241..4a78bda58 100644 --- a/python/valuecell/core/__init__.py +++ b/python/valuecell/core/__init__.py @@ -2,15 +2,17 @@ from .agent.decorator import create_wrapped_agent from .agent.responses import notification, streaming from .session import ( - InMemoryMessageStore, InMemorySessionStore, Message, - MessageStore, Role, Session, SessionManager, SessionStatus, SessionStore, +) +from .session.message_store import ( + InMemoryMessageStore, + MessageStore, SQLiteMessageStore, ) diff --git a/python/valuecell/core/agent/decorator.py b/python/valuecell/core/agent/decorator.py index 65805c783..99e064bad 100644 --- a/python/valuecell/core/agent/decorator.py +++ b/python/valuecell/core/agent/decorator.py @@ -13,7 +13,7 @@ InMemoryTaskStore, TaskUpdater, ) -from a2a.types import AgentCard, Part, TaskState, TextPart, UnsupportedOperationError +from a2a.types import AgentCard, TaskState, UnsupportedOperationError from a2a.utils import new_agent_text_message, new_task from a2a.utils.errors import ServerError from valuecell.core.agent.card import find_local_agent_card_by_agent_name @@ -21,7 +21,7 @@ BaseAgent, NotifyResponse, StreamResponse, - StreamResponseEvent, + CommonResponseEvent, ) from valuecell.utils import parse_host_port from .responses import EventPredicates @@ -118,34 +118,6 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non task_id = task.id session_id = task.context_id updater = TaskUpdater(event_queue, task_id, session_id) - artifact_id = f"artifact-{agent_name}-{session_id}-{task_id}" - chunk_idx = -1 - - # Local helper to add a chunk - async def _add_chunk( - response: StreamResponse | NotifyResponse, is_complete: bool - ): - nonlocal chunk_idx - - chunk_idx += 1 - if not response.content: - return - - parts = [Part(root=TextPart(text=response.content))] - response_event = response.event - metadata = { - "response_event": response_event.value, - "subtask_id": response.subtask_id, - } - if response_event == StreamResponseEvent.COMPONENT_GENERATOR: - metadata["component_type"] = response.metadata.get("component_type") - await updater.add_artifact( - parts=parts, - artifact_id=artifact_id, - append=chunk_idx > 0, - last_chunk=is_complete, - metadata=metadata, - ) # Stream from the user agent and update task incrementally await updater.update_status( @@ -170,35 +142,34 @@ async def _add_chunk( f"Agent {agent_name} reported failure: {response.content}" ) - is_complete = EventPredicates.is_task_completed(response_event) + metadata = {"response_event": response_event.value} if EventPredicates.is_tool_call(response_event): + metadata["tool_call_id"] = response.metadata.get("tool_call_id") + metadata["tool_name"] = response.metadata.get("tool_name") + metadata["tool_result"] = response.metadata.get("content") await updater.update_status( TaskState.working, message=new_agent_text_message(response.content or ""), - metadata={ - "event": response_event.value, - "tool_call_id": response.metadata.get("tool_call_id"), - "tool_name": response.metadata.get("tool_name"), - "tool_result": response.metadata.get("content"), - "subtask_id": response.subtask_id, - }, + metadata=metadata, ) continue if EventPredicates.is_reasoning(response_event): await updater.update_status( TaskState.working, message=new_agent_text_message(response.content or ""), - metadata={ - "event": response_event.value, - "subtask_id": response.subtask_id, - }, + metadata=metadata, ) continue - await _add_chunk(response, is_complete=is_complete) - if is_complete: - await updater.complete() - break + if not response.content: + continue + if response_event == CommonResponseEvent.COMPONENT_GENERATOR: + metadata["component_type"] = response.metadata.get("component_type") + await updater.update_status( + TaskState.working, + message=new_agent_text_message(response.content or ""), + metadata=metadata, + ) except Exception as e: message = f"Error during {agent_name} agent execution: {e}" @@ -206,8 +177,9 @@ async def _add_chunk( await updater.update_status( TaskState.failed, message=new_agent_text_message(message, session_id, task_id), - final=True, ) + finally: + await updater.complete() async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: # Default cancel operation diff --git a/python/valuecell/core/agent/responses.py b/python/valuecell/core/agent/responses.py index 74bc6ed3c..8ef129d2f 100644 --- a/python/valuecell/core/agent/responses.py +++ b/python/valuecell/core/agent/responses.py @@ -3,38 +3,33 @@ from typing import Optional from valuecell.core.types import ( + CommonResponseEvent, NotifyResponse, NotifyResponseEvent, StreamResponse, StreamResponseEvent, SystemResponseEvent, + TaskStatusEvent, ToolCallPayload, - _TaskResponseEvent, ) class _StreamResponseNamespace: """Factory methods for streaming responses.""" - def message_chunk( - self, content: str, subtask_id: str | None = None - ) -> StreamResponse: + def message_chunk(self, content: str) -> StreamResponse: return StreamResponse( event=StreamResponseEvent.MESSAGE_CHUNK, content=content, - subtask_id=subtask_id, ) - def tool_call_started( - self, tool_call_id: str, tool_name: str, subtask_id: str | None = None - ) -> StreamResponse: + def tool_call_started(self, tool_call_id: str, tool_name: str) -> StreamResponse: return StreamResponse( event=StreamResponseEvent.TOOL_CALL_STARTED, metadata=ToolCallPayload( tool_call_id=tool_call_id, tool_name=tool_name, ).model_dump(), - subtask_id=subtask_id, ) def tool_call_completed( @@ -42,7 +37,6 @@ def tool_call_completed( tool_result: str, tool_call_id: str, tool_name: str, - subtask_id: str | None = None, ) -> StreamResponse: return StreamResponse( event=StreamResponseEvent.TOOL_CALL_COMPLETED, @@ -51,42 +45,19 @@ def tool_call_completed( tool_name=tool_name, tool_result=tool_result, ).model_dump(), - subtask_id=subtask_id, - ) - - def reasoning_started(self, subtask_id: str | None = None) -> StreamResponse: - return StreamResponse( - event=StreamResponseEvent.REASONING_STARTED, - subtask_id=subtask_id, - ) - - def reasoning(self, content: str, subtask_id: str | None = None) -> StreamResponse: - return StreamResponse( - event=StreamResponseEvent.REASONING, - content=content, - subtask_id=subtask_id, - ) - - def reasoning_completed(self, subtask_id: str | None = None) -> StreamResponse: - return StreamResponse( - event=StreamResponseEvent.REASONING_COMPLETED, - subtask_id=subtask_id, ) - def component_generator( - self, content: str, component_type: str, subtask_id: str | None = None - ) -> StreamResponse: + def component_generator(self, content: str, component_type: str) -> StreamResponse: return StreamResponse( - event=StreamResponseEvent.COMPONENT_GENERATOR, + event=CommonResponseEvent.COMPONENT_GENERATOR, content=content, metadata={"component_type": component_type}, - subtask_id=subtask_id, ) def done(self, content: Optional[str] = None) -> StreamResponse: return StreamResponse( content=content, - event=_TaskResponseEvent.TASK_COMPLETED, + event=TaskStatusEvent.TASK_COMPLETED, ) def failed(self, content: Optional[str] = None) -> StreamResponse: @@ -108,10 +79,17 @@ def message(self, content: str) -> NotifyResponse: event=NotifyResponseEvent.MESSAGE, ) + def component_generator(self, content: str, component_type: str) -> StreamResponse: + return StreamResponse( + event=CommonResponseEvent.COMPONENT_GENERATOR, + content=content, + metadata={"component_type": component_type}, + ) + def done(self, content: Optional[str] = None) -> NotifyResponse: return NotifyResponse( content=content, - event=_TaskResponseEvent.TASK_COMPLETED, + event=TaskStatusEvent.TASK_COMPLETED, ) def failed(self, content: Optional[str] = None) -> NotifyResponse: @@ -134,7 +112,7 @@ class EventPredicates: @staticmethod def is_task_completed(response_type) -> bool: return response_type in { - _TaskResponseEvent.TASK_COMPLETED, + TaskStatusEvent.TASK_COMPLETED, } @staticmethod @@ -158,6 +136,13 @@ def is_reasoning(response_type) -> bool: StreamResponseEvent.REASONING_COMPLETED, } + @staticmethod + def is_message(response_type) -> bool: + return response_type in { + StreamResponseEvent.MESSAGE_CHUNK, + NotifyResponseEvent.MESSAGE, + } + __all__ = [ "streaming", diff --git a/python/valuecell/core/coordinate/__init__.py b/python/valuecell/core/coordinate/__init__.py index 571324943..9c1c53bcb 100644 --- a/python/valuecell/core/coordinate/__init__.py +++ b/python/valuecell/core/coordinate/__init__.py @@ -1,7 +1,6 @@ from .models import ExecutionPlan from .orchestrator import AgentOrchestrator, get_default_orchestrator from .planner import ExecutionPlanner -from .callback import store_task_in_session __all__ = [ @@ -9,5 +8,4 @@ "get_default_orchestrator", "ExecutionPlanner", "ExecutionPlan", - "store_task_in_session", ] diff --git a/python/valuecell/core/coordinate/callback.py b/python/valuecell/core/coordinate/callback.py deleted file mode 100644 index bd67ca260..000000000 --- a/python/valuecell/core/coordinate/callback.py +++ /dev/null @@ -1,16 +0,0 @@ -from a2a.types import Task -from valuecell.core.session import get_default_session_manager, Role - - -async def store_task_in_session(task: Task) -> None: - session_id = task.metadata.get("session_id") - if not session_id: - return - - session_manager = get_default_session_manager() - if not task.artifacts: - return - if not task.artifacts[-1].parts: - return - content = task.artifacts[-1].parts[-1].root.text - await session_manager.add_message(session_id, Role.AGENT, content) diff --git a/python/valuecell/core/coordinate/orchestrator.py b/python/valuecell/core/coordinate/orchestrator.py index 2c2b78635..a7a0f8089 100644 --- a/python/valuecell/core/coordinate/orchestrator.py +++ b/python/valuecell/core/coordinate/orchestrator.py @@ -1,30 +1,22 @@ import asyncio import logging -from collections import defaultdict from typing import AsyncGenerator, Dict, Optional from a2a.types import TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent from valuecell.core.agent.connect import get_default_remote_connections -from valuecell.core.agent.responses import EventPredicates from valuecell.core.coordinate.response import ResponseFactory +from valuecell.core.coordinate.response_buffer import ResponseBuffer, SaveItem from valuecell.core.coordinate.response_router import ( RouteResult, SideEffectKind, - handle_artifact_update, handle_status_update, ) -from valuecell.core.session import Role, SessionStatus, get_default_session_manager +from valuecell.core.session import SessionStatus, get_default_session_manager from valuecell.core.task import Task, get_default_task_manager from valuecell.core.task.models import TaskPattern -from valuecell.core.types import ( - BaseResponse, - NotifyResponseEvent, - StreamResponseEvent, - UserInput, -) +from valuecell.core.types import BaseResponse, UserInput from valuecell.utils.uuid import generate_thread_id -from .callback import store_task_in_session from .models import ExecutionPlan from .planner import ExecutionPlanner, UserInputRequest @@ -126,6 +118,8 @@ def __init__(self): self.planner = ExecutionPlanner(self.agent_connections) self._response_factory = ResponseFactory() + # Buffer for streaming responses -> persisted ConversationItems + self._response_buffer = ResponseBuffer() # ==================== Public API Methods ==================== @@ -199,32 +193,18 @@ def get_user_input_prompt(self, session_id: str) -> Optional[str]: """Get the user input prompt for a specific session""" return self.user_input_manager.get_request_prompt(session_id) - async def create_session(self, user_id: str, title: str = None): - """Create a new session for the user""" - return await self.session_manager.create_session(user_id, title) - async def close_session(self, session_id: str): """Close an existing session and clean up resources""" # Cancel any running tasks for this session - cancelled_count = await self.task_manager.cancel_session_tasks(session_id) + await self.task_manager.cancel_session_tasks(session_id) # Clean up execution context await self._cancel_execution(session_id) - # Add system message to mark session as closed - await self.session_manager.add_message( - session_id, - Role.SYSTEM, - f"Session closed. {cancelled_count} tasks were cancelled.", - ) - - async def get_session_history(self, session_id: str): + async def get_session_history(self, session_id: str) -> list[BaseResponse]: """Get session message history""" - return await self.session_manager.get_session_messages(session_id) - - async def get_user_sessions(self, user_id: str, limit: int = 100, offset: int = 0): - """Get all sessions for a user""" - return await self.session_manager.list_user_sessions(user_id, limit, offset) + items = await self.session_manager.get_session_messages(session_id) + return [self._response_factory.from_conversation_item(it) for it in items] async def cleanup(self): """Cleanup resources and expired contexts""" @@ -271,10 +251,20 @@ async def _handle_session_continuation( context.add_metadata(pending_response=user_input.query) await self.provide_user_input(session_id, user_input.query) + thread_id = generate_thread_id() + response = self._response_factory.thread_started( + conversation_id=session_id, thread_id=thread_id, user_query=user_input.query + ) + await self._persist_from_buffer(response) + yield response + context.thread_id = thread_id + # Resume based on execution stage if context.stage == "planning": - async for chunk in self._continue_planning(session_id, context): - yield chunk + async for response in self._continue_planning( + session_id, thread_id, context + ): + yield response # Resuming execution stage is not yet supported else: yield self._response_factory.system_failed( @@ -288,14 +278,11 @@ async def _handle_new_request( """Handle a new user request""" session_id = user_input.meta.session_id thread_id = generate_thread_id() - yield self._response_factory.thread_started( - conversation_id=session_id, thread_id=thread_id - ) - - # Add user message to session - await self.session_manager.add_message( - session_id, Role.USER, user_input.query, user_id=user_input.meta.user_id + response = self._response_factory.thread_started( + conversation_id=session_id, thread_id=thread_id, user_query=user_input.query ) + await self._persist_from_buffer(response) + yield response # Create planning task with user input callback context_aware_callback = self._create_context_aware_callback(session_id) @@ -305,10 +292,10 @@ async def _handle_new_request( ) # Monitor planning progress - async for chunk in self._monitor_planning_task( + async for response in self._monitor_planning_task( planning_task, thread_id, user_input, context_aware_callback ): - yield chunk + yield response def _create_context_aware_callback(self, session_id: str): """Create a callback that adds session context to user input requests""" @@ -344,17 +331,19 @@ async def _monitor_planning_task( # Update session status and send user input request await self._request_user_input(session_id) - yield self._response_factory.plan_require_user_input( + response = self._response_factory.plan_require_user_input( session_id, thread_id, self.get_user_input_prompt(session_id) ) + await self._persist_from_buffer(response) + yield response return await asyncio.sleep(ASYNC_SLEEP_INTERVAL) # Planning completed, execute plan plan = await planning_task - async for chunk in self._execute_plan_with_input_support(plan, thread_id): - yield chunk + async for response in self._execute_plan_with_input_support(plan, thread_id): + yield response async def _request_user_input(self, session_id: str): """Set session to require user input and send the request""" @@ -379,16 +368,11 @@ def _validate_execution_context( return True async def _continue_planning( - self, session_id: str, context: ExecutionContext + self, session_id: str, thread_id: str, context: ExecutionContext ) -> AsyncGenerator[BaseResponse, None]: """Resume planning stage execution""" planning_task = context.get_metadata("planning_task") original_user_input = context.get_metadata("original_user_input") - thread_id = generate_thread_id() - context.thread_id = thread_id - yield self._response_factory.thread_started( - conversation_id=session_id, thread_id=thread_id - ) if not all([planning_task, original_user_input]): yield self._response_factory.plan_failed( @@ -406,9 +390,11 @@ async def _continue_planning( prompt = self.get_user_input_prompt(session_id) # Ensure session is set to require user input again for repeated prompts await self._request_user_input(session_id) - yield self._response_factory.plan_require_user_input( + response = self._response_factory.plan_require_user_input( session_id, thread_id, prompt ) + await self._persist_from_buffer(response) + yield response return await asyncio.sleep(ASYNC_SLEEP_INTERVAL) @@ -417,8 +403,8 @@ async def _continue_planning( plan = await planning_task del self._execution_contexts[session_id] - async for message in self._execute_plan_with_input_support(plan, thread_id): - yield message + async for response in self._execute_plan_with_input_support(plan, thread_id): + yield response async def _cancel_execution(self, session_id: str): """Cancel execution and clean up all related resources""" @@ -481,9 +467,6 @@ async def _execute_plan_with_input_support( ) return - # Track agent responses for session storage - agent_responses = defaultdict(str) - for task in plan.tasks: try: # Register the task with TaskManager @@ -493,42 +476,24 @@ async def _execute_plan_with_input_support( async for response in self._execute_task_with_input_support( task, thread_id, metadata ): + # Ensure buffered events carry a stable paragraph item_id + annotated = self._response_buffer.annotate(response) # Accumulate based on event - if response.event in { - StreamResponseEvent.MESSAGE_CHUNK, - StreamResponseEvent.REASONING, - NotifyResponseEvent.MESSAGE, - } and isinstance(response.data.payload.content, str): - agent_responses[task.agent_name] += ( - response.data.payload.content - ) - yield response + yield annotated - if ( - EventPredicates.is_task_completed(response.event) - or task.pattern == TaskPattern.RECURRING - ): - if agent_responses[task.agent_name].strip(): - await self.session_manager.add_message( - session_id, - Role.AGENT, - agent_responses[task.agent_name], - ) - agent_responses[task.agent_name] = "" + # Persist via ResponseBuffer + await self._persist_from_buffer(annotated) except Exception as e: - error_msg = f"(Error) Error executing {task.agent_name}: {str(e)}" + error_msg = f"(Error) Error executing {task.task_id}: {str(e)}" logger.exception(f"Task execution failed: {error_msg}") yield self._response_factory.task_failed( session_id, thread_id, task.task_id, - _generate_task_default_subtask_id(task.task_id), error_msg, ) - # Save any remaining agent responses - await self._save_remaining_responses(session_id, agent_responses) yield self._response_factory.done(session_id, thread_id) async def _execute_task_with_input_support( @@ -551,7 +516,6 @@ async def _execute_task_with_input_support( agent_card = await self.agent_connections.start_agent( agent_name, with_listener=False, - notification_callback=store_task_in_session, ) client = await self.agent_connections.get_client(agent_name) if not client: @@ -581,6 +545,7 @@ async def _execute_task_with_input_support( self._response_factory, task, thread_id, event ) for r in result.responses: + r = self._response_buffer.annotate(r) yield r # Apply side effects for eff in result.side_effects: @@ -593,11 +558,9 @@ async def _execute_task_with_input_support( continue if isinstance(event, TaskArtifactUpdateEvent): - responses = await handle_artifact_update( - self._response_factory, task, thread_id, event + logger.info( + f"Received unexpected artifact update for task {task.task_id}: {event}" ) - for r in responses: - yield r continue # Complete task successfully @@ -606,25 +569,42 @@ async def _execute_task_with_input_support( conversation_id=task.session_id, thread_id=thread_id, task_id=task.task_id, - subtask_id=_generate_task_default_subtask_id(task.task_id), ) + # Finalize buffered aggregates for this task (explicit flush at task end) + items = self._response_buffer.flush_task( + conversation_id=task.session_id, + thread_id=thread_id, + task_id=task.task_id, + ) + await self._persist_items(items) except Exception as e: + # On failure, finalize any buffered aggregates for this task + items = self._response_buffer.flush_task( + conversation_id=task.session_id, + thread_id=thread_id, + task_id=task.task_id, + ) + await self._persist_items(items) await self.task_manager.fail_task(task.task_id, str(e)) raise e - async def _save_remaining_responses(self, session_id: str, agent_responses: dict): - """Save any remaining agent responses to the session""" - for agent_name, full_response in agent_responses.items(): - if full_response.strip(): - await self.session_manager.add_message( - session_id, Role.AGENT, full_response - ) - - -def _generate_task_default_subtask_id(task_id: str) -> str: - """Generate a default subtask ID based on the main task ID""" - return f"{task_id}-default_subtask" + async def _persist_from_buffer(self, response: BaseResponse): + """Ingest a response into the buffer and persist any SaveMessages produced.""" + items = self._response_buffer.ingest(response) + await self._persist_items(items) + + async def _persist_items(self, items: list[SaveItem]): + for it in items: + await self.session_manager.add_message( + role=it.role, + event=it.event, + conversation_id=it.conversation_id, + thread_id=it.thread_id, + task_id=it.task_id, + payload=it.payload, + item_id=it.item_id, + ) # ==================== Module-level Factory Function ==================== diff --git a/python/valuecell/core/coordinate/planner.py b/python/valuecell/core/coordinate/planner.py index b38d3d6a4..17c2b2569 100644 --- a/python/valuecell/core/coordinate/planner.py +++ b/python/valuecell/core/coordinate/planner.py @@ -118,7 +118,7 @@ async def _analyze_input_and_create_tasks( # Execute planning with the agent run_response = agent.run( message=PlannerInput( - desired_agent_name=user_input.get_desired_agent(), + desired_agent_name=user_input.desired_agent_name, query=user_input.query, ) ) diff --git a/python/valuecell/core/coordinate/response.py b/python/valuecell/core/coordinate/response.py index ab9f73190..285a1adce 100644 --- a/python/valuecell/core/coordinate/response.py +++ b/python/valuecell/core/coordinate/response.py @@ -3,8 +3,10 @@ from typing_extensions import Literal from valuecell.core.types import ( BaseResponseDataPayload, + CommonResponseEvent, ComponentGeneratorResponse, ComponentGeneratorResponseDataPayload, + ConversationItem, ConversationStartedResponse, DoneResponse, MessageResponse, @@ -12,29 +14,139 @@ PlanFailedResponse, PlanRequireUserInputResponse, ReasoningResponse, + Role, StreamResponseEvent, SystemFailedResponse, + SystemResponseEvent, TaskCompletedResponse, TaskFailedResponse, + TaskStatusEvent, ThreadStartedResponse, ToolCallPayload, ToolCallResponse, UnifiedResponseData, ) +from valuecell.utils.uuid import generate_item_id, generate_uuid class ResponseFactory: + def from_conversation_item(self, item: ConversationItem): + """Reconstruct a BaseResponse from a persisted ConversationItem. + + - Maps the stored event to the appropriate Response subtype + - Parses payload JSON back into the right payload model when possible + - Preserves the original item_id so callers can correlate history items + """ + + # Coerce enums that may have been persisted as strings + ev = item.event + if isinstance(ev, str): + for enum_cls in ( + SystemResponseEvent, + StreamResponseEvent, + NotifyResponseEvent, + CommonResponseEvent, + TaskStatusEvent, + ): + try: + ev = enum_cls(ev) # type: ignore[arg-type] + break + except Exception: + continue + + role = item.role + if isinstance(role, str): + try: + role = Role(role) + except Exception: + role = Role.AGENT + + # Helpers for payload parsing + def parse_payload_as(model_cls): + raw = item.payload + if raw is None: + return None + try: + return model_cls.model_validate_json(raw) + except Exception: + # Fallback to plain text payload + try: + return BaseResponseDataPayload(content=str(raw)) + except Exception: + return None + + # Base UnifiedResponseData builder + def make_data(payload=None): + return UnifiedResponseData( + conversation_id=item.conversation_id, + thread_id=item.thread_id, + task_id=item.task_id, + payload=payload, + role=role, + item_id=item.item_id, + ) + + # ----- System-level events ----- + if ev == SystemResponseEvent.THREAD_STARTED: + payload = parse_payload_as(BaseResponseDataPayload) + return ThreadStartedResponse(data=make_data(payload)) + + if ev == SystemResponseEvent.PLAN_REQUIRE_USER_INPUT: + payload = parse_payload_as(BaseResponseDataPayload) + return PlanRequireUserInputResponse(data=make_data(payload)) + + # ----- Stream/notify/common events ----- + if ev == StreamResponseEvent.MESSAGE_CHUNK: + payload = parse_payload_as(BaseResponseDataPayload) + return MessageResponse( + event=StreamResponseEvent.MESSAGE_CHUNK, data=make_data(payload) + ) + + if ev == NotifyResponseEvent.MESSAGE: + payload = parse_payload_as(BaseResponseDataPayload) + return MessageResponse( + event=NotifyResponseEvent.MESSAGE, data=make_data(payload) + ) + + if ev in ( + StreamResponseEvent.REASONING, + StreamResponseEvent.REASONING_STARTED, + StreamResponseEvent.REASONING_COMPLETED, + ): + payload = parse_payload_as(BaseResponseDataPayload) + # ReasoningResponse accepts optional payload + return ReasoningResponse(event=ev, data=make_data(payload)) + + if ev == CommonResponseEvent.COMPONENT_GENERATOR: + payload = parse_payload_as(ComponentGeneratorResponseDataPayload) + return ComponentGeneratorResponse(data=make_data(payload)) + + if ev in ( + StreamResponseEvent.TOOL_CALL_STARTED, + StreamResponseEvent.TOOL_CALL_COMPLETED, + ): + payload = parse_payload_as(ToolCallPayload) + return ToolCallResponse(event=ev, data=make_data(payload)) + + raise ValueError( + f"Unsupported event type: {ev} when processing conversation item." + ) + def conversation_started(self, conversation_id: str) -> ConversationStartedResponse: return ConversationStartedResponse( - data=UnifiedResponseData(conversation_id=conversation_id) + data=UnifiedResponseData(conversation_id=conversation_id, role=Role.SYSTEM) ) def thread_started( - self, conversation_id: str, thread_id: str + self, conversation_id: str, thread_id: str, user_query: str ) -> ThreadStartedResponse: return ThreadStartedResponse( data=UnifiedResponseData( - conversation_id=conversation_id, thread_id=thread_id + conversation_id=conversation_id, + thread_id=thread_id, + task_id=generate_uuid("ask"), + payload=BaseResponseDataPayload(content=user_query), + role=Role.USER, ) ) @@ -43,6 +155,7 @@ def system_failed(self, conversation_id: str, content: str) -> SystemFailedRespo data=UnifiedResponseData( conversation_id=conversation_id, payload=BaseResponseDataPayload(content=content), + role=Role.SYSTEM, ) ) @@ -51,6 +164,7 @@ def done(self, conversation_id: str, thread_id: str) -> DoneResponse: data=UnifiedResponseData( conversation_id=conversation_id, thread_id=thread_id, + role=Role.SYSTEM, ) ) @@ -62,6 +176,7 @@ def plan_require_user_input( conversation_id=conversation_id, thread_id=thread_id, payload=BaseResponseDataPayload(content=content), + role=Role.SYSTEM, ) ) @@ -73,6 +188,7 @@ def plan_failed( conversation_id=conversation_id, thread_id=thread_id, payload=BaseResponseDataPayload(content=content), + role=Role.SYSTEM, ) ) @@ -81,7 +197,6 @@ def task_failed( conversation_id: str, thread_id: str, task_id: str, - subtask_id: str | None, content: str, ) -> TaskFailedResponse: return TaskFailedResponse( @@ -89,8 +204,8 @@ def task_failed( conversation_id=conversation_id, thread_id=thread_id, task_id=task_id, - subtask_id=subtask_id, payload=BaseResponseDataPayload(content=content), + role=Role.AGENT, ) ) @@ -99,14 +214,13 @@ def task_completed( conversation_id: str, thread_id: str, task_id: str, - subtask_id: str | None, ) -> TaskCompletedResponse: return TaskCompletedResponse( data=UnifiedResponseData( conversation_id=conversation_id, thread_id=thread_id, task_id=task_id, - subtask_id=subtask_id, + role=Role.AGENT, ), ) @@ -115,7 +229,6 @@ def tool_call( conversation_id: str, thread_id: str, task_id: str, - subtask_id: str, event: Literal[ StreamResponseEvent.TOOL_CALL_STARTED, StreamResponseEvent.TOOL_CALL_COMPLETED, @@ -130,12 +243,12 @@ def tool_call( conversation_id=conversation_id, thread_id=thread_id, task_id=task_id, - subtask_id=subtask_id, payload=ToolCallPayload( tool_call_id=tool_call_id, tool_name=tool_name, tool_result=tool_result, ), + role=Role.AGENT, ), ) @@ -145,8 +258,8 @@ def message_response_general( conversation_id: str, thread_id: str, task_id: str, - subtask_id: str, content: str, + item_id: Optional[str] = None, ) -> MessageResponse: return MessageResponse( event=event, @@ -154,8 +267,11 @@ def message_response_general( conversation_id=conversation_id, thread_id=thread_id, task_id=task_id, - subtask_id=subtask_id, - payload=BaseResponseDataPayload(content=content), + payload=BaseResponseDataPayload( + content=content, + ), + role=Role.AGENT, + item_id=item_id or generate_item_id(), ), ) @@ -164,7 +280,6 @@ def reasoning( conversation_id: str, thread_id: str, task_id: str, - subtask_id: str, event: Literal[ StreamResponseEvent.REASONING, StreamResponseEvent.REASONING_STARTED, @@ -178,8 +293,8 @@ def reasoning( conversation_id=conversation_id, thread_id=thread_id, task_id=task_id, - subtask_id=subtask_id, - payload=BaseResponseDataPayload(content=content) if content else None, + payload=(BaseResponseDataPayload(content=content) if content else None), + role=Role.AGENT, ), ) @@ -188,7 +303,6 @@ def component_generator( conversation_id: str, thread_id: str, task_id: str, - subtask_id: str, content: str, component_type: str, ) -> ComponentGeneratorResponse: @@ -197,10 +311,10 @@ def component_generator( conversation_id=conversation_id, thread_id=thread_id, task_id=task_id, - subtask_id=subtask_id, payload=ComponentGeneratorResponseDataPayload( content=content, component_type=component_type, ), + role=Role.AGENT, ), ) diff --git a/python/valuecell/core/coordinate/response_buffer.py b/python/valuecell/core/coordinate/response_buffer.py new file mode 100644 index 000000000..06790f435 --- /dev/null +++ b/python/valuecell/core/coordinate/response_buffer.py @@ -0,0 +1,271 @@ +import time +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +from pydantic import BaseModel +from valuecell.core.types import ( + BaseResponse, + BaseResponseDataPayload, + CommonResponseEvent, + NotifyResponseEvent, + Role, + StreamResponseEvent, + SystemResponseEvent, + UnifiedResponseData, +) +from valuecell.utils.uuid import generate_item_id + + +@dataclass +class SaveItem: + item_id: str + event: object # ConversationItemEvent union; keep generic to avoid circular typing + conversation_id: str + thread_id: Optional[str] + task_id: Optional[str] + payload: Optional[BaseModel] + role: Role = Role.AGENT + + +# conversation_id, thread_id, task_id, event +BufferKey = Tuple[str, Optional[str], Optional[str], object] + + +class BufferEntry: + def __init__(self, item_id: Optional[str] = None, role: Optional[Role] = None): + self.parts: List[str] = [] + self.last_updated: float = time.monotonic() + # Stable paragraph id for this buffer entry. Reused across streamed chunks + # until this entry is flushed (debounce/boundary). On size-based flush, + # we rotate to a new paragraph id for subsequent chunks. + self.item_id: str = item_id or generate_item_id() + self.role: Optional[Role] = role + + def append(self, text: str): + if text: + self.parts.append(text) + self.last_updated = time.monotonic() + + def snapshot_payload(self) -> Optional[BaseResponseDataPayload]: + """Return current aggregate content without clearing the buffer.""" + if not self.parts: + return None + content = "".join(self.parts) + return BaseResponseDataPayload(content=content) + + +class ResponseBuffer: + """Buffers streaming responses and emits SaveMessage at suitable boundaries. + + Simplified rules (no debounce, no size-based rotation): + - Immediate write-through: tool_call_completed, component_generator, message, plan_require_user_input + - Buffered: message_chunk, reasoning + - Maintain a stable paragraph item_id per (conversation, thread, task, event) + - On every chunk, update the aggregate and return a SaveItem for upsert + - Buffer key = (conversation_id, thread_id, task_id, event) + """ + + def __init__(self): + self._buffers: Dict[BufferKey, BufferEntry] = {} + + self._immediate_events = { + StreamResponseEvent.TOOL_CALL_COMPLETED, + CommonResponseEvent.COMPONENT_GENERATOR, + NotifyResponseEvent.MESSAGE, + SystemResponseEvent.PLAN_REQUIRE_USER_INPUT, + SystemResponseEvent.THREAD_STARTED, + } + self._buffered_events = { + StreamResponseEvent.MESSAGE_CHUNK, + StreamResponseEvent.REASONING, + } + + def annotate(self, resp: BaseResponse) -> BaseResponse: + """Ensure buffered events carry a stable paragraph item_id on the response. + + For buffered events (message_chunk, reasoning), we assign a stable + paragraph id per (conversation, thread, task, event) key and stamp it + into resp.data.item_id so the frontend can correlate chunks and the + final persisted SaveItem. Immediate and boundary events are left as-is. + """ + data: UnifiedResponseData = resp.data + ev = resp.event + if ev in self._buffered_events: + key: BufferKey = ( + data.conversation_id, + data.thread_id, + data.task_id, + ev, + ) + entry = self._buffers.get(key) + if not entry: + # Start a new paragraph buffer with a fresh paragraph item_id + entry = BufferEntry(role=data.role) + self._buffers[key] = entry + # Stamp the response with the stable paragraph id + data.item_id = entry.item_id + resp.data = data + return resp + + def ingest(self, resp: BaseResponse) -> List[SaveItem]: + data: UnifiedResponseData = resp.data + ev = resp.event + + ctx = ( + data.conversation_id, + data.thread_id, + data.task_id, + ) + out: List[SaveItem] = [] + + # Immediate: write-through, but treat as paragraph boundary for buffered keys + if ev in self._immediate_events: + # Flush buffered aggregates for this context before the immediate item + conv_id, th_id, tk_id = ctx + keys_to_flush = self._collect_task_keys(conv_id, th_id, tk_id) + out.extend(self._finalize_keys(keys_to_flush)) + # Now write the immediate item + out.append(self._make_save_item_from_response(resp)) + return out + + # Buffered: accumulate by (ctx + event) + if ev in self._buffered_events: + key: BufferKey = (*ctx, ev) + entry = self._buffers.get(key) + if not entry: + # If annotate() wasn't called, create an entry now. + entry = BufferEntry(role=data.role) + self._buffers[key] = entry + + # Extract text content from payload + payload = data.payload + text = None + if isinstance(payload, BaseResponseDataPayload): + text = payload.content or "" + elif isinstance(payload, BaseModel): + # Fallback: serialize whole payload + text = payload.model_dump_json(exclude_none=True) + elif isinstance(payload, str): + text = payload + else: + text = "" + + if text: + entry.append(text) + # Always upsert current aggregate (no size-based rotation) + snap = entry.snapshot_payload() + if snap is not None: + out.append( + self._make_save_item( + event=ev, + data=data, + payload=snap, + item_id=entry.item_id, + ) + ) + return out + + # Other events: ignore for storage by default + return out + + # No flush API: paragraph boundaries are triggered by immediate events only + + def _collect_task_keys( + self, + conversation_id: str, + thread_id: Optional[str], + task_id: Optional[str], + ) -> List[BufferKey]: + keys: List[BufferKey] = [] + for key in list(self._buffers.keys()): + k_conv, k_thread, k_task, k_event = key + if ( + k_conv == conversation_id + and (thread_id is None or k_thread == thread_id) + and (task_id is None or k_task == task_id) + and k_event in self._buffered_events + ): + keys.append(key) + return keys + + def _finalize_keys(self, keys: List[BufferKey]) -> List[SaveItem]: + out: List[SaveItem] = [] + for key in keys: + entry = self._buffers.get(key) + if not entry: + continue + payload = entry.snapshot_payload() + if payload is not None: + out.append( + SaveItem( + item_id=entry.item_id, + event=key[3], + conversation_id=key[0], + thread_id=key[1], + task_id=key[2], + payload=payload, + role=entry.role or Role.AGENT, + ) + ) + if key in self._buffers: + del self._buffers[key] + return out + + def flush_task( + self, + conversation_id: str, + thread_id: Optional[str], + task_id: Optional[str], + ) -> List[SaveItem]: + """Finalize and emit all buffered aggregates for a given task context. + + This writes current aggregates (using their stable paragraph item_id) + and clears the corresponding buffers. Use at task end (success or fail). + """ + keys_to_flush = self._collect_task_keys(conversation_id, thread_id, task_id) + return self._finalize_keys(keys_to_flush) + + def _make_save_item_from_response(self, resp: BaseResponse) -> SaveItem: + data: UnifiedResponseData = resp.data + payload = data.payload + + # Ensure payload is BaseModel for SessionManager + if isinstance(payload, BaseModel): + bm = payload + elif isinstance(payload, str): + bm = BaseResponseDataPayload(content=payload) + elif payload is None: + bm = BaseResponseDataPayload(content=None) + else: + # Fallback to JSON string + try: + bm = BaseResponseDataPayload(content=str(payload)) + except Exception: + bm = BaseResponseDataPayload(content=None) + + return SaveItem( + item_id=data.item_id, + event=resp.event, + conversation_id=data.conversation_id, + thread_id=data.thread_id, + task_id=data.task_id, + payload=bm, + role=data.role, + ) + + def _make_save_item( + self, + event: object, + data: UnifiedResponseData, + payload: BaseModel, + item_id: str | None = None, + ) -> SaveItem: + return SaveItem( + item_id=item_id or generate_item_id(), + event=event, + conversation_id=data.conversation_id, + thread_id=data.thread_id, + task_id=data.task_id, + payload=payload, + role=data.role, + ) diff --git a/python/valuecell/core/coordinate/response_router.py b/python/valuecell/core/coordinate/response_router.py index 7b97b13ea..bca1b6fd9 100644 --- a/python/valuecell/core/coordinate/response_router.py +++ b/python/valuecell/core/coordinate/response_router.py @@ -3,12 +3,15 @@ from enum import Enum from typing import List, Optional -from a2a.types import TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent +from a2a.types import TaskState, TaskStatusUpdateEvent from a2a.utils import get_message_text from valuecell.core.agent.responses import EventPredicates from valuecell.core.coordinate.response import ResponseFactory from valuecell.core.task import Task -from valuecell.core.types import BaseResponse, StreamResponseEvent +from valuecell.core.types import ( + BaseResponse, + CommonResponseEvent, +) logger = logging.getLogger(__name__) @@ -34,10 +37,6 @@ def __post_init__(self): self.side_effects = [] -def _default_subtask_id(task_id: str) -> str: - return f"{task_id}_default-subtask" - - async def handle_status_update( response_factory: ResponseFactory, task: Task, @@ -58,7 +57,6 @@ async def handle_status_update( conversation_id=task.session_id, thread_id=thread_id, task_id=task.task_id, - subtask_id=_default_subtask_id(task.task_id), content=err_msg, ) ) @@ -72,9 +70,6 @@ async def handle_status_update( return RouteResult(responses) response_event = event.metadata.get("response_event") - subtask_id = event.metadata.get("subtask_id") - if not subtask_id: - subtask_id = _default_subtask_id(task.task_id) # Tool call events if state == TaskState.working and EventPredicates.is_tool_call(response_event): @@ -89,7 +84,6 @@ async def handle_status_update( conversation_id=task.session_id, thread_id=thread_id, task_id=task.task_id, - subtask_id=subtask_id, event=response_event, tool_call_id=tool_call_id, tool_name=tool_name, @@ -99,58 +93,47 @@ async def handle_status_update( return RouteResult(responses) # Reasoning messages + content = get_message_text(event.status.message, "") if state == TaskState.working and EventPredicates.is_reasoning(response_event): responses.append( response_factory.reasoning( conversation_id=task.session_id, thread_id=thread_id, task_id=task.task_id, - subtask_id=subtask_id, event=response_event, - content=get_message_text(event.status.message, ""), + content=content, ) ) return RouteResult(responses) - return RouteResult(responses) - - -async def handle_artifact_update( - response_factory: ResponseFactory, - task: Task, - thread_id: str, - event: TaskArtifactUpdateEvent, -) -> List[BaseResponse]: - responses: List[BaseResponse] = [] - artifact = event.artifact - subtask_id = artifact.metadata.get("subtask_id") if artifact.metadata else None - if not subtask_id: - subtask_id = _default_subtask_id(task.task_id) - response_event = artifact.metadata.get("response_event") - content = get_message_text(artifact, "") - - if response_event == StreamResponseEvent.COMPONENT_GENERATOR: - component_type = artifact.metadata.get("component_type", "unknown") + # component generator + if ( + state == TaskState.working + and response_event == CommonResponseEvent.COMPONENT_GENERATOR + ): + component_type = event.metadata.get("component_type", "unknown") responses.append( response_factory.component_generator( conversation_id=task.session_id, thread_id=thread_id, task_id=task.task_id, - subtask_id=subtask_id, content=content, component_type=component_type, ) ) - return responses - - responses.append( - response_factory.message_response_general( - event=response_event, - conversation_id=task.session_id, - thread_id=thread_id, - task_id=task.task_id, - subtask_id=subtask_id, - content=content, + return RouteResult(responses) + + # general messages + if state == TaskState.working and EventPredicates.is_message(response_event): + responses.append( + response_factory.message_response_general( + event=response_event, + conversation_id=task.session_id, + thread_id=thread_id, + task_id=task.task_id, + content=content, + ) ) - ) - return responses + return RouteResult(responses) + + return RouteResult(responses) diff --git a/python/valuecell/core/coordinate/tests/test_e2e_persistence.py b/python/valuecell/core/coordinate/tests/test_e2e_persistence.py new file mode 100644 index 000000000..c222546cd --- /dev/null +++ b/python/valuecell/core/coordinate/tests/test_e2e_persistence.py @@ -0,0 +1,47 @@ +import asyncio + +import pytest + +from valuecell.core.coordinate.orchestrator import AgentOrchestrator +from valuecell.core.types import UserInput, UserInputMetadata + + +@pytest.mark.asyncio +async def test_orchestrator_buffer_store_e2e(tmp_path, monkeypatch): + # Point default SessionManager to a temp sqlite file + db_path = tmp_path / "e2e_valuecell.db" + monkeypatch.setenv("VALUECELL_SQLITE_DB", str(db_path)) + + orch = AgentOrchestrator() + + # Prepare a session and a simple query; orchestrator will create the session if missing + session_id = "e2e-session" + user_id = "e2e-user" + ui = UserInput( + query="hello world", + desired_agent_name="TestAgent", + meta=UserInputMetadata(session_id=session_id, user_id=user_id), + ) + + # We don't have a live agent, so we expect planner/agent logic to raise; we just want to ensure + # that at least conversation_started and done/error go through the buffer->store path without crashing. + out = [] + try: + async for resp in orch.process_user_input(ui): + out.append(resp) + # allow buffer debounce tick + await asyncio.sleep(0) + except Exception: + # Orchestrator is defensive, should not raise; but in case, we still proceed to check persistence + pass + + # Verify persistence: at least 1 message exists for session + msgs = await orch.session_manager.get_session_messages(session_id) + assert isinstance(msgs, list) + assert len(msgs) >= 1 + + # Also verify we can count and fetch latest + cnt = await orch.session_manager.get_message_count(session_id) + assert cnt == len(msgs) + latest = await orch.session_manager.get_latest_message(session_id) + assert latest is not None diff --git a/python/valuecell/core/coordinate/tests/test_orchestrator.py b/python/valuecell/core/coordinate/tests/test_orchestrator.py index ff5e711ce..434d06839 100644 --- a/python/valuecell/core/coordinate/tests/test_orchestrator.py +++ b/python/valuecell/core/coordinate/tests/test_orchestrator.py @@ -329,29 +329,3 @@ async def test_agent_connection_error( out.append(chunk) assert any("(Error)" in c.data.payload.content for c in out if c.data.payload) - - -@pytest.mark.asyncio -async def test_create_and_close_session( - orchestrator: AgentOrchestrator, user_id: str, session_id: str -): - # create - new_id = await orchestrator.create_session(user_id, "Title") - orchestrator.session_manager.create_session.assert_called_once_with( - user_id, "Title" - ) - assert new_id == "new-session-id" - - # close - orchestrator.task_manager.cancel_session_tasks.return_value = 1 - await orchestrator.close_session(session_id) - orchestrator.task_manager.cancel_session_tasks.assert_called_once_with(session_id) - orchestrator.session_manager.add_message.assert_called_once() - - -@pytest.mark.asyncio -async def test_cleanup(orchestrator: AgentOrchestrator): - orchestrator.agent_connections = Mock() - orchestrator.agent_connections.stop_all = AsyncMock() - await orchestrator.cleanup() - orchestrator.agent_connections.stop_all.assert_called_once() diff --git a/python/valuecell/core/coordinate/tests/test_response_factory.py b/python/valuecell/core/coordinate/tests/test_response_factory.py new file mode 100644 index 000000000..3488e4fe7 --- /dev/null +++ b/python/valuecell/core/coordinate/tests/test_response_factory.py @@ -0,0 +1,102 @@ +import pytest +from valuecell.core.coordinate.response import ResponseFactory +from valuecell.core.types import ( + BaseResponseDataPayload, + CommonResponseEvent, + ComponentGeneratorResponseDataPayload, + ConversationItem, + NotifyResponseEvent, + Role, + StreamResponseEvent, + SystemResponseEvent, + ToolCallPayload, +) + + +@pytest.fixture +def factory() -> ResponseFactory: + return ResponseFactory() + + +def _mk_item( + *, + event: str, + payload: str | None, + role: str | Role = "agent", + item_id: str = "it-1", + conversation_id: str = "sess-1", + thread_id: str | None = "th-1", + task_id: str | None = "tk-1", +) -> ConversationItem: + return ConversationItem( + item_id=item_id, + role=role, # stored as string in SQLite + event=event, # stored as string in SQLite + conversation_id=conversation_id, + thread_id=thread_id, + task_id=task_id, + payload=payload, + ) + + +def test_thread_started_with_payload(factory: ResponseFactory): + payload = BaseResponseDataPayload(content="hello user").model_dump_json() + item = _mk_item( + event=SystemResponseEvent.THREAD_STARTED.value, + payload=payload, + role="user", + ) + resp = factory.from_conversation_item(item) + assert resp.event == SystemResponseEvent.THREAD_STARTED + assert resp.data.payload is not None + assert resp.data.payload.content == "hello user" # type: ignore[attr-defined] + + +def test_message_chunk(factory: ResponseFactory): + payload = BaseResponseDataPayload(content="chunk").model_dump_json() + item = _mk_item(event=StreamResponseEvent.MESSAGE_CHUNK.value, payload=payload) + resp = factory.from_conversation_item(item) + assert resp.event == StreamResponseEvent.MESSAGE_CHUNK + assert resp.data.payload.content == "chunk" # type: ignore[attr-defined] + + +def test_notify_message(factory: ResponseFactory): + payload = BaseResponseDataPayload(content="notify").model_dump_json() + item = _mk_item(event=NotifyResponseEvent.MESSAGE.value, payload=payload) + resp = factory.from_conversation_item(item) + assert resp.event == NotifyResponseEvent.MESSAGE + assert resp.data.payload.content == "notify" # type: ignore[attr-defined] + + +def test_reasoning_with_payload(factory: ResponseFactory): + payload = BaseResponseDataPayload(content="thinking...").model_dump_json() + item = _mk_item(event=StreamResponseEvent.REASONING.value, payload=payload) + resp = factory.from_conversation_item(item) + assert resp.event == StreamResponseEvent.REASONING + assert resp.data.payload.content == "thinking..." # type: ignore[attr-defined] + + +def test_component_generator(factory: ResponseFactory): + payload = ComponentGeneratorResponseDataPayload( + content="render this", component_type="chart" + ).model_dump_json() + item = _mk_item( + event=CommonResponseEvent.COMPONENT_GENERATOR.value, + payload=payload, + ) + resp = factory.from_conversation_item(item) + assert resp.event == CommonResponseEvent.COMPONENT_GENERATOR + assert resp.data.payload.component_type == "chart" # type: ignore[attr-defined] + + +def test_tool_call_completed(factory: ResponseFactory): + payload = ToolCallPayload( + tool_call_id="tc-1", tool_name="search", tool_result="{result}" + ).model_dump_json() + item = _mk_item( + event=StreamResponseEvent.TOOL_CALL_COMPLETED.value, + payload=payload, + ) + resp = factory.from_conversation_item(item) + assert resp.event == StreamResponseEvent.TOOL_CALL_COMPLETED + assert resp.data.payload.tool_name == "search" # type: ignore[attr-defined] diff --git a/python/valuecell/core/session/__init__.py b/python/valuecell/core/session/__init__.py index b9ffa29a6..16bd38c75 100644 --- a/python/valuecell/core/session/__init__.py +++ b/python/valuecell/core/session/__init__.py @@ -4,9 +4,9 @@ SessionManager, get_default_session_manager, ) -from .models import Message, Role, Session, SessionStatus +from valuecell.core.types import ConversationItem as Message, Role +from .models import Session, SessionStatus from .store import InMemorySessionStore, SessionStore -from .message_store import MessageStore, InMemoryMessageStore, SQLiteMessageStore __all__ = [ # Models @@ -20,8 +20,5 @@ # Session storage "SessionStore", "InMemorySessionStore", - # Message storage - "MessageStore", - "InMemoryMessageStore", - "SQLiteMessageStore", + # Message storage (re-exported from core.__init__) ] diff --git a/python/valuecell/core/session/manager.py b/python/valuecell/core/session/manager.py index c5a2a7a27..4aa66d664 100644 --- a/python/valuecell/core/session/manager.py +++ b/python/valuecell/core/session/manager.py @@ -1,11 +1,18 @@ +import os from datetime import datetime from typing import List, Optional +from valuecell.core.types import ( + ConversationItem, + ConversationItemEvent, + ResponsePayload, + Role, +) from valuecell.utils import generate_uuid -from .models import Message, Role, Session, SessionStatus +from .message_store import InMemoryMessageStore, MessageStore, SQLiteMessageStore +from .models import Session, SessionStatus from .store import InMemorySessionStore, SessionStore -from .message_store import MessageStore, InMemoryMessageStore class SessionManager: @@ -63,13 +70,14 @@ async def session_exists(self, session_id: str) -> bool: async def add_message( self, - session_id: str, role: Role, - content: str, - user_id: Optional[str] = None, - agent_name: Optional[str] = None, + event: ConversationItemEvent, + conversation_id: str, + thread_id: Optional[str] = None, task_id: Optional[str] = None, - ) -> Optional[Message]: + payload: ResponsePayload = None, + item_id: Optional[str] = None, + ) -> Optional[ConversationItem]: """Add message to session Args: @@ -81,41 +89,46 @@ async def add_message( task_id: Associated task ID (optional) """ # Verify session exists - session = await self.get_session(session_id) + session = await self.get_session(conversation_id) if not session: return None - # Use provided user_id or get from session - if user_id is None: - user_id = session.user_id - # Create message - message = Message( - message_id=generate_uuid("msg"), - session_id=session_id, - user_id=user_id, - agent_name=agent_name, + # Serialize payload to JSON string if it's a pydantic model + payload_str = None + if payload is not None: + try: + # pydantic BaseModel supports model_dump_json + payload_str = payload.model_dump_json(exclude_none=True) + except Exception: + try: + payload_str = str(payload) + except Exception: + payload_str = None + + item = ConversationItem( + item_id=item_id or generate_uuid("item"), role=role, - content=content, + event=event, + conversation_id=conversation_id, + thread_id=thread_id, task_id=task_id, + payload=payload_str, ) # Save message directly to message store - await self.message_store.save_message(message) + await self.message_store.save_message(item) # Update session timestamp session.touch() await self.session_store.save_session(session) - return message + return item async def get_session_messages( self, session_id: str, - limit: Optional[int] = None, - offset: int = 0, - role: Optional[Role] = None, - ) -> List[Message]: + ) -> List[ConversationItem]: """Get messages for a session with optional filtering and pagination Args: @@ -124,13 +137,13 @@ async def get_session_messages( offset: Number of messages to skip role: Filter by specific role (optional) """ - return await self.message_store.get_messages(session_id, limit, offset, role) + return await self.message_store.get_messages(session_id) - async def get_latest_message(self, session_id: str) -> Optional[Message]: + async def get_latest_message(self, session_id: str) -> Optional[ConversationItem]: """Get latest message in a session""" return await self.message_store.get_latest_message(session_id) - async def get_message(self, message_id: str) -> Optional[Message]: + async def get_message(self, message_id: str) -> Optional[ConversationItem]: """Get a specific message by ID""" return await self.message_store.get_message(message_id) @@ -138,7 +151,9 @@ async def get_message_count(self, session_id: str) -> int: """Get total message count for a session""" return await self.message_store.get_message_count(session_id) - async def get_messages_by_role(self, session_id: str, role: Role) -> List[Message]: + async def get_messages_by_role( + self, session_id: str, role: Role + ) -> List[ConversationItem]: """Get messages filtered by role""" return await self.message_store.get_messages(session_id, role=role) @@ -191,7 +206,23 @@ async def get_sessions_by_status( # Default session manager instance -_session_manager = SessionManager() +def _default_db_path() -> str: + """Resolve repository root and return default DB path valuecell.db. + + Layout assumption: this file is at repo_root/python/valuecell/core/session/manager.py + We walk up 4 levels to reach repo_root. + """ + here = os.path.dirname(__file__) + repo_root = os.path.abspath(os.path.join(here, "..", "..", "..", "..")) + return os.path.join(repo_root, "valuecell.db") + + +def _resolve_db_path() -> str: + return os.environ.get("VALUECELL_SQLITE_DB") or _default_db_path() + + +# Default: use SQLite at repo root valuecell.db (env VALUECELL_SQLITE_DB overrides) +_session_manager = SessionManager(message_store=SQLiteMessageStore(_resolve_db_path())) def get_default_session_manager() -> SessionManager: diff --git a/python/valuecell/core/session/message_store.py b/python/valuecell/core/session/message_store.py index d195e0f56..2a9bc67a5 100644 --- a/python/valuecell/core/session/message_store.py +++ b/python/valuecell/core/session/message_store.py @@ -1,18 +1,17 @@ +from __future__ import annotations + +import asyncio import sqlite3 -import json +import aiosqlite from abc import ABC, abstractmethod -from datetime import datetime -from typing import List, Optional, Dict, Any +from typing import Dict, List, Optional -from .models import Message, Role +from valuecell.core.types import ConversationItem, Role class MessageStore(ABC): - """Abstract base class for message storage""" - @abstractmethod - async def save_message(self, message: Message) -> None: - """Save a single message""" + async def save_message(self, message: ConversationItem) -> None: ... @abstractmethod async def get_messages( @@ -21,45 +20,31 @@ async def get_messages( limit: Optional[int] = None, offset: int = 0, role: Optional[Role] = None, - ) -> List[Message]: - """Get messages for a session with optional filtering and pagination""" + ) -> List[ConversationItem]: ... @abstractmethod - async def get_message(self, message_id: str) -> Optional[Message]: - """Get a specific message by ID""" + async def get_latest_message( + self, session_id: str + ) -> Optional[ConversationItem]: ... @abstractmethod - async def get_latest_message(self, session_id: str) -> Optional[Message]: - """Get the latest message in a session""" + async def get_message(self, message_id: str) -> Optional[ConversationItem]: ... @abstractmethod - async def get_message_count(self, session_id: str) -> int: - """Get total message count for a session""" + async def get_message_count(self, session_id: str) -> int: ... @abstractmethod - async def delete_session_messages(self, session_id: str) -> int: - """Delete all messages for a session, returns count of deleted messages""" - - @abstractmethod - async def delete_message(self, message_id: str) -> bool: - """Delete a specific message""" + async def delete_session_messages(self, session_id: str) -> None: ... class InMemoryMessageStore(MessageStore): - """In-memory message store implementation for testing and development""" - def __init__(self): - self._messages: Dict[str, Message] = {} - self._session_messages: Dict[str, List[str]] = {} - - async def save_message(self, message: Message) -> None: - """Save message to memory""" - self._messages[message.message_id] = message + # session_id -> list[ConversationItem] + self._messages: Dict[str, List[ConversationItem]] = {} - # Maintain session index - if message.session_id not in self._session_messages: - self._session_messages[message.session_id] = [] - self._session_messages[message.session_id].append(message.message_id) + async def save_message(self, message: ConversationItem) -> None: + arr = self._messages.setdefault(message.conversation_id, []) + arr.append(message) async def get_messages( self, @@ -67,173 +52,108 @@ async def get_messages( limit: Optional[int] = None, offset: int = 0, role: Optional[Role] = None, - ) -> List[Message]: - """Get messages for a session""" - message_ids = self._session_messages.get(session_id, []) - messages = [self._messages[msg_id] for msg_id in message_ids] - - # Filter by role if specified - if role: - messages = [msg for msg in messages if msg.role == role] - - # Sort by timestamp - messages.sort(key=lambda m: m.timestamp) - - # Apply pagination - if offset > 0: - messages = messages[offset:] + ) -> List[ConversationItem]: + items = list(self._messages.get(session_id, [])) + if role is not None: + items = [m for m in items if m.role == role] + if offset: + items = items[offset:] if limit is not None: - messages = messages[:limit] - - return messages + items = items[:limit] + return items - async def get_message(self, message_id: str) -> Optional[Message]: - """Get a specific message""" - return self._messages.get(message_id) + async def get_latest_message(self, session_id: str) -> Optional[ConversationItem]: + items = self._messages.get(session_id, []) + return items[-1] if items else None - async def get_latest_message(self, session_id: str) -> Optional[Message]: - """Get the latest message in a session""" - messages = await self.get_messages(session_id) - return messages[-1] if messages else None + async def get_message(self, message_id: str) -> Optional[ConversationItem]: + for arr in self._messages.values(): + for m in arr: + if m.item_id == message_id: + return m + return None async def get_message_count(self, session_id: str) -> int: - """Get message count for a session""" - return len(self._session_messages.get(session_id, [])) - - async def delete_session_messages(self, session_id: str) -> int: - """Delete all messages for a session""" - message_ids = self._session_messages.get(session_id, []) - count = len(message_ids) - - # Remove from messages dict - for msg_id in message_ids: - self._messages.pop(msg_id, None) - - # Remove session index - self._session_messages.pop(session_id, None) + return len(self._messages.get(session_id, [])) - return count - - async def delete_message(self, message_id: str) -> bool: - """Delete a specific message""" - message = self._messages.pop(message_id, None) - if not message: - return False - - # Remove from session index - session_id = message.session_id - if session_id in self._session_messages: - try: - self._session_messages[session_id].remove(message_id) - except ValueError: - pass # Already removed - - return True - - def clear_all(self) -> None: - """Clear all messages (for testing)""" - self._messages.clear() - self._session_messages.clear() + async def delete_session_messages(self, session_id: str) -> None: + self._messages.pop(session_id, None) class SQLiteMessageStore(MessageStore): - """SQLite-based message store implementation""" - - def __init__(self, db_path: Optional[str] = None): - """Initialize SQLite message store - - Args: - db_path: Path to SQLite database file. If None, uses in-memory database. - """ - self.db_path = db_path or ":memory:" - self._init_database() - - def _init_database(self) -> None: - """Initialize database schema""" - with sqlite3.connect(self.db_path) as conn: - conn.execute(""" - CREATE TABLE IF NOT EXISTS messages ( - message_id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - user_id TEXT NOT NULL, - agent_name TEXT, - role TEXT NOT NULL, - content TEXT NOT NULL, - timestamp TEXT NOT NULL, - task_id TEXT, - metadata TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP + """SQLite-backed message store using aiosqlite for true async I/O.""" + + def __init__(self, db_path: str): + self.db_path = db_path + self._initialized = False + self._init_lock = None # lazy to avoid loop-binding in __init__ + + async def _ensure_initialized(self) -> None: + if self._initialized: + return + if self._init_lock is None: + self._init_lock = asyncio.Lock() + async with self._init_lock: + if self._initialized: + return + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + """ + CREATE TABLE IF NOT EXISTS messages ( + item_id TEXT PRIMARY KEY, + role TEXT NOT NULL, + event TEXT NOT NULL, + conversation_id TEXT NOT NULL, + thread_id TEXT, + task_id TEXT, + payload TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """ ) - """) - - # Create indexes for common queries - conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_messages_session_id - ON messages(session_id) - """) - - conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_messages_timestamp - ON messages(session_id, timestamp) - """) - - conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_messages_role - ON messages(session_id, role) - """) - - def _message_to_dict(self, message: Message) -> Dict[str, Any]: - """Convert Message object to database record""" - return { - "message_id": message.message_id, - "session_id": message.session_id, - "user_id": message.user_id, - "agent_name": message.agent_name, - "role": message.role.value, - "content": message.content, - "timestamp": message.timestamp.isoformat(), - "task_id": message.task_id, - "metadata": json.dumps(message.metadata) if message.metadata else None, - } - - def _dict_to_message(self, row: Dict[str, Any]) -> Message: - """Convert database record to Message object""" - return Message( - message_id=row["message_id"], - session_id=row["session_id"], - user_id=row["user_id"], - agent_name=row["agent_name"], - role=Role(row["role"]), - content=row["content"], - timestamp=datetime.fromisoformat(row["timestamp"]), + await db.execute( + """ + CREATE INDEX IF NOT EXISTS idx_messages_conv_time + ON messages (conversation_id, created_at); + """ + ) + await db.commit() + self._initialized = True + + @staticmethod + def _row_to_message(row: sqlite3.Row) -> ConversationItem: + return ConversationItem( + item_id=row["item_id"], + role=row["role"], + event=row["event"], + conversation_id=row["conversation_id"], + thread_id=row["thread_id"], task_id=row["task_id"], - metadata=json.loads(row["metadata"]) if row["metadata"] else {}, + payload=row["payload"], ) - async def save_message(self, message: Message) -> None: - """Save message to SQLite database""" - data = self._message_to_dict(message) - - with sqlite3.connect(self.db_path) as conn: - conn.execute( + async def save_message(self, message: ConversationItem) -> None: + await self._ensure_initialized() + role_val = getattr(message.role, "value", str(message.role)) + event_val = getattr(message.event, "value", str(message.event)) + async with aiosqlite.connect(self.db_path) as db: + await db.execute( """ - INSERT OR REPLACE INTO messages - (message_id, session_id, user_id, agent_name, role, content, - timestamp, task_id, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, + INSERT OR REPLACE INTO messages ( + item_id, role, event, conversation_id, thread_id, task_id, payload + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, ( - data["message_id"], - data["session_id"], - data["user_id"], - data["agent_name"], - data["role"], - data["content"], - data["timestamp"], - data["task_id"], - data["metadata"], + message.item_id, + role_val, + event_val, + message.conversation_id, + message.thread_id, + message.task_id, + message.payload, ), ) + await db.commit() async def get_messages( self, @@ -241,104 +161,65 @@ async def get_messages( limit: Optional[int] = None, offset: int = 0, role: Optional[Role] = None, - ) -> List[Message]: - """Get messages for a session""" - query = """ - SELECT message_id, session_id, user_id, agent_name, role, content, - timestamp, task_id, metadata - FROM messages - WHERE session_id = ? - """ + ) -> List[ConversationItem]: + await self._ensure_initialized() params = [session_id] - - # Add role filter if specified - if role: - query += " AND role = ?" - params.append(role.value) - - # Order by timestamp - query += " ORDER BY timestamp ASC" - - # Add pagination + where = "WHERE conversation_id = ?" + if role is not None: + where += " AND role = ?" + params.append(getattr(role, "value", str(role))) + sql = f"SELECT * FROM messages {where} ORDER BY datetime(created_at) ASC" if limit is not None: - query += " LIMIT ? OFFSET ?" - params.extend([limit, offset]) - - with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row - cursor = conn.execute(query, params) - rows = cursor.fetchall() - - return [self._dict_to_message(dict(row)) for row in rows] - - async def get_message(self, message_id: str) -> Optional[Message]: - """Get a specific message by ID""" - with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row - cursor = conn.execute( - """ - SELECT message_id, session_id, user_id, agent_name, role, content, - timestamp, task_id, metadata - FROM messages - WHERE message_id = ? - """, - (message_id,), - ) - - row = cursor.fetchone() - return self._dict_to_message(dict(row)) if row else None - - async def get_latest_message(self, session_id: str) -> Optional[Message]: - """Get the latest message in a session""" - with sqlite3.connect(self.db_path) as conn: - conn.row_factory = sqlite3.Row - cursor = conn.execute( - """ - SELECT message_id, session_id, user_id, agent_name, role, content, - timestamp, task_id, metadata - FROM messages - WHERE session_id = ? - ORDER BY timestamp DESC - LIMIT 1 - """, + sql += " LIMIT ?" + params.append(int(limit)) + if offset: + if limit is None: + sql += " LIMIT -1" + sql += " OFFSET ?" + params.append(int(offset)) + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = sqlite3.Row + cur = await db.execute(sql, params) + rows = await cur.fetchall() + return [self._row_to_message(r) for r in rows] + + async def get_latest_message(self, session_id: str) -> Optional[ConversationItem]: + await self._ensure_initialized() + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = sqlite3.Row + cur = await db.execute( + "SELECT * FROM messages WHERE conversation_id = ? ORDER BY datetime(created_at) DESC LIMIT 1", (session_id,), ) - - row = cursor.fetchone() - return self._dict_to_message(dict(row)) if row else None + row = await cur.fetchone() + return self._row_to_message(row) if row else None + + async def get_message(self, message_id: str) -> Optional[ConversationItem]: + await self._ensure_initialized() + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = sqlite3.Row + cur = await db.execute( + "SELECT * FROM messages WHERE item_id = ?", + (message_id,), + ) + row = await cur.fetchone() + return self._row_to_message(row) if row else None async def get_message_count(self, session_id: str) -> int: - """Get message count for a session""" - with sqlite3.connect(self.db_path) as conn: - cursor = conn.execute( - """ - SELECT COUNT(*) FROM messages WHERE session_id = ? - """, + await self._ensure_initialized() + async with aiosqlite.connect(self.db_path) as db: + cur = await db.execute( + "SELECT COUNT(1) FROM messages WHERE conversation_id = ?", (session_id,), ) - - return cursor.fetchone()[0] - - async def delete_session_messages(self, session_id: str) -> int: - """Delete all messages for a session""" - with sqlite3.connect(self.db_path) as conn: - cursor = conn.execute( - """ - DELETE FROM messages WHERE session_id = ? - """, + row = await cur.fetchone() + return int(row[0] if row else 0) + + async def delete_session_messages(self, session_id: str) -> None: + await self._ensure_initialized() + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + "DELETE FROM messages WHERE conversation_id = ?", (session_id,), ) - - return cursor.rowcount - - async def delete_message(self, message_id: str) -> bool: - """Delete a specific message""" - with sqlite3.connect(self.db_path) as conn: - cursor = conn.execute( - """ - DELETE FROM messages WHERE message_id = ? - """, - (message_id,), - ) - - return cursor.rowcount > 0 + await db.commit() diff --git a/python/valuecell/core/session/models.py b/python/valuecell/core/session/models.py index e3ccb13c5..2835216d4 100644 --- a/python/valuecell/core/session/models.py +++ b/python/valuecell/core/session/models.py @@ -1,18 +1,10 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, Optional +from typing import Optional from pydantic import BaseModel, Field -class Role(str, Enum): - """Message role enumeration""" - - USER = "user" - AGENT = "agent" - SYSTEM = "system" - - class SessionStatus(str, Enum): """Session status enumeration""" @@ -21,27 +13,6 @@ class SessionStatus(str, Enum): REQUIRE_USER_INPUT = "require_user_input" -class Message(BaseModel): - """Message data model""" - - message_id: str = Field(..., description="Unique message identifier") - session_id: str = Field(..., description="Session ID this message belongs to") - user_id: str = Field(..., description="User ID") - agent_name: Optional[str] = Field(None, description="Agent name") - role: Role = Field(..., description="Message role") - content: str = Field(..., description="Message content") - timestamp: datetime = Field( - default_factory=datetime.now, description="Message timestamp" - ) - task_id: Optional[str] = Field(None, description="Associated task ID") - metadata: Dict[str, Any] = Field( - default_factory=dict, description="Message metadata" - ) - - class Config: - json_encoders = {datetime: lambda v: v.isoformat()} - - class Session(BaseModel): """Session data model - lightweight metadata only, messages stored separately""" diff --git a/python/valuecell/core/session/tests/test_sqlite_message_store.py b/python/valuecell/core/session/tests/test_sqlite_message_store.py new file mode 100644 index 000000000..5d6c23b04 --- /dev/null +++ b/python/valuecell/core/session/tests/test_sqlite_message_store.py @@ -0,0 +1,64 @@ +import os +import tempfile + +import pytest +from valuecell.core.session.message_store import SQLiteMessageStore +from valuecell.core.types import ConversationItem, Role, SystemResponseEvent + + +@pytest.mark.asyncio +async def test_sqlite_message_store_basic_crud(): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + try: + store = SQLiteMessageStore(path) + + # create and save two messages + m1 = ConversationItem( + item_id="i1", + role=Role.SYSTEM, + event=SystemResponseEvent.THREAD_STARTED, + conversation_id="s1", + thread_id="t1", + task_id=None, + payload='{"a":1}', + ) + m2 = ConversationItem( + item_id="i2", + role=Role.SYSTEM, + event=SystemResponseEvent.DONE, + conversation_id="s1", + thread_id="t1", + task_id=None, + payload='{"a":1}', + ) + await store.save_message(m1) + await store.save_message(m2) + + # count + cnt = await store.get_message_count("s1") + assert cnt == 2 + + # get latest + latest = await store.get_latest_message("s1") + assert latest is not None + assert latest.item_id in {"i1", "i2"} + + # list + msgs = await store.get_messages("s1") + assert len(msgs) == 2 + ids = {m.item_id for m in msgs} + assert ids == {"i1", "i2"} + + # get one + one = await store.get_message("i1") + assert one is not None + assert one.item_id == "i1" + + # delete + await store.delete_session_messages("s1") + cnt2 = await store.get_message_count("s1") + assert cnt2 == 0 + finally: + if os.path.exists(path): + os.remove(path) diff --git a/python/valuecell/core/types.py b/python/valuecell/core/types.py index ab06e2dc5..b5149f8f7 100644 --- a/python/valuecell/core/types.py +++ b/python/valuecell/core/types.py @@ -4,6 +4,7 @@ from a2a.types import Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent from pydantic import BaseModel, Field +from valuecell.utils.uuid import generate_item_id class UserInputMetadata(BaseModel): @@ -30,22 +31,6 @@ class Config: frozen = False extra = "forbid" - def has_desired_agent(self) -> bool: - """Check if a specific agent is desired""" - return self.desired_agent_name is not None - - def get_desired_agent(self) -> Optional[str]: - """Get the desired agent name""" - return self.desired_agent_name - - def set_desired_agent(self, agent_name: str) -> None: - """Set the desired agent name""" - self.desired_agent_name = agent_name - - def clear_desired_agent(self) -> None: - """Clear the desired agent name""" - self.desired_agent_name = None - class SystemResponseEvent(str, Enum): CONVERSATION_STARTED = "conversation_started" @@ -57,15 +42,18 @@ class SystemResponseEvent(str, Enum): DONE = "done" -class _TaskResponseEvent(str, Enum): +class TaskStatusEvent(str, Enum): TASK_STARTED = "task_started" TASK_COMPLETED = "task_completed" TASK_CANCELLED = "task_cancelled" +class CommonResponseEvent(str, Enum): + COMPONENT_GENERATOR = "component_generator" + + class StreamResponseEvent(str, Enum): MESSAGE_CHUNK = "message_chunk" - COMPONENT_GENERATOR = "component_generator" TOOL_CALL_STARTED = "tool_call_started" TOOL_CALL_COMPLETED = "tool_call_completed" REASONING_STARTED = "reasoning_started" @@ -84,7 +72,7 @@ class StreamResponse(BaseModel): None, description="The content of the stream response, typically a chunk of data or message.", ) - event: StreamResponseEvent | _TaskResponseEvent = Field( + event: StreamResponseEvent | TaskStatusEvent = Field( ..., description="The type of stream response, indicating its purpose or content nature.", ) @@ -92,10 +80,6 @@ class StreamResponse(BaseModel): None, description="Optional metadata providing additional context about the response", ) - subtask_id: Optional[str] = Field( - None, - description="Optional subtask ID if the response is related to a specific subtask", - ) class NotifyResponse(BaseModel): @@ -105,7 +89,7 @@ class NotifyResponse(BaseModel): ..., description="The content of the notification response", ) - event: NotifyResponseEvent | _TaskResponseEvent = Field( + event: NotifyResponseEvent | TaskStatusEvent = Field( ..., description="The type of notification response", ) @@ -135,6 +119,39 @@ class ComponentGeneratorResponseDataPayload(BaseResponseDataPayload): ] +ConversationItemEvent = Union[ + StreamResponseEvent, + NotifyResponseEvent, + SystemResponseEvent, + CommonResponseEvent, + TaskStatusEvent, +] + + +class Role(str, Enum): + """Message role enumeration""" + + USER = "user" + AGENT = "agent" + SYSTEM = "system" + + +class ConversationItem(BaseModel): + """Message item structure for conversation history""" + + item_id: str = Field(..., description="Unique message identifier") + role: Role = Field(..., description="Role of the message sender") + event: ConversationItemEvent = Field(..., description="Event type of the message") + conversation_id: str = Field( + ..., description="Conversation ID this message belongs to" + ) + thread_id: Optional[str] = Field(None, description="Thread ID if part of a thread") + task_id: Optional[str] = Field( + None, description="Task ID if associated with a task" + ) + payload: str = Field(..., description="The actual message payload") + + class UnifiedResponseData(BaseModel): """Unified response data structure with optional hierarchy fields. @@ -147,18 +164,17 @@ class UnifiedResponseData(BaseModel): None, description="Unique ID for the message thread" ) task_id: Optional[str] = Field(None, description="Unique ID for the task") - subtask_id: Optional[str] = Field( - None, description="Unique ID for the subtask, if any" - ) payload: Optional[ResponsePayload] = Field( None, description="The message data payload" ) + role: Role = Field(..., description="The role of the message sender") + item_id: str = Field(default_factory=generate_item_id) class BaseResponse(BaseModel, ABC): """Top-level response envelope used for all events.""" - event: StreamResponseEvent | NotifyResponseEvent | SystemResponseEvent = Field( + event: ConversationItemEvent = Field( ..., description="The event type of the response" ) data: UnifiedResponseData = Field( @@ -197,8 +213,9 @@ class MessageResponse(BaseResponse): class ComponentGeneratorResponse(BaseResponse): - event: Literal[StreamResponseEvent.COMPONENT_GENERATOR] = Field( - StreamResponseEvent.COMPONENT_GENERATOR + event: Literal[CommonResponseEvent.COMPONENT_GENERATOR] = Field( + CommonResponseEvent.COMPONENT_GENERATOR, + description="The event type of the response", ) data: UnifiedResponseData = Field(..., description="The component generator data") @@ -244,8 +261,8 @@ class TaskFailedResponse(BaseResponse): class TaskCompletedResponse(BaseResponse): - event: Literal[_TaskResponseEvent.TASK_COMPLETED] = Field( - _TaskResponseEvent.TASK_COMPLETED, description="The event type of the response" + event: Literal[TaskStatusEvent.TASK_COMPLETED] = Field( + TaskStatusEvent.TASK_COMPLETED, description="The event type of the response" ) data: UnifiedResponseData = Field(..., description="The task data payload") diff --git a/python/valuecell/server/api/routers/agent_stream.py b/python/valuecell/server/api/routers/agent_stream.py index b9295fdff..5ecc35f74 100644 --- a/python/valuecell/server/api/routers/agent_stream.py +++ b/python/valuecell/server/api/routers/agent_stream.py @@ -35,11 +35,10 @@ async def generate_stream(): return StreamingResponse( generate_stream(), - media_type="text/plain", + media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", - "Content-Type": "text/event-stream", }, ) diff --git a/python/valuecell/utils/uuid.py b/python/valuecell/utils/uuid.py index 4e29a9306..0abaf05dd 100644 --- a/python/valuecell/utils/uuid.py +++ b/python/valuecell/utils/uuid.py @@ -8,8 +8,8 @@ def generate_uuid(prefix: str = None) -> str: return f"{prefix}-{uuid4().hex}" -def generate_message_id() -> str: - return generate_uuid("msg") +def generate_item_id() -> str: + return generate_uuid("item") def generate_thread_id() -> str: