diff --git a/python/.env.example b/python/.env.example index 45d6c8d2d..240bb76f8 100644 --- a/python/.env.example +++ b/python/.env.example @@ -7,6 +7,7 @@ APP_ENVIRONMENT=development # You can generate a strong key using `openssl rand -base64 42`. # Alternatively you can set it with `SECRET_KEY` environment variable. SECRET_KEY="" +OPENROUTER_API_KEY= # Ensure UTF-8 encoding # i18n settings, different locales can be set here. @@ -33,6 +34,7 @@ SEC_EMAIL=your.name@example.com # Model IDs for OpenRouter SEC_PARSER_MODEL_ID=openai/gpt-4o-mini SEC_ANALYSIS_MODEL_ID=deepseek/deepseek-chat-v3-0324 +AI_HEDGE_FUND_PARSER_MODEL_ID=openai/gpt-4o-mini # SEC Agent Settings SEC_MAX_FILINGS=5 diff --git a/python/third_party/ai-hedge-fund/adapter/__main__.py b/python/third_party/ai-hedge-fund/adapter/__main__.py index 41db88bf6..820f8d2bc 100644 --- a/python/third_party/ai-hedge-fund/adapter/__main__.py +++ b/python/third_party/ai-hedge-fund/adapter/__main__.py @@ -1,5 +1,6 @@ import asyncio import logging +import os from datetime import datetime from typing import List @@ -61,19 +62,26 @@ class AIHedgeFundAgent(BaseAgent): def __init__(self): super().__init__() self.agno_agent = Agent( - model=OpenRouter(id="openai/gpt-4o-mini"), + model=OpenRouter( + id=os.getenv("AI_HEDGE_FUND_PARSER_MODEL_ID") or "openai/gpt-4o-mini" + ), response_model=HedgeFundRequest, markdown=True, ) async def stream(self, query, session_id, task_id): - logger.info(f"Parsing query: {query}. Task ID: {task_id}, Session ID: {session_id}") + logger.info( + f"Parsing query: {query}. Task ID: {task_id}, Session ID: {session_id}" + ) run_response = self.agno_agent.run( f"Parse the following hedge fund analysis request and extract the parameters: {query}" ) hedge_fund_request = run_response.content if not isinstance(hedge_fund_request, HedgeFundRequest): - raise ValueError(f"Unable to parse query: {query}") + logger.error(f"Unable to parse query: {query}") + raise ValueError( + f"Unable to parse your query. Please provide allowed tickers: {allowed_tickers}" + ) end_date = datetime.now().strftime("%Y-%m-%d") end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") @@ -103,9 +111,7 @@ async def stream(self, query, session_id, task_id): }, } - logger.info( - f"Start analyzing. Task ID: {task_id}, Session ID: {session_id}" - ) + logger.info(f"Start analyzing. Task ID: {task_id}, Session ID: {session_id}") for stream_type, chunk in run_hedge_fund_stream( tickers=hedge_fund_request.tickers, start_date=start_date, diff --git a/python/valuecell/core/coordinate/orchestrator.py b/python/valuecell/core/coordinate/orchestrator.py index df07b1b3f..4d9c6080f 100644 --- a/python/valuecell/core/coordinate/orchestrator.py +++ b/python/valuecell/core/coordinate/orchestrator.py @@ -62,6 +62,10 @@ async def process_user_input( session_id = user_input.meta.session_id # Add user message to session + if not await self.session_manager.session_exists(session_id): + await self.session_manager.create_session( + user_input.meta.user_id, session_id=session_id + ) await self.session_manager.add_message(session_id, Role.USER, user_input.query) try: diff --git a/python/valuecell/core/coordinate/tests/test_orchestrator.py b/python/valuecell/core/coordinate/tests/test_orchestrator.py index 1cf1f53f4..a1add2029 100644 --- a/python/valuecell/core/coordinate/tests/test_orchestrator.py +++ b/python/valuecell/core/coordinate/tests/test_orchestrator.py @@ -150,6 +150,7 @@ def mock_session_manager() -> Mock: mock.create_session = AsyncMock(return_value="new-session-id") mock.get_session_messages = AsyncMock(return_value=[]) mock.list_user_sessions = AsyncMock(return_value=[]) + mock.session_exists = AsyncMock(return_value=True) return mock diff --git a/python/valuecell/core/session/manager.py b/python/valuecell/core/session/manager.py index 83c1e9bae..5db218e98 100644 --- a/python/valuecell/core/session/manager.py +++ b/python/valuecell/core/session/manager.py @@ -14,11 +14,16 @@ def __init__(self, store: Optional[SessionStore] = None): self.store = store or InMemorySessionStore() async def create_session( - self, user_id: str, title: Optional[str] = None + self, + user_id: str, + title: Optional[str] = None, + session_id: Optional[str] = None, ) -> Session: """Create new session""" session = Session( - session_id=generate_uuid("session"), user_id=user_id, title=title + session_id=session_id or generate_uuid("session"), + user_id=user_id, + title=title, ) await self.store.save_session(session) return session diff --git a/python/valuecell/core/types.py b/python/valuecell/core/types.py index 9148e60e8..f252fe987 100644 --- a/python/valuecell/core/types.py +++ b/python/valuecell/core/types.py @@ -9,7 +9,7 @@ class UserInputMetadata(BaseModel): """Metadata associated with user input""" - session_id: str = Field(..., description="Session ID for this request") + session_id: Optional[str] = Field(None, description="Session ID for this request") user_id: str = Field(..., description="User ID who made this request") diff --git a/python/valuecell/examples/ai_hedge_fund_websocket_example.html b/python/valuecell/examples/ai_hedge_fund_websocket_example.html new file mode 100644 index 000000000..30649ba9c --- /dev/null +++ b/python/valuecell/examples/ai_hedge_fund_websocket_example.html @@ -0,0 +1,286 @@ + + + + + + AI Hedge Fund WebSocket Client + + + +
+

AI Hedge Fund Analysis

+ +
+ Status: Disconnected +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ + + + + + +
+
+ + + + \ No newline at end of file diff --git a/python/valuecell/examples/ai_hedge_fund_websocket_example.py b/python/valuecell/examples/ai_hedge_fund_websocket_example.py new file mode 100644 index 000000000..9482ccecf --- /dev/null +++ b/python/valuecell/examples/ai_hedge_fund_websocket_example.py @@ -0,0 +1,174 @@ +import json +import logging +from typing import Optional + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from pydantic import BaseModel, Field +from valuecell.core.coordinate.orchestrator import get_default_orchestrator +from valuecell.core.types import UserInput, UserInputMetadata + +logger = logging.getLogger(__name__) + +AGENT_ANALYST_MAP = { + "aswath_damodaran_agent": ("Aswath Damodaran", "aswath_damodaran"), + "ben_graham_agent": ("Ben Graham", "ben_graham"), + "bill_ackman_agent": ("Bill Ackman", "bill_ackman"), + "cathie_wood_agent": ("Cathie Wood", "cathie_wood"), + "charlie_munger_agent": ("Charlie Munger", "charlie_munger"), + "michael_burry_agent": ("Michael Burry", "michael_burry"), + "mohnish_pabrai_agent": ("Mohnish Pabrai", "mohnish_pabrai"), + "peter_lynch_agent": ("Peter Lynch", "peter_lynch"), + "phil_fisher_agent": ("Phil Fisher", "phil_fisher"), + "rakesh_jhunjhunwala_agent": ("Rakesh Jhunjhunwala", "rakesh_jhunjhunwala"), + "stanley_druckenmiller_agent": ("Stanley Druckenmiller", "stanley_druckenmiller"), + "warren_buffett_agent": ("Warren Buffett", "warren_buffett"), + "technical_analyst_agent": ("Technical Analyst", "technical_analyst"), + "fundamentals_analyst_agent": ("Fundamentals Analyst", "fundamentals_analyst"), + "sentiment_analyst_agent": ("Sentiment Analyst", "sentiment_analyst"), + "valuation_analyst_agent": ("Valuation Analyst", "valuation_analyst"), +} + + +class AnalysisRequest(BaseModel): + agent_name: str = Field(..., description="The name of the agent to use") + query: str = Field(..., description="The user's query for the agent") + session_id: Optional[str] = Field( + None, description="Session ID, will be auto-generated if not provided" + ) + user_id: str = Field("default_user", description="User ID") + + +def _parse_user_input(request: AnalysisRequest) -> UserInput: + """Parse user input into internal format""" + session_id = request.session_id or f"{request.agent_name}_session_{request.user_id}" + + meta = UserInputMetadata( + session_id=session_id, + user_id=request.user_id, + ) + + query = request.query + selected_analyst = AGENT_ANALYST_MAP.get(request.agent_name) + if selected_analyst: + query += f"\n\n**Hint**: Use {selected_analyst[0]} ({selected_analyst[1]}) in your analysis." + + return UserInput(desired_agent_name="AIHedgeFundAgent", query=query, meta=meta) + + +app = FastAPI( + title="AI Hedge Fund WebSocket API", + description="Real-time stock analysis via WebSocket", + version="1.0.0", +) + + +@app.get("/") +async def root(): + """Health check endpoint""" + return { + "message": "AI Hedge Fund WebSocket API is running", + "version": "1.0.0", + "websocket_endpoint": "/ws", + } + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + """WebSocket endpoint for real-time stock analysis""" + await websocket.accept() + logger.info("WebSocket connection established") + + try: + orchestrator = get_default_orchestrator() + + while True: + # Receive message from client + data = await websocket.receive_text() + logger.info(f"Received message: {data}") + + try: + # Parse the incoming message + message_data = json.loads(data) + + # Validate agent name + agent_name = message_data.get("agent_name") + if agent_name not in AGENT_ANALYST_MAP: + await websocket.send_text( + json.dumps( + { + "type": "error", + "message": f"Unsupported agent: {agent_name}. Available agents: {list(AGENT_ANALYST_MAP.keys())}", + } + ) + ) + continue + + # Create analysis request + request = AnalysisRequest(**message_data) + user_input = _parse_user_input(request) + + # Send analysis start notification + await websocket.send_text( + json.dumps( + { + "type": "analysis_started", + "agent_name": request.agent_name, + } + ) + ) + + # Stream analysis results + async for message_chunk in orchestrator.process_user_input(user_input): + response = { + "type": "analysis_chunk", + "message": str(message_chunk), + "agent_name": request.agent_name, + } + await websocket.send_text(json.dumps(response)) + logger.info(f"Sent message chunk: {message_chunk}") + + # Send completion notification + await websocket.send_text( + json.dumps( + { + "type": "analysis_completed", + "agent_name": request.agent_name, + } + ) + ) + + except json.JSONDecodeError: + await websocket.send_text( + json.dumps({"type": "error", "message": "Invalid JSON format"}) + ) + except Exception as e: + logger.error(f"Error processing request: {e}") + await websocket.send_text( + json.dumps( + {"type": "error", "message": f"Analysis failed: {str(e)}"} + ) + ) + + except WebSocketDisconnect: + logger.info("WebSocket connection closed") + except Exception as e: + logger.error(f"WebSocket error: {e}") + + +if __name__ == "__main__": + import uvicorn + + # Configure logging + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + # Start server + uvicorn.run( + "ai_hedge_fund_websocket_example:app", + host="0.0.0.0", + port=8000, + reload=True, + log_level="info", + )