From d0a6962e843cb5810b8cb5904e294a9ca0da3fb7 Mon Sep 17 00:00:00 2001 From: Zhaofeng Zhang <24791380+vcfgv@users.noreply.github.com> Date: Thu, 18 Sep 2025 10:06:37 +0800 Subject: [PATCH 1/6] feat: enhance session management with message handling and status tracking --- python/valuecell/core/__init__.py | 8 + python/valuecell/core/session/__init__.py | 16 +- python/valuecell/core/session/manager.py | 164 ++++++--- .../valuecell/core/session/message_store.py | 344 ++++++++++++++++++ python/valuecell/core/session/models.py | 64 ++-- python/valuecell/core/session/store.py | 5 +- 6 files changed, 515 insertions(+), 86 deletions(-) create mode 100644 python/valuecell/core/session/message_store.py diff --git a/python/valuecell/core/__init__.py b/python/valuecell/core/__init__.py index e2700bd70..ca0a90f49 100644 --- a/python/valuecell/core/__init__.py +++ b/python/valuecell/core/__init__.py @@ -4,8 +4,12 @@ Message, Role, Session, + SessionStatus, SessionManager, SessionStore, + MessageStore, + InMemoryMessageStore, + SQLiteMessageStore, ) # Task management @@ -31,9 +35,13 @@ "Message", "Role", "Session", + "SessionStatus", "SessionManager", "SessionStore", "InMemorySessionStore", + "MessageStore", + "InMemoryMessageStore", + "SQLiteMessageStore", # Task exports "Task", "TaskStatus", diff --git a/python/valuecell/core/session/__init__.py b/python/valuecell/core/session/__init__.py index 8630d024d..b9ffa29a6 100644 --- a/python/valuecell/core/session/__init__.py +++ b/python/valuecell/core/session/__init__.py @@ -1,15 +1,27 @@ """Session module initialization""" -from .manager import SessionManager, get_default_session_manager -from .models import Message, Role, Session +from .manager import ( + SessionManager, + get_default_session_manager, +) +from .models import Message, Role, Session, SessionStatus from .store import InMemorySessionStore, SessionStore +from .message_store import MessageStore, InMemoryMessageStore, SQLiteMessageStore __all__ = [ + # Models "Message", "Role", "Session", + "SessionStatus", + # Session management "SessionManager", "get_default_session_manager", + # Session storage "SessionStore", "InMemorySessionStore", + # Message storage + "MessageStore", + "InMemoryMessageStore", + "SQLiteMessageStore", ] diff --git a/python/valuecell/core/session/manager.py b/python/valuecell/core/session/manager.py index 5db218e98..c5a2a7a27 100644 --- a/python/valuecell/core/session/manager.py +++ b/python/valuecell/core/session/manager.py @@ -3,15 +3,21 @@ from valuecell.utils import generate_uuid -from .models import Message, Role, Session +from .models import Message, Role, Session, SessionStatus from .store import InMemorySessionStore, SessionStore +from .message_store import MessageStore, InMemoryMessageStore class SessionManager: - """Session manager""" + """Session manager - handles both session metadata and messages through separate stores""" - def __init__(self, store: Optional[SessionStore] = None): - self.store = store or InMemorySessionStore() + def __init__( + self, + session_store: Optional[SessionStore] = None, + message_store: Optional[MessageStore] = None, + ): + self.session_store = session_store or InMemorySessionStore() + self.message_store = message_store or InMemoryMessageStore() async def create_session( self, @@ -25,91 +31,116 @@ async def create_session( user_id=user_id, title=title, ) - await self.store.save_session(session) + await self.session_store.save_session(session) return session async def get_session(self, session_id: str) -> Optional[Session]: - """Get session""" - return await self.store.load_session(session_id) + """Get session metadata""" + return await self.session_store.load_session(session_id) async def update_session(self, session: Session) -> None: - """Update session""" + """Update session metadata""" session.updated_at = datetime.now() - await self.store.save_session(session) + await self.session_store.save_session(session) async def delete_session(self, session_id: str) -> bool: - """Delete session""" - return await self.store.delete_session(session_id) + """Delete session and all its messages""" + # First delete all messages for this session + await self.message_store.delete_session_messages(session_id) + + # Then delete the session metadata + return await self.session_store.delete_session(session_id) async def list_user_sessions( self, user_id: str, limit: int = 100, offset: int = 0 ) -> List[Session]: """List user sessions""" - return await self.store.list_sessions(user_id, limit, offset) + return await self.session_store.list_sessions(user_id, limit, offset) async def session_exists(self, session_id: str) -> bool: """Check if session exists""" - return await self.store.session_exists(session_id) + return await self.session_store.session_exists(session_id) async def add_message( - self, session_id: str, role: Role, content: str, task_id: Optional[str] = None + self, + session_id: str, + role: Role, + content: str, + user_id: Optional[str] = None, + agent_name: Optional[str] = None, + task_id: Optional[str] = None, ) -> Optional[Message]: - """Add message to session""" + """Add message to session + + Args: + session_id: Session ID to add message to + role: Message role (USER, AGENT, SYSTEM) + content: Message content + user_id: User ID (will be fetched from session if not provided) + agent_name: Agent name (optional) + task_id: Associated task ID (optional) + """ + # Verify session exists session = await self.get_session(session_id) if not session: return None + # Use provided user_id or get from session + if user_id is None: + user_id = session.user_id + + # Create message message = Message( message_id=generate_uuid("msg"), session_id=session_id, + user_id=user_id, + agent_name=agent_name, role=role, content=content, task_id=task_id, ) - session.add_message(message) - await self.update_session(session) + # Save message directly to message store + await self.message_store.save_message(message) + + # Update session timestamp + session.touch() + await self.session_store.save_session(session) + return message async def get_session_messages( - self, session_id: str, limit: Optional[int] = None + self, + session_id: str, + limit: Optional[int] = None, + offset: int = 0, + role: Optional[Role] = None, ) -> List[Message]: - """Get session messages""" - session = await self.get_session(session_id) - if not session: - return [] - - messages = session.messages - if limit is not None: - messages = messages[-limit:] # Get latest limit messages + """Get messages for a session with optional filtering and pagination - return messages + Args: + session_id: Session ID + limit: Maximum number of messages to return + offset: Number of messages to skip + role: Filter by specific role (optional) + """ + return await self.message_store.get_messages(session_id, limit, offset, role) async def get_latest_message(self, session_id: str) -> Optional[Message]: - """Get latest session message""" - session = await self.get_session(session_id) - if not session: - return None + """Get latest message in a session""" + return await self.message_store.get_latest_message(session_id) - return session.get_latest_message() + async def get_message(self, message_id: str) -> Optional[Message]: + """Get a specific message by ID""" + return await self.message_store.get_message(message_id) - async def update_session_context(self, session_id: str, key: str, value) -> bool: - """Update session context""" - session = await self.get_session(session_id) - if not session: - return False + async def get_message_count(self, session_id: str) -> int: + """Get total message count for a session""" + return await self.message_store.get_message_count(session_id) - session.update_context(key, value) - await self.update_session(session) - return True - - async def get_session_context(self, session_id: str, key: str, default=None): - """Get session context value""" - session = await self.get_session(session_id) - if not session: - return default - - return session.get_context(key, default) + async def get_messages_by_role(self, session_id: str, role: Role) -> List[Message]: + """Get messages filtered by role""" + return await self.message_store.get_messages(session_id, role=role) async def deactivate_session(self, session_id: str) -> bool: """Deactivate session""" @@ -117,8 +148,8 @@ async def deactivate_session(self, session_id: str) -> bool: if not session: return False - session.is_active = False - await self.update_session(session) + session.deactivate() + await self.session_store.save_session(session) return True async def activate_session(self, session_id: str) -> bool: @@ -127,13 +158,42 @@ async def activate_session(self, session_id: str) -> bool: if not session: return False - session.is_active = True - await self.update_session(session) + session.activate() + await self.session_store.save_session(session) + return True + + async def set_session_status(self, session_id: str, status: SessionStatus) -> bool: + """Set session status""" + session = await self.get_session(session_id) + if not session: + return False + + session.set_status(status) + await self.session_store.save_session(session) return True + async def require_user_input(self, session_id: str) -> bool: + """Mark session as requiring user input""" + return await self.set_session_status( + session_id, SessionStatus.REQUIRE_USER_INPUT + ) + + async def get_sessions_by_status( + self, user_id: str, status: SessionStatus, limit: int = 100, offset: int = 0 + ) -> List[Session]: + """Get user sessions filtered by status""" + # Get all user sessions and filter by status + # Note: This could be optimized by adding status filtering to the store interface + all_sessions = await self.session_store.list_sessions( + user_id, limit * 2, offset + ) + return [session for session in all_sessions if session.status == status][:limit] + +# Default session manager instance _session_manager = SessionManager() def get_default_session_manager() -> SessionManager: + """Get the default session manager instance""" return _session_manager diff --git a/python/valuecell/core/session/message_store.py b/python/valuecell/core/session/message_store.py new file mode 100644 index 000000000..d195e0f56 --- /dev/null +++ b/python/valuecell/core/session/message_store.py @@ -0,0 +1,344 @@ +import sqlite3 +import json +from abc import ABC, abstractmethod +from datetime import datetime +from typing import List, Optional, Dict, Any + +from .models import Message, Role + + +class MessageStore(ABC): + """Abstract base class for message storage""" + + @abstractmethod + async def save_message(self, message: Message) -> None: + """Save a single message""" + + @abstractmethod + async def get_messages( + self, + session_id: str, + limit: Optional[int] = None, + offset: int = 0, + role: Optional[Role] = None, + ) -> List[Message]: + """Get messages for a session with optional filtering and pagination""" + + @abstractmethod + async def get_message(self, message_id: str) -> Optional[Message]: + """Get a specific message by ID""" + + @abstractmethod + async def get_latest_message(self, session_id: str) -> Optional[Message]: + """Get the latest message in a session""" + + @abstractmethod + async def get_message_count(self, session_id: str) -> int: + """Get total message count for a session""" + + @abstractmethod + async def delete_session_messages(self, session_id: str) -> int: + """Delete all messages for a session, returns count of deleted messages""" + + @abstractmethod + async def delete_message(self, message_id: str) -> bool: + """Delete a specific message""" + + +class InMemoryMessageStore(MessageStore): + """In-memory message store implementation for testing and development""" + + def __init__(self): + self._messages: Dict[str, Message] = {} + self._session_messages: Dict[str, List[str]] = {} + + async def save_message(self, message: Message) -> None: + """Save message to memory""" + self._messages[message.message_id] = message + + # Maintain session index + if message.session_id not in self._session_messages: + self._session_messages[message.session_id] = [] + self._session_messages[message.session_id].append(message.message_id) + + async def get_messages( + self, + session_id: str, + limit: Optional[int] = None, + offset: int = 0, + role: Optional[Role] = None, + ) -> List[Message]: + """Get messages for a session""" + message_ids = self._session_messages.get(session_id, []) + messages = [self._messages[msg_id] for msg_id in message_ids] + + # Filter by role if specified + if role: + messages = [msg for msg in messages if msg.role == role] + + # Sort by timestamp + messages.sort(key=lambda m: m.timestamp) + + # Apply pagination + if offset > 0: + messages = messages[offset:] + if limit is not None: + messages = messages[:limit] + + return messages + + async def get_message(self, message_id: str) -> Optional[Message]: + """Get a specific message""" + return self._messages.get(message_id) + + async def get_latest_message(self, session_id: str) -> Optional[Message]: + """Get the latest message in a session""" + messages = await self.get_messages(session_id) + return messages[-1] if messages else None + + async def get_message_count(self, session_id: str) -> int: + """Get message count for a session""" + return len(self._session_messages.get(session_id, [])) + + async def delete_session_messages(self, session_id: str) -> int: + """Delete all messages for a session""" + message_ids = self._session_messages.get(session_id, []) + count = len(message_ids) + + # Remove from messages dict + for msg_id in message_ids: + self._messages.pop(msg_id, None) + + # Remove session index + self._session_messages.pop(session_id, None) + + return count + + async def delete_message(self, message_id: str) -> bool: + """Delete a specific message""" + message = self._messages.pop(message_id, None) + if not message: + return False + + # Remove from session index + session_id = message.session_id + if session_id in self._session_messages: + try: + self._session_messages[session_id].remove(message_id) + except ValueError: + pass # Already removed + + return True + + def clear_all(self) -> None: + """Clear all messages (for testing)""" + self._messages.clear() + self._session_messages.clear() + + +class SQLiteMessageStore(MessageStore): + """SQLite-based message store implementation""" + + def __init__(self, db_path: Optional[str] = None): + """Initialize SQLite message store + + Args: + db_path: Path to SQLite database file. If None, uses in-memory database. + """ + self.db_path = db_path or ":memory:" + self._init_database() + + def _init_database(self) -> None: + """Initialize database schema""" + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS messages ( + message_id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + user_id TEXT NOT NULL, + agent_name TEXT, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp TEXT NOT NULL, + task_id TEXT, + metadata TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indexes for common queries + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_messages_session_id + ON messages(session_id) + """) + + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_messages_timestamp + ON messages(session_id, timestamp) + """) + + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_messages_role + ON messages(session_id, role) + """) + + def _message_to_dict(self, message: Message) -> Dict[str, Any]: + """Convert Message object to database record""" + return { + "message_id": message.message_id, + "session_id": message.session_id, + "user_id": message.user_id, + "agent_name": message.agent_name, + "role": message.role.value, + "content": message.content, + "timestamp": message.timestamp.isoformat(), + "task_id": message.task_id, + "metadata": json.dumps(message.metadata) if message.metadata else None, + } + + def _dict_to_message(self, row: Dict[str, Any]) -> Message: + """Convert database record to Message object""" + return Message( + message_id=row["message_id"], + session_id=row["session_id"], + user_id=row["user_id"], + agent_name=row["agent_name"], + role=Role(row["role"]), + content=row["content"], + timestamp=datetime.fromisoformat(row["timestamp"]), + task_id=row["task_id"], + metadata=json.loads(row["metadata"]) if row["metadata"] else {}, + ) + + async def save_message(self, message: Message) -> None: + """Save message to SQLite database""" + data = self._message_to_dict(message) + + with sqlite3.connect(self.db_path) as conn: + conn.execute( + """ + INSERT OR REPLACE INTO messages + (message_id, session_id, user_id, agent_name, role, content, + timestamp, task_id, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + data["message_id"], + data["session_id"], + data["user_id"], + data["agent_name"], + data["role"], + data["content"], + data["timestamp"], + data["task_id"], + data["metadata"], + ), + ) + + async def get_messages( + self, + session_id: str, + limit: Optional[int] = None, + offset: int = 0, + role: Optional[Role] = None, + ) -> List[Message]: + """Get messages for a session""" + query = """ + SELECT message_id, session_id, user_id, agent_name, role, content, + timestamp, task_id, metadata + FROM messages + WHERE session_id = ? + """ + params = [session_id] + + # Add role filter if specified + if role: + query += " AND role = ?" + params.append(role.value) + + # Order by timestamp + query += " ORDER BY timestamp ASC" + + # Add pagination + if limit is not None: + query += " LIMIT ? OFFSET ?" + params.extend([limit, offset]) + + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute(query, params) + rows = cursor.fetchall() + + return [self._dict_to_message(dict(row)) for row in rows] + + async def get_message(self, message_id: str) -> Optional[Message]: + """Get a specific message by ID""" + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute( + """ + SELECT message_id, session_id, user_id, agent_name, role, content, + timestamp, task_id, metadata + FROM messages + WHERE message_id = ? + """, + (message_id,), + ) + + row = cursor.fetchone() + return self._dict_to_message(dict(row)) if row else None + + async def get_latest_message(self, session_id: str) -> Optional[Message]: + """Get the latest message in a session""" + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute( + """ + SELECT message_id, session_id, user_id, agent_name, role, content, + timestamp, task_id, metadata + FROM messages + WHERE session_id = ? + ORDER BY timestamp DESC + LIMIT 1 + """, + (session_id,), + ) + + row = cursor.fetchone() + return self._dict_to_message(dict(row)) if row else None + + async def get_message_count(self, session_id: str) -> int: + """Get message count for a session""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute( + """ + SELECT COUNT(*) FROM messages WHERE session_id = ? + """, + (session_id,), + ) + + return cursor.fetchone()[0] + + async def delete_session_messages(self, session_id: str) -> int: + """Delete all messages for a session""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute( + """ + DELETE FROM messages WHERE session_id = ? + """, + (session_id,), + ) + + return cursor.rowcount + + async def delete_message(self, message_id: str) -> bool: + """Delete a specific message""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute( + """ + DELETE FROM messages WHERE message_id = ? + """, + (message_id,), + ) + + return cursor.rowcount > 0 diff --git a/python/valuecell/core/session/models.py b/python/valuecell/core/session/models.py index 738d6d822..e3ccb13c5 100644 --- a/python/valuecell/core/session/models.py +++ b/python/valuecell/core/session/models.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, Field @@ -13,11 +13,21 @@ class Role(str, Enum): SYSTEM = "system" +class SessionStatus(str, Enum): + """Session status enumeration""" + + ACTIVE = "active" + INACTIVE = "inactive" + REQUIRE_USER_INPUT = "require_user_input" + + class Message(BaseModel): """Message data model""" message_id: str = Field(..., description="Unique message identifier") session_id: str = Field(..., description="Session ID this message belongs to") + user_id: str = Field(..., description="User ID") + agent_name: Optional[str] = Field(None, description="Agent name") role: Role = Field(..., description="Message role") content: str = Field(..., description="Message content") timestamp: datetime = Field( @@ -33,7 +43,7 @@ class Config: class Session(BaseModel): - """Session data model""" + """Session data model - lightweight metadata only, messages stored separately""" session_id: str = Field(..., description="Unique session identifier") user_id: str = Field(..., description="User ID") @@ -44,43 +54,35 @@ class Session(BaseModel): updated_at: datetime = Field( default_factory=datetime.now, description="Last update time" ) - messages: List[Message] = Field( - default_factory=list, description="Session message list" + status: SessionStatus = Field( + default=SessionStatus.ACTIVE, description="Session status" ) - context: Dict[str, Any] = Field( - default_factory=dict, description="Session context data" - ) - is_active: bool = Field(True, description="Whether session is active") class Config: json_encoders = {datetime: lambda v: v.isoformat()} - def add_message(self, message: Message) -> None: - """Add message to session""" - messages = list(self.messages) - messages.append(message) - self.messages = messages + @property + def is_active(self) -> bool: + """Backward compatibility property - returns True if session is active""" + return self.status == SessionStatus.ACTIVE + + def set_status(self, status: SessionStatus) -> None: + """Update session status and timestamp""" + self.status = status self.updated_at = datetime.now() - def get_messages_by_role(self, role: Role) -> List[Message]: - """Get messages by role""" - return [msg for msg in self.messages if msg.role == role] + def activate(self) -> None: + """Set session to active status""" + self.set_status(SessionStatus.ACTIVE) - def get_latest_message(self) -> Optional[Message]: - """Get latest message""" - return self.messages[-1] if self.messages else None + def deactivate(self) -> None: + """Set session to inactive status""" + self.set_status(SessionStatus.INACTIVE) - def get_message_count(self) -> int: - """Get message count""" - return len(self.messages) + def require_user_input(self) -> None: + """Set session to require user input status""" + self.set_status(SessionStatus.REQUIRE_USER_INPUT) - def update_context(self, key: str, value: Any) -> None: - """Update session context""" - context = dict(self.context) - context[key] = value - self.context = context + def touch(self) -> None: + """Update the session's last activity timestamp""" self.updated_at = datetime.now() - - def get_context(self, key: str, default: Any = None) -> Any: - """Get session context value""" - return dict(self.context).get(key, default) diff --git a/python/valuecell/core/session/store.py b/python/valuecell/core/session/store.py index df98d2cb3..99f5c6ae5 100644 --- a/python/valuecell/core/session/store.py +++ b/python/valuecell/core/session/store.py @@ -5,7 +5,10 @@ class SessionStore(ABC): - """Session storage abstract base class""" + """Session storage abstract base class - handles session metadata only. + + Messages are stored separately using MessageStore implementations. + """ @abstractmethod async def save_session(self, session: Session) -> None: From ad67e114cea19eaaa86efe83dd02ce8f3ac54ca9 Mon Sep 17 00:00:00 2001 From: Zhaofeng Zhang <24791380+vcfgv@users.noreply.github.com> Date: Thu, 18 Sep 2025 11:11:54 +0800 Subject: [PATCH 2/6] feat: add agent_name to MessageChunkMetadata and update orchestrator to track agent responses --- .../valuecell/core/coordinate/orchestrator.py | 50 +++++++++++++------ python/valuecell/core/types.py | 1 + 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/python/valuecell/core/coordinate/orchestrator.py b/python/valuecell/core/coordinate/orchestrator.py index 46aeed96b..58d3ffdef 100644 --- a/python/valuecell/core/coordinate/orchestrator.py +++ b/python/valuecell/core/coordinate/orchestrator.py @@ -1,4 +1,5 @@ import logging +from collections import defaultdict from typing import AsyncGenerator from a2a.types import TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent @@ -34,6 +35,7 @@ def _create_message_chunk( content: str, session_id: str, user_id: str, + agent_name: str, kind: MessageDataKind = MessageDataKind.TEXT, is_final: bool = False, status: MessageChunkStatus = MessageChunkStatus.partial, @@ -43,19 +45,23 @@ def _create_message_chunk( content=content, kind=kind, meta=MessageChunkMetadata( - session_id=session_id, user_id=user_id, status=status + session_id=session_id, + user_id=user_id, + agent_name=agent_name, + status=status, ), is_final=is_final, ) def _create_error_message_chunk( - self, error_msg: str, session_id: str, user_id: str + self, error_msg: str, session_id: str, user_id: str, agent_name: str ) -> MessageChunk: """Create an error MessageChunk with standardized format""" return self._create_message_chunk( content=f"(Error): {error_msg}", session_id=session_id, user_id=user_id, + agent_name=agent_name, is_final=True, status=MessageChunkStatus.failure, ) @@ -71,28 +77,38 @@ async def process_user_input( await self.session_manager.create_session( user_input.meta.user_id, session_id=session_id ) - await self.session_manager.add_message(session_id, Role.USER, user_input.query) + await self.session_manager.add_message( + session_id, Role.USER, user_input.query, user_id=user_input.meta.user_id + ) try: # Create execution plan with user_id plan = await self.planner.create_plan(user_input) - # Stream execution results - full_response = "" + # Stream execution results and track agent responses separately + agent_responses = defaultdict(str) # Dict[agent_name, response_content] async for chunk in self._execute_plan(plan, user_input.meta.model_dump()): - full_response += chunk.content + agent_name = chunk.meta.agent_name + + # Track content by agent + agent_responses[agent_name] += chunk.content + yield chunk - # Add final assistant response to session - await self.session_manager.add_message( - session_id, Role.AGENT, full_response - ) + # Add separate messages for each agent's complete response + for agent_name, full_response in agent_responses.items(): + if full_response.strip(): # Only save non-empty responses + await self.session_manager.add_message( + session_id, Role.AGENT, full_response, agent_name=agent_name + ) except Exception as e: error_msg = f"Error processing request: {str(e)}" - await self.session_manager.add_message(session_id, Role.SYSTEM, error_msg) + await self.session_manager.add_message( + session_id, Role.SYSTEM, error_msg, agent_name="__system__" + ) yield self._create_error_message_chunk( - error_msg, session_id, user_input.meta.user_id + error_msg, session_id, user_input.meta.user_id, "__system__" ) async def _execute_plan( @@ -103,7 +119,7 @@ async def _execute_plan( session_id, user_id = metadata["session_id"], metadata["user_id"] if not plan.tasks: yield self._create_error_message_chunk( - "No tasks found for this request.", session_id, user_id + "No tasks found for this request.", session_id, user_id, "__system__" ) return @@ -119,7 +135,9 @@ async def _execute_plan( except Exception as e: error_msg = f"Error executing {task.agent_name}: {str(e)}" - yield self._create_error_message_chunk(error_msg, session_id, user_id) + yield self._create_error_message_chunk( + error_msg, session_id, user_id, task.agent_name + ) async def _execute_task( self, task, query: str, metadata: dict @@ -165,6 +183,7 @@ async def _execute_task( err_msg, task.session_id, task.user_id, + task.agent_name, ) return @@ -174,6 +193,7 @@ async def _execute_task( get_message_text(event.artifact, ""), task.session_id, task.user_id, + task.agent_name, ) # Complete task @@ -182,6 +202,7 @@ async def _execute_task( "", task.session_id, task.user_id, + task.agent_name, is_final=True, status=MessageChunkStatus.success, ) @@ -210,6 +231,7 @@ async def close_session(self, session_id: str): session_id, Role.SYSTEM, f"Session closed. {cancelled_count} tasks were cancelled.", + agent_name="orchestrator", ) async def get_session_history(self, session_id: str): diff --git a/python/valuecell/core/types.py b/python/valuecell/core/types.py index 56fd63713..16ed55a8b 100644 --- a/python/valuecell/core/types.py +++ b/python/valuecell/core/types.py @@ -69,6 +69,7 @@ class MessageChunkMetadata(BaseModel): ) session_id: str = Field(..., description="Session ID for this request") user_id: str = Field(..., description="User ID who made this request") + agent_name: str = Field(..., description="Agent name handling this message") class MessageChunk(BaseModel): From 245e89aa1ec559dc45b4af03de47be6a22377b36 Mon Sep 17 00:00:00 2001 From: Zhaofeng Zhang <24791380+vcfgv@users.noreply.github.com> Date: Thu, 18 Sep 2025 15:43:16 +0800 Subject: [PATCH 3/6] feat: enhance agent execution and notification handling with new task patterns and response models --- python/valuecell/core/agent/decorator.py | 15 +++++---- .../valuecell/core/coordinate/orchestrator.py | 18 +++++++++-- python/valuecell/core/task/models.py | 10 ++++++ python/valuecell/core/types.py | 31 +++++++++++++++++-- 4 files changed, 62 insertions(+), 12 deletions(-) diff --git a/python/valuecell/core/agent/decorator.py b/python/valuecell/core/agent/decorator.py index 71f0ae4fe..59fdb7454 100644 --- a/python/valuecell/core/agent/decorator.py +++ b/python/valuecell/core/agent/decorator.py @@ -150,19 +150,14 @@ def __init__(self, agent: BaseAgent): self.agent = agent async def execute(self, context: RequestContext, event_queue: EventQueue) -> None: - # Ensure agent implements streaming interface - if not hasattr(self.agent, "stream"): - raise NotImplementedError( - f"Agent {self.agent.__class__.__name__} must implement 'stream' method" - ) - # Prepare query and ensure a task exists in the system query = context.get_user_input() task = context.current_task + metadata = context.metadata if not task: message = context.message task = new_task(message) - task.metadata = message.metadata + task.metadata = metadata await event_queue.enqueue_event(task) # Helper state @@ -186,7 +181,10 @@ async def _add_chunk(content: str, last: bool = False): # Stream from the user agent and update task incrementally await updater.update_status(TaskState.working) try: - async for item in self.agent.stream(query, task.context_id, task.id): + query_handler = ( + self.agent.notify if metadata.get("notify") else self.agent.stream + ) + async for item in query_handler(query, task.context_id, task.id): content = item.get("content", "") is_complete = item.get("is_task_complete", True) @@ -195,6 +193,7 @@ async def _add_chunk(content: str, last: bool = False): if is_complete: await updater.complete() break + except Exception as e: message = ( f"Error during {self.agent.__class__.__name__} agent execution: {e}" diff --git a/python/valuecell/core/coordinate/orchestrator.py b/python/valuecell/core/coordinate/orchestrator.py index 58d3ffdef..4f442733c 100644 --- a/python/valuecell/core/coordinate/orchestrator.py +++ b/python/valuecell/core/coordinate/orchestrator.py @@ -6,7 +6,8 @@ from a2a.utils import get_message_text from valuecell.core.agent.connect import get_default_remote_connections from valuecell.core.session import Role, get_default_session_manager -from valuecell.core.task import get_default_task_manager +from valuecell.core.task import Task, get_default_task_manager +from valuecell.core.task.models import TaskPattern from valuecell.core.types import ( MessageChunk, MessageChunkMetadata, @@ -95,6 +96,16 @@ async def process_user_input( yield chunk + if chunk.is_final and agent_responses[agent_name].strip(): + # Save final response to session when final chunk is received + await self.session_manager.add_message( + session_id, + Role.AGENT, + agent_responses[agent_name], + agent_name=agent_name, + ) + agent_responses[agent_name] = "" + # Add separate messages for each agent's complete response for agent_name, full_response in agent_responses.items(): if full_response.strip(): # Only save non-empty responses @@ -140,7 +151,7 @@ async def _execute_plan( ) async def _execute_task( - self, task, query: str, metadata: dict + self, task: Task, query: str, metadata: dict ) -> AsyncGenerator[MessageChunk, None]: """Execute a single task by calling the specified agent - streams results""" @@ -158,6 +169,8 @@ async def _execute_task( if not client: raise RuntimeError(f"Could not connect to agent {task.agent_name}") + if task.pattern != TaskPattern.ONCE: + metadata["notify"] = True response = await client.send_message( query, context_id=task.session_id, @@ -194,6 +207,7 @@ async def _execute_task( task.session_id, task.user_id, task.agent_name, + is_final=metadata.get("notify", False), ) # Complete task diff --git a/python/valuecell/core/task/models.py b/python/valuecell/core/task/models.py index d42c57938..76bae0033 100644 --- a/python/valuecell/core/task/models.py +++ b/python/valuecell/core/task/models.py @@ -16,6 +16,13 @@ class TaskStatus(str, Enum): CANCELLED = "cancelled" # Cancelled by user or system +class TaskPattern(str, Enum): + """Task pattern enumeration""" + + ONCE = "once" # One-time task + RECURRING = "recurring" # Recurring task + + class Task(BaseModel): """Task data model""" @@ -30,6 +37,9 @@ class Task(BaseModel): status: TaskStatus = Field( default=TaskStatus.PENDING, description="Current task status" ) + pattern: TaskPattern = Field( + default=TaskPattern.ONCE, description="Task execution pattern" + ) # Time-related fields created_at: datetime = Field( diff --git a/python/valuecell/core/types.py b/python/valuecell/core/types.py index 16ed55a8b..f00f16e08 100644 --- a/python/valuecell/core/types.py +++ b/python/valuecell/core/types.py @@ -100,6 +100,16 @@ class StreamResponse(BaseModel): ) +class NotifyResponse(BaseModel): + """Response model for notification agent responses""" + + content: str = Field( + ..., + description="The content of the notification response", + ) + + +# TODO: keep only essential parameters class BaseAgent(ABC): """ Abstract base class for all agents. @@ -107,10 +117,10 @@ class BaseAgent(ABC): @abstractmethod async def stream( - self, query, session_id, task_id + self, query: str, session_id: str, task_id: str ) -> AsyncGenerator[StreamResponse, None]: """ - Process user queries and return streaming responses + Process user queries and return streaming responses (user-initiated) Args: query: User query content @@ -122,6 +132,23 @@ async def stream( """ raise NotImplementedError + @abstractmethod + async def notify( + self, query: str, session_id: str, task_id: str + ) -> AsyncGenerator[NotifyResponse, None]: + """ + Send proactive notifications to subscribed users (agent-initiated) + + Args: + query: User query content, can be empty for some agents + session_id: Session ID for the notification + user_id: Target user ID for the notification + + Yields: + StreamResponse: Notification content and status + """ + raise NotImplementedError + # Message response type for agent communication RemoteAgentResponse = tuple[ From 42f90c826745d1e03c0d88ffda4b8e69f3751c37 Mon Sep 17 00:00:00 2001 From: Zhaofeng Zhang <24791380+vcfgv@users.noreply.github.com> Date: Fri, 19 Sep 2025 00:57:39 +0800 Subject: [PATCH 4/6] feat: enhance user input handling and execution planning with new context management and task definitions --- python/configs/agent_cards/sec_agent.json | 80 ++ python/valuecell/core/coordinate/models.py | 65 +- .../valuecell/core/coordinate/orchestrator.py | 715 +++++++++++++++--- python/valuecell/core/coordinate/planner.py | 225 +++++- .../core/coordinate/planner_prompts.py | 51 ++ python/valuecell/core/task/__init__.py | 3 +- python/valuecell/core/task/models.py | 1 + python/valuecell/core/types.py | 1 - 8 files changed, 997 insertions(+), 144 deletions(-) create mode 100644 python/valuecell/core/coordinate/planner_prompts.py diff --git a/python/configs/agent_cards/sec_agent.json b/python/configs/agent_cards/sec_agent.json index 08f0c192e..e7bdad61e 100644 --- a/python/configs/agent_cards/sec_agent.json +++ b/python/configs/agent_cards/sec_agent.json @@ -1,5 +1,85 @@ { "name": "Sec13FundAgent", "url": "http://localhost:10001/", + "description": "Sec13FundAgent can analyze SEC filings like 10-Q, 10-K, 13-F and analyze stock holdings of institutional investment managers. It can chat about stock performance, financial metrics, and market trends or track specific stocks and provide updates.", + "skills": [ + { + "id": "analyze_13f_filings", + "name": "Analyze 13F Filings", + "description": "Analyze 13F filings to extract stock holdings and provide insights on institutional investment managers' portfolios.", + "examples": [ + "What are the top holdings of Berkshire Hathaway in the latest 13F filing?", + "How has the portfolio of Vanguard Group changed over the last four quarters?", + "Can you provide insights on the stock performance of Apple and Microsoft based on recent 13F filings?" + ], + "tags": ["13F", "stock holdings", "institutional investors"] + }, + { + "id": "compare_quarterly_reports", + "name": "Compare Quarterly Financial Reports", + "description": "Compare and analyze quarterly financial reports (10-Q) across multiple periods to identify trends, performance changes, and key financial metrics evolution.", + "examples": [ + "Compare Tesla's last three quarterly 10-Q reports", + "Analyze Apple's Q1 vs Q4 financial performance from their 10-Q filings", + "What are the key changes in Amazon's quarterly revenue and expenses over the past year?" + ], + "tags": ["10-Q", "quarterly reports", "financial analysis"] + }, + { + "id": "summarize_annual_reports", + "name": "Summarize Annual Financial Reports", + "description": "Extract and summarize key information from annual 10-K reports, including business overview, risk factors, financial highlights, and strategic initiatives.", + "examples": [ + "Help me summarize Nvidia's latest 10-K report", + "What are the main highlights from Microsoft's annual 10-K filing?", + "Provide a comprehensive summary of Meta's business segments from their 10-K report" + ], + "tags": ["10-K", "annual reports", "financial summary"] + }, + { + "id": "compare_industry_performance", + "name": "Compare Industry Financial Performance", + "description": "Analyze and compare financial performance metrics across companies within the same industry sector, identifying competitive positioning and market leaders.", + "examples": [ + "How does AMD's financial performance compare to Intel and Nvidia?", + "Compare the profitability metrics of JPMorgan, Bank of America, and Wells Fargo", + "Analyze the revenue growth rates of Netflix, Disney, and Warner Bros Discovery" + ], + "tags": ["industry comparison", "financial metrics", "competitive analysis"] + }, + { + "id": "analyze_revenue_mix_evolution", + "name": "Analyze Revenue Mix Evolution", + "description": "Track and analyze changes in company revenue composition, business segment performance, and revenue diversification strategies over time.", + "examples": [ + "How has Google's revenue mix changed over the past 3 years?", + "Analyze the evolution of Amazon's revenue streams from 2020 to 2023", + "What changes occurred in Apple's product revenue mix in the last 5 years?" + ], + "tags": ["revenue mix", "business segments", "financial evolution"] + }, + { + "id": "assess_financial_health_metrics", + "name": "Assess Financial Health Metrics", + "description": "Evaluate company financial health through key metrics analysis including liquidity ratios, debt levels, profitability margins, and cash flow patterns.", + "examples": [ + "Assess Tesla's current financial health based on their latest filings", + "What do the debt-to-equity ratios tell us about Boeing's financial stability?", + "Analyze the cash flow trends and liquidity position of General Electric" + ], + "tags": ["financial health", "liquidity", "debt analysis"] + }, + { + "id": "identify_risk_factors", + "name": "Identify Financial Risk Factors", + "description": "Extract and analyze risk factors disclosed in SEC filings, categorize risks by type, and assess potential impact on business operations and financial performance.", + "examples": [ + "What are the main risk factors disclosed in Uber's latest 10-K filing?", + "Analyze the regulatory risks mentioned in JPMorgan's annual report", + "Compare the risk disclosures between traditional automakers and EV companies" + ], + "tags": ["risk factors", "financial risk", "SEC filings"] + } + ], "enabled": true } \ No newline at end of file diff --git a/python/valuecell/core/coordinate/models.py b/python/valuecell/core/coordinate/models.py index 6155f4333..aff8496f0 100644 --- a/python/valuecell/core/coordinate/models.py +++ b/python/valuecell/core/coordinate/models.py @@ -1,15 +1,70 @@ -from typing import List -from pydantic import BaseModel, Field +from typing import List, Optional +from pydantic import BaseModel, Field from valuecell.core.task import Task +from valuecell.core.task.models import TaskPattern class ExecutionPlan(BaseModel): - """Execution plan containing multiple tasks""" + """ + Execution plan containing multiple tasks for fulfilling a user request. + + This model represents a structured plan that breaks down a user's request + into executable tasks that can be processed by different agents. + """ plan_id: str = Field(..., description="Unique plan identifier") - session_id: str = Field(..., description="Session ID this plan belongs to") + session_id: Optional[str] = Field( + None, description="Session ID this plan belongs to" + ) user_id: str = Field(..., description="User ID who requested this plan") - query: str = Field(..., description="Original user input") + query: str = Field(..., description="Original user query that generated this plan") tasks: List[Task] = Field(default_factory=list, description="Tasks to execute") created_at: str = Field(..., description="Plan creation timestamp") + + +class _TaskBrief(BaseModel): + """ + Represents a task to be executed by an agent. + + This is a simplified task representation used during the planning phase + before being converted to a full Task object. + """ + + query: str = Field(..., description="The task to be performed") + agent_name: str = Field(..., description="Name of the agent executing this task") + pattern: TaskPattern = Field( + default=TaskPattern.ONCE, description="Task execution pattern" + ) + + +class PlannerInput(BaseModel): + """ + Schema for planner input containing user query and metadata. + + This schema is used by the planning agent to structure its input + when determining what tasks should be executed. + """ + + desired_agent_name: str = Field( + ..., description="The name of the agent the user wants to use for the task" + ) + query: str = Field( + ..., description="The user's input or request which may need clarification" + ) + + +class PlannerResponse(BaseModel): + """ + Schema for planner response containing tasks and planning metadata. + + This schema is used by the planning agent to structure its response + when determining what tasks should be executed. + """ + + tasks: List[_TaskBrief] = Field(..., description="List of tasks to be executed") + adequate: bool = Field( + ..., + description="true if information is adequate for task execution, false if more input is needed", + ) + reason: str = Field(..., description="Reason for the planning decision") diff --git a/python/valuecell/core/coordinate/orchestrator.py b/python/valuecell/core/coordinate/orchestrator.py index 4f442733c..bf326d6f1 100644 --- a/python/valuecell/core/coordinate/orchestrator.py +++ b/python/valuecell/core/coordinate/orchestrator.py @@ -1,11 +1,12 @@ +import asyncio import logging from collections import defaultdict -from typing import AsyncGenerator +from typing import AsyncGenerator, Dict, Optional from a2a.types import TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent from a2a.utils import get_message_text from valuecell.core.agent.connect import get_default_remote_connections -from valuecell.core.session import Role, get_default_session_manager +from valuecell.core.session import Role, SessionStatus, get_default_session_manager from valuecell.core.task import Task, get_default_task_manager from valuecell.core.task.models import TaskPattern from valuecell.core.types import ( @@ -18,19 +19,358 @@ from .callback import store_task_in_session from .models import ExecutionPlan -from .planner import ExecutionPlanner +from .planner import ExecutionPlanner, UserInputRequest logger = logging.getLogger(__name__) +# Constants for configuration +DEFAULT_CONTEXT_TIMEOUT_SECONDS = 3600 # 1 hour +ASYNC_SLEEP_INTERVAL = 0.1 # 100ms + + +class ExecutionContext: + """Manages the state of an interrupted execution for resumption""" + + def __init__(self, stage: str, session_id: str, user_id: str): + self.stage = stage + self.session_id = session_id + self.user_id = user_id + self.created_at = asyncio.get_event_loop().time() + self.metadata: Dict = {} + + def is_expired( + self, max_age_seconds: int = DEFAULT_CONTEXT_TIMEOUT_SECONDS + ) -> bool: + """Check if this context has expired""" + current_time = asyncio.get_event_loop().time() + return current_time - self.created_at > max_age_seconds + + def validate_user(self, user_id: str) -> bool: + """Validate that the user ID matches the original request""" + return self.user_id == user_id + + def add_metadata(self, **kwargs): + """Add metadata to the context""" + self.metadata.update(kwargs) + + def get_metadata(self, key: str, default=None): + """Get metadata value""" + return self.metadata.get(key, default) + + +class UserInputManager: + """Manages pending user input requests and their lifecycle""" + + def __init__(self): + self._pending_requests: Dict[str, UserInputRequest] = {} + + def add_request(self, session_id: str, request: UserInputRequest): + """Add a pending user input request""" + self._pending_requests[session_id] = request + + def has_pending_request(self, session_id: str) -> bool: + """Check if there's a pending request for the session""" + return session_id in self._pending_requests + + def get_request_prompt(self, session_id: str) -> Optional[str]: + """Get the prompt for a pending request""" + request = self._pending_requests.get(session_id) + return request.prompt if request else None + + def provide_response(self, session_id: str, response: str) -> bool: + """Provide a response to a pending request""" + if session_id not in self._pending_requests: + return False + + request = self._pending_requests[session_id] + request.provide_response(response) + del self._pending_requests[session_id] + return True + + def clear_request(self, session_id: str): + """Clear a pending request""" + self._pending_requests.pop(session_id, None) + class AgentOrchestrator: + """ + Orchestrates execution of user requests through multiple agents with Human-in-the-Loop support. + + This class manages the entire lifecycle of user requests including: + - Planning phase with user input collection + - Task execution with interruption support + - Session state management + - Error handling and recovery + """ + def __init__(self): self.session_manager = get_default_session_manager() self.task_manager = get_default_task_manager() self.agent_connections = get_default_remote_connections() + # Initialize user input management + self.user_input_manager = UserInputManager() + + # Initialize execution context management + self._execution_contexts: Dict[str, ExecutionContext] = {} + + # Initialize planner self.planner = ExecutionPlanner(self.agent_connections) + # ==================== Public API Methods ==================== + + async def process_user_input( + self, user_input: UserInput + ) -> AsyncGenerator[MessageChunk, None]: + """ + Main entry point for processing user requests with Human-in-the-Loop support. + + Handles three types of scenarios: + 1. New user requests - starts planning and execution + 2. Continuation of interrupted sessions - resumes from saved state + 3. User input responses - provides input to waiting requests + + Args: + user_input: The user's input containing query and metadata + + Yields: + MessageChunk: Streaming response chunks from agents + """ + session_id = user_input.meta.session_id + user_id = user_input.meta.user_id + + try: + # Ensure session exists + session = await self._ensure_session_exists(session_id, user_id) + + # Handle session continuation vs new request + if session.status == SessionStatus.REQUIRE_USER_INPUT: + async for chunk in self._handle_session_continuation(user_input): + yield chunk + else: + async for chunk in self._handle_new_request(user_input): + yield chunk + + except Exception as e: + logger.exception(f"Error processing user input for session {session_id}") + yield self._create_error_message_chunk( + f"Error processing request: {str(e)}", session_id, user_id, "__system__" + ) + + async def provide_user_input(self, session_id: str, response: str): + """ + Provide user input response for a specific session. + + Args: + session_id: The session ID waiting for input + response: The user's response to the input request + """ + if self.user_input_manager.provide_response(session_id, response): + # Update session status to active + session = await self.session_manager.get_session(session_id) + if session: + session.activate() + await self.session_manager.update_session(session) + + def has_pending_user_input(self, session_id: str) -> bool: + """Check if a session has pending user input request""" + return self.user_input_manager.has_pending_request(session_id) + + def get_user_input_prompt(self, session_id: str) -> Optional[str]: + """Get the user input prompt for a specific session""" + return self.user_input_manager.get_request_prompt(session_id) + + async def create_session(self, user_id: str, title: str = None): + """Create a new session for the user""" + return await self.session_manager.create_session(user_id, title) + + async def close_session(self, session_id: str): + """Close an existing session and clean up resources""" + # Cancel any running tasks for this session + cancelled_count = await self.task_manager.cancel_session_tasks(session_id) + + # Clean up execution context + await self._cancel_execution(session_id) + + # Add system message to mark session as closed + await self.session_manager.add_message( + session_id, + Role.SYSTEM, + f"Session closed. {cancelled_count} tasks were cancelled.", + agent_name="orchestrator", + ) + + async def get_session_history(self, session_id: str): + """Get session message history""" + return await self.session_manager.get_session_messages(session_id) + + async def get_user_sessions(self, user_id: str, limit: int = 100, offset: int = 0): + """Get all sessions for a user""" + return await self.session_manager.list_user_sessions(user_id, limit, offset) + + async def cleanup(self): + """Cleanup resources and expired contexts""" + await self._cleanup_expired_contexts() + await self.agent_connections.stop_all() + + # ==================== Private Helper Methods ==================== + + # ==================== Private Helper Methods ==================== + + async def _handle_user_input_request(self, request: UserInputRequest): + """Handle user input request from planner""" + # Extract session_id from request context + session_id = getattr(request, "session_id", None) + if session_id: + self.user_input_manager.add_request(session_id, request) + + async def _ensure_session_exists(self, session_id: str, user_id: str): + """Ensure a session exists, creating it if necessary""" + session = await self.session_manager.get_session(session_id) + if not session: + await self.session_manager.create_session(user_id, session_id=session_id) + session = await self.session_manager.get_session(session_id) + return session + + async def _handle_session_continuation( + self, user_input: UserInput + ) -> AsyncGenerator[MessageChunk, None]: + """Handle continuation of an interrupted session""" + session_id = user_input.meta.session_id + user_id = user_input.meta.user_id + + # Validate execution context exists + if session_id not in self._execution_contexts: + yield self._create_error_message_chunk( + "No execution context found for this session. The session may have expired.", + session_id, + user_id, + "__system__", + ) + return + + context = self._execution_contexts[session_id] + + # Validate context integrity and user consistency + if not self._validate_execution_context(context, user_id): + yield self._create_error_message_chunk( + "Invalid execution context or user mismatch.", + session_id, + user_id, + "__system__", + ) + await self._cancel_execution(session_id) + return + + # Provide user response and resume execution + # If we are in an execution stage, store the pending response for resume + context.add_metadata(pending_response=user_input.query) + await self.provide_user_input(session_id, user_input.query) + + # Resume based on execution stage + if context.stage == "planning": + async for chunk in self._continue_planning(session_id, context): + yield chunk + # TODO: Add support for resuming execution stage if needed + else: + yield self._create_error_message_chunk( + "Resuming execution stage is not yet supported.", + session_id, + user_id, + "__system__", + ) + + async def _handle_new_request( + self, user_input: UserInput + ) -> AsyncGenerator[MessageChunk, None]: + """Handle a new user request""" + session_id = user_input.meta.session_id + + # Add user message to session + await self.session_manager.add_message( + session_id, Role.USER, user_input.query, user_id=user_input.meta.user_id + ) + + # Create planning task with user input callback + context_aware_callback = self._create_context_aware_callback(session_id) + + planning_task = asyncio.create_task( + self.planner.create_plan(user_input, context_aware_callback) + ) + + # Monitor planning progress + async for chunk in self._monitor_planning_task( + planning_task, user_input, context_aware_callback + ): + yield chunk + + def _create_context_aware_callback(self, session_id: str): + """Create a callback that adds session context to user input requests""" + + async def context_aware_handle(request): + request.session_id = session_id + await self._handle_user_input_request(request) + + return context_aware_handle + + async def _monitor_planning_task( + self, planning_task, user_input: UserInput, callback + ) -> AsyncGenerator[MessageChunk, None]: + """Monitor planning task and handle user input interruptions""" + session_id = user_input.meta.session_id + user_id = user_input.meta.user_id + + # Wait for planning completion or user input request + while not planning_task.done(): + if self.has_pending_user_input(session_id): + # Save planning context + context = ExecutionContext("planning", session_id, user_id) + context.add_metadata( + original_user_input=user_input, + planning_task=planning_task, + planner_callback=callback, + ) + self._execution_contexts[session_id] = context + + # Update session status and send user input request + await self._request_user_input(session_id, user_id) + yield self._create_user_input_request_chunk( + self.get_user_input_prompt(session_id), session_id, context.user_id + ) + return + + await asyncio.sleep(ASYNC_SLEEP_INTERVAL) + + # Planning completed, execute plan + plan = await planning_task + async for chunk in self._execute_plan_with_input_support( + plan, user_input.meta.model_dump() + ): + yield chunk + + async def _request_user_input(self, session_id: str, _user_id: str): + """Set session to require user input and send the request""" + # Note: _user_id parameter kept for potential future use in user validation + session = await self.session_manager.get_session(session_id) + if session: + session.require_user_input() + await self.session_manager.update_session(session) + + def _validate_execution_context( + self, context: ExecutionContext, user_id: str + ) -> bool: + """Validate execution context integrity""" + if not hasattr(context, "stage") or not context.stage: + return False + + if not context.validate_user(user_id): + return False + + if context.is_expired(): + return False + + return True + def _create_message_chunk( self, content: str, @@ -41,7 +381,7 @@ def _create_message_chunk( is_final: bool = False, status: MessageChunkStatus = MessageChunkStatus.partial, ) -> MessageChunk: - """Create a MessageChunk with common metadata""" + """Create a MessageChunk with standardized metadata""" return MessageChunk( content=content, kind=kind, @@ -67,110 +407,322 @@ def _create_error_message_chunk( status=MessageChunkStatus.failure, ) - async def process_user_input( - self, user_input: UserInput + def _create_user_input_request_chunk( + self, + prompt: str, + session_id: str, + user_id: str, + agent_name: str = "__planner__", + ) -> MessageChunk: + """Create a user input request MessageChunk""" + return self._create_message_chunk( + content=f"USER_INPUT_REQUIRED:{prompt}", + session_id=session_id, + user_id=user_id, + agent_name=agent_name, + kind=MessageDataKind.COMMAND, + is_final=True, + status=MessageChunkStatus.partial, + ) + + async def _continue_planning( + self, session_id: str, context: ExecutionContext ) -> AsyncGenerator[MessageChunk, None]: - """Main entry point for processing user input - streams results""" + """Resume planning stage execution""" + planning_task = context.get_metadata("planning_task") + original_user_input = context.get_metadata("original_user_input") - session_id = user_input.meta.session_id - # Add user message to session - if not await self.session_manager.session_exists(session_id): - await self.session_manager.create_session( - user_input.meta.user_id, session_id=session_id + if not all([planning_task, original_user_input]): + yield self._create_error_message_chunk( + "Invalid planning context - missing required data", + session_id, + context.user_id, + "__planner__", ) - await self.session_manager.add_message( - session_id, Role.USER, user_input.query, user_id=user_input.meta.user_id - ) + await self._cancel_execution(session_id) + return + + # Continue monitoring planning task + while not planning_task.done(): + if self.has_pending_user_input(session_id): + # Still need more user input, send request + prompt = self.get_user_input_prompt(session_id) + # Ensure session is set to require user input again for repeated prompts + await self._request_user_input(session_id, context.user_id) + yield self._create_user_input_request_chunk( + prompt, session_id, context.user_id + ) + return + + await asyncio.sleep(ASYNC_SLEEP_INTERVAL) + + # Planning completed, execute plan and clean up context + plan = await planning_task + del self._execution_contexts[session_id] + + async for chunk in self._execute_plan_with_input_support( + plan, original_user_input.meta.model_dump() + ): + yield chunk + + async def _cancel_execution(self, session_id: str): + """Cancel execution and clean up all related resources""" + # Clean up execution context + if session_id in self._execution_contexts: + context = self._execution_contexts[session_id] + + # Cancel planning task if it exists and is not done + planning_task = context.get_metadata("planning_task") + if planning_task and not planning_task.done(): + planning_task.cancel() + + del self._execution_contexts[session_id] + + # Clear pending user input + self.user_input_manager.clear_request(session_id) + + # Reset session status + session = await self.session_manager.get_session(session_id) + if session: + session.activate() + await self.session_manager.update_session(session) + + async def _cleanup_expired_contexts( + self, max_age_seconds: int = DEFAULT_CONTEXT_TIMEOUT_SECONDS + ): + """Clean up execution contexts that have been idle for too long""" + expired_sessions = [ + session_id + for session_id, context in self._execution_contexts.items() + if context.is_expired(max_age_seconds) + ] + + for session_id in expired_sessions: + await self._cancel_execution(session_id) + logger.warning( + f"Cleaned up expired execution context for session {session_id}" + ) + + # ==================== Plan and Task Execution Methods ==================== + + async def _execute_plan_with_input_support( + self, plan: ExecutionPlan, metadata: dict + ) -> AsyncGenerator[MessageChunk, None]: + """ + Execute an execution plan with Human-in-the-Loop support. + + This method streams execution results and handles user input interruptions + during task execution. + + Args: + plan: The execution plan containing tasks to execute + metadata: Execution metadata containing session and user info + """ + session_id, user_id = metadata["session_id"], metadata["user_id"] + + if not plan.tasks: + yield self._create_error_message_chunk( + "No tasks found for this request.", session_id, user_id, "__system__" + ) + return + + # Track agent responses for session storage + agent_responses = defaultdict(str) + + for task in plan.tasks: + try: + # Register the task with TaskManager + await self.task_manager.store.save_task(task) + # Execute task with input support + async for chunk in self._execute_task_with_input_support( + task, plan.query, metadata + ): + # Accumulate agent responses + agent_name = chunk.meta.agent_name + agent_responses[agent_name] += chunk.content + yield chunk + + # Save complete responses to session + if chunk.is_final and agent_responses[agent_name].strip(): + await self.session_manager.add_message( + session_id, + Role.AGENT, + agent_responses[agent_name], + agent_name=agent_name, + ) + agent_responses[agent_name] = "" + + except Exception as e: + error_msg = f"Error executing {task.agent_name}: {str(e)}" + logger.exception(f"Task execution failed: {error_msg}") + yield self._create_error_message_chunk( + error_msg, session_id, user_id, task.agent_name + ) + + # Save any remaining agent responses + await self._save_remaining_responses(session_id, agent_responses) + + async def _execute_task_with_input_support( + self, task: Task, query: str, metadata: dict + ) -> AsyncGenerator[MessageChunk, None]: + """ + Execute a single task with user input interruption support. + + Args: + task: The task to execute + query: The query/prompt for the task + metadata: Execution metadata + """ try: - # Create execution plan with user_id - plan = await self.planner.create_plan(user_input) + # Start task execution + await self.task_manager.start_task(task.task_id) - # Stream execution results and track agent responses separately - agent_responses = defaultdict(str) # Dict[agent_name, response_content] - async for chunk in self._execute_plan(plan, user_input.meta.model_dump()): - agent_name = chunk.meta.agent_name + # Get agent connection + agent_card = await self.agent_connections.start_agent( + task.agent_name, + with_listener=False, + notification_callback=store_task_in_session, + ) + client = await self.agent_connections.get_client(task.agent_name) - # Track content by agent - agent_responses[agent_name] += chunk.content + if not client: + raise RuntimeError(f"Could not connect to agent {task.agent_name}") - yield chunk + # Configure metadata for notifications + if task.pattern != TaskPattern.ONCE: + metadata["notify"] = True - if chunk.is_final and agent_responses[agent_name].strip(): - # Save final response to session when final chunk is received - await self.session_manager.add_message( - session_id, - Role.AGENT, - agent_responses[agent_name], - agent_name=agent_name, - ) - agent_responses[agent_name] = "" + # Send message to agent + response = await client.send_message( + query, + context_id=task.session_id, + metadata=metadata, + streaming=agent_card.capabilities.streaming, + ) + + # Process streaming responses + async for remote_task, event in response: + if event is None and remote_task.status.state == TaskState.submitted: + task.remote_task_ids.append(remote_task.id) + continue + + if isinstance(event, TaskStatusUpdateEvent): + await self._handle_task_status_update(event, task) + + # TODO: Check for user input requirement + # Handle task failure + if event.status.state == TaskState.failed: + err_msg = get_message_text(event.status.message) + await self.task_manager.fail_task(task.task_id, err_msg) + yield self._create_error_message_chunk( + err_msg, task.session_id, task.user_id, task.agent_name + ) + return - # Add separate messages for each agent's complete response - for agent_name, full_response in agent_responses.items(): - if full_response.strip(): # Only save non-empty responses - await self.session_manager.add_message( - session_id, Role.AGENT, full_response, agent_name=agent_name + elif isinstance(event, TaskArtifactUpdateEvent): + yield self._create_message_chunk( + get_message_text(event.artifact, ""), + task.session_id, + task.user_id, + task.agent_name, + is_final=metadata.get("notify", False), ) - except Exception as e: - error_msg = f"Error processing request: {str(e)}" - await self.session_manager.add_message( - session_id, Role.SYSTEM, error_msg, agent_name="__system__" - ) - yield self._create_error_message_chunk( - error_msg, session_id, user_input.meta.user_id, "__system__" + # Complete task successfully + await self.task_manager.complete_task(task.task_id) + yield self._create_message_chunk( + "", + task.session_id, + task.user_id, + task.agent_name, + is_final=True, + status=MessageChunkStatus.success, ) - async def _execute_plan( + except Exception as e: + await self.task_manager.fail_task(task.task_id, str(e)) + raise e + + async def _handle_task_status_update( + self, event: TaskStatusUpdateEvent, task: Task + ): + """Handle task status update events""" + logger.info(f"Task {task.task_id} status update: {event.status.state}") + + # Add any additional status-specific handling here + if event.status.state == TaskState.submitted: + # Task was submitted successfully + pass + elif event.status.state == TaskState.completed: + # Task completed successfully + pass + + async def _save_remaining_responses(self, session_id: str, agent_responses: dict): + """Save any remaining agent responses to the session""" + for agent_name, full_response in agent_responses.items(): + if full_response.strip(): + await self.session_manager.add_message( + session_id, Role.AGENT, full_response, agent_name=agent_name + ) + + # ==================== Legacy Task Execution (No HIL Support) ==================== + + async def _execute_plan_legacy( self, plan: ExecutionPlan, metadata: dict ) -> AsyncGenerator[MessageChunk, None]: - """Execute an execution plan - streams results""" + """ + Execute an execution plan without Human-in-the-Loop support. + This is a simplified version for backwards compatibility. + """ session_id, user_id = metadata["session_id"], metadata["user_id"] + if not plan.tasks: yield self._create_error_message_chunk( "No tasks found for this request.", session_id, user_id, "__system__" ) return - # Execute tasks (simple sequential execution for now) + # Execute tasks sequentially for task in plan.tasks: try: - # Register the task with TaskManager await self.task_manager.store.save_task(task) - - # Stream task execution results with user_id context - async for chunk in self._execute_task(task, plan.query, metadata): + async for chunk in self._execute_task_legacy( + task, plan.query, metadata + ): yield chunk - except Exception as e: error_msg = f"Error executing {task.agent_name}: {str(e)}" yield self._create_error_message_chunk( error_msg, session_id, user_id, task.agent_name ) - async def _execute_task( + async def _execute_task_legacy( self, task: Task, query: str, metadata: dict ) -> AsyncGenerator[MessageChunk, None]: - """Execute a single task by calling the specified agent - streams results""" + """ + Execute a single task without user input interruption support. + This is a simplified version for backwards compatibility. + """ try: - # Start task await self.task_manager.start_task(task.task_id) - # Get agent client + # Get agent connection agent_card = await self.agent_connections.start_agent( task.agent_name, with_listener=False, notification_callback=store_task_in_session, ) client = await self.agent_connections.get_client(task.agent_name) + if not client: raise RuntimeError(f"Could not connect to agent {task.agent_name}") if task.pattern != TaskPattern.ONCE: metadata["notify"] = True + response = await client.send_message( query, context_id=task.session_id, @@ -184,23 +736,18 @@ async def _execute_task( task.remote_task_ids.append(remote_task.id) continue - if ( - isinstance(event, TaskStatusUpdateEvent) - # and event.status.state == TaskState.input_required - ): + if isinstance(event, TaskStatusUpdateEvent): logger.info(f"Task status update: {event.status.state}") + if event.status.state == TaskState.failed: err_msg = get_message_text(event.status.message) await self.task_manager.fail_task(task.task_id, err_msg) yield self._create_error_message_chunk( - err_msg, - task.session_id, - task.user_id, - task.agent_name, + err_msg, task.session_id, task.user_id, task.agent_name ) return - continue + if isinstance(event, TaskArtifactUpdateEvent): yield self._create_message_chunk( get_message_text(event.artifact, ""), @@ -222,47 +769,15 @@ async def _execute_task( ) except Exception as e: - # Fail task await self.task_manager.fail_task(task.task_id, str(e)) raise e - async def create_session(self, user_id: str, title: str = None): - """Create a new session for the user""" - return await self.session_manager.create_session(user_id, title) - - async def close_session(self, session_id: str): - """Close an existing session""" - # In a more sophisticated implementation, you might want to: - # 1. Cancel any ongoing tasks in this session - # 2. Save session metadata - # 3. Clean up resources - - # Cancel any running tasks for this session - cancelled_count = await self.task_manager.cancel_session_tasks(session_id) - - # Add a system message to mark the session as closed - await self.session_manager.add_message( - session_id, - Role.SYSTEM, - f"Session closed. {cancelled_count} tasks were cancelled.", - agent_name="orchestrator", - ) - - async def get_session_history(self, session_id: str): - """Get session message history""" - return await self.session_manager.get_session_messages(session_id) - - async def get_user_sessions(self, user_id: str, limit: int = 100, offset: int = 0): - """Get all sessions for a user""" - return await self.session_manager.list_user_sessions(user_id, limit, offset) - - async def cleanup(self): - """Cleanup resources""" - await self.agent_connections.stop_all() +# ==================== Module-level Factory Function ==================== _orchestrator = AgentOrchestrator() def get_default_orchestrator() -> AgentOrchestrator: + """Get the default singleton instance of AgentOrchestrator""" return _orchestrator diff --git a/python/valuecell/core/coordinate/planner.py b/python/valuecell/core/coordinate/planner.py index 16ae57f43..b3ad8727c 100644 --- a/python/valuecell/core/coordinate/planner.py +++ b/python/valuecell/core/coordinate/planner.py @@ -1,70 +1,221 @@ +import asyncio +import logging from datetime import datetime -from typing import List +from typing import Callable, List, Optional -from valuecell.utils import generate_uuid +from agno.agent import Agent +from agno.models.openrouter import OpenRouter +from agno.tools.user_control_flow import UserControlFlowTools from valuecell.core.agent.connect import RemoteConnections -from valuecell.core.task import Task, TaskStatus +from valuecell.core.coordinate.planner_prompts import ( + PLANNER_INSTRUCTIONS, + create_prompt_with_datetime, +) +from valuecell.core.task import Task, TaskPattern, TaskStatus from valuecell.core.types import UserInput +from valuecell.utils import generate_uuid + +from .models import ExecutionPlan, PlannerInput, PlannerResponse + +logger = logging.getLogger(__name__) + + +class UserInputRequest: + """ + Represents a request for user input during plan creation or execution. + + This class uses asyncio.Event to enable non-blocking waiting for user responses + in the Human-in-the-Loop workflow. + """ + + def __init__(self, prompt: str): + self.prompt = prompt + self.response: Optional[str] = None + self.event = asyncio.Event() -from .models import ExecutionPlan + async def wait_for_response(self) -> str: + """Wait for user response asynchronously""" + await self.event.wait() + return self.response + + def provide_response(self, response: str): + """Provide the user's response and signal completion""" + self.response = response + self.event.set() class ExecutionPlanner: - """Simple execution planner that analyzes user input and creates execution plans""" + """ + Creates execution plans by analyzing user input and determining appropriate agent tasks. + + This planner uses AI to understand user requests and break them down into + executable tasks that can be handled by specific agents. It supports + Human-in-the-Loop interactions when additional clarification is needed. + """ - def __init__(self, agent_connections: RemoteConnections): + def __init__( + self, + agent_connections: RemoteConnections, + ): self.agent_connections = agent_connections - async def create_plan(self, user_input: UserInput) -> ExecutionPlan: - """Create an execution plan from user input""" + async def create_plan( + self, user_input: UserInput, user_input_callback: Optional[Callable] = None + ) -> ExecutionPlan: + """ + Create an execution plan from user input. + Args: + user_input: The user's request to be planned + + Returns: + ExecutionPlan: A structured plan with tasks for execution + """ plan = ExecutionPlan( plan_id=generate_uuid("plan"), session_id=user_input.meta.session_id, user_id=user_input.meta.user_id, - query=user_input.query, + query=user_input.query, # Store the original query created_at=datetime.now().isoformat(), ) - # Simple planning logic - create tasks directly with user_id context - tasks = await self._analyze_input_and_create_tasks(user_input) + # Analyze input and create appropriate tasks + tasks = await self._analyze_input_and_create_tasks( + user_input, user_input_callback + ) plan.tasks = tasks return plan async def _analyze_input_and_create_tasks( - self, user_input: UserInput + self, user_input: UserInput, user_input_callback: Optional[Callable] = None ) -> List[Task]: - """Analyze user input and create tasks for appropriate agents""" - - # Check if user specified a desired agent - if user_input.has_desired_agent(): - desired_agent = user_input.get_desired_agent() - available_agents = self.agent_connections.list_available_agents() - - # If the desired agent exists, use it directly - if desired_agent in available_agents: - return [ - self._create_task( - user_input.meta.session_id, - user_input.meta.user_id, - desired_agent, - ) - ] - - raise ValueError("No suitable agent found for the request.") - - def _create_task(self, session_id: str, user_id: str, agent_name: str) -> Task: - """Create a new task for the specified agent""" + """ + Analyze user input and create tasks for appropriate agents. + + This method uses an AI agent to understand the user's request and determine + what agents should be involved and what tasks they should perform. + """ + # Create planning agent with appropriate tools and instructions + agent = Agent( + model=OpenRouter(id="openai/gpt-4o-mini"), + tools=[ + UserControlFlowTools(), + self.get_agent_card, + ], + markdown=False, + debug_mode=True, + instructions=[ + create_prompt_with_datetime(PLANNER_INSTRUCTIONS), + ], + ) + + # Execute planning with the agent + run_response = agent.run( + message=PlannerInput( + desired_agent_name=user_input.get_desired_agent(), + query=user_input.query, + ) + ) + + # Handle user input requests through Human-in-the-Loop workflow + while run_response.is_paused: + for tool in run_response.tools_requiring_user_input: + input_schema = tool.user_input_schema + + for field in input_schema: + if user_input_callback: + # Use callback for async user input + request = UserInputRequest(field.description) + await user_input_callback(request) + user_value = await request.wait_for_response() + else: + # Fallback to synchronous input for testing/simple scenarios + user_value = input(f"{field.description}: ") + + field.value = user_value + + # Continue agent execution with updated inputs + run_response = agent.continue_run( + run_id=run_response.run_id, updated_tools=run_response.tools + ) + + if not run_response.is_paused: + break + + # Parse planning result and create tasks + try: + plan_raw = PlannerResponse.model_validate_json(run_response.content) + except Exception as e: + raise ValueError( + f"Planner produced invalid JSON for PlannerResponse: {e}. " + f"Raw content: {run_response.content}" + ) from e + logger.info(f"Planner produced plan: {plan_raw}") + if not plan_raw.adequate or not plan_raw.tasks: + # If information is still inadequate, return empty task list + raise ValueError( + "Planner indicated information is inadequate or produced no tasks." + f" Reason: {plan_raw.reason}" + ) + return [ + self._create_task( + user_input.meta.session_id, + user_input.meta.user_id, + task.agent_name, + task.query, + task.pattern, + ) + for task in plan_raw.tasks + ] + + def _create_task( + self, + session_id: str, + user_id: str, + agent_name: str, + query: str, + pattern: TaskPattern = TaskPattern.ONCE, + ) -> Task: + """ + Create a new task for the specified agent. + + Args: + session_id: Session this task belongs to + user_id: User who requested this task + agent_name: Name of the agent to execute the task + query: Query/prompt for the agent + pattern: Execution pattern (once or recurring) + + Returns: + Task: Configured task ready for execution + """ return Task( task_id=generate_uuid("task"), session_id=session_id, user_id=user_id, agent_name=agent_name, status=TaskStatus.PENDING, + query=query, + pattern=pattern, ) - async def add_task(self, plan: ExecutionPlan, agent_name: str) -> None: - """Add a task to an existing plan""" - task = self._create_task(plan.session_id, plan.user_id, agent_name) - plan.tasks.append(task) + def get_agent_card(self, agent_name: str) -> str: + """ + Get the capabilities description of a specified agent by name. + + This function returns capability information for agents that can be used + in the planning process to determine if an agent is suitable for a task. + + Args: + agent_name: The name of the agent whose capabilities are to be retrieved + + Returns: + str: A description of the agent's capabilities and supported operations + """ + self.agent_connections.list_remote_agents() + if card := self.agent_connections.get_remote_agent_card(agent_name): + # Note: Returning a plain string for now; consider structured return in future + return str(card) + + return "The requested agent could not be found or is not available." diff --git a/python/valuecell/core/coordinate/planner_prompts.py b/python/valuecell/core/coordinate/planner_prompts.py new file mode 100644 index 000000000..534ca236f --- /dev/null +++ b/python/valuecell/core/coordinate/planner_prompts.py @@ -0,0 +1,51 @@ +from datetime import datetime +from textwrap import dedent + + +def create_prompt_with_datetime(base_prompt: str) -> str: + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + return dedent( + f""" + {base_prompt} + + **Other Important Context** + - Current date and time: {now} + """ + ) + + +# noqa: E501 +PLANNER_INSTRUCTIONS = """ +You are an AI Agent execution planner. Your role is to analyze user requests and create executable task plans using available agents. + +**Process:** +1. Call `get_agent_card` with the desired agent name to understand its capabilities +2. Analyze the user input for completeness and clarity +3. If information is insufficient or unclear, call `get_user_input` for clarification +4. Generate a structured execution plan when sufficient information is available + +**Task Pattern:** +- **ONCE**: Single execution with immediate results (default for most requests) +- **RECURRING**: Periodic execution with scheduled updates (for tracking, monitoring, notifications, or ongoing updates) + +**Guidelines:** +- Accept stock symbols as provided unless obviously ambiguous +- Ask only one clarification question at a time +- Wait for user response before asking additional questions +- Generate clear, specific prompts suitable for AI model execution +- Output must be valid JSON following the Response Format +- Output will be parsed programmatically, so ensure strict adherence to the format and do not include any extra text + +**Response Format:** +{ + "tasks": [ + { + "query": "Clear, specific task description", + "agent_name": "target_agent_name", + "pattern": "once" | "recurring" + } + ], + "adequate": boolean, # true if information is adequate for task execution, false if more input is needed + "reason": "Explanation of planning decision" +} +""" diff --git a/python/valuecell/core/task/__init__.py b/python/valuecell/core/task/__init__.py index 4582ad8f3..ea768f605 100644 --- a/python/valuecell/core/task/__init__.py +++ b/python/valuecell/core/task/__init__.py @@ -1,12 +1,13 @@ """Task module initialization""" from .manager import TaskManager, get_default_task_manager -from .models import Task, TaskStatus +from .models import Task, TaskStatus, TaskPattern from .store import InMemoryTaskStore, TaskStore __all__ = [ "Task", "TaskStatus", + "TaskPattern", "TaskManager", "TaskStore", "InMemoryTaskStore", diff --git a/python/valuecell/core/task/models.py b/python/valuecell/core/task/models.py index 76bae0033..66718258b 100644 --- a/python/valuecell/core/task/models.py +++ b/python/valuecell/core/task/models.py @@ -31,6 +31,7 @@ class Task(BaseModel): default_factory=list, description="Task identifier determined by the remote agent after submission", ) + query: str = Field(..., description="The task to be performed") session_id: str = Field(..., description="Session ID this task belongs to") user_id: str = Field(..., description="User ID who initiated this task") agent_name: str = Field(..., description="Name of the agent executing this task") diff --git a/python/valuecell/core/types.py b/python/valuecell/core/types.py index f00f16e08..f80a5ee74 100644 --- a/python/valuecell/core/types.py +++ b/python/valuecell/core/types.py @@ -132,7 +132,6 @@ async def stream( """ raise NotImplementedError - @abstractmethod async def notify( self, query: str, session_id: str, task_id: str ) -> AsyncGenerator[NotifyResponse, None]: From 562933f9111bf3592bd78be48947b24be3ae986a Mon Sep 17 00:00:00 2001 From: Zhaofeng Zhang <24791380+vcfgv@users.noreply.github.com> Date: Fri, 19 Sep 2025 14:19:57 +0800 Subject: [PATCH 5/6] feat: add AGENT_DEBUG_MODE configuration and update ExecutionPlanner to utilize it --- python/.env.example | 4 +- python/valuecell/core/coordinate/planner.py | 55 +- .../coordinate/tests/test_orchestrator.py | 1396 +++-------------- 3 files changed, 307 insertions(+), 1148 deletions(-) diff --git a/python/.env.example b/python/.env.example index 240bb76f8..2972c2341 100644 --- a/python/.env.example +++ b/python/.env.example @@ -22,11 +22,13 @@ API_HOST=localhost API_PORT=8000 API_DEBUG=false - # Database settings DB_CHARSET=utf8mb4 DB_COLLATION=utf8mb4_unicode_ci +# Agent Settings +AGENT_DEBUG_MODE=false + # SEC Agent Configuration # Email address for SEC API requests (required by SEC) SEC_EMAIL=your.name@example.com diff --git a/python/valuecell/core/coordinate/planner.py b/python/valuecell/core/coordinate/planner.py index b3ad8727c..07b89b142 100644 --- a/python/valuecell/core/coordinate/planner.py +++ b/python/valuecell/core/coordinate/planner.py @@ -1,8 +1,10 @@ import asyncio import logging +import os from datetime import datetime from typing import Callable, List, Optional +from a2a.types import AgentCard from agno.agent import Agent from agno.models.openrouter import OpenRouter from agno.tools.user_control_flow import UserControlFlowTools @@ -101,10 +103,10 @@ async def _analyze_input_and_create_tasks( model=OpenRouter(id="openai/gpt-4o-mini"), tools=[ UserControlFlowTools(), - self.get_agent_card, + self.tool_get_agent_description, ], markdown=False, - debug_mode=True, + debug_mode=os.getenv("AGENT_DEBUG_MODE", "false").lower() == "true", instructions=[ create_prompt_with_datetime(PLANNER_INSTRUCTIONS), ], @@ -200,7 +202,7 @@ def _create_task( pattern=pattern, ) - def get_agent_card(self, agent_name: str) -> str: + def tool_get_agent_description(self, agent_name: str) -> str: """ Get the capabilities description of a specified agent by name. @@ -215,7 +217,50 @@ def get_agent_card(self, agent_name: str) -> str: """ self.agent_connections.list_remote_agents() if card := self.agent_connections.get_remote_agent_card(agent_name): - # Note: Returning a plain string for now; consider structured return in future - return str(card) + return agentcard_to_prompt(card) return "The requested agent could not be found or is not available." + + +def agentcard_to_prompt(card: AgentCard): + """ + Convert AgentCard JSON structure to LLM-friendly prompt string. + + Args: + agentcard (AgentCard): The agentcard JSON structure + + Returns: + str: Formatted prompt string for LLM processing + """ + + # Start with basic agent information + prompt = f"# Agent: {card.name}\n\n" + + # Add description + prompt += f"**Description:** {card.description}\n\n" + + # Add skills section + if card.skills: + prompt += "## Available Skills\n\n" + + for i, skill in enumerate(card.skills, 1): + prompt += f"### {i}. {skill.name} (`{skill.id}`)\n\n" + prompt += f"**Description:** {skill.description}\n\n" + + # Add examples if available + if skill.examples: + prompt += "**Examples:**\n" + for example in skill.examples: + prompt += f"- {example}\n" + prompt += "\n" + + # Add tags if available + if skill.tags: + tags_str = ", ".join([f"`{tag}`" for tag in skill.tags]) + prompt += f"**Tags:** {tags_str}\n\n" + + # Add separator between skills (except for the last one) + if i < len(card.skills): + prompt += "---\n\n" + + return prompt.strip() diff --git a/python/valuecell/core/coordinate/tests/test_orchestrator.py b/python/valuecell/core/coordinate/tests/test_orchestrator.py index a1add2029..d9a3003c1 100644 --- a/python/valuecell/core/coordinate/tests/test_orchestrator.py +++ b/python/valuecell/core/coordinate/tests/test_orchestrator.py @@ -1,139 +1,85 @@ """ -Comprehensive pytest tests for AgentOrchestrator. - -This test suite covers the 2x2 matrix of agent capabilities: -- streaming: True/False -- push_notifications: True/False - -Test coverage includes: -- Core flow processing with different agent capabilities -- Session management and message handling -- Task lifecycle management -- Error handling and edge cases -- Metadata propagation -- Resource management +Lean pytest tests for AgentOrchestrator. + +Focus on essential behavior without over-engineering: +- Happy path (streaming and non-streaming) +- Planner error and agent connection error +- Session create/close and cleanup """ +from types import SimpleNamespace +from typing import Any, AsyncGenerator from unittest.mock import AsyncMock, Mock -from typing import AsyncGenerator, Any import pytest from a2a.types import ( AgentCapabilities, AgentCard, AgentSkill, - TaskState, - TaskStatusUpdateEvent, + Artifact, + Part, TaskArtifactUpdateEvent, + TaskState, TaskStatus, - Part, + TaskStatusUpdateEvent, TextPart, - Artifact, ) -from valuecell.core.coordinate.orchestrator import AgentOrchestrator from valuecell.core.coordinate.models import ExecutionPlan -from valuecell.core.session import Role +from valuecell.core.coordinate.orchestrator import AgentOrchestrator +from valuecell.core.session import SessionStatus from valuecell.core.task import Task, TaskStatus as CoreTaskStatus -from valuecell.core.types import ( - UserInput, - UserInputMetadata, - MessageChunk, - MessageDataKind, -) +from valuecell.core.types import UserInput, UserInputMetadata + + +# ------------------------- +# Fixtures +# ------------------------- @pytest.fixture def session_id() -> str: - """Sample session ID for testing.""" return "test-session-123" @pytest.fixture def user_id() -> str: - """Sample user ID for testing.""" return "test-user-456" @pytest.fixture def sample_query() -> str: - """Sample user query for testing.""" return "What is the latest stock price for AAPL?" @pytest.fixture -def user_input_metadata(session_id: str, user_id: str) -> UserInputMetadata: - """Sample user input metadata.""" - return UserInputMetadata(session_id=session_id, user_id=user_id) - - -@pytest.fixture -def sample_user_input( - sample_query: str, user_input_metadata: UserInputMetadata -) -> UserInput: - """Sample user input for testing.""" +def sample_user_input(session_id: str, user_id: str, sample_query: str) -> UserInput: return UserInput( - query=sample_query, desired_agent_name="TestAgent", meta=user_input_metadata + query=sample_query, + desired_agent_name="TestAgent", + meta=UserInputMetadata(session_id=session_id, user_id=user_id), ) @pytest.fixture -def sample_task(session_id: str, user_id: str) -> Task: - """Sample task for testing.""" +def sample_task(session_id: str, user_id: str, sample_query: str) -> Task: return Task( - task_id="test-task-789", + task_id="task-1", session_id=session_id, user_id=user_id, agent_name="TestAgent", + query=sample_query, status=CoreTaskStatus.PENDING, remote_task_ids=[], ) -@pytest.fixture( - params=[ - (True, True), # streaming + push_notifications - (True, False), # streaming only - (False, True), # push_notifications only - (False, False), # basic agent - ] -) -def agent_capabilities(request) -> AgentCapabilities: - """Parametrized fixture for different agent capability combinations.""" - streaming, push_notifications = request.param - return AgentCapabilities(streaming=streaming, push_notifications=push_notifications) - - -@pytest.fixture -def mock_agent_card(agent_capabilities: AgentCapabilities) -> AgentCard: - """Mock agent card with different capabilities.""" - return AgentCard( - name="TestAgent", - description="Test agent for unit testing", - url="http://localhost:8000/", - version="1.0.0", - default_input_modes=["text"], - default_output_modes=["text"], - capabilities=agent_capabilities, - skills=[ - AgentSkill( - id="test_skill_1", - name="test_skill", - description="Test skill", - tags=["test", "demo"], - ) - ], - supports_authenticated_extended_card=False, - ) - - @pytest.fixture -def sample_execution_plan( +def sample_plan( session_id: str, user_id: str, sample_query: str, sample_task: Task ) -> ExecutionPlan: - """Sample execution plan with one task.""" return ExecutionPlan( - plan_id="test-plan-123", + plan_id="plan-1", session_id=session_id, user_id=user_id, query=sample_query, @@ -142,1104 +88,270 @@ def sample_execution_plan( ) -@pytest.fixture -def mock_session_manager() -> Mock: - """Mock session manager.""" - mock = Mock() - mock.add_message = AsyncMock() - mock.create_session = AsyncMock(return_value="new-session-id") - mock.get_session_messages = AsyncMock(return_value=[]) - mock.list_user_sessions = AsyncMock(return_value=[]) - mock.session_exists = AsyncMock(return_value=True) - return mock +def _stub_session(status: Any = SessionStatus.ACTIVE): + # Minimal session stub with status and basic methods used by orchestrator + s = SimpleNamespace(status=status) + def activate(): + s.status = SessionStatus.ACTIVE -@pytest.fixture -def mock_task_manager() -> Mock: - """Mock task manager.""" - mock = Mock() - mock.store = Mock() - mock.store.save_task = AsyncMock() - mock.start_task = AsyncMock() - mock.complete_task = AsyncMock() - mock.fail_task = AsyncMock() - mock.cancel_session_tasks = AsyncMock(return_value=0) - return mock + def require_user_input(): + s.status = SessionStatus.REQUIRE_USER_INPUT - -@pytest.fixture -def mock_agent_client() -> Mock: - """Mock agent client for different response types.""" - mock = Mock() - mock.send_message = AsyncMock() - return mock + s.activate = activate + s.require_user_input = require_user_input + return s @pytest.fixture -def mock_agent_connections(mock_agent_card: AgentCard, mock_agent_client: Mock) -> Mock: - """Mock agent connections.""" - mock = Mock() - mock.start_agent = AsyncMock(return_value=mock_agent_card) - mock.get_client = AsyncMock(return_value=mock_agent_client) - mock.list_available_agents = Mock(return_value=["TestAgent"]) - mock.stop_all = AsyncMock() - return mock +def mock_session_manager() -> Mock: + m = Mock() + m.add_message = AsyncMock() + m.create_session = AsyncMock(return_value="new-session-id") + m.get_session_messages = AsyncMock(return_value=[]) + m.list_user_sessions = AsyncMock(return_value=[]) + m.get_session = AsyncMock(return_value=_stub_session()) + m.update_session = AsyncMock() + return m @pytest.fixture -def mock_planner(sample_execution_plan: ExecutionPlan) -> Mock: - """Mock execution planner.""" - mock = Mock() - mock.create_plan = AsyncMock(return_value=sample_execution_plan) - return mock +def mock_task_manager() -> Mock: + m = Mock() + m.store = Mock() + m.store.save_task = AsyncMock() + m.start_task = AsyncMock() + m.complete_task = AsyncMock() + m.fail_task = AsyncMock() + m.cancel_session_tasks = AsyncMock(return_value=0) + return m @pytest.fixture -def orchestrator( - mock_session_manager: Mock, - mock_task_manager: Mock, - mock_agent_connections: Mock, - mock_planner: Mock, -) -> AgentOrchestrator: - """AgentOrchestrator instance with mocked dependencies.""" - orchestrator = AgentOrchestrator() - orchestrator.session_manager = mock_session_manager - orchestrator.task_manager = mock_task_manager - orchestrator.agent_connections = mock_agent_connections - orchestrator.planner = mock_planner - return orchestrator - - -def create_mock_remote_task(task_id: str = "remote-task-123") -> Mock: - """Create a mock remote task.""" - remote_task = Mock() - remote_task.id = task_id - remote_task.status = Mock() - remote_task.status.state = TaskState.submitted - return remote_task - - -async def create_streaming_response( - content_chunks: list[str], remote_task_id: str = "remote-task-123" -) -> AsyncGenerator[tuple[Mock, Any], None]: - """Create a mock streaming response.""" - remote_task = create_mock_remote_task(remote_task_id) - - # First yield the task submission with None event (matching new logic) - yield remote_task, None - - # Then yield content chunks - for i, chunk in enumerate(content_chunks): - # Create proper Artifact with Part and TextPart - text_part = TextPart(text=chunk) - part = Part(root=text_part) - artifact = Artifact(artifactId=f"test-artifact-{i}", parts=[part]) - - artifact_event = TaskArtifactUpdateEvent( - artifact=artifact, - contextId="test-context", - taskId=remote_task_id, - final=False, - ) - yield remote_task, artifact_event - - -async def create_non_streaming_response( - content: str, remote_task_id: str = "remote-task-123" -) -> AsyncGenerator[tuple[Mock, Any], None]: - """Create a mock non-streaming response.""" - remote_task = create_mock_remote_task(remote_task_id) - - # First yield the task submission with None event - yield remote_task, None - - # For non-streaming, just yield a final status update - yield ( - remote_task, - TaskStatusUpdateEvent( - status=TaskStatus(state=TaskState.completed), - contextId="test-context", - taskId=remote_task_id, - final=True, - ), +def mock_agent_card_streaming() -> AgentCard: + return AgentCard( + name="TestAgent", + description="", + url="http://localhost", + version="1.0", + default_input_modes=["text"], + default_output_modes=["text"], + capabilities=AgentCapabilities(streaming=True, push_notifications=False), + skills=[AgentSkill(id="s1", name="n", description="d", tags=[])], + supports_authenticated_extended_card=False, ) -async def create_failed_response( - error_message: str, remote_task_id: str = "remote-task-123" -) -> AsyncGenerator[tuple[Mock, Any], None]: - """Create a mock failed response.""" - remote_task = create_mock_remote_task(remote_task_id) - - # First yield the task submission with None event - yield remote_task, None - - # Then yield a failed status update - yield ( - remote_task, - TaskStatusUpdateEvent( - status=TaskStatus(state=TaskState.failed, message=error_message), - contextId="test-context", - taskId=remote_task_id, - final=True, - ), +@pytest.fixture +def mock_agent_card_non_streaming() -> AgentCard: + return AgentCard( + name="TestAgent", + description="", + url="http://localhost", + version="1.0", + default_input_modes=["text"], + default_output_modes=["text"], + capabilities=AgentCapabilities(streaming=False, push_notifications=False), + skills=[AgentSkill(id="s1", name="n", description="d", tags=[])], + supports_authenticated_extended_card=False, ) -class TestCoreFlow: - """Test core orchestrator flow with different agent capabilities.""" - - @pytest.mark.asyncio - async def test_process_user_input_success( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - mock_agent_card: AgentCard, - session_id: str, - user_id: str, - sample_query: str, - ): - """Test successful user input processing with different agent capabilities.""" - # Setup mock responses based on agent capabilities - if mock_agent_card.capabilities.streaming: - mock_response = create_streaming_response(["Hello", " World", "!"]) - else: - mock_response = create_non_streaming_response("Hello World!") - - mock_agent_client.send_message.return_value = mock_response - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify session messages - orchestrator.session_manager.add_message.assert_any_call( - session_id, Role.USER, sample_query - ) - - # Verify task operations - orchestrator.task_manager.store.save_task.assert_called_once() - orchestrator.task_manager.start_task.assert_called_once() - - # Verify agent interactions - orchestrator.agent_connections.start_agent.assert_called_once() - orchestrator.agent_connections.get_client.assert_called_once_with("TestAgent") - - # Verify send_message call with correct streaming parameter - expected_streaming = mock_agent_card.capabilities.streaming - mock_agent_client.send_message.assert_called_once() - call_args = mock_agent_client.send_message.call_args - assert call_args.kwargs["streaming"] == expected_streaming - - # Verify chunks based on agent capabilities - if mock_agent_card.capabilities.streaming: - # Streaming agents should produce content chunks plus a final empty chunk - assert len(chunks) >= 1 - for chunk in chunks: - assert isinstance(chunk, MessageChunk) - assert chunk.kind == MessageDataKind.TEXT - assert chunk.meta.session_id == session_id - assert chunk.meta.user_id == user_id - - # The last chunk should be final and empty (task completion marker) - final_chunk = chunks[-1] - assert final_chunk.is_final is True - assert final_chunk.content == "" - else: - # Non-streaming agents should still produce a final completion chunk - assert len(chunks) >= 1 - final_chunk = chunks[-1] - assert final_chunk.is_final is True - assert final_chunk.content == "" - - @pytest.mark.asyncio - async def test_streaming_agent_chunk_processing( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - mock_agent_card: AgentCard, - session_id: str, - user_id: str, - ): - """Test streaming agent chunk processing specifically.""" - # Skip test for non-streaming agents or push notification agents - if ( - not mock_agent_card.capabilities.streaming - or mock_agent_card.capabilities.push_notifications - ): - pytest.skip("Test only for streaming agents without push notifications") - - # Setup streaming response - content_chunks = ["Hello", " from", " streaming", " agent!"] - mock_agent_client.send_message.return_value = create_streaming_response( - content_chunks - ) - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify we got chunks (content chunks + final empty chunk) - assert len(chunks) >= len(content_chunks) + 1 - - # Verify chunk content and metadata - content_received = [] - for chunk in chunks[:-1]: # Exclude the final empty chunk - assert isinstance(chunk, MessageChunk) - assert chunk.kind == MessageDataKind.TEXT - assert chunk.meta.session_id == session_id - assert chunk.meta.user_id == user_id - content_received.append(chunk.content) - - # Verify final chunk is empty and marked as final - final_chunk = chunks[-1] - assert final_chunk.is_final is True - assert final_chunk.content == "" - - # Verify all content was received - full_content = "".join(content_received) - assert "Hello from streaming agent!" in full_content - - @pytest.mark.asyncio - async def test_non_push_notification_agent_processing( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - mock_agent_card: AgentCard, - ): - """Test that non-push notification agents continue with normal processing.""" - # Skip test for push notification agents - if mock_agent_card.capabilities.push_notifications: - pytest.skip("Test only for non-push notification agents") - - # Setup response based on streaming capability - if mock_agent_card.capabilities.streaming: - mock_agent_client.send_message.return_value = create_streaming_response( - ["Processing", " normally"] - ) - else: - mock_agent_client.send_message.return_value = create_non_streaming_response( - "Processing normally" - ) - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify normal processing continues for non-push notification agents - # All agents should now produce at least one final chunk - assert len(chunks) >= 1 - - # The final chunk should be empty and marked as final - final_chunk = chunks[-1] - assert final_chunk.is_final is True - assert final_chunk.content == "" - - # Task should be completed - orchestrator.task_manager.complete_task.assert_called_once() - - if mock_agent_card.capabilities.streaming: - # Streaming agents should produce content chunks + final chunk - assert len(chunks) >= 2 # At least content chunks + final chunk - - @pytest.mark.asyncio - async def test_push_notifications_early_return( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - mock_agent_card: AgentCard, - ): - """Test that push notification agents return early.""" - # Skip test for non-push notification agents - if not mock_agent_card.capabilities.push_notifications: - pytest.skip("Test only for push notification agents") - - # Setup response - mock_agent_client.send_message.return_value = create_streaming_response( - ["Should not be processed"] - ) - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # For push notification agents, no chunks should be yielded from streaming - # since they return early and rely on notifications - # The only chunks should be final session messages - - # Verify agent is started with notification callback - orchestrator.agent_connections.start_agent.assert_called_once() - call_args = orchestrator.agent_connections.start_agent.call_args - assert "notification_callback" in call_args.kwargs - - -class TestSessionManagement: - """Test session management functionality.""" +@pytest.fixture +def mock_agent_client() -> Mock: + c = Mock() + c.send_message = AsyncMock() + return c - @pytest.mark.asyncio - async def test_create_session(self, orchestrator: AgentOrchestrator, user_id: str): - """Test session creation.""" - session_id = await orchestrator.create_session(user_id, "Test Session") - orchestrator.session_manager.create_session.assert_called_once_with( - user_id, "Test Session" - ) - assert session_id == "new-session-id" - - @pytest.mark.asyncio - async def test_close_session( - self, orchestrator: AgentOrchestrator, session_id: str - ): - """Test session closure with task cancellation.""" - orchestrator.task_manager.cancel_session_tasks.return_value = 2 +@pytest.fixture +def mock_planner(sample_plan: ExecutionPlan) -> Mock: + p = Mock() + p.create_plan = AsyncMock(return_value=sample_plan) + return p - await orchestrator.close_session(session_id) - orchestrator.task_manager.cancel_session_tasks.assert_called_once_with( - session_id - ) - orchestrator.session_manager.add_message.assert_called_once() - - # Verify system message was added - call_args = orchestrator.session_manager.add_message.call_args - assert call_args[0][1] == Role.SYSTEM # Role - assert "2 tasks were cancelled" in call_args[0][2] # Message content - - @pytest.mark.asyncio - async def test_get_session_history( - self, orchestrator: AgentOrchestrator, session_id: str - ): - """Test getting session history.""" - await orchestrator.get_session_history(session_id) - - orchestrator.session_manager.get_session_messages.assert_called_once_with( - session_id - ) +@pytest.fixture +def orchestrator( + mock_session_manager: Mock, mock_task_manager: Mock, mock_planner: Mock +) -> AgentOrchestrator: + o = AgentOrchestrator() + o.session_manager = mock_session_manager + o.task_manager = mock_task_manager + o.planner = mock_planner + return o - @pytest.mark.asyncio - async def test_get_user_sessions( - self, orchestrator: AgentOrchestrator, user_id: str - ): - """Test getting user sessions with pagination.""" - await orchestrator.get_user_sessions(user_id, limit=50, offset=10) - orchestrator.session_manager.list_user_sessions.assert_called_once_with( - user_id, 50, 10 - ) +# ------------------------- +# Helpers +# ------------------------- - @pytest.mark.asyncio - async def test_session_message_lifecycle( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - session_id: str, - sample_query: str, - ): - """Test that user and agent messages are properly added to session.""" - # Setup mock response - mock_agent_client.send_message.return_value = create_streaming_response( - ["Response"] - ) - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify user message was added first - calls = orchestrator.session_manager.add_message.call_args_list - assert len(calls) >= 2 - - # First call should be user message - user_call = calls[0] - assert user_call[0][0] == session_id - assert user_call[0][1] == Role.USER - assert user_call[0][2] == sample_query - - # Last call should be agent response - agent_call = calls[-1] - assert agent_call[0][0] == session_id - assert agent_call[0][1] == Role.AGENT - - -class TestTaskManagement: - """Test task lifecycle management.""" - - @pytest.mark.asyncio - async def test_task_lifecycle_success( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - mock_agent_card: AgentCard, - sample_task: Task, - ): - """Test successful task lifecycle: register -> start -> complete.""" - # Setup response based on agent capabilities - if mock_agent_card.capabilities.streaming: - mock_agent_client.send_message.return_value = create_streaming_response( - ["Done"] - ) - else: - mock_agent_client.send_message.return_value = create_non_streaming_response( - "Done" +def _make_streaming_response( + chunks: list[str], remote_task_id: str = "rt-1" +) -> AsyncGenerator[tuple[Mock, Any], None]: + async def gen(): + rt = Mock() + rt.id = remote_task_id + rt.status = Mock(state=TaskState.submitted) + # First yield submission with None event + yield rt, None + for i, text in enumerate(chunks): + part = Part(root=TextPart(text=text)) + artifact = Artifact(artifactId=f"a-{i}", parts=[part]) + yield ( + rt, + TaskArtifactUpdateEvent( + artifact=artifact, + contextId="ctx", + taskId=remote_task_id, + final=False, + ), ) - # Execute - async for _ in orchestrator.process_user_input(sample_user_input): - pass - - # Verify task lifecycle calls - orchestrator.task_manager.store.save_task.assert_called_once() - orchestrator.task_manager.start_task.assert_called_once() - orchestrator.task_manager.complete_task.assert_called_once() - - # Verify agent connections - orchestrator.agent_connections.start_agent.assert_called_once() - start_agent_call = orchestrator.agent_connections.start_agent.call_args - assert start_agent_call.kwargs["with_listener"] is False - assert "notification_callback" in start_agent_call.kwargs - - @pytest.mark.asyncio - async def test_task_failure_handling( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - ): - """Test task failure handling with proper cleanup.""" - # Setup a failed response - error_message = "Task processing failed" - mock_agent_client.send_message.return_value = create_failed_response( - error_message - ) - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify task failure was handled - orchestrator.task_manager.start_task.assert_called_once() - orchestrator.task_manager.fail_task.assert_called_once() - - # Verify error message was yielded - error_chunks = [chunk for chunk in chunks if error_message in chunk.content] - assert len(error_chunks) >= 1 - assert error_chunks[0].is_final is True - - @pytest.mark.asyncio - async def test_remote_task_id_tracking( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - sample_task: Task, - ): - """Test that remote task IDs are properly tracked.""" - remote_task_id = "test-remote-task-456" - mock_agent_client.send_message.return_value = create_streaming_response( - ["Content"], remote_task_id - ) - - # Execute - async for _ in orchestrator.process_user_input(sample_user_input): - pass - - # Verify remote task ID was tracked - # Note: In the actual test, we'd need to inspect the task object - # This is more of an integration test aspect - orchestrator.task_manager.store.save_task.assert_called_once() - orchestrator.task_manager.start_task.assert_called_once() - orchestrator.task_manager.complete_task.assert_called_once() - - # Verify task was started before completion - start_call = orchestrator.task_manager.start_task.call_args - complete_call = orchestrator.task_manager.complete_task.call_args - assert start_call[0][0] == complete_call[0][0] # Same task_id - - -class TestErrorHandling: - """Test error handling scenarios.""" - - @pytest.mark.asyncio - async def test_planner_error( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - session_id: str, - ): - """Test error handling when planner fails.""" - # Setup planner to fail - orchestrator.planner.create_plan.side_effect = ValueError("Planning failed") - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify error handling - assert len(chunks) == 1 - assert "(Error)" in chunks[0].content - assert "Planning failed" in chunks[0].content - - # Verify error message added to session - orchestrator.session_manager.add_message.assert_any_call( - session_id, Role.SYSTEM, "Error processing request: Planning failed" - ) - - @pytest.mark.asyncio - async def test_agent_connection_error( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_connections: Mock, - ): - """Test error handling when agent connection fails.""" - # Setup agent connections to fail - mock_agent_connections.get_client.return_value = None - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify error was yielded - error_chunks = [c for c in chunks if "(Error)" in c.content] - assert len(error_chunks) >= 1 - assert "Could not connect to agent" in error_chunks[0].content - - @pytest.mark.asyncio - async def test_empty_execution_plan( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - session_id: str, - user_id: str, - ): - """Test handling of empty execution plan.""" - # Setup empty execution plan - from valuecell.core.coordinate.models import ExecutionPlan - - empty_plan = ExecutionPlan( - plan_id="empty-plan", - session_id=session_id, - user_id=user_id, - query=sample_user_input.query, - tasks=[], - created_at="2025-09-16T10:00:00", - ) - orchestrator.planner.create_plan.return_value = empty_plan - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify appropriate message was yielded - assert len(chunks) >= 1 - assert "No tasks found for this request" in chunks[0].content - - -class TestEdgeCases: - """Test edge cases and boundary conditions.""" - - @pytest.mark.asyncio - async def test_metadata_propagation( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - session_id: str, - user_id: str, - ): - """Test that metadata is properly propagated through the system.""" - mock_agent_client.send_message.return_value = create_streaming_response( - ["Test"] - ) - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify metadata in all chunks - for chunk in chunks: - assert chunk.meta.session_id == session_id - assert chunk.meta.user_id == user_id - - # Verify metadata passed to agent - mock_agent_client.send_message.assert_called_once() - call_args = mock_agent_client.send_message.call_args - metadata = call_args.kwargs["metadata"] - assert metadata["session_id"] == session_id - assert metadata["user_id"] == user_id - - @pytest.mark.asyncio - async def test_cleanup_resources(self, orchestrator: AgentOrchestrator): - """Test resource cleanup.""" - await orchestrator.cleanup() - - # Verify agent connections are stopped - orchestrator.agent_connections.stop_all.assert_called_once() - - -class TestIntegration: - """Integration tests that test component interactions.""" - - @pytest.mark.asyncio - async def test_full_flow_integration( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - mock_agent_card: AgentCard, - session_id: str, - user_id: str, - sample_query: str, - ): - """Test the complete flow from user input to response.""" - # Setup streaming response - content_chunks = ["Integration", " test", " response"] - mock_agent_client.send_message.return_value = create_streaming_response( - content_chunks - ) - - # Execute the full flow - all_chunks = [] - full_response = "" - async for chunk in orchestrator.process_user_input(sample_user_input): - all_chunks.append(chunk) - full_response += chunk.content - - # Verify the complete flow - # 1. User message added to session - orchestrator.session_manager.add_message.assert_any_call( - session_id, Role.USER, sample_query - ) - - # 2. Plan created - orchestrator.planner.create_plan.assert_called_once_with(sample_user_input) - - # 3. Task saved and started - orchestrator.task_manager.store.save_task.assert_called_once() - orchestrator.task_manager.start_task.assert_called_once() + return gen() - # 4. Agent started and message sent - orchestrator.agent_connections.start_agent.assert_called_once() - mock_agent_client.send_message.assert_called_once() - # 5. Task completed - if not mock_agent_card.capabilities.push_notifications: - orchestrator.task_manager.complete_task.assert_called_once() - - # 6. Final response added to session - orchestrator.session_manager.add_message.assert_any_call( - session_id, Role.AGENT, full_response - ) - - # 7. Verify response content - if mock_agent_card.capabilities.streaming: - content_received = "".join( - [chunk.content for chunk in all_chunks[:-1]] - ) # Exclude final empty chunk - assert "Integration test response" in content_received - - @pytest.mark.asyncio - async def test_task_failed_status_handling( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - ): - """Test handling of TaskState.failed status updates.""" - error_message = "Remote task failed" - mock_agent_client.send_message.return_value = create_failed_response( - error_message - ) - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify task was marked as failed - orchestrator.task_manager.fail_task.assert_called_once() - fail_call_args = orchestrator.task_manager.fail_task.call_args - assert ( - error_message in fail_call_args[0][1] - ) # Error message passed to fail_task - - # Verify error message was yielded to user - error_chunks = [c for c in chunks if error_message in c.content and c.is_final] - assert len(error_chunks) >= 1 - - @pytest.mark.asyncio - async def test_agent_start_with_correct_parameters( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_client: Mock, - ): - """Test that agent is started with correct parameters.""" - mock_agent_client.send_message.return_value = create_streaming_response( - ["Test"] - ) - - # Execute - async for _ in orchestrator.process_user_input(sample_user_input): - pass - - # Verify agent started with correct parameters - orchestrator.agent_connections.start_agent.assert_called_once() - call_args = orchestrator.agent_connections.start_agent.call_args - - # Check the new parameters - assert call_args.kwargs["with_listener"] is False - assert "notification_callback" in call_args.kwargs - assert call_args.kwargs["notification_callback"] is not None - - @pytest.mark.asyncio - async def test_agent_connection_error( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - mock_agent_connections: Mock, - ): - """Test error handling when agent connection fails.""" - # Setup agent connection to fail - orchestrator.agent_connections.get_client.return_value = None - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify error was handled - error_chunks = [c for c in chunks if "(Error)" in c.content] - assert len(error_chunks) >= 1 - assert "Could not connect to agent" in error_chunks[0].content - - @pytest.mark.asyncio - async def test_empty_execution_plan( - self, - orchestrator: AgentOrchestrator, - sample_user_input: UserInput, - session_id: str, - user_id: str, - ): - """Test handling of empty execution plan.""" - # Setup empty plan - empty_plan = ExecutionPlan( - plan_id="empty-plan", - session_id=session_id, - user_id=user_id, - query=sample_user_input.query, - tasks=[], - created_at="2025-09-16T10:00:00", - ) - orchestrator.planner.create_plan.return_value = empty_plan - - # Execute - chunks = [] - async for chunk in orchestrator.process_user_input(sample_user_input): - chunks.append(chunk) - - # Verify appropriate message - assert len(chunks) == 1 - assert "No tasks found for this request" in chunks[0].content - - -class TestConcurrency: - """Test concurrent access scenarios.""" - - @pytest.mark.asyncio - async def test_concurrent_requests_same_agent( - self, - orchestrator: AgentOrchestrator, - mock_agent_client: Mock, - mock_agent_card: AgentCard, - session_id: str, - user_id: str, - ): - """Comprehensive test for concurrent process_user_input calls with same agent_name.""" - import asyncio - - # Create two different user inputs but with same agent_name - user_input_1 = UserInput( - query="First request to TestAgent", - desired_agent_name="TestAgent", - meta=UserInputMetadata( - session_id=f"{session_id}_1", user_id=f"{user_id}_1" - ), - ) - - user_input_2 = UserInput( - query="Second request to TestAgent", - desired_agent_name="TestAgent", - meta=UserInputMetadata( - session_id=f"{session_id}_2", user_id=f"{user_id}_2" - ), - ) - - # Setup mock responses for both requests - response_1 = create_streaming_response( - ["Response from request 1"], "remote-task-1" - ) - response_2 = create_streaming_response( - ["Response from request 2"], "remote-task-2" - ) - - # Use side_effect to return different responses for each call - mock_agent_client.send_message.side_effect = [response_1, response_2] - - # Track agent start calls to verify concurrent handling - start_agent_call_count = 0 - original_start_agent = orchestrator.agent_connections.start_agent - - async def track_start_agent(*args, **kwargs): - nonlocal start_agent_call_count - start_agent_call_count += 1 - # Remove the artificial delay - use asyncio.sleep(0) to yield control - await asyncio.sleep(0) # Just yield control, no actual delay - return await original_start_agent(*args, **kwargs) - - orchestrator.agent_connections.start_agent.side_effect = track_start_agent - - # Execute both requests concurrently - async def process_request_1(): - chunks_1 = [] - async for chunk in orchestrator.process_user_input(user_input_1): - chunks_1.append(chunk) - return chunks_1 - - async def process_request_2(): - chunks_2 = [] - async for chunk in orchestrator.process_user_input(user_input_2): - chunks_2.append(chunk) - return chunks_2 - - # Run both requests concurrently - results = await asyncio.gather( - process_request_1(), process_request_2(), return_exceptions=True - ) - - chunks_1, chunks_2 = results - - # Verify both requests completed successfully - assert not isinstance(chunks_1, Exception), f"Request 1 failed: {chunks_1}" - assert not isinstance(chunks_2, Exception), f"Request 2 failed: {chunks_2}" - assert len(chunks_1) > 0, "Request 1 should have produced chunks" - assert len(chunks_2) > 0, "Request 2 should have produced chunks" - - # Verify agent was started (possibly multiple times due to concurrent access) - # The exact number depends on the locking mechanism in RemoteConnections - assert start_agent_call_count >= 1, "Agent should be started at least once" - - # Verify both requests got different session contexts - session_1_chunks = [ - c for c in chunks_1 if c.meta.session_id == f"{session_id}_1" - ] - session_2_chunks = [ - c for c in chunks_2 if c.meta.session_id == f"{session_id}_2" - ] - - assert len(session_1_chunks) > 0, ( - "Request 1 should have chunks with correct session_id" - ) - assert len(session_2_chunks) > 0, ( - "Request 2 should have chunks with correct session_id" - ) - - # Verify both requests called the task manager - assert orchestrator.task_manager.store.save_task.call_count >= 2 - assert orchestrator.task_manager.start_task.call_count >= 2 - - # Verify session messages were added for both sessions - session_add_calls = orchestrator.session_manager.add_message.call_args_list - session_1_calls = [ - call for call in session_add_calls if call[0][0] == f"{session_id}_1" - ] - session_2_calls = [ - call for call in session_add_calls if call[0][0] == f"{session_id}_2" - ] - - assert len(session_1_calls) >= 2, ( - "Session 1 should have user and agent messages" - ) - assert len(session_2_calls) >= 2, ( - "Session 2 should have user and agent messages" - ) - - @pytest.mark.asyncio - async def test_concurrent_requests_different_agents( - self, - orchestrator: AgentOrchestrator, - mock_agent_client: Mock, - session_id: str, - user_id: str, - ): - """Test concurrent requests to different agents work independently.""" - import asyncio - - # Create user inputs for different agents - user_input_agent_1 = UserInput( - query="Request to Agent1", - desired_agent_name="Agent1", - meta=UserInputMetadata( - session_id=f"{session_id}_a1", user_id=f"{user_id}_a1" +def _make_non_streaming_response( + remote_task_id: str = "rt-1", +) -> AsyncGenerator[tuple[Mock, Any], None]: + async def gen(): + rt = Mock() + rt.id = remote_task_id + rt.status = Mock(state=TaskState.submitted) + yield rt, None + yield ( + rt, + TaskStatusUpdateEvent( + status=TaskStatus(state=TaskState.completed), + contextId="ctx", + taskId=remote_task_id, + final=True, ), ) - user_input_agent_2 = UserInput( - query="Request to Agent2", - desired_agent_name="Agent2", - meta=UserInputMetadata( - session_id=f"{session_id}_a2", user_id=f"{user_id}_a2" - ), - ) + return gen() - # Setup different execution plans for different agents - from valuecell.core.task import Task, TaskStatus as CoreTaskStatus - task_1 = Task( - task_id="task-agent1", - session_id=f"{session_id}_a1", - user_id=f"{user_id}_a1", - agent_name="Agent1", - status=CoreTaskStatus.PENDING, - remote_task_ids=[], - ) +# ------------------------- +# Tests +# ------------------------- - task_2 = Task( - task_id="task-agent2", - session_id=f"{session_id}_a2", - user_id=f"{user_id}_a2", - agent_name="Agent2", - status=CoreTaskStatus.PENDING, - remote_task_ids=[], - ) - plan_1 = ExecutionPlan( - plan_id="plan-agent1", - session_id=f"{session_id}_a1", - user_id=f"{user_id}_a1", - query="Request to Agent1", - tasks=[task_1], - created_at="2025-09-16T10:00:00", - ) +@pytest.mark.asyncio +async def test_happy_path_streaming( + orchestrator: AgentOrchestrator, + mock_agent_client: Mock, + mock_agent_card_streaming: AgentCard, + sample_user_input: UserInput, +): + # Inject agent connections mock + ac = Mock() + ac.start_agent = AsyncMock(return_value=mock_agent_card_streaming) + ac.get_client = AsyncMock(return_value=mock_agent_client) + ac.stop_all = AsyncMock() + orchestrator.agent_connections = ac - plan_2 = ExecutionPlan( - plan_id="plan-agent2", - session_id=f"{session_id}_a2", - user_id=f"{user_id}_a2", - query="Request to Agent2", - tasks=[task_2], - created_at="2025-09-16T10:00:00", - ) + mock_agent_client.send_message.return_value = _make_streaming_response( + ["Hello", " World"] + ) - # Setup planner to return different plans - orchestrator.planner.create_plan.side_effect = [plan_1, plan_2] - - # Setup agent responses - response_1 = create_streaming_response(["Agent1 response"], "remote-task-a1") - response_2 = create_streaming_response(["Agent2 response"], "remote-task-a2") - mock_agent_client.send_message.side_effect = [response_1, response_2] - - # Execute both requests concurrently - async def process_agent_1(): - chunks = [] - async for chunk in orchestrator.process_user_input(user_input_agent_1): - chunks.append(chunk) - return chunks - - async def process_agent_2(): - chunks = [] - async for chunk in orchestrator.process_user_input(user_input_agent_2): - chunks.append(chunk) - return chunks - - results = await asyncio.gather( - process_agent_1(), process_agent_2(), return_exceptions=True - ) + # Execute + out = [] + async for chunk in orchestrator.process_user_input(sample_user_input): + out.append(chunk) + + # Minimal assertions + orchestrator.task_manager.store.save_task.assert_called_once() + orchestrator.task_manager.start_task.assert_called_once() + ac.start_agent.assert_called_once() + ac.get_client.assert_called_once_with("TestAgent") + mock_agent_client.send_message.assert_called_once() + # Should at least yield something (content or final) + assert len(out) >= 1 + + +@pytest.mark.asyncio +async def test_happy_path_non_streaming( + orchestrator: AgentOrchestrator, + mock_agent_client: Mock, + mock_agent_card_non_streaming: AgentCard, + sample_user_input: UserInput, +): + ac = Mock() + ac.start_agent = AsyncMock(return_value=mock_agent_card_non_streaming) + ac.get_client = AsyncMock(return_value=mock_agent_client) + ac.stop_all = AsyncMock() + orchestrator.agent_connections = ac + + mock_agent_client.send_message.return_value = _make_non_streaming_response() + + out = [] + async for chunk in orchestrator.process_user_input(sample_user_input): + out.append(chunk) + + orchestrator.task_manager.start_task.assert_called_once() + orchestrator.task_manager.complete_task.assert_called_once() + assert len(out) >= 1 + + +@pytest.mark.asyncio +async def test_planner_error( + orchestrator: AgentOrchestrator, sample_user_input: UserInput +): + orchestrator.planner.create_plan.side_effect = RuntimeError("Planning failed") + + # Need agent connections to exist but won't be used + orchestrator.agent_connections = Mock() + + out = [] + async for chunk in orchestrator.process_user_input(sample_user_input): + out.append(chunk) + + assert len(out) == 1 + assert "(Error)" in out[0].content + assert "Planning failed" in out[0].content + + +@pytest.mark.asyncio +async def test_agent_connection_error( + orchestrator: AgentOrchestrator, + sample_user_input: UserInput, + mock_agent_card_streaming: AgentCard, +): + ac = Mock() + ac.start_agent = AsyncMock(return_value=mock_agent_card_streaming) + ac.get_client = AsyncMock(return_value=None) # Simulate connection failure + orchestrator.agent_connections = ac + + out = [] + async for chunk in orchestrator.process_user_input(sample_user_input): + out.append(chunk) + + assert any("(Error)" in c.content for c in out) + + +@pytest.mark.asyncio +async def test_create_and_close_session( + orchestrator: AgentOrchestrator, user_id: str, session_id: str +): + # create + new_id = await orchestrator.create_session(user_id, "Title") + orchestrator.session_manager.create_session.assert_called_once_with( + user_id, "Title" + ) + assert new_id == "new-session-id" - chunks_1, chunks_2 = results - - # Verify both requests completed successfully - assert not isinstance(chunks_1, Exception), f"Agent1 request failed: {chunks_1}" - assert not isinstance(chunks_2, Exception), f"Agent2 request failed: {chunks_2}" - assert len(chunks_1) > 0, "Agent1 request should produce chunks" - assert len(chunks_2) > 0, "Agent2 request should produce chunks" - - # Verify different agents were started - assert orchestrator.agent_connections.start_agent.call_count == 2 - start_calls = orchestrator.agent_connections.start_agent.call_args_list - agent_names = [call[0][0] for call in start_calls] - assert "Agent1" in agent_names - assert "Agent2" in agent_names - - @pytest.mark.asyncio - async def test_concurrent_requests_same_session( - self, - orchestrator: AgentOrchestrator, - mock_agent_client: Mock, - session_id: str, - user_id: str, - ): - """Test concurrent requests in the same session.""" - import asyncio - - # Create two requests for the same session but different queries - user_input_1 = UserInput( - query="First query in session", - desired_agent_name="TestAgent", - meta=UserInputMetadata(session_id=session_id, user_id=user_id), - ) + # close + orchestrator.task_manager.cancel_session_tasks.return_value = 1 + await orchestrator.close_session(session_id) + orchestrator.task_manager.cancel_session_tasks.assert_called_once_with(session_id) + orchestrator.session_manager.add_message.assert_called_once() - user_input_2 = UserInput( - query="Second query in session", - desired_agent_name="TestAgent", - meta=UserInputMetadata(session_id=session_id, user_id=user_id), - ) - # Setup responses - response_1 = create_streaming_response(["First response"], "remote-task-1") - response_2 = create_streaming_response(["Second response"], "remote-task-2") - mock_agent_client.send_message.side_effect = [response_1, response_2] - - # Execute concurrently - async def process_1(): - chunks = [] - async for chunk in orchestrator.process_user_input(user_input_1): - chunks.append(chunk) - return chunks - - async def process_2(): - chunks = [] - async for chunk in orchestrator.process_user_input(user_input_2): - chunks.append(chunk) - return chunks - - results = await asyncio.gather(process_1(), process_2(), return_exceptions=True) - - chunks_1, chunks_2 = results - - # Verify both completed successfully - assert not isinstance(chunks_1, Exception) - assert not isinstance(chunks_2, Exception) - assert len(chunks_1) > 0 - assert len(chunks_2) > 0 - - # Verify session messages were added for both queries - session_calls = orchestrator.session_manager.add_message.call_args_list - user_calls = [call for call in session_calls if call[0][1] == Role.USER] - assert len(user_calls) >= 2, "Both user queries should be added to session" +@pytest.mark.asyncio +async def test_cleanup(orchestrator: AgentOrchestrator): + orchestrator.agent_connections = Mock() + orchestrator.agent_connections.stop_all = AsyncMock() + await orchestrator.cleanup() + orchestrator.agent_connections.stop_all.assert_called_once() From 6ab82a8c289d7996f7b1e706df4ed22256f2e9cc Mon Sep 17 00:00:00 2001 From: Zhaofeng Zhang <24791380+vcfgv@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:00:58 +0800 Subject: [PATCH 6/6] fix: reorder import statements for consistency in connection.py --- python/valuecell/server/db/connection.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/valuecell/server/db/connection.py b/python/valuecell/server/db/connection.py index f70c4efcf..05d5000cb 100644 --- a/python/valuecell/server/db/connection.py +++ b/python/valuecell/server/db/connection.py @@ -1,8 +1,9 @@ """Database connection and session management for ValueCell Server.""" from typing import Generator -from sqlalchemy import create_engine, Engine -from sqlalchemy.orm import sessionmaker, Session + +from sqlalchemy import Engine, create_engine +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from ..config.settings import get_settings