diff --git a/agentpress/__init__.py b/agentpress/__init__.py new file mode 100644 index 0000000..3ebeb22 --- /dev/null +++ b/agentpress/__init__.py @@ -0,0 +1 @@ +# Empty file to mark as package \ No newline at end of file diff --git a/agentpress/examples/.env.example b/agentpress/agents/.env.example similarity index 100% rename from agentpress/examples/.env.example rename to agentpress/agents/.env.example diff --git a/agentpress/examples/simple_web_dev/agent.py b/agentpress/agents/simple_web_dev/agent.py similarity index 52% rename from agentpress/examples/simple_web_dev/agent.py rename to agentpress/agents/simple_web_dev/agent.py index 16c0f15..2a7be66 100644 --- a/agentpress/examples/simple_web_dev/agent.py +++ b/agentpress/agents/simple_web_dev/agent.py @@ -8,23 +8,22 @@ - Use either XML or Standard tool calling patterns """ -import os import asyncio import json from agentpress.thread_manager import ThreadManager -from tools.files_tool import FilesTool +from example.tools.files_tool import FilesTool from agentpress.state_manager import StateManager -from tools.terminal_tool import TerminalTool -from agentpress.api_factory import register_api_endpoint +from example.tools.terminal_tool import TerminalTool import logging from typing import AsyncGenerator, Optional, Dict, Any import sys +from agentpress.api.api_factory import register_thread_task_api + BASE_SYSTEM_MESSAGE = """ You are a world-class web developer who can create, edit, and delete files, and execute terminal commands. You write clean, well-structured code. Keep iterating on existing files, continue working on this existing codebase - do not omit previous progress; instead, keep iterating. - Available tools: - create_file: Create new files with specified content - delete_file: Remove existing files @@ -69,7 +68,6 @@ } } Think deeply and step by step. - """ XML_FORMAT = """ @@ -88,87 +86,77 @@ + """ -def get_anthropic_api_key(): - """Get Anthropic API key from environment or prompt user.""" - api_key = os.getenv("ANTHROPIC_API_KEY") - if not api_key: - api_key = input("\n🔑 Please enter your Anthropic API key: ").strip() - if not api_key: - print("❌ No API key provided. Please set ANTHROPIC_API_KEY environment variable or enter a key.") - sys.exit(1) - os.environ["ANTHROPIC_API_KEY"] = api_key - return api_key - -@register_api_endpoint("/main_agent") +@register_thread_task_api("/agent") async def run_agent( thread_id: str, - use_xml: bool = True, max_iterations: int = 5, - project_description: Optional[str] = None + user_input: Optional[str] = None, ) -> Dict[str, Any]: - """Run the development agent with specified configuration.""" - # Initialize managers - thread_manager = ThreadManager() - await thread_manager.initialize() + """Run the development agent with specified configuration. - state_manager = StateManager(thread_id) - await state_manager.initialize() - - # Register tools - thread_manager.add_tool(FilesTool, thread_id=thread_id) - thread_manager.add_tool(TerminalTool, thread_id=thread_id) + Args: + thread_id (str): The ID of the thread. + max_iterations (int, optional): The maximum number of iterations. Defaults to 5. + user_input (Optional[str], optional): The user input. Defaults to None. + """ + thread_manager = ThreadManager() + state_manager = StateManager(thread_id) - # Add initial project description if provided - if project_description: + if user_input: await thread_manager.add_message( thread_id, { "role": "user", - "content": project_description + "content": user_input } ) - # Set up system message with appropriate format + thread_manager.add_tool(FilesTool, thread_id=thread_id) + thread_manager.add_tool(TerminalTool, thread_id=thread_id) + system_message = { "role": "system", - "content": BASE_SYSTEM_MESSAGE + (XML_FORMAT if use_xml else "") + "content": BASE_SYSTEM_MESSAGE + XML_FORMAT } - # Create initial event to track agent loop - await thread_manager.create_event( - thread_id=thread_id, - event_type="agent_loop_started", - content={ - "max_iterations": max_iterations, - "use_xml": use_xml, - "project_description": project_description - }, - include_in_llm_message_history=False - ) - - results = [] iteration = 0 while iteration < max_iterations: iteration += 1 - files_tool = FilesTool(thread_id) - await files_tool._init_workspace_state() + files_tool = FilesTool(thread_id=thread_id) - state = await state_manager.get_latest_state() - - state_message = { - "role": "user", - "content": f""" -Current development environment workspace state: - -{json.dumps(state, indent=2)} - - """ + state = await state_manager.export_store() + + temporary_message_content = f""" + You are tasked to complete the LATEST USER REQUEST! + + {user_input} + + + Current development environment workspace state: + + {json.dumps(state, indent=2) if state else "{}"} + + + CONTINUE WITH THE TASK! USE THE SESSION TOOL TO STOP THE SESSION IF THE TASK IS COMPLETE. + """ + + await thread_manager.add_message( + thread_id=thread_id, + message_data=temporary_message_content, + message_type="temporary_message", + include_in_llm_message_history=False + ) + + temporary_message = { + "role": "user", + "content": temporary_message_content } - model_name = "anthropic/claude-3-5-sonnet-latest" + model_name = "anthropic/claude-3-5-sonnet-latest" response = await thread_manager.run_thread( thread_id=thread_id, @@ -177,51 +165,57 @@ async def run_agent( temperature=0.1, max_tokens=8096, tool_choice="auto", - temporary_message=state_message, - native_tool_calling=not use_xml, - xml_tool_calling=use_xml, + temporary_message=temporary_message, + native_tool_calling=False, + xml_tool_calling=True, stream=True, - execute_tools_on_stream=False, - parallel_tool_execution=True + execute_tools_on_stream=True, + parallel_tool_execution=True, ) - # Handle both streaming and regular responses - if hasattr(response, '__aiter__'): - chunks = [] + if isinstance(response, AsyncGenerator): + print("\n🤖 Assistant is responding:") try: async for chunk in response: - chunks.append(chunk) + if hasattr(chunk.choices[0], 'delta'): + delta = chunk.choices[0].delta + + if hasattr(delta, 'content') and delta.content is not None: + content = delta.content + print(content, end='', flush=True) + + # Check for open_files_in_editor tag and continue if found + if '' in content: + print("\n📂 Opening files in editor, continuing to next iteration...") + continue + + if hasattr(delta, 'tool_calls') and delta.tool_calls: + for tool_call in delta.tool_calls: + if tool_call.function: + if tool_call.function.name: + print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True) + if tool_call.function.arguments: + print(f" {tool_call.function.arguments}", end='', flush=True) + + print("\n✨ Response completed\n") + except Exception as e: + print(f"\n❌ Error processing stream: {e}", file=sys.stderr) logging.error(f"Error processing stream: {e}") - raise - response = chunks - - results.append({ - "iteration": iteration, - "response": response - }) + else: + print("\nNon-streaming response received:", response) - # Create iteration completion event - await thread_manager.create_event( - thread_id=thread_id, - event_type="iteration_complete", - content={ - "iteration_number": iteration, - "max_iterations": max_iterations, - # "state": state - }, - include_in_llm_message_history=False - ) + # # Get latest assistant message and check for stop_session + # latest_msg = await thread_manager.get_llm_history_messages( + # thread_id=thread_id, + # only_latest_assistant=True + # ) + # if latest_msg and '' in latest_msg: + # break - return { - "thread_id": thread_id, - "iterations": results, - } if __name__ == "__main__": - print("\n🚀 Welcome to AgentPress Web Developer Example!") - - get_anthropic_api_key() + print("\n🚀 Welcome to AgentPress!") project_description = input("What would you like to build? (default: Create a modern, responsive landing page)\n> ") if not project_description.strip(): @@ -241,10 +235,27 @@ async def run_agent( print(f"\n{'XML-based' if use_xml else 'Standard'} agent will help you build: {project_description}") print("Use Ctrl+C to stop the agent at any time.") - async def async_main(): + async def test_agent(): thread_manager = ThreadManager() thread_id = await thread_manager.create_thread() logging.info(f"Created new thread: {thread_id}") - await run_agent(thread_id, use_xml, project_description=project_description) - - asyncio.run(async_main()) \ No newline at end of file + + try: + result = await run_agent( + thread_id=thread_id, + max_iterations=5, + user_input=project_description, + ) + print("\n✅ Test completed successfully!") + + except Exception as e: + print(f"\n❌ Test failed: {str(e)}") + raise + + try: + asyncio.run(test_agent()) + except KeyboardInterrupt: + print("\n⚠️ Test interrupted by user") + except Exception as e: + print(f"\n❌ Test failed with error: {str(e)}") + raise \ No newline at end of file diff --git a/agentpress/examples/simple_web_dev/tools/files_tool.py b/agentpress/agents/simple_web_dev/tools/files_tool.py similarity index 97% rename from agentpress/examples/simple_web_dev/tools/files_tool.py rename to agentpress/agents/simple_web_dev/tools/files_tool.py index 9dca71a..0d5948d 100644 --- a/agentpress/examples/simple_web_dev/tools/files_tool.py +++ b/agentpress/agents/simple_web_dev/tools/files_tool.py @@ -60,13 +60,8 @@ def __init__(self, thread_id: Optional[str] = None): os.makedirs(self.workspace, exist_ok=True) if thread_id: self.state_manager = StateManager(thread_id) - asyncio.create_task(self._init_state()) - self.SNIPPET_LINES = 4 # Number of context lines to show around edits - - async def _init_state(self): - """Initialize state manager and workspace state.""" - await self.state_manager.initialize() - await self._init_workspace_state() + asyncio.create_task(self._init_workspace_state()) + self.SNIPPET_LINES = 4 def _should_exclude_file(self, rel_path: str) -> bool: """Check if a file should be excluded based on path, name, or extension""" @@ -264,6 +259,9 @@ async def str_replace(self, file_path: str, old_str: str, new_str: str) -> ToolR new_content = content.replace(old_str, new_str) full_path.write_text(new_content) + # Update state after file modification + await self._update_workspace_state() + # Show snippet around the edit replacement_line = content.split(old_str)[0].count('\n') start_line = max(0, replacement_line - self.SNIPPET_LINES) diff --git a/agentpress/examples/simple_web_dev/tools/terminal_tool.py b/agentpress/agents/simple_web_dev/tools/terminal_tool.py similarity index 69% rename from agentpress/examples/simple_web_dev/tools/terminal_tool.py rename to agentpress/agents/simple_web_dev/tools/terminal_tool.py index 3c5b176..c9736dc 100644 --- a/agentpress/examples/simple_web_dev/tools/terminal_tool.py +++ b/agentpress/agents/simple_web_dev/tools/terminal_tool.py @@ -2,7 +2,6 @@ import asyncio import subprocess from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema -from agentpress.state_manager import StateManager from typing import Optional class TerminalTool(Tool): @@ -12,24 +11,6 @@ def __init__(self, thread_id: Optional[str] = None): super().__init__() self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace') os.makedirs(self.workspace, exist_ok=True) - if thread_id: - self.state_manager = StateManager(thread_id) - asyncio.create_task(self._init_state()) - - async def _init_state(self): - """Initialize state manager.""" - await self.state_manager.initialize() - - async def _update_command_history(self, command: str, output: str, success: bool): - """Update command history in state""" - history = await self.state_manager.get("terminal_history") or [] - history.append({ - "command": command, - "output": output, - "success": success, - "cwd": os.path.relpath(os.getcwd(), self.workspace) - }) - await self.state_manager.set("terminal_history", history) @openapi_schema({ "type": "function", @@ -76,12 +57,6 @@ async def execute_command(self, command: str) -> ToolResult: error = stderr.decode() if stderr else "" success = process.returncode == 0 - await self._update_command_history( - command=command, - output=output + error, - success=success - ) - if success: return self.success_response({ "output": output, @@ -93,11 +68,6 @@ async def execute_command(self, command: str) -> ToolResult: return self.fail_response(f"Command failed with exit code {process.returncode}: {error}") except Exception as e: - await self._update_command_history( - command=command, - output=str(e), - success=False - ) return self.fail_response(f"Error executing command: {str(e)}") finally: os.chdir(original_dir) diff --git a/agentpress/api.py b/agentpress/api.py deleted file mode 100644 index e52bdb4..0000000 --- a/agentpress/api.py +++ /dev/null @@ -1,253 +0,0 @@ -from contextlib import asynccontextmanager -from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, BackgroundTasks -from fastapi.middleware.cors import CORSMiddleware -from typing import Optional, List, Dict, Any -from pydantic import BaseModel -from agentpress.thread_manager import ThreadManager -import asyncio -import json -import uvicorn -import logging -import importlib -from agentpress.api_factory import app as api_app, discover_tasks - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# Global managers -thread_manager: Optional[ThreadManager] = None - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Lifespan context manager for FastAPI application.""" - # Startup - global thread_manager - thread_manager = ThreadManager() - await thread_manager.initialize() - - yield - - # Shutdown - # Add any cleanup code here if needed - -# Create FastAPI app -app = FastAPI(title="AgentPress API", lifespan=lifespan) - -# Enable CORS -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Import and mount the API Factory app -try: - # Run task discovery - discover_tasks() - logger.info("Task discovery completed") - - # Mount the API Factory app at /tasks instead of root - app.mount("/tasks", api_app) - logger.info("Mounted API Factory app at /tasks") -except Exception as e: - logger.error(f"Error setting up API Factory: {e}") - raise - -# WebSocket connection manager -class WebSocketManager: - """Manages WebSocket connections for real-time thread updates.""" - - def __init__(self): - self.active_connections: Dict[str, List[WebSocket]] = {} - - async def connect(self, websocket: WebSocket, thread_id: str): - """Connect a WebSocket to a thread.""" - await websocket.accept() - if thread_id not in self.active_connections: - self.active_connections[thread_id] = [] - self.active_connections[thread_id].append(websocket) - - def disconnect(self, websocket: WebSocket, thread_id: str): - """Disconnect a WebSocket from a thread.""" - if thread_id in self.active_connections: - self.active_connections[thread_id].remove(websocket) - if not self.active_connections[thread_id]: - del self.active_connections[thread_id] - - async def broadcast_to_thread(self, thread_id: str, message: dict): - """Broadcast a message to all connections in a thread.""" - if thread_id in self.active_connections: - for connection in self.active_connections[thread_id]: - try: - await connection.send_json(message) - except WebSocketDisconnect: - self.disconnect(connection, thread_id) - -# Initialize WebSocket manager -ws_manager = WebSocketManager() - -# Pydantic models for request/response validation -class EventCreate(BaseModel): - event_type: str - content: Dict[str, Any] - include_in_llm_message_history: bool = False - llm_message: Optional[Dict[str, Any]] = None - -class EventUpdate(BaseModel): - content: Optional[Dict[str, Any]] = None - include_in_llm_message_history: Optional[bool] = None - llm_message: Optional[Dict[str, Any]] = None - -class ThreadEvents(BaseModel): - only_llm_messages: bool = False - event_types: Optional[List[str]] = None - order_by: str = "created_at" - order: str = "ASC" - -# REST API Endpoints -@app.post("/threads", response_model=dict, status_code=201) -async def create_thread(): - """Create a new thread.""" - thread_id = await thread_manager.create_thread() - return {"thread_id": thread_id} - -@app.delete("/threads/{thread_id}", status_code=204) -async def delete_thread(thread_id: str): - """Delete a thread and all its events.""" - success = await thread_manager.delete_thread(thread_id) - if not success: - raise HTTPException(status_code=404, detail="Thread not found") - return None - -@app.post("/threads/{thread_id}/events", response_model=dict, status_code=201) -async def create_event(thread_id: str, event: EventCreate, background_tasks: BackgroundTasks): - """Create a new event in a thread.""" - # First verify thread exists - if not await thread_manager.thread_exists(thread_id): - raise HTTPException(status_code=404, detail="Thread not found") - - try: - event_id = await thread_manager.create_event( - thread_id=thread_id, - event_type=event.event_type, - content=event.content, - include_in_llm_message_history=event.include_in_llm_message_history, - llm_message=event.llm_message - ) - # Broadcast to WebSocket connections - background_tasks.add_task( - ws_manager.broadcast_to_thread, - thread_id, - {"type": "event_created", "event_id": event_id, "event": event.dict()} - ) - return {"event_id": event_id} - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - -@app.delete("/threads/{thread_id}/events/{event_id}", status_code=204) -async def delete_event(thread_id: str, event_id: str, background_tasks: BackgroundTasks): - """Delete a specific event.""" - # First verify thread exists - if not await thread_manager.thread_exists(thread_id): - raise HTTPException(status_code=404, detail="Thread not found") - - # Then verify event exists and belongs to thread - if not await thread_manager.event_belongs_to_thread(event_id, thread_id): - raise HTTPException(status_code=404, detail="Event not found in this thread") - - success = await thread_manager.delete_event(event_id) - if not success: - raise HTTPException(status_code=500, detail="Failed to delete event") - - # Broadcast to WebSocket connections - background_tasks.add_task( - ws_manager.broadcast_to_thread, - thread_id, - {"type": "event_deleted", "event_id": event_id} - ) - return None - -@app.patch("/threads/{thread_id}/events/{event_id}", status_code=200) -async def update_event(thread_id: str, event_id: str, event: EventUpdate, background_tasks: BackgroundTasks): - """Update an existing event.""" - # First verify thread exists - if not await thread_manager.thread_exists(thread_id): - raise HTTPException(status_code=404, detail="Thread not found") - - # Then verify event exists and belongs to thread - if not await thread_manager.event_belongs_to_thread(event_id, thread_id): - raise HTTPException(status_code=404, detail="Event not found in this thread") - - success = await thread_manager.update_event( - event_id=event_id, - thread_id=thread_id, - content=event.content, - include_in_llm_message_history=event.include_in_llm_message_history, - llm_message=event.llm_message - ) - if not success: - raise HTTPException(status_code=500, detail="Failed to update event") - - # Broadcast to WebSocket connections - background_tasks.add_task( - ws_manager.broadcast_to_thread, - thread_id, - {"type": "event_updated", "event_id": event_id, "updates": event.dict(exclude_unset=True)} - ) - return {"status": "success"} - -@app.get("/threads/{thread_id}/events") -async def get_thread_events( - thread_id: str, - only_llm_messages: bool = False, - event_types: Optional[List[str]] = None, - order_by: str = "created_at", - order: str = "ASC" -): - """Get events from a thread with filtering options.""" - if not await thread_manager.thread_exists(thread_id): - raise HTTPException(status_code=404, detail="Thread not found") - - events = await thread_manager.get_thread_events( - thread_id=thread_id, - only_llm_messages=only_llm_messages, - event_types=event_types, - order_by=order_by, - order=order - ) - return {"events": events} - -@app.get("/threads/{thread_id}/messages") -async def get_thread_messages(thread_id: str): - """Get LLM-formatted messages from thread events.""" - if not await thread_manager.thread_exists(thread_id): - raise HTTPException(status_code=404, detail="Thread not found") - - messages = await thread_manager.get_thread_llm_messages(thread_id) - return {"messages": messages} - -# WebSocket Endpoint -@app.websocket("/ws/threads/{thread_id}") -async def websocket_endpoint(websocket: WebSocket, thread_id: str): - """WebSocket endpoint for real-time thread updates.""" - # Verify thread exists before accepting connection - if not await thread_manager.thread_exists(thread_id): - await websocket.close(code=4004, reason="Thread not found") - return - - await ws_manager.connect(websocket, thread_id) - try: - while True: - await websocket.receive_json() - - except WebSocketDisconnect: - ws_manager.disconnect(websocket, thread_id) - except Exception as e: - await websocket.send_json({"type": "error", "detail": str(e)}) - ws_manager.disconnect(websocket, thread_id) - -if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/agentpress/api/__init__.py b/agentpress/api/__init__.py new file mode 100644 index 0000000..3ebeb22 --- /dev/null +++ b/agentpress/api/__init__.py @@ -0,0 +1 @@ +# Empty file to mark as package \ No newline at end of file diff --git a/agentpress/api/api.py b/agentpress/api/api.py new file mode 100644 index 0000000..480e3d5 --- /dev/null +++ b/agentpress/api/api.py @@ -0,0 +1,255 @@ +from contextlib import asynccontextmanager +from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, BackgroundTasks +from fastapi.middleware.cors import CORSMiddleware +from typing import Optional, List, Dict, Any, Union +from pydantic import BaseModel +from agentpress.thread_manager import ThreadManager +import asyncio +import uvicorn +import logging +from agentpress.api.ws import ws_manager +from agentpress.api.api_factory import ( + app as thread_task_app, + register_thread_task_api, + discover_tasks, + thread_manager as task_thread_manager +) +# from agentpress.api_factory import app as api_app, discover_tasks + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Global managers +thread_manager: Optional[ThreadManager] = None + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan context manager for FastAPI application.""" + # Startup + global thread_manager + thread_manager = ThreadManager() + + # Share thread_manager with task API + global task_thread_manager + task_thread_manager = thread_manager + + # Wait for DB initialization + db = thread_manager.db + if db._initialization_task: + await db._initialization_task + + # Run task discovery during startup + discover_tasks() + + yield + + # Shutdown + # Add any cleanup code here if needed + +# Create FastAPI app +app = FastAPI(title="AgentPress API", lifespan=lifespan) + +# Enable CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# # Import and mount the API Factory app +# try: +# # Run task discovery +# # discover_tasks() +# logger.info("Task discovery completed") + +# # Mount the API Factory app at /tasks instead of root +# app.mount("/tasks", api_app) +# logger.info("Mounted API Factory app at /tasks") +# except Exception as e: +# logger.error(f"Error setting up API Factory: {e}") +# raise + +# Pydantic models for request/response validation +class MessageCreate(BaseModel): + """Model for creating messages in a thread.""" + message_data: Union[str, Dict[str, Any]] + images: Optional[List[Dict[str, Any]]] = None + include_in_llm_message_history: bool = True + message_type: Optional[str] = None + +# REST API Endpoints +@app.post("/threads", response_model=dict, status_code=201) +async def create_thread(): + """Create a new thread.""" + thread_id = await thread_manager.create_thread() + return {"thread_id": thread_id} + +@app.post("/threads/{thread_id}/messages", response_model=dict, status_code=201) +async def create_message(thread_id: str, message: MessageCreate, background_tasks: BackgroundTasks): + """Create a new message in a thread.""" + if not await thread_manager.thread_exists(thread_id): + raise HTTPException(status_code=404, detail="Thread not found") + + try: + await thread_manager.add_message( + thread_id=thread_id, + message_data=message.message_data, + images=message.images, + include_in_llm_message_history=message.include_in_llm_message_history, + message_type=message.message_type + ) + + # Broadcast to WebSocket connections + background_tasks.add_task( + ws_manager.broadcast_to_thread, + thread_id, + {"type": "message_created", "message": message.dict()} + ) + return {"status": "success"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +# TODO: BROKEN FOR SOME REASON – RETURNS [] SHOULD RETURN, LLM MESSAGE STYLE +@app.get("/threads/{thread_id}/llm_history_messages") +async def get_thread_llm_messages( + thread_id: str, + hide_tool_msgs: bool = False, + only_latest_assistant: bool = False, +): + """Get messages from a thread with filtering options.""" + if not await thread_manager.thread_exists(thread_id): + raise HTTPException(status_code=404, detail="Thread not found") + + messages = await thread_manager.get_llm_history_messages( + thread_id=thread_id, + hide_tool_msgs=hide_tool_msgs, + only_latest_assistant=only_latest_assistant, + ) + return {"messages": messages} + +@app.get("/threads/{thread_id}/messages") +async def get_thread_messages( + thread_id: str, + message_types: Optional[List[str]] = None, + limit: Optional[int] = 50, + offset: Optional[int] = 0, + before_timestamp: Optional[str] = None, + after_timestamp: Optional[str] = None, + include_in_llm_message_history: Optional[bool] = None, + order: str = "asc" +): + """ + Get messages from a thread with comprehensive filtering options. + + Args: + thread_id: Thread identifier + message_types: Optional list of message types to filter by + limit: Maximum number of messages to return (default: 50) + offset: Number of messages to skip for pagination + before_timestamp: Optional filter for messages before timestamp + after_timestamp: Optional filter for messages after timestamp + include_in_llm_message_history: Optional filter for LLM history inclusion + order: Sort order - "asc" or "desc" + """ + if not await thread_manager.thread_exists(thread_id): + raise HTTPException(status_code=404, detail="Thread not found") + + try: + messages = await thread_manager.get_messages( + thread_id=thread_id, + message_types=message_types, + limit=limit, + offset=offset, + before_timestamp=before_timestamp, + after_timestamp=after_timestamp, + include_in_llm_message_history=include_in_llm_message_history, + order=order + ) + return {"messages": messages} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +# TODO ONLY SEND POLLING UPDATES (IN EVEN HIGHER FREQUENCY THEN 1per sec) - IF THEY ARE ANY ACTIVE TASKS FOR THAT THREAD. AS LONG AS THEY ARE ACTIVE TASKS START & STOP THE POLLING BASED ON WHETHER THERE IS AN ACTIVE TASK FOR THE THREAD. IMPLEMENT in API_FACTORY as well to broadcast this ofc & trigger/disable the polling. + +# WebSocket Endpoint +@app.websocket("/threads/{thread_id}") +async def websocket_endpoint( + websocket: WebSocket, + thread_id: str, + message_types: Optional[List[str]] = None, + limit: Optional[int] = 50, + offset: Optional[int] = 0, + before_timestamp: Optional[str] = None, + after_timestamp: Optional[str] = None, + include_in_llm_message_history: Optional[bool] = None, + order: str = "desc" +): + """ + WebSocket endpoint for real-time thread updates with filtering and pagination. + + Query Parameters: + message_types: Optional list of message types to filter by + limit: Maximum number of messages to return (default: 50) + offset: Number of messages to skip (for pagination) + before_timestamp: Optional timestamp to filter messages before + after_timestamp: Optional timestamp to filter messages after + include_in_llm_message_history: Optional bool to filter messages by LLM history inclusion + order: Sort order - "asc" or "desc" (default: desc) + """ + try: + if not await thread_manager.thread_exists(thread_id): + await websocket.close(code=4004, reason="Thread not found") + return + + await ws_manager.connect(websocket, thread_id) + + while True: + try: + # Get messages with all filters + result = await thread_manager.get_messages( + thread_id=thread_id, + message_types=message_types, + limit=limit, + offset=offset, + before_timestamp=before_timestamp, + after_timestamp=after_timestamp, + include_in_llm_message_history=include_in_llm_message_history, + order=order + ) + + # Send messages and pagination info + await websocket.send_json({ + "type": "messages", + "data": result + }) + + # Poll every second + await asyncio.sleep(1) + + except WebSocketDisconnect: + ws_manager.disconnect(websocket, thread_id) + break + except Exception as e: + logging.error(f"WebSocket error: {e}") + await websocket.send_json({ + "type": "error", + "data": str(e) + }) + ws_manager.disconnect(websocket, thread_id) + break + + except Exception as e: + logging.error(f"WebSocket connection error: {e}") + try: + await websocket.close(code=1011, reason=str(e)) + except: + pass + +# Update the mounting of thread_task_app +app.mount("/tasks", thread_task_app) + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/agentpress/api/api_factory.py b/agentpress/api/api_factory.py new file mode 100644 index 0000000..70d37c9 --- /dev/null +++ b/agentpress/api/api_factory.py @@ -0,0 +1,349 @@ +""" +Thread Task API Factory for registering and managing thread-associated long-running tasks. +""" + +import sys +import os +import inspect +import uuid +import asyncio +import logging +import importlib +from functools import wraps +from typing import Callable, Dict, Any, Optional, List, ForwardRef +from fastapi import FastAPI, HTTPException, Request +from pydantic import create_model +from agentpress.thread_manager import ThreadManager +from contextlib import asynccontextmanager + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Initialize managers at module level +thread_manager: Optional[ThreadManager] = None +_decorated_functions: Dict[str, Callable] = {} +_running_tasks: Dict[str, Dict[str, Any]] = {} + +def find_project_root(): + """Find the project root by looking for pyproject.toml""" + current = os.path.abspath(os.path.dirname(__file__)) + while current != '/': + if os.path.exists(os.path.join(current, 'pyproject.toml')): + return current + current = os.path.dirname(current) + return None + +def discover_tasks(): + """ + Discover all decorated functions in the project. + Scans from the project root (where pyproject.toml is located). + """ + logger.info("Starting task discovery") + + # Find project root + project_root = find_project_root() + logger.info(f"Project root found at: {project_root}") + + # Add project root to Python path if not already there + if project_root not in sys.path: + sys.path.insert(0, project_root) + + # Walk through all Python files in the project + for root, _, files in os.walk(project_root): + for file in files: + if file.endswith('.py'): + module_path = os.path.join(root, file) + module_name = os.path.relpath(module_path, project_root) + module_name = os.path.splitext(module_name)[0].replace(os.path.sep, '.') + + try: + logger.info(f"Attempting to import module: {module_name}") + module = importlib.import_module(module_name) + + # Inspect all module members + for name, obj in inspect.getmembers(module): + if inspect.isfunction(obj): + # Check if this function has been decorated with register_thread_task_api + if hasattr(obj, '__closure__') and obj.__closure__: + for cell in obj.__closure__: + if cell.cell_contents in _decorated_functions.values(): + path = next( + p for p, f in _decorated_functions.items() + if f == cell.cell_contents + ) + _decorated_functions[path] = obj + logger.info(f"Registered function: {obj.__name__} at path: {path}") + + except Exception as e: + logger.error(f"Error importing {module_name}: {e}") + + logger.info(f"Task discovery complete. Registered paths: {list(_decorated_functions.keys())}") + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan context manager for FastAPI application.""" + global thread_manager + + # Initialize ThreadManager if not already initialized + if thread_manager is None: + thread_manager = ThreadManager() + # Wait for DB initialization + if thread_manager.db._initialization_task: + await thread_manager.db._initialization_task + + # Run task discovery during startup + discover_tasks() + + yield + # Cleanup if needed + +# Create FastAPI app +app = FastAPI( + title="Thread Task API", + description="API for managing thread-associated long-running tasks", + openapi_tags=[{ + "name": "Generated Thread Tasks", + "description": "Dynamically generated endpoints for thread-associated tasks" + }], + lifespan=lifespan +) + +# Add middleware to ensure thread_manager is available +@app.middleware("http") +async def ensure_thread_manager(request: Request, call_next): + """Ensure thread_manager is initialized before handling requests.""" + global thread_manager + if thread_manager is None: + thread_manager = ThreadManager() + if thread_manager.db._initialization_task: + await thread_manager.db._initialization_task + return await call_next(request) + +def register_thread_task_api(path: str): + """ + Decorator to register a function as a thread-associated task API endpoint. + The decorated function must have thread_id as its first parameter. + """ + def decorator(func: Callable): + logger.info(f"Registering thread task API endpoint: {path} for function {func.__name__}") + _decorated_functions[path] = func + + # Validate that thread_id is the first parameter + params = inspect.signature(func).parameters + if 'thread_id' not in params: + raise ValueError(f"Function {func.__name__} must have thread_id as a parameter") + + # Create Pydantic model for function parameters + model_fields = {} + for name, param in params.items(): + if name == 'self': # Skip self parameter for methods + continue + + annotation = param.annotation + if annotation == inspect.Parameter.empty: + annotation = Any + + # Convert string annotations to ForwardRef + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + + default = ... if param.default == inspect.Parameter.empty else param.default + model_fields[name] = (annotation, default) + + RequestModel = create_model(f'{func.__name__}Request', **model_fields) + + # Register the start endpoint + @app.post( + f"{path}/start", + response_model=dict, + summary=f"Start {func.__name__}", + description=f"Start a new {func.__name__} task associated with a thread", + tags=["Generated Thread Tasks"] + ) + async def start_task(params: RequestModel): + logger.info(f"Starting task at {path}/start") + + # Validate thread exists + if not await thread_manager.thread_exists(params.thread_id): + raise HTTPException(status_code=404, detail="Thread not found") + + task_id = str(uuid.uuid4()) + kwargs = params.dict() + + # Create the task + task = asyncio.create_task(func(**kwargs)) + + # Store task with thread association + _running_tasks[task_id] = { + "thread_id": params.thread_id, + "task": task, + "status": "running", + "path": path, + "started_at": asyncio.get_event_loop().time() + } + + # Add task info to thread messages + await thread_manager.add_message( + thread_id=params.thread_id, + message_data={ + "type": "task_started", + "task_id": task_id, + "path": path, + "status": "running" + }, + message_type="task_status", + include_in_llm_message_history=False + ) + + return {"task_id": task_id} + + # Register the stop endpoint + @app.post( + f"{path}/stop/{{task_id}}", + response_model=dict, + summary=f"Stop {func.__name__}", + description=f"Stop a running {func.__name__} task", + tags=["Generated Thread Tasks"] + ) + async def stop_task(task_id: str): + if task_id not in _running_tasks: + raise HTTPException(status_code=404, detail="Task not found") + + task_info = _running_tasks[task_id] + task_info["task"].cancel() + task_info["status"] = "cancelled" + + # Update thread with task cancellation + await thread_manager.add_message( + thread_id=task_info["thread_id"], + message_data={ + "type": "task_stopped", + "task_id": task_id, + "path": task_info["path"], + "status": "cancelled" + }, + message_type="task_status", + include_in_llm_message_history=False + ) + + return {"status": "stopped"} + + # Register the status endpoint + @app.get( + f"{path}/status/{{task_id}}", + response_model=dict, + summary=f"Get {func.__name__} status", + description=f"Get the status of a {func.__name__} task", + tags=["Generated Thread Tasks"] + ) + async def get_status(task_id: str): + if task_id not in _running_tasks: + raise HTTPException(status_code=404, detail="Task not found") + + task_info = _running_tasks[task_id] + task = task_info["task"] + + if task.done(): + try: + result = task.result() + status = "completed" + if hasattr(result, '__aiter__'): + status = "streaming" + + # Update thread with task completion + await thread_manager.add_message( + thread_id=task_info["thread_id"], + message_data={ + "type": "task_completed", + "task_id": task_id, + "path": task_info["path"], + "status": status, + "result": result if status == "completed" else None + }, + message_type="task_status", + include_in_llm_message_history=False + ) + + return { + "status": status, + "result": result if status == "completed" else None + } + + except asyncio.CancelledError: + return {"status": "cancelled"} + except Exception as e: + error_status = { + "status": "failed", + "error": str(e) + } + + # Update thread with task failure + await thread_manager.add_message( + thread_id=task_info["thread_id"], + message_data={ + "type": "task_failed", + "task_id": task_id, + "path": task_info["path"], + "status": "failed", + "error": str(e) + }, + message_type="task_status", + include_in_llm_message_history=False + ) + + return error_status + + return {"status": "running"} + + @wraps(func) + async def wrapper(*args, **kwargs): + return await func(*args, **kwargs) + return wrapper + return decorator + +@app.get("/threads/{thread_id}/tasks") +async def get_thread_tasks(thread_id: str): + """Get all tasks associated with a thread.""" + if not await thread_manager.thread_exists(thread_id): + raise HTTPException(status_code=404, detail="Thread not found") + + thread_tasks = { + task_id: { + "path": info["path"], + "status": info["status"], + "started_at": info["started_at"] + } + for task_id, info in _running_tasks.items() + if info["thread_id"] == thread_id + } + + return {"tasks": thread_tasks} + +@app.delete("/threads/{thread_id}/tasks") +async def stop_thread_tasks(thread_id: str): + """Stop all tasks associated with a thread.""" + if not await thread_manager.thread_exists(thread_id): + raise HTTPException(status_code=404, detail="Thread not found") + + stopped_tasks = [] + for task_id, info in list(_running_tasks.items()): + if info["thread_id"] == thread_id: + info["task"].cancel() + info["status"] = "cancelled" + stopped_tasks.append(task_id) + + # Update thread with task cancellations + if stopped_tasks: + await thread_manager.add_message( + thread_id=thread_id, + message_data={ + "type": "tasks_stopped", + "task_ids": stopped_tasks, + "status": "cancelled" + }, + message_type="task_status", + include_in_llm_message_history=False + ) + + return {"stopped_tasks": stopped_tasks} \ No newline at end of file diff --git a/agentpress/api/ws.py b/agentpress/api/ws.py new file mode 100644 index 0000000..2ee5a28 --- /dev/null +++ b/agentpress/api/ws.py @@ -0,0 +1,45 @@ +"""WebSocket management system for real-time updates.""" + +import logging +from typing import Dict, List +from fastapi import WebSocket, WebSocketDisconnect + +class WebSocketManager: + """Manages WebSocket connections for real-time thread updates.""" + + def __init__(self): + self.active_connections: Dict[str, List[WebSocket]] = {} + + async def connect(self, websocket: WebSocket, thread_id: str): + """Connect a WebSocket to a thread.""" + await websocket.accept() + if thread_id not in self.active_connections: + self.active_connections[thread_id] = [] + self.active_connections[thread_id].append(websocket) + + def disconnect(self, websocket: WebSocket, thread_id: str): + """Disconnect a WebSocket from a thread.""" + if thread_id in self.active_connections: + self.active_connections[thread_id].remove(websocket) + if not self.active_connections[thread_id]: + del self.active_connections[thread_id] + + async def broadcast_to_thread(self, thread_id: str, message: dict): + """Broadcast a message to all connections in a thread.""" + if thread_id in self.active_connections: + disconnected = [] + for connection in self.active_connections[thread_id]: + try: + await connection.send_json(message) + except WebSocketDisconnect: + disconnected.append(connection) + except Exception as e: + logging.warning(f"Failed to send message to websocket: {e}") + disconnected.append(connection) + + # Clean up disconnected connections + for connection in disconnected: + self.disconnect(connection, thread_id) + +# Global WebSocket manager instance +ws_manager = WebSocketManager() \ No newline at end of file diff --git a/agentpress/api_factory.py b/agentpress/api_factory.py deleted file mode 100644 index 69d0257..0000000 --- a/agentpress/api_factory.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -API Factory for registering and managing FastAPI endpoints. -""" - -import sys -import inspect -import importlib -import pkgutil -import uuid -import asyncio -import logging -import os -from functools import wraps -from typing import Callable, Dict, Any, Optional, List -from fastapi import FastAPI, BackgroundTasks, HTTPException -from pydantic import create_model, BaseModel - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -app = FastAPI() -_decorated_functions: Dict[str, Callable] = {} -_running_tasks: Dict[str, asyncio.Task] = {} - -def register_api_endpoint(path: str): - """Decorator to register a function as an API endpoint with task management.""" - def decorator(func: Callable): - logger.info(f"Registering API endpoint: {path} for function {func.__name__}") - _decorated_functions[path] = func - - # Create Pydantic model for function parameters - params = inspect.signature(func).parameters - model_fields = { - name: (param.annotation if param.annotation != inspect.Parameter.empty else Any, ... if param.default == inspect.Parameter.empty else param.default) - for name, param in params.items() - if name != 'self' # Skip self parameter for methods - } - RequestModel = create_model(f'{func.__name__}Request', **model_fields) - - # Register the start endpoint - @app.post(f"{path}/start", response_model=dict) - async def start_task(params: Optional[RequestModel] = None): - logger.info(f"Starting task at {path}/start") - task_id = str(uuid.uuid4()) - kwargs = params.dict() if params else {} - _running_tasks[task_id] = asyncio.create_task(func(**kwargs)) - return {"task_id": task_id} - - # Register the stop endpoint - @app.post(f"{path}/stop/{{task_id}}") - async def stop_task(task_id: str): - if task_id not in _running_tasks: - raise HTTPException(status_code=404, detail="Task not found") - _running_tasks[task_id].cancel() - return {"status": "stopped"} - - # Register the status endpoint - @app.get(f"{path}/status/{{task_id}}") - async def get_status(task_id: str): - if task_id not in _running_tasks: - raise HTTPException(status_code=404, detail="Task not found") - task = _running_tasks[task_id] - - if task.done(): - try: - result = task.result() - # Check if this is a streaming response - if hasattr(result, '__aiter__') or ( - isinstance(result, dict) and ( - any(hasattr(v, '__aiter__') for v in result.values()) or - # Check for streaming responses in iterations - ( - 'iterations' in result and - result['iterations'] and - any(hasattr(r.get('response'), '__aiter__') for r in result['iterations']) - ) - ) - ): - return {"status": "streaming"} - return {"status": "completed", "result": result} - except asyncio.CancelledError: - return {"status": "cancelled"} - except Exception as e: - return {"status": "failed", "error": str(e)} - return {"status": "running"} - - # Also register a direct endpoint for simple calls - @app.post(path) - async def direct_call(background_tasks: BackgroundTasks, params: Optional[RequestModel] = None): - kwargs = params.dict() if params else {} - task_id = str(uuid.uuid4()) - task = asyncio.create_task(func(**kwargs)) - _running_tasks[task_id] = task - return {"task_id": task_id} - - logger.info(f"Successfully registered all endpoints for {path}") - - @wraps(func) - async def wrapper(*args, **kwargs): - return await func(*args, **kwargs) - return wrapper - return decorator - -def find_project_root() -> str: - """Find the project root by looking for pyproject.toml.""" - current = os.path.abspath(os.path.dirname(__file__)) - while current != '/': - if os.path.exists(os.path.join(current, 'pyproject.toml')): - return current - current = os.path.dirname(current) - return os.path.dirname(__file__) # Fallback to current directory - -def discover_tasks(): - """ - Discover all decorated functions in the project. - Scans from the project root (where pyproject.toml is located). - """ - logger.info("Starting task discovery") - - # Find project root - project_root = find_project_root() - logger.info(f"Project root found at: {project_root}") - - # Add project root to Python path if not already there - if project_root not in sys.path: - sys.path.insert(0, project_root) - - # Walk through all Python files in the project - for root, _, files in os.walk(project_root): - for file in files: - if file.endswith('.py'): - module_path = os.path.join(root, file) - module_name = os.path.relpath(module_path, project_root) - module_name = os.path.splitext(module_name)[0].replace(os.path.sep, '.') - - try: - logger.info(f"Attempting to import module: {module_name}") - module = importlib.import_module(module_name) - - # Inspect all module members - for name, obj in inspect.getmembers(module): - if inspect.isfunction(obj): - # Check if this function has been decorated with @task - if any( - path for path, func in _decorated_functions.items() - if func.__name__ == obj.__name__ and func.__module__ == obj.__module__ - ): - logger.info(f"Found already registered function: {obj.__name__}") - continue - - # Check for our decorator in the function's closure - if hasattr(obj, '__closure__') and obj.__closure__: - for cell in obj.__closure__: - if cell.cell_contents in _decorated_functions.values(): - # Found a decorated function that wasn't registered - path = next( - p for p, f in _decorated_functions.items() - if f == cell.cell_contents - ) - _decorated_functions[path] = obj - logger.info(f"Registered function: {obj.__name__} at path: {path}") - - except Exception as e: - logger.error(f"Error importing {module_name}: {e}") - - logger.info(f"Task discovery complete. Registered paths: {list(_decorated_functions.keys())}") - -# Auto-discover tasks on import -discover_tasks() \ No newline at end of file diff --git a/agentpress/cli.py b/agentpress/cli.py index 3300cb5..471fa7f 100644 --- a/agentpress/cli.py +++ b/agentpress/cli.py @@ -2,99 +2,44 @@ import shutil import click import questionary -from typing import List, Dict, Optional, Tuple +from typing import Dict import time import pkg_resources import requests from packaging import version -import re -MODULES = { - "llm": { - "required": True, - "files": ["llm.py"], - "description": "LLM Interface - Core module for interacting with large language models (OpenAI, Anthropic, 100+ LLMs using the OpenAI Input/Output Format powered by LiteLLM). Handles API calls, response streaming, and model-specific configurations." - }, - "tool": { - "required": True, - "files": [ - "tool.py", - "tool_registry.py" - ], - "description": "Tool System Foundation - Defines the base architecture for creating and managing tools. Includes the tool registry for registering, organizing, and accessing tool functions." - }, - "processors": { - "required": True, - "files": [ - "processor/base_processors.py", - "processor/llm_response_processor.py", - "processor/standard/standard_tool_parser.py", - "processor/standard/standard_tool_executor.py", - "processor/standard/standard_results_adder.py", - "processor/xml/xml_tool_parser.py", - "processor/xml/xml_tool_executor.py", - "processor/xml/xml_results_adder.py" - ], - "description": "Response Processing System - Handles parsing and executing LLM responses, managing tool calls, and processing results. Supports both standard OpenAI-style function calling and XML-based tool execution patterns." - }, - "thread_management": { - "required": True, - "files": [ - "thread_manager.py", - "thread_viewer_ui.py" - ], - "description": "Conversation Management System - Handles message threading, conversation history, and provides a UI for viewing conversation threads. Manages the flow of messages between the user, LLM, and tools." - }, - "state_management": { - "required": True, - "files": ["state_manager.py"], - "description": "State Persistence System - Provides thread-safe storage and retrieval of conversation state, tool data, and other persistent information. Enables maintaining context across sessions and managing shared state between components." - }, - "db_connection": { - "required": True, - "files": ["db_connection.py"], - "description": "Database Connection - Provides a connection to a SQLite database for storing and retrieving conversation state, tool data, and other persistent information." - } -} +PACKAGE_NAME = "agentpress" +PYPI_URL = f"https://pypi.org/pypi/{PACKAGE_NAME}/json" STARTER_EXAMPLES = { "simple_web_dev_example_agent": { "description": "Interactive web development agent with file and terminal manipulation capabilities. Demonstrates both standard and XML-based tool calling patterns.", "files": { - "agent.py": "examples/simple_web_dev/agent.py", - "tools/files_tool.py": "examples/simple_web_dev/tools/files_tool.py", - "tools/terminal_tool.py": "examples/simple_web_dev/tools/terminal_tool.py", - ".env.example": "examples/.env.example" + "agent.py": "agents/simple_web_dev/agent.py", + "tools/files_tool.py": "agents/simple_web_dev/tools/files_tool.py", + "tools/terminal_tool.py": "agents/simple_web_dev/tools/terminal_tool.py", + ".env.example": "agents/.env.example" } } } -PACKAGE_NAME = "agentpress" -PYPI_URL = f"https://pypi.org/pypi/{PACKAGE_NAME}/json" - -def check_for_updates() -> Tuple[Optional[str], Optional[str], bool]: - """ - Check if there's a newer version available on PyPI - Returns: (current_version, latest_version, update_available) - """ +def check_for_updates(): + """Check if there's a newer version available on PyPI""" try: current_version = pkg_resources.get_distribution(PACKAGE_NAME).version response = requests.get(PYPI_URL, timeout=2) - response.raise_for_status() # Raise exception for bad status codes + response.raise_for_status() latest_version = response.json()["info"]["version"] - # Compare versions properly using packaging.version current_ver = version.parse(current_version) latest_ver = version.parse(latest_version) return current_version, latest_version, latest_ver > current_ver except requests.RequestException: - # Handle network-related errors silently return None, None, False except Exception as e: - # Log other unexpected errors but don't break the CLI click.echo(f"Warning: Failed to check for updates: {str(e)}", err=True) return None, None, False @@ -102,7 +47,6 @@ def show_welcome(): """Display welcome message with ASCII art""" click.clear() - # Check for updates current_version, latest_version, update_available = check_for_updates() click.echo(""" @@ -122,16 +66,17 @@ def show_welcome(): time.sleep(1) -def copy_module_files(src_dir: str, dest_dir: str, files: List[str]): - """Copy module files from package to destination""" +def copy_package_files(src_dir: str, dest_dir: str): + """Copy all package files except agents folder to destination""" os.makedirs(dest_dir, exist_ok=True) - with click.progressbar(files, label='Copying files') as file_list: - for file in file_list: - src = os.path.join(src_dir, file) - dst = os.path.join(dest_dir, file) - os.makedirs(os.path.dirname(dst), exist_ok=True) - shutil.copy2(src, dst) + def ignore_patterns(path, names): + # Ignore the agents directory and any __pycache__ directories + return [n for n in names if n == 'agents' or n == '__pycache__'] + + with click.progressbar(length=1, label='Copying files') as bar: + shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True, ignore=ignore_patterns) + bar.update(1) def copy_example_files(src_dir: str, dest_dir: str, files: Dict[str, str]): """Copy example files from package to destination""" @@ -142,19 +87,6 @@ def copy_example_files(src_dir: str, dest_dir: str, files: Dict[str, str]): shutil.copy2(src, dst) click.echo(f" ✓ Created {dest_path}") -def update_file_paths(file_path: str, replacements: Dict[str, str]): - """Update file paths in the given file""" - with open(file_path, 'r') as f: - content = f.read() - - for old, new in replacements.items(): - # Escape special characters in the old string - escaped_old = re.escape(old) - content = re.sub(escaped_old, new, content) - - with open(file_path, 'w') as f: - f.write(content) - @click.group() def cli(): """AgentPress CLI - Initialize your AgentPress modules""" @@ -165,7 +97,6 @@ def init(): """Initialize AgentPress modules in your project""" show_welcome() - # Set components directory name to 'agentpress' components_dir = "agentpress" if os.path.exists(components_dir): @@ -195,41 +126,20 @@ def init(): # Get package directory package_dir = os.path.dirname(os.path.abspath(__file__)) - # Show all modules status - click.echo("\n🔧 AgentPress Modules Configuration\n") - - # Show required modules including state_manager - click.echo("📦 Required Modules (pre-selected):") - required_modules = {name: module for name, module in MODULES.items() - if module["required"] or name == "state_management"} - for name, module in required_modules.items(): - click.echo(f" ✓ {click.style(name, fg='green')} - {module['description']}") - - # Create selections dict with required modules pre-selected - selections = {name: True for name in required_modules.keys()} - click.echo("\n🚀 Setting up your AgentPress...") time.sleep(0.5) try: - # Copy selected modules - selected_modules = [name for name, selected in selections.items() if selected] - all_files = [] - for module in selected_modules: - all_files.extend(MODULES[module]["files"]) - - # Create components directory and copy module files + # Create components directory and copy all files except agents folder components_dir_path = os.path.abspath(components_dir) - copy_module_files(package_dir, components_dir_path, all_files) - + copy_package_files(package_dir, components_dir_path) - - # Copy example only if a valid example (not None) was selected + # Copy example if selected if selected_example and selected_example in STARTER_EXAMPLES: click.echo(f"\n📝 Creating {selected_example}...") copy_example_files( package_dir, - os.getcwd(), # Use current working directory + os.getcwd(), STARTER_EXAMPLES[selected_example]["files"] ) @@ -246,7 +156,6 @@ def init(): click.echo(f"\nRun the example agent:") click.echo(" python agent.py") - except Exception as e: click.echo(f"\n❌ Error during setup: {str(e)}", err=True) return diff --git a/agentpress/db_connection.py b/agentpress/db_connection.py index 8802984..cc52f1c 100644 --- a/agentpress/db_connection.py +++ b/agentpress/db_connection.py @@ -7,105 +7,122 @@ from contextlib import asynccontextmanager import os import asyncio -import json class DBConnection: """Singleton database connection manager.""" + _instance = None _initialized = False - _db_path = "ap.db" + _db_path = "/Users/markokraemer/Projects/softgen/softgen-core/main.db" + _init_lock = asyncio.Lock() + _initialization_task = None def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) + # Start initialization when instance is first created + cls._initialization_task = asyncio.create_task(cls._instance._initialize()) return cls._instance - async def initialize(self): - """Initialize the database connection and schema.""" - if self._initialized: - return + def __init__(self): + """No initialization needed in __init__ as it's handled in __new__""" + pass - try: - # Ensure the database directory exists - os.makedirs(os.path.dirname(os.path.abspath(self._db_path)), exist_ok=True) + @classmethod + async def _initialize(cls): + """Internal initialization method.""" + if cls._initialized: + return - # Initialize database and create schema - async with aiosqlite.connect(self._db_path) as db: - await db.execute("PRAGMA foreign_keys = ON") + async with cls._init_lock: + if cls._initialized: # Double-check after acquiring lock + return - # Create threads table - await db.execute(""" - CREATE TABLE IF NOT EXISTS threads ( - id TEXT PRIMARY KEY, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """) + try: + async with aiosqlite.connect(cls._db_path) as db: + # Threads table + await db.execute(""" + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Messages table + await db.execute(""" + CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + thread_id TEXT, + type TEXT, + content TEXT, + include_in_llm_message_history BOOLEAN DEFAULT TRUE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (thread_id) REFERENCES threads (id) + ) + """) + + await db.commit() + cls._initialized = True + logging.info("Database schema initialized") + except Exception as e: + logging.error(f"Database initialization error: {e}") + raise - # Create events table - await db.execute(""" - CREATE TABLE IF NOT EXISTS events ( - id TEXT PRIMARY KEY, - thread_id TEXT, - type TEXT, - content TEXT, - include_in_llm_message_history INTEGER DEFAULT 0, - llm_message TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE - ) - """) + @classmethod + def set_db_path(cls, db_path: str): + """Set custom database path.""" + if cls._initialized: + raise RuntimeError("Cannot change database path after initialization") + cls._db_path = db_path + logging.info(f"Updated database path to: {db_path}") - await db.commit() - logging.info("Database initialized successfully") - self._initialized = True - - except Exception as e: - logging.error(f"Failed to initialize database: {e}") - raise + @asynccontextmanager + async def connection(self): + """Get a database connection.""" + # Wait for initialization to complete if it hasn't already + if self._initialization_task and not self._initialized: + await self._initialization_task + + async with aiosqlite.connect(self._db_path) as conn: + try: + yield conn + except Exception as e: + logging.error(f"Database error: {e}") + raise @asynccontextmanager async def transaction(self): - """Get a database connection with transaction support.""" - if not self._initialized: - raise Exception("Database not initialized. Call initialize() first.") - - async with aiosqlite.connect(self._db_path) as db: - await db.execute("PRAGMA foreign_keys = ON") + """Execute operations in a transaction.""" + async with self.connection() as db: try: yield db await db.commit() - logging.debug("Transaction committed successfully") except Exception as e: await db.rollback() - logging.error(f"Transaction failed, rolling back: {e}") + logging.error(f"Transaction error: {e}") raise async def execute(self, query: str, params: tuple = ()): - """Execute a query and return the cursor.""" - async with aiosqlite.connect(self._db_path) as db: - await db.execute("PRAGMA foreign_keys = ON") - return await db.execute(query, params) - - async def fetch_all(self, query: str, params: tuple = ()): - """Execute a query and fetch all results.""" - async with aiosqlite.connect(self._db_path) as db: - await db.execute("PRAGMA foreign_keys = ON") - cursor = await db.execute(query, params) - return await cursor.fetchall() + """Execute a single query.""" + async with self.connection() as db: + try: + result = await db.execute(query, params) + await db.commit() + return result + except Exception as e: + logging.error(f"Query execution error: {e}") + raise async def fetch_one(self, query: str, params: tuple = ()): - """Execute a query and fetch one result.""" - async with aiosqlite.connect(self._db_path) as db: - await db.execute("PRAGMA foreign_keys = ON") - cursor = await db.execute(query, params) - return await cursor.fetchone() + """Fetch a single row.""" + async with self.connection() as db: + async with db.execute(query, params) as cursor: + return await cursor.fetchone() - def _serialize_json(self, data): - """Serialize data to JSON string.""" - return json.dumps(data) if data is not None else None - - def _deserialize_json(self, data): - """Deserialize JSON string to data.""" - return json.loads(data) if data is not None else None \ No newline at end of file + async def fetch_all(self, query: str, params: tuple = ()): + """Fetch all rows.""" + async with self.connection() as db: + async with db.execute(query, params) as cursor: + return await cursor.fetchall() \ No newline at end of file diff --git a/agentpress/llm.py b/agentpress/llm.py index 0ad4432..57194fd 100644 --- a/agentpress/llm.py +++ b/agentpress/llm.py @@ -1,4 +1,4 @@ -from typing import Union, Dict, Any +from typing import Union, Dict, Any, Optional, List import litellm import os import json @@ -11,6 +11,14 @@ ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY', None) GROQ_API_KEY = os.environ.get('GROQ_API_KEY', None) AGENTOPS_API_KEY = os.environ.get('AGENTOPS_API_KEY', None) +FIREWORKS_API_KEY = os.environ.get('FIREWORKS_AI_API_KEY', None) +DEEPSEEK_API_KEY = os.environ.get('DEEPSEEK_API_KEY', None) +OPENROUTER_API_KEY = os.environ.get('OPENROUTER_API_KEY', None) +GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY', None) + +AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID', None) +AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY', None) +AWS_REGION_NAME = os.environ.get('AWS_REGION_NAME', None) if OPENAI_API_KEY: os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY @@ -18,6 +26,22 @@ os.environ['ANTHROPIC_API_KEY'] = ANTHROPIC_API_KEY if GROQ_API_KEY: os.environ['GROQ_API_KEY'] = GROQ_API_KEY +if FIREWORKS_API_KEY: + os.environ['FIREWORKS_AI_API_KEY'] = FIREWORKS_API_KEY +if DEEPSEEK_API_KEY: + os.environ['DEEPSEEK_API_KEY'] = DEEPSEEK_API_KEY +if OPENROUTER_API_KEY: + os.environ['OPENROUTER_API_KEY'] = OPENROUTER_API_KEY +if GEMINI_API_KEY: + os.environ['GEMINI_API_KEY'] = GEMINI_API_KEY + +# Add AWS environment variables if they exist +if AWS_ACCESS_KEY_ID: + os.environ['AWS_ACCESS_KEY_ID'] = AWS_ACCESS_KEY_ID +if AWS_SECRET_ACCESS_KEY: + os.environ['AWS_SECRET_ACCESS_KEY'] = AWS_SECRET_ACCESS_KEY +if AWS_REGION_NAME: + os.environ['AWS_REGION_NAME'] = AWS_REGION_NAME async def make_llm_api_call( messages: list, @@ -31,7 +55,8 @@ async def make_llm_api_call( api_base: str = None, agentops_session: Any = None, stream: bool = False, - top_p: float = None + top_p: float = None, + stop: Optional[Union[str, List[str]]] = None # Add stop parameter ) -> Union[Dict[str, Any], Any]: """ Make an API call to a language model using litellm. @@ -52,6 +77,7 @@ async def make_llm_api_call( agentops_session (Any, optional): Session for agentops integration stream (bool, optional): Whether to stream the response. Defaults to False top_p (float, optional): Top-p sampling parameter + stop (Union[str, List[str]], optional): Up to 4 sequences where the API will stop generating tokens Returns: Union[Dict[str, Any], Any]: API response, either complete or streaming @@ -59,7 +85,7 @@ async def make_llm_api_call( Raises: Exception: If API call fails after retries """ - # litellm.set_verbose = True + litellm.set_verbose = False async def attempt_api_call(api_call_func, max_attempts=3): """ @@ -75,10 +101,17 @@ async def attempt_api_call(api_call_func, max_attempts=3): Raises: Exception: If all retry attempts fail """ + nonlocal model_name # Add this to access model_name for attempt in range(max_attempts): try: return await api_call_func() except litellm.exceptions.RateLimitError as e: + # Check if it's Bedrock Claude and switch to direct Anthropic + if "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0" in model_name: + logging.info("Rate limit hit with Bedrock Claude, falling back to direct Anthropic API...") + model_name = "anthropic/claude-3-5-sonnet-latest" + continue + logging.warning(f"Rate limit exceeded. Waiting for 30 seconds before retrying...") await asyncio.sleep(30) except OpenAIError as e: @@ -105,6 +138,10 @@ async def api_call(): "stream": stream, } + # Add stop parameter if provided + if stop is not None: + api_call_params["stop"] = stop + # Add optional parameters if provided if api_key: api_call_params["api_key"] = api_key @@ -129,6 +166,32 @@ async def api_call(): "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" } + # Add OpenRouter specific parameters + if "openrouter" in model_name.lower(): + if settings.or_site_url: + api_call_params["headers"] = { + "HTTP-Referer": settings.or_site_url + } + if settings.or_app_name: + api_call_params["headers"] = { + "X-Title": settings.or_app_name + } + + # Add special handling for Deepseek + if "deepseek" in model_name.lower(): + api_call_params["frequency_penalty"] = 0.5 + api_call_params["temperature"] = 0.7 + api_call_params["presence_penalty"] = 0.1 + + # Add Bedrock-specific parameters + if "bedrock" in model_name.lower(): + if settings.aws_access_key_id: + api_call_params["aws_access_key_id"] = settings.aws_access_key_id + if settings.aws_secret_access_key: + api_call_params["aws_secret_access_key"] = settings.aws_secret_access_key + if settings.aws_region_name: + api_call_params["aws_region_name"] = settings.aws_region_name + # Log the API request # logging.info(f"Sending API request: {json.dumps(api_call_params, indent=2)}") @@ -137,10 +200,36 @@ async def api_call(): response = await agentops_session.patch(litellm.acompletion)(**api_call_params) else: response = await litellm.acompletion(**api_call_params) - - # Log the API response + # logging.info(f"Received API response: {response}") + # # For streaming responses, attach cost tracking + # if stream: + # # Create a wrapper object to track costs across chunks + # cost_tracker = { + # "prompt_tokens": 0, + # "completion_tokens": 0, + # "total_tokens": 0, + # "cost": 0.0 + # } + + # # Get the cost per token for the model + # model_cost = litellm.model_cost.get(model_name, {}) + # input_cost = model_cost.get('input_cost_per_token', 0) + # output_cost = model_cost.get('output_cost_per_token', 0) + + # # Attach the cost tracker to the response + # response.cost_tracker = cost_tracker + # response.model_info = { + # "input_cost_per_token": input_cost, + # "output_cost_per_token": output_cost + # } + # else: + # # For non-streaming, cost is already included in the response + # response._hidden_params = { + # "response_cost": litellm.completion_cost(completion_response=response) + # } + return response return await attempt_api_call(api_call) @@ -188,4 +277,37 @@ async def test_llm_api_call(stream=True): print(response.choices[0].message.content) print() - asyncio.run(test_llm_api_call()) + # asyncio.run(test_llm_api_call()) + + async def test_bedrock(): + """ + Test function for Bedrock API call. + """ + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello from Bedrock!"} + ] + model_name = "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0" + + response = await make_llm_api_call(messages, model_name, stream=True) + + print("\n🤖 Streaming response from Bedrock:\n") + buffer = "" + async for chunk in response: + if isinstance(chunk, dict) and 'choices' in chunk: + content = chunk['choices'][0]['delta'].get('content', '') + else: + content = chunk.choices[0].delta.content + + if content: + buffer += content + if content[-1].isspace(): + print(buffer, end='', flush=True) + buffer = "" + + if buffer: + print(buffer, flush=True) + print("\n✨ Stream completed.\n") + + # Add test_bedrock to the test runs + # asyncio.run(test_bedrock()) diff --git a/agentpress/processor/__init__.py b/agentpress/processor/__init__.py new file mode 100644 index 0000000..3ebeb22 --- /dev/null +++ b/agentpress/processor/__init__.py @@ -0,0 +1 @@ +# Empty file to mark as package \ No newline at end of file diff --git a/agentpress/processor/base_processors.py b/agentpress/processor/base_processors.py index ec72418..52e4943 100644 --- a/agentpress/processor/base_processors.py +++ b/agentpress/processor/base_processors.py @@ -172,7 +172,7 @@ class ResultsAdderBase(ABC): Attributes: add_message: Callback for adding new messages update_message: Callback for updating existing messages - get_messages: Callback for retrieving thread messages + get_llm_history_messages: Callback for retrieving thread messages message_added: Flag tracking if initial message has been added """ @@ -184,7 +184,7 @@ def __init__(self, thread_manager): """ self.add_message = thread_manager.add_message self.update_message = thread_manager._update_message - self.get_messages = thread_manager.get_messages + self.get_llm_history_messages = thread_manager.get_llm_history_messages self.message_added = False @abstractmethod diff --git a/agentpress/processor/llm_response_processor.py b/agentpress/processor/llm_response_processor.py index 2457221..5ecb19e 100644 --- a/agentpress/processor/llm_response_processor.py +++ b/agentpress/processor/llm_response_processor.py @@ -9,12 +9,9 @@ """ import asyncio -from typing import Callable, Dict, Any, AsyncGenerator, Optional +from typing import Dict, Any, AsyncGenerator import logging from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase -from agentpress.processor.standard.standard_tool_parser import StandardToolParser -from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor -from agentpress.processor.standard.standard_results_adder import StandardResultsAdder class LLMResponseProcessor: """Handles LLM response processing and tool execution management. @@ -37,51 +34,30 @@ class LLMResponseProcessor: def __init__( self, thread_id: str, - available_functions: Dict = None, - add_message_callback: Callable = None, - update_message_callback: Callable = None, - get_messages_callback: Callable = None, - parallel_tool_execution: bool = True, - tool_parser: Optional[ToolParserBase] = None, - tool_executor: Optional[ToolExecutorBase] = None, - results_adder: Optional[ResultsAdderBase] = None, - thread_manager = None + tool_executor: ToolExecutorBase, + tool_parser: ToolParserBase, + available_functions: Dict, + results_adder: ResultsAdderBase ): """Initialize the response processor. Args: thread_id: ID of the conversation thread - available_functions: Dictionary of available tool functions - add_message_callback: Callback for adding messages - update_message_callback: Callback for updating messages - get_messages_callback: Callback for listing messages - parallel_tool_execution: Whether to execute tools in parallel - tool_parser: Custom tool parser implementation tool_executor: Custom tool executor implementation + tool_parser: Custom tool parser implementation + available_functions: Dictionary of available tool functions results_adder: Custom results adder implementation - thread_manager: Optional thread manager instance """ self.thread_id = thread_id - self.tool_executor = tool_executor or StandardToolExecutor(parallel=parallel_tool_execution) - self.tool_parser = tool_parser or StandardToolParser() - self.available_functions = available_functions or {} - - # Create minimal thread manager if needed - if thread_manager is None and (add_message_callback and update_message_callback and get_messages_callback): - class MinimalThreadManager: - def __init__(self, add_msg, update_msg, list_msg): - self.add_message = add_msg - self._update_message = update_msg - self.get_messages = list_msg - thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, get_messages_callback) - - self.results_adder = results_adder or StandardResultsAdder(thread_manager) - - # State tracking for streaming - self.tool_calls_buffer = {} - self.processed_tool_calls = set() + self.tool_executor = tool_executor + self.tool_parser = tool_parser + self.available_functions = available_functions + self.results_adder = results_adder self.content_buffer = "" + self.tool_calls_buffer = {} self.tool_calls_accumulated = [] + self.processed_tool_calls = set() + self._executing_tools = set() # Track currently executing tools async def process_stream( self, @@ -92,8 +68,9 @@ async def process_stream( """Process streaming LLM response and handle tool execution.""" pending_tool_calls = [] background_tasks = set() + stream_completed = False # New flag to track stream completion - async def handle_message_management(chunk): + async def handle_message_management(chunk, is_final=False): try: # Accumulate content if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: @@ -106,28 +83,37 @@ async def handle_message_management(chunk): self.tool_calls_buffer ) if parsed_message and 'tool_calls' in parsed_message: - self.tool_calls_accumulated = parsed_message['tool_calls'] + new_tool_calls = [ + tool_call for tool_call in parsed_message['tool_calls'] + if tool_call['id'] not in self.processed_tool_calls + ] + if new_tool_calls: + self.tool_calls_accumulated.extend(new_tool_calls) # Handle tool execution and results if execute_tools and self.tool_calls_accumulated: new_tool_calls = [ tool_call for tool_call in self.tool_calls_accumulated - if tool_call['id'] not in self.processed_tool_calls + if (tool_call['id'] not in self.processed_tool_calls and + tool_call['id'] not in self._executing_tools) ] if new_tool_calls: if execute_tools_on_stream: + for tool_call in new_tool_calls: + self._executing_tools.add(tool_call['id']) + results = await self.tool_executor.execute_tool_calls( tool_calls=new_tool_calls, available_functions=self.available_functions, thread_id=self.thread_id, executed_tool_calls=self.processed_tool_calls ) + for result in results: await self.results_adder.add_tool_result(self.thread_id, result) self.processed_tool_calls.add(result['tool_call_id']) - else: - pending_tool_calls.extend(new_tool_calls) + self._executing_tools.discard(result['tool_call_id']) # Add/update assistant message message = { @@ -152,7 +138,10 @@ async def handle_message_management(chunk): ) # Handle stream completion - if chunk.choices[0].finish_reason: + if chunk.choices[0].finish_reason or is_final: + nonlocal stream_completed + stream_completed = True + if not execute_tools_on_stream and pending_tool_calls: results = await self.tool_executor.execute_tool_calls( tool_calls=pending_tool_calls, @@ -165,8 +154,16 @@ async def handle_message_management(chunk): self.processed_tool_calls.add(result['tool_call_id']) pending_tool_calls.clear() + # Set final state on the chunk instead of returning it + chunk._final_state = { + "content": self.content_buffer, + "tool_calls": self.tool_calls_accumulated, + "processed_tool_calls": list(self.processed_tool_calls) + } + except Exception as e: logging.error(f"Error in background task: {e}") + raise try: async for chunk in response_stream: @@ -175,9 +172,22 @@ async def handle_message_management(chunk): task.add_done_callback(background_tasks.discard) yield chunk + # Create a final dummy chunk to handle completion + final_chunk = type('DummyChunk', (), { + 'choices': [type('DummyChoice', (), { + 'delta': type('DummyDelta', (), {'content': None}), + 'finish_reason': 'stop' + })] + })() + + # Process final state + await handle_message_management(final_chunk, is_final=True) + yield final_chunk + + # Wait for all background tasks to complete if background_tasks: await asyncio.gather(*background_tasks, return_exceptions=True) - + except Exception as e: logging.error(f"Error in stream processing: {e}") for task in background_tasks: diff --git a/agentpress/processor/standard/__init__.py b/agentpress/processor/standard/__init__.py new file mode 100644 index 0000000..3ebeb22 --- /dev/null +++ b/agentpress/processor/standard/__init__.py @@ -0,0 +1 @@ +# Empty file to mark as package \ No newline at end of file diff --git a/agentpress/processor/standard/standard_results_adder.py b/agentpress/processor/standard/standard_results_adder.py index 49d08b6..810091e 100644 --- a/agentpress/processor/standard/standard_results_adder.py +++ b/agentpress/processor/standard/standard_results_adder.py @@ -81,6 +81,6 @@ async def add_tool_result(self, thread_id: str, result: Dict[str, Any]): - Checks for duplicate tool results before adding - Adds result only if tool_call_id is unique """ - messages = await self.get_messages(thread_id) + messages = await self.get_llm_history_messages(thread_id) if not any(msg.get('tool_call_id') == result['tool_call_id'] for msg in messages): await self.add_message(thread_id, result) diff --git a/agentpress/processor/xml/__init__.py b/agentpress/processor/xml/__init__.py new file mode 100644 index 0000000..3ebeb22 --- /dev/null +++ b/agentpress/processor/xml/__init__.py @@ -0,0 +1 @@ +# Empty file to mark as package \ No newline at end of file diff --git a/agentpress/processor/xml/xml_results_adder.py b/agentpress/processor/xml/xml_results_adder.py index 49593da..4b8123d 100644 --- a/agentpress/processor/xml/xml_results_adder.py +++ b/agentpress/processor/xml/xml_results_adder.py @@ -79,7 +79,7 @@ async def add_tool_result(self, thread_id: str, result: Dict[str, Any]): """ try: # Get the original tool call to find the root tag - messages = await self.get_messages(thread_id) + messages = await self.get_llm_history_messages(thread_id) assistant_msg = next((msg for msg in reversed(messages) if msg['role'] == 'assistant'), None) @@ -107,10 +107,9 @@ async def add_tool_result(self, thread_id: str, result: Dict[str, Any]): await self.add_message(thread_id, result_message) except Exception as e: - logging.error(f"Error adding tool result: {e}") - # Ensure the result is still added even if there's an error + logging.error(f"Error adding tool result: {e}") # Ensure the result is still added even if there's an error result_message = { "role": "user", "content": f"Result for {result['name']}:\n{result['content']}" } - await self.add_message(thread_id, result_message) \ No newline at end of file + await self.add_message(thread_id, result_message) diff --git a/agentpress/processor/xml/xml_tool_executor.py b/agentpress/processor/xml/xml_tool_executor.py index 185e8e6..10bcb68 100644 --- a/agentpress/processor/xml/xml_tool_executor.py +++ b/agentpress/processor/xml/xml_tool_executor.py @@ -38,6 +38,8 @@ def __init__(self, parallel: bool = True, tool_registry: Optional[ToolRegistry] """ self.parallel = parallel self.tool_registry = tool_registry or ToolRegistry() + # Add internal tracking of executed tools + self._executed_tools = set() async def execute_tool_calls( self, @@ -65,20 +67,28 @@ async def execute_tool_calls( if executed_tool_calls is None: executed_tool_calls = set() + # Filter out already executed tool calls + new_tool_calls = [ + tool_call for tool_call in tool_calls + if tool_call['id'] not in executed_tool_calls + and tool_call['id'] not in self._executed_tools + ] + + if not new_tool_calls: + return [] + if self.parallel: - return await self._execute_parallel( - tool_calls, - available_functions, - thread_id, - executed_tool_calls - ) + results = await self._execute_parallel(new_tool_calls, available_functions, thread_id, executed_tool_calls) else: - return await self._execute_sequential( - tool_calls, - available_functions, - thread_id, - executed_tool_calls - ) + results = await self._execute_sequential(new_tool_calls, available_functions, thread_id, executed_tool_calls) + + # Track executed tools internally + for tool_call in new_tool_calls: + self._executed_tools.add(tool_call['id']) + if executed_tool_calls is not None: + executed_tool_calls.add(tool_call['id']) + + return results async def _execute_parallel( self, @@ -87,9 +97,10 @@ async def _execute_parallel( thread_id: str, executed_tool_calls: Set[str] ) -> List[Dict[str, Any]]: - async def execute_single_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]: - if tool_call['id'] in executed_tool_calls: - logging.info(f"Tool call {tool_call['id']} already executed") + async def execute_single_tool(tool_call: Dict[str, Any]) -> Optional[Dict[str, Any]]: + # Double-check the tool hasn't been executed + if (tool_call['id'] in executed_tool_calls or + tool_call['id'] in self._executed_tools): return None try: diff --git a/agentpress/state_manager.py b/agentpress/state_manager.py index 01bf1d4..4281e43 100644 --- a/agentpress/state_manager.py +++ b/agentpress/state_manager.py @@ -1,207 +1,118 @@ -""" -Manages persistent state storage for AgentPress components using thread-based events. - -The StateManager provides thread-safe access to state data stored as events in threads, -allowing components to save and retrieve data across sessions. Each state update -creates a new event containing the complete state. -""" - import json import logging from typing import Any, Optional, List, Dict -from asyncio import Lock +import uuid from agentpress.thread_manager import ThreadManager class StateManager: """ - Manages persistent state storage for AgentPress components using thread events. - - The StateManager provides thread-safe access to state data stored as events, - maintaining the complete state in each event for better consistency and tracking. - - Attributes: - lock (Lock): Asyncio lock for thread-safe state access - thread_id (str): Thread ID for state storage - thread_manager (ThreadManager): Thread manager instance for event handling + Manages state storage using thread messages. + Each state message contains a complete snapshot of the state at that point in time. """ def __init__(self, thread_id: str): - """ - Initialize StateManager with thread ID. - - Args: - thread_id (str): Thread ID for state storage - """ - self.lock = Lock() - self.thread_id = thread_id + """Initialize StateManager with a thread ID.""" self.thread_manager = ThreadManager() - logging.info(f"StateManager initialized with thread_id: {self.thread_id}") - - async def initialize(self): - """Initialize the thread manager.""" - await self.thread_manager.initialize() + self.thread_id = thread_id + self._state_cache = None + logging.info(f"StateManager initialized for thread: {thread_id}") - async def _ensure_initialized(self): - """Ensure thread manager is initialized.""" - if not self.thread_manager.db._initialized: - await self.initialize() + async def _get_state(self) -> Dict[str, Any]: + """Get the current complete state.""" + if self._state_cache is not None: + return self._state_cache.copy() # Return copy to prevent cache mutation - async def _get_current_state(self) -> dict: - """Get the current state from the latest state event.""" - await self._ensure_initialized() - events = await self.thread_manager.get_thread_events( - thread_id=self.thread_id, - event_types=["state"], - order_by="created_at", - order="DESC" + # Get the latest state message + rows = await self.thread_manager.db.fetch_all( + """ + SELECT content + FROM messages + WHERE thread_id = ? AND type = 'state_message' + ORDER BY created_at DESC LIMIT 1 + """, + (self.thread_id,) ) - if events: - return events[0]["content"].get("state", {}) + + if rows: + try: + self._state_cache = json.loads(rows[0][0]) + return self._state_cache.copy() + except json.JSONDecodeError: + logging.error("Failed to parse state JSON") + return {} - async def _save_state(self, state: dict): - """Save the complete state as a new event.""" - await self._ensure_initialized() - await self.thread_manager.create_event( + async def _save_state(self, state: Dict[str, Any]): + """Save a new complete state snapshot.""" + # Format state as a string with proper indentation + formatted_state = json.dumps(state, indent=2) + + # Save new state message with complete snapshot + await self.thread_manager.add_message( thread_id=self.thread_id, - event_type="state", - content={"state": state} + message_data=formatted_state, + message_type='state_message', + include_in_llm_message_history=False ) + + # Update cache with a copy + self._state_cache = state.copy() async def set(self, key: str, data: Any) -> Any: - """ - Store data with a key in the state. - - Args: - key (str): Simple string key like "config" or "settings" - data (Any): Any JSON-serializable data - - Returns: - Any: The stored data - """ - async with self.lock: - try: - current_state = await self._get_current_state() - current_state[key] = data - await self._save_state(current_state) - logging.info(f'Updated state key: {key}') - return data - except Exception as e: - logging.error(f"Error setting state: {e}") - raise + """Store any JSON-serializable data with a key.""" + state = await self._get_state() + state[key] = data + await self._save_state(state) + logging.info(f'Updated state key: {key}') + return data async def get(self, key: str) -> Optional[Any]: - """ - Get data for a key from the current state. - - Args: - key (str): Simple string key like "config" or "settings" - - Returns: - Any: The stored data for the key, or None if key not found - """ - async with self.lock: - try: - current_state = await self._get_current_state() - if key in current_state: - logging.info(f'Retrieved key: {key}') - return current_state[key] - logging.info(f'Key not found: {key}') - return None - except Exception as e: - logging.error(f"Error getting state: {e}") - raise + """Get data for a key.""" + state = await self._get_state() + if key in state: + data = state[key] + logging.info(f'Retrieved key: {key}') + return data + logging.info(f'Key not found: {key}') + return None async def delete(self, key: str): - """ - Delete a key from the state. - - Args: - key (str): Simple string key like "config" or "settings" - """ - async with self.lock: - try: - current_state = await self._get_current_state() - if key in current_state: - del current_state[key] - await self._save_state(current_state) - logging.info(f"Deleted key: {key}") - except Exception as e: - logging.error(f"Error deleting state: {e}") - raise + """Delete data for a key.""" + state = await self._get_state() + if key in state: + del state[key] + await self._save_state(state) + logging.info(f"Deleted key: {key}") async def update(self, key: str, data: Dict[str, Any]) -> Optional[Any]: - """ - Update existing dictionary data for a key by merging. - - Args: - key (str): Simple string key like "config" or "settings" - data (Dict[str, Any]): Dictionary of updates to merge - - Returns: - Optional[Any]: Updated data if successful, None if key not found - """ - async with self.lock: - try: - current_state = await self._get_current_state() - if key in current_state and isinstance(current_state[key], dict): - current_state[key].update(data) - await self._save_state(current_state) - return current_state[key] - return None - except Exception as e: - logging.error(f"Error updating state: {e}") - raise + """Update existing data for a key by merging dictionaries.""" + state = await self._get_state() + if key in state and isinstance(state[key], dict): + state[key].update(data) + await self._save_state(state) + logging.info(f'Updated state key: {key}') + return state[key] + return None async def append(self, key: str, item: Any) -> Optional[List[Any]]: - """ - Append an item to a list stored at key. - - Args: - key (str): Simple string key like "config" or "settings" - item (Any): Item to append - - Returns: - Optional[List[Any]]: Updated list if successful, None if key not found - """ - async with self.lock: - try: - current_state = await self._get_current_state() - if key not in current_state: - current_state[key] = [] - if isinstance(current_state[key], list): - current_state[key].append(item) - await self._save_state(current_state) - return current_state[key] - return None - except Exception as e: - logging.error(f"Error appending to state: {e}") - raise + """Append an item to a list stored at key.""" + state = await self._get_state() + if key not in state: + state[key] = [] + if isinstance(state[key], list): + state[key].append(item) + await self._save_state(state) + logging.info(f'Appended to key: {key}') + return state[key] + return None - async def get_latest_state(self) -> dict: - """ - Get the latest complete state. - - Returns: - dict: Complete contents of the latest state - """ - async with self.lock: - try: - state = await self._get_current_state() - logging.info(f"Retrieved latest state with {len(state)} keys") - return state - except Exception as e: - logging.error(f"Error getting latest state: {e}") - raise + async def export_store(self) -> dict: + """Export entire state.""" + state = await self._get_state() + return state - async def clear_state(self): - """ - Clear the entire state. - """ - async with self.lock: - try: - await self._save_state({}) - logging.info("Cleared state") - except Exception as e: - logging.error(f"Error clearing state: {e}") - raise + async def clear_store(self): + """Clear entire state.""" + await self._save_state({}) + self._state_cache = {} + logging.info("Cleared state") diff --git a/agentpress/thread_manager.py b/agentpress/thread_manager.py index 8289b34..921e448 100644 --- a/agentpress/thread_manager.py +++ b/agentpress/thread_manager.py @@ -2,7 +2,7 @@ Conversation thread management system for AgentPress. This module provides comprehensive conversation management, including: -- Thread and Event CRUD operations +- Thread creation and persistence - Message handling with support for text and images - Tool registration and execution - LLM interaction with streaming support @@ -13,251 +13,104 @@ import logging import asyncio import uuid -from datetime import datetime from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator from agentpress.llm import make_llm_api_call from agentpress.tool import Tool, ToolResult from agentpress.tool_registry import ToolRegistry from agentpress.processor.llm_response_processor import LLMResponseProcessor -from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase from agentpress.db_connection import DBConnection -from agentpress.processor.xml.xml_tool_parser import XMLToolParser -from agentpress.processor.xml.xml_tool_executor import XMLToolExecutor -from agentpress.processor.xml.xml_results_adder import XMLResultsAdder from agentpress.processor.standard.standard_tool_parser import StandardToolParser from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor from agentpress.processor.standard.standard_results_adder import StandardResultsAdder +from agentpress.processor.xml.xml_tool_parser import XMLToolParser +from agentpress.processor.xml.xml_tool_executor import XMLToolExecutor +from agentpress.processor.xml.xml_results_adder import XMLResultsAdder class ThreadManager: - """Manages conversation threads with LLM models and tool execution.""" + """Manages conversation threads with LLM models and tool execution. + + Provides comprehensive conversation management, handling message threading, + tool registration, and LLM interactions with support for both standard and + XML-based tool execution patterns. + """ def __init__(self): """Initialize ThreadManager.""" - self.tool_registry = ToolRegistry() self.db = DBConnection() - - async def initialize(self): - """Initialize async components.""" - await self.db.initialize() + self.tool_registry = ToolRegistry() def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs): """Add a tool to the ThreadManager.""" self.tool_registry.register_tool(tool_class, function_names, **kwargs) - async def thread_exists(self, thread_id: str) -> bool: - """Check if a thread exists.""" - await self._ensure_initialized() - result = await self.db.fetch_one( - "SELECT 1 FROM threads WHERE id = ?", - (thread_id,) - ) - return result is not None - - async def _ensure_initialized(self): - """Ensure database is initialized.""" - if not self.db._initialized: - await self.initialize() - - async def event_belongs_to_thread(self, event_id: str, thread_id: str) -> bool: - """Check if an event exists and belongs to a thread.""" - await self._ensure_initialized() - result = await self.db.fetch_one( - "SELECT 1 FROM events WHERE id = ? AND thread_id = ?", - (event_id, thread_id) - ) - return result is not None - - # Core Thread Operations async def create_thread(self) -> str: """Create a new conversation thread.""" - await self._ensure_initialized() thread_id = str(uuid.uuid4()) + await self.db.execute( + "INSERT INTO threads (id) VALUES (?)", + (thread_id,) + ) + return thread_id + + async def add_message( + self, + thread_id: str, + message_data: Union[str, Dict[str, Any]], + images: Optional[List[Dict[str, Any]]] = None, + include_in_llm_message_history: bool = True, + message_type: Optional[str] = None + ): + """Add a message to an existing thread.""" + logging.info(f"Adding message to thread {thread_id}") + try: - async with self.db.transaction() as conn: - await conn.execute( - "INSERT INTO threads (id) VALUES (?)", - (thread_id,) - ) - logging.info(f"Created thread: {thread_id}") - return thread_id - except Exception as e: - logging.error(f"Failed to create thread: {e}") - raise - - async def delete_thread(self, thread_id: str) -> bool: - """Delete a thread and all its events (cascade).""" - try: - result = await self.db.execute( - "DELETE FROM threads WHERE id = ?", - (thread_id,) - ) - # Check if any rows were affected - return result.rowcount > 0 - except Exception as e: - logging.error(f"Failed to delete thread {thread_id}: {e}") - return False - - # Core Event Operations - async def create_event( - self, - thread_id: str, - event_type: str, - content: Dict[str, Any], - include_in_llm_message_history: bool = False, - llm_message: Optional[Dict[str, Any]] = None - ) -> str: - """Create a new event in a thread.""" - await self._ensure_initialized() - event_id = str(uuid.uuid4()) - try: - async with self.db.transaction() as conn: - # First verify thread exists - cursor = await conn.execute("SELECT 1 FROM threads WHERE id = ?", (thread_id,)) - if not await cursor.fetchone(): - raise Exception(f"Thread {thread_id} does not exist") - - # Then create the event - await conn.execute( - """ - INSERT INTO events ( - id, thread_id, type, content, - include_in_llm_message_history, llm_message - ) VALUES (?, ?, ?, ?, ?, ?) - """, - ( - event_id, - thread_id, - event_type, - self.db._serialize_json(content), - 1 if include_in_llm_message_history else 0, - self.db._serialize_json(llm_message) if llm_message else None - ) - ) - logging.info(f"Created event {event_id} in thread {thread_id}") - return event_id - except Exception as e: - logging.error(f"Failed to create event in thread {thread_id}: {e}") - raise - - async def delete_event(self, event_id: str) -> bool: - """Delete a specific event.""" - try: - result = await self.db.execute( - "DELETE FROM events WHERE id = ?", - (event_id,) - ) - # Check if any rows were affected - return result.rowcount > 0 - except Exception as e: - logging.error(f"Failed to delete event {event_id}: {e}") - return False - - async def update_event( - self, - event_id: str, - thread_id: str, - content: Optional[Dict[str, Any]] = None, - include_in_llm_message_history: Optional[bool] = None, - llm_message: Optional[Dict[str, Any]] = None - ) -> bool: - """Update an existing event.""" - try: - # First verify the event exists and belongs to the thread - event = await self.db.fetch_one( - "SELECT 1 FROM events WHERE id = ? AND thread_id = ?", - (event_id, thread_id) - ) - if not event: - return False - - updates = [] - params = [] - if content is not None: - updates.append("content = ?") - params.append(self.db._serialize_json(content)) - if include_in_llm_message_history is not None: - updates.append("include_in_llm_message_history = ?") - params.append(1 if include_in_llm_message_history else 0) - if llm_message is not None: - updates.append("llm_message = ?") - params.append(self.db._serialize_json(llm_message)) - - if not updates: - return False - - query = f""" - UPDATE events - SET {', '.join(updates)}, updated_at = CURRENT_TIMESTAMP - WHERE id = ? AND thread_id = ? - """ - params.extend([event_id, thread_id]) - - result = await self.db.execute(query, tuple(params)) - return result.rowcount > 0 - except Exception as e: - logging.error(f"Failed to update event {event_id}: {e}") - return False - - async def get_thread_events( - self, - thread_id: str, - only_llm_messages: bool = False, - event_types: Optional[List[str]] = None, - order_by: str = "created_at", - order: str = "ASC" - ) -> List[Dict[str, Any]]: - """Get events from a thread with filtering options.""" - try: - query = ["SELECT * FROM events WHERE thread_id = ?"] - params = [thread_id] - - if only_llm_messages: - query.append("AND include_in_llm_message_history = 1") + message_id = str(uuid.uuid4()) - if event_types: - placeholders = ','.join(['?' for _ in event_types]) - query.append(f"AND type IN ({placeholders})") - params.extend(event_types) - - query.append(f"ORDER BY {order_by} {order}") + # Handle string content + if isinstance(message_data, str): + content = message_data + role = 'unknown' + + # Determine message type only for LLM-related messages if not provided + if message_type is None: + type_mapping = { + 'user': 'user_message', + 'assistant': 'assistant_message', + 'tool': 'tool_message', + 'system': 'system_message' + } + message_type = type_mapping.get(role, 'unknown_message') + + else: + # For dict message_data, check if it's an LLM message format + if 'role' in message_data and 'content' in message_data: + content = message_data.get('content', '') + role = message_data.get('role', 'unknown') + + # Determine message type for LLM messages if not provided + if message_type is None: + type_mapping = { + 'user': 'user_message', + 'assistant': 'assistant_message', + 'tool': 'tool_message', + 'system': 'system_message' + } + message_type = type_mapping.get(role, 'unknown_message') + else: + # For non-LLM messages, use the entire message_data as content + content = message_data - rows = await self.db.fetch_all(' '.join(query), tuple(params)) + # Handle content processing + if isinstance(content, ToolResult): + content = str(content) - events = [] - for row in rows: - event = { - "id": row[0], - "thread_id": row[1], - "type": row[2], - "content": self.db._deserialize_json(row[3]), - "include_in_llm_message_history": bool(row[4]), - "llm_message": self.db._deserialize_json(row[5]) if row[5] else None, - "created_at": row[6], - "updated_at": row[7] - } - events.append(event) - - return events - except Exception as e: - logging.error(f"Failed to get events for thread {thread_id}: {e}") - return [] - - async def get_thread_llm_messages(self, thread_id: str) -> List[Dict[str, Any]]: - """Get LLM-formatted messages from thread events.""" - events = await self.get_thread_events(thread_id, only_llm_messages=True) - return [event["llm_message"] for event in events if event["llm_message"]] - - # Message handling methods refactored for event-based system - async def add_message(self, thread_id: str, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None): - """Add a message as an event to the thread.""" - try: # Handle image attachments if images: - if isinstance(message_data['content'], str): - content = [{"type": "text", "text": message_data['content']}] - else: - content = message_data['content'] if isinstance(message_data['content'], list) else [] + if isinstance(content, str): + content = [{"type": "text", "text": content}] + elif not isinstance(content, list): + content = [] for image in images: image_content = { @@ -268,75 +121,139 @@ async def add_message(self, thread_id: str, message_data: Dict[str, Any], images } } content.append(image_content) - else: - content = message_data['content'] - # Create event - event_type = f"message_{message_data['role']}" - await self.create_event( - thread_id=thread_id, - event_type=event_type, - content={"raw_content": content}, - include_in_llm_message_history=True, - llm_message=message_data + # Convert content to JSON string if it's a dict or list + if isinstance(content, (dict, list)): + content = json.dumps(content) + + # Insert the message + await self.db.execute( + """ + INSERT INTO messages ( + id, thread_id, type, content, include_in_llm_message_history + ) VALUES (?, ?, ?, ?, ?) + """, + (message_id, thread_id, message_type, content, include_in_llm_message_history) ) + logging.info(f"Message added to thread {thread_id}") + except Exception as e: logging.error(f"Failed to add message to thread {thread_id}: {e}") raise - async def get_messages( - self, + async def get_llm_history_messages( + self, thread_id: str, hide_tool_msgs: bool = False, only_latest_assistant: bool = False, - regular_list: bool = True ) -> List[Dict[str, Any]]: - """Get messages from thread events with filtering.""" - messages = await self.get_thread_llm_messages(thread_id) - + """ + Retrieve messages from a thread that are marked for LLM history. + + Args: + thread_id: The thread to get messages from + hide_tool_msgs: Whether to hide tool messages + only_latest_assistant: Whether to only return the latest assistant message + + Returns: + List of messages formatted for LLM context + """ + + # Get only messages marked for LLM history + rows = await self.db.fetch_all( + """ + SELECT type, content + FROM messages + WHERE thread_id = ? + AND include_in_llm_message_history = TRUE + ORDER BY created_at ASC + """, + (thread_id,) + ) + + # Convert DB rows to message format + messages = [] + type_to_role = { + 'user_message': 'user', + 'assistant_message': 'assistant', + 'tool_message': 'tool', + 'system_message': 'system' + } + + for row in rows: + msg_type, content = row + + # Try to parse JSON content + try: + content = json.loads(content) + except (json.JSONDecodeError, TypeError): + pass # Keep content as is if it's not JSON + + message = { + 'role': type_to_role.get(msg_type, 'unknown'), + 'content': content + } + messages.append(message) + + # Apply filters if only_latest_assistant: for msg in reversed(messages): if msg.get('role') == 'assistant': return [msg] return [] - + if hide_tool_msgs: messages = [ {k: v for k, v in msg.items() if k != 'tool_calls'} for msg in messages if msg.get('role') != 'tool' ] - - if regular_list: - messages = [ - msg for msg in messages - if msg.get('role') in ['system', 'assistant', 'tool', 'user'] - ] - + + return messages async def _update_message(self, thread_id: str, message: Dict[str, Any]): - """Update the last assistant message event.""" - events = await self.get_thread_events( - thread_id=thread_id, - event_types=["message_assistant"], - order_by="created_at", - order="DESC" - ) - - if events: - last_assistant_event = events[0] - await self.update_event( - event_id=last_assistant_event["id"], - thread_id=thread_id, - content={"raw_content": message.get("content")}, - llm_message=message + """Update an existing message in the thread.""" + try: + # Find the latest assistant message for this thread + row = await self.db.fetch_one( + """ + SELECT id FROM messages + WHERE thread_id = ? AND type = 'assistant_message' + ORDER BY created_at DESC LIMIT 1 + """, + (thread_id,) ) + + if not row: + return + + message_id = row[0] + + # Convert content to JSON string if needed + content = message.get('content', '') + if isinstance(content, (dict, list)): + content = json.dumps(content) + + # Update the message + async with self.db.transaction() as conn: + await conn.execute( + """ + UPDATE messages + SET content = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = ? + """, + (content, message_id) + ) + + except Exception as e: + logging.error(f"Failed to update message: {e}") + raise async def cleanup_incomplete_tool_calls(self, thread_id: str): """Clean up incomplete tool calls in a thread.""" - messages = await self.get_messages(thread_id) + messages = await self.get_llm_history_messages(thread_id) last_assistant_message = next((m for m in reversed(messages) if m['role'] == 'assistant' and 'tool_calls' in m), None) @@ -374,21 +291,19 @@ async def cleanup_incomplete_tool_calls(self, thread_id: str): async def run_thread( self, thread_id: str, - system_message: Dict[str, Any], + system_message: Dict[str, str], model_name: str, - temperature: float = 0, - max_tokens: Optional[int] = None, + temperature: float = 0.7, + max_tokens: int = 4096, tool_choice: str = "auto", - temporary_message: Optional[Dict[str, Any]] = None, - native_tool_calling: bool = False, + temporary_message: Optional[Dict[str, str]] = None, + native_tool_calling: bool = True, xml_tool_calling: bool = False, - execute_tools: bool = True, stream: bool = False, - execute_tools_on_stream: bool = False, - parallel_tool_execution: bool = False, - tool_parser: Optional[ToolParserBase] = None, - tool_executor: Optional[ToolExecutorBase] = None, - results_adder: Optional[ResultsAdderBase] = None + execute_tools: bool = True, + execute_tools_on_stream: bool = True, + parallel_tool_execution: bool = True, + stop: Optional[Union[str, List[str]]] = None ) -> Union[Dict[str, Any], AsyncGenerator]: """Run a conversation thread with specified parameters. @@ -406,9 +321,7 @@ async def run_thread( stream: Whether to stream the response execute_tools_on_stream: Whether to execute tools during streaming parallel_tool_execution: Whether to execute tools in parallel - tool_parser: Custom tool parser implementation - tool_executor: Custom tool executor implementation - results_adder: Custom results adder implementation + stop (Union[str, List[str]], optional): Up to 4 sequences where the API will stop generating tokens Returns: Union[Dict[str, Any], AsyncGenerator]: Response or stream @@ -421,71 +334,185 @@ async def run_thread( - Cannot use both native and XML tool calling simultaneously - Streaming responses include both content and tool results """ - # Validate tool calling configuration - if native_tool_calling and xml_tool_calling: - raise ValueError("Cannot use both native LLM tool calling and XML tool calling simultaneously") - - # Initialize tool components if any tool calling is enabled - if native_tool_calling or xml_tool_calling: - if tool_parser is None: - tool_parser = XMLToolParser(tool_registry=self.tool_registry) if xml_tool_calling else StandardToolParser() - - if tool_executor is None: - tool_executor = XMLToolExecutor(parallel=parallel_tool_execution, tool_registry=self.tool_registry) if xml_tool_calling else StandardToolExecutor(parallel=parallel_tool_execution) - - if results_adder is None: - results_adder = XMLResultsAdder(self) if xml_tool_calling else StandardResultsAdder(self) - try: - messages = await self.get_messages(thread_id) - prepared_messages = [system_message] + messages - if temporary_message: - prepared_messages.append(temporary_message) - - openapi_tool_schemas = None - if native_tool_calling: - openapi_tool_schemas = self.tool_registry.get_openapi_schemas() - available_functions = self.tool_registry.get_available_functions() - elif xml_tool_calling: - available_functions = self.tool_registry.get_available_functions() - else: - available_functions = {} - - response_processor = LLMResponseProcessor( + # Add thread run start message + await self.add_message( thread_id=thread_id, - available_functions=available_functions, - add_message_callback=self.add_message, - update_message_callback=self._update_message, - get_messages_callback=self.get_messages, - parallel_tool_execution=parallel_tool_execution, - tool_parser=tool_parser, - tool_executor=tool_executor, - results_adder=results_adder + message_data={ + "name": "thread_run", + "status": "started", + "details": { + "model": model_name, + "temperature": temperature, + "native_tool_calling": native_tool_calling, + "xml_tool_calling": xml_tool_calling, + "execute_tools": execute_tools, + "stream": stream + } + }, + message_type="agentpress_system", + include_in_llm_message_history=False ) - llm_response = await self._run_thread_completion( - messages=prepared_messages, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tools=openapi_tool_schemas, - tool_choice=tool_choice if native_tool_calling else None, - stream=stream - ) + try: + messages = await self.get_llm_history_messages(thread_id) + prepared_messages = [system_message] + messages + if temporary_message: + prepared_messages.append(temporary_message) + + openapi_tool_schemas = None + if native_tool_calling: + openapi_tool_schemas = self.tool_registry.get_openapi_schemas() + available_functions = self.tool_registry.get_available_functions() + elif xml_tool_calling: + available_functions = self.tool_registry.get_available_functions() + else: + available_functions = {} - if stream: - return response_processor.process_stream( - response_stream=llm_response, - execute_tools=execute_tools, - execute_tools_on_stream=execute_tools_on_stream + # Initialize appropriate tool parser and executor based on calling type + if xml_tool_calling: + tool_parser = XMLToolParser(tool_registry=self.tool_registry) + tool_executor = XMLToolExecutor(parallel=parallel_tool_execution, tool_registry=self.tool_registry) + results_adder = XMLResultsAdder(self) + else: + tool_parser = StandardToolParser() + tool_executor = StandardToolExecutor(parallel=parallel_tool_execution) + results_adder = StandardResultsAdder(self) + + # Create a SINGLE response processor instance + response_processor = LLMResponseProcessor( + thread_id=thread_id, + tool_executor=tool_executor, + tool_parser=tool_parser, + available_functions=available_functions, + results_adder=results_adder ) - await response_processor.process_response( - response=llm_response, - execute_tools=execute_tools - ) + response = await self._run_thread_completion( + messages=prepared_messages, + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + tools=openapi_tool_schemas, + tool_choice=tool_choice if native_tool_calling else None, + stream=stream, + stop=stop + ) + + if stream: + async def stream_with_completion(): + processor = response_processor.process_stream( + response_stream=response, + execute_tools=execute_tools, + execute_tools_on_stream=execute_tools_on_stream + ) + + final_state = None + async for chunk in processor: + yield chunk + if hasattr(chunk, '_final_state'): + final_state = chunk._final_state + + # Add completion message after stream ends + completion_message = { + "name": "thread_run", + "status": "completed", + "details": { + "model": model_name, + "temperature": temperature, + "native_tool_calling": native_tool_calling, + "xml_tool_calling": xml_tool_calling, + "execute_tools": execute_tools, + "stream": stream + } + } + + # TODO: Add usage, cost information – from final llm response + + # # Add usage information from final state + # if final_state and 'usage' in final_state: + # completion_message["usage"] = final_state["usage"] + # elif hasattr(response, 'cost_tracker'): + # completion_message["usage"] = { + # "prompt_tokens": response.cost_tracker['prompt_tokens'], + # "completion_tokens": response.cost_tracker['completion_tokens'], + # "total_tokens": response.cost_tracker['total_tokens'], + # "cost_usd": response.cost_tracker['cost'] + # } + + await self.add_message( + thread_id=thread_id, + message_data=completion_message, + message_type="agentpress_system", + include_in_llm_message_history=False + ) + + return stream_with_completion() + + # For non-streaming, process response once + await response_processor.process_response( + response=response, + execute_tools=execute_tools + ) + + # Add completion message on success with cost information + completion_message = { + "name": "thread_run", + "status": "completed", + "details": { + "model": model_name, + "temperature": temperature, + "native_tool_calling": native_tool_calling, + "xml_tool_calling": xml_tool_calling, + "execute_tools": execute_tools, + "stream": stream + } + } + + # TODO: Add usage, cost information – from final llm response + + # # Add cost information if available + # if hasattr(response, 'cost'): + # completion_message["usage"] = { + # "cost_usd": response.cost + # } + # if hasattr(response, 'usage'): + # if "usage" not in completion_message: + # completion_message["usage"] = {} + # completion_message["usage"].update({ + # "prompt_tokens": response.usage.prompt_tokens, + # "completion_tokens": response.usage.completion_tokens, + # "total_tokens": response.usage.total_tokens + # }) + + await self.add_message( + thread_id=thread_id, + message_data=completion_message, + message_type="agentpress_system", + include_in_llm_message_history=False + ) + + return response - return llm_response + except Exception as e: + # Add error message if something goes wrong + + # TODO: FIX THAT THREAD RUN CATCHES ERRORS FROM LLM.PY from RUN_THREAD_COMPLETION & CORRECTLY ADD ERROR MESSAGE TO THREAD + + await self.add_message( + thread_id=thread_id, + message_data={ + "name": "thread_run", + "status": "error", + "error": str(e), + "details": { + "error_type": type(e).__name__ + } + }, + message_type="agentpress_system", + include_in_llm_message_history=False + ) + raise except Exception as e: logging.error(f"Error in run_thread: {str(e)}") @@ -502,104 +529,191 @@ async def _run_thread_completion( max_tokens: Optional[int], tools: Optional[List[Dict[str, Any]]], tool_choice: Optional[str], - stream: bool - ) -> Union[Any, AsyncGenerator]: + stream: bool, + stop: Optional[Union[str, List[str]]] = None + ) -> Union[Dict[str, Any], AsyncGenerator]: """Get completion from LLM API.""" - return await make_llm_api_call( + response = await make_llm_api_call( messages, model_name, temperature=temperature, max_tokens=max_tokens, tools=tools, tool_choice=tool_choice, - stream=stream + stream=stream, + stop=stop ) -if __name__ == "__main__": - import asyncio - from agentpress.examples.example_agent.tools.files_tool import FilesTool + # For streaming responses, wrap in a cost-tracking generator + if stream: + async def cost_tracking_stream(): + try: + async for chunk in response: + # Update token counts if available + if hasattr(chunk, 'usage'): + response.cost_tracker['prompt_tokens'] = chunk.usage.prompt_tokens + response.cost_tracker['completion_tokens'] = chunk.usage.completion_tokens + response.cost_tracker['total_tokens'] = chunk.usage.total_tokens + + # Calculate running cost + input_cost = response.model_info['input_cost_per_token'] + output_cost = response.model_info['output_cost_per_token'] + + cost = (response.cost_tracker['prompt_tokens'] * input_cost + + response.cost_tracker['completion_tokens'] * output_cost) + response.cost_tracker['cost'] = cost + + # Attach cost tracker to the chunk for final state + if hasattr(chunk, '_final_state'): + chunk._final_state['usage'] = { + "prompt_tokens": response.cost_tracker['prompt_tokens'], + "completion_tokens": response.cost_tracker['completion_tokens'], + "total_tokens": response.cost_tracker['total_tokens'], + "cost_usd": response.cost_tracker['cost'] + } + yield chunk + except Exception as e: + logging.error(f"Error in cost tracking stream: {e}") + raise + + return cost_tracking_stream() + + return response - async def main(): - # Initialize managers - thread_manager = ThreadManager() - - # Register available tools - thread_manager.add_tool(FilesTool) - - # Create a new thread - thread_id = await thread_manager.create_thread() + async def get_messages( + self, + thread_id: str, + message_types: Optional[List[str]] = None, + limit: Optional[int] = 50, # Default limit of 50 messages + offset: Optional[int] = 0, # Starting offset for pagination + before_timestamp: Optional[str] = None, + after_timestamp: Optional[str] = None, + include_in_llm_message_history: Optional[bool] = None, + order: str = "asc" + ) -> Dict[str, Any]: + """ + Retrieve messages from a thread with optional filtering and pagination. - # Add a test message - await thread_manager.add_message(thread_id, { - "role": "user", - "content": "Please create 10x files – Each should be a chapter of a book about an Introduction to Robotics.." - }) - - # Define system message - system_message = { - "role": "system", - "content": "You are a helpful assistant that can create, read, update, and delete files." - } + Args: + thread_id: The thread to get messages from + message_types: Optional list of message types to filter by + limit: Maximum number of messages to return (default: 50) + offset: Number of messages to skip (for pagination) + before_timestamp: Optional timestamp to filter messages before + after_timestamp: Optional timestamp to filter messages after + include_in_llm_message_history: Optional bool to filter messages by LLM history inclusion + order: Sort order - "asc" or "desc" + + Returns: + Dict containing messages list and pagination info + """ + try: + # Build the base query for total count + count_query = """ + SELECT COUNT(*) + FROM messages + WHERE thread_id = ? + """ + count_params = [thread_id] - # Test with streaming response and tool execution - print("\n🤖 Testing streaming response with tools:") - response = await thread_manager.run_thread( - thread_id=thread_id, - system_message=system_message, - model_name="anthropic/claude-3-5-haiku-latest", - temperature=0.7, - max_tokens=4096, - stream=True, - native_tool_calling=True, - execute_tools=True, - execute_tools_on_stream=True, - parallel_tool_execution=True - ) + # Build the base query for messages + query = """ + SELECT id, type, content, created_at, updated_at, include_in_llm_message_history + FROM messages + WHERE thread_id = ? + """ + params = [thread_id] - # Handle streaming response - if isinstance(response, AsyncGenerator): - print("\nAssistant is responding:") - content_buffer = "" - try: - async for chunk in response: - if hasattr(chunk.choices[0], 'delta'): - delta = chunk.choices[0].delta - - # Handle content streaming - if hasattr(delta, 'content') and delta.content is not None: - content_buffer += delta.content - if delta.content.endswith((' ', '\n')): - print(content_buffer, end='', flush=True) - content_buffer = "" - - # Handle tool calls - if hasattr(delta, 'tool_calls') and delta.tool_calls: - for tool_call in delta.tool_calls: - # Print tool name when it first appears - if tool_call.function and tool_call.function.name: - print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True) - - # Print arguments as they stream in - if tool_call.function and tool_call.function.arguments: - print(f" {tool_call.function.arguments}", end='', flush=True) + # Add filters to both queries + if message_types: + placeholders = ','.join('?' * len(message_types)) + filter_sql = f" AND type IN ({placeholders})" + query += filter_sql + count_query += filter_sql + params.extend(message_types) + count_params.extend(message_types) - # Print any remaining content - if content_buffer: - print(content_buffer, flush=True) - print("\n✨ Response completed\n") + if before_timestamp: + query += " AND created_at < ?" + count_query += " AND created_at < ?" + params.append(before_timestamp) + count_params.append(before_timestamp) - except Exception as e: - print(f"\n❌ Error processing stream: {e}") - else: - print("\n✨ Response completed\n") + if after_timestamp: + query += " AND created_at > ?" + count_query += " AND created_at > ?" + params.append(after_timestamp) + count_params.append(after_timestamp) + + if include_in_llm_message_history is not None: + query += " AND include_in_llm_message_history = ?" + count_query += " AND include_in_llm_message_history = ?" + params.append(include_in_llm_message_history) + count_params.append(include_in_llm_message_history) + + # Get total count for pagination + total_count = await self.db.fetch_one(count_query, tuple(count_params)) + total_count = total_count[0] if total_count else 0 + + # Add ordering and pagination + query += f" ORDER BY created_at {'ASC' if order.lower() == 'asc' else 'DESC'}" + query += " LIMIT ? OFFSET ?" + params.extend([limit, offset]) + + # Execute query + rows = await self.db.fetch_all(query, tuple(params)) + + # Convert rows to dictionaries + messages = [] + for row in rows: + message = { + 'id': row[0], + 'type': row[1], + 'content': row[2], + 'created_at': row[3], + 'updated_at': row[4], + 'include_in_llm_message_history': bool(row[5]) + } - # Display final thread state - messages = await thread_manager.get_messages(thread_id) - print("\n📝 Final Thread State:") - for msg in messages: - role = msg.get('role', 'unknown') - content = msg.get('content', '') - print(f"\n{role.upper()}: {content[:100]}...") + # Try to parse JSON content + try: + message['content'] = json.loads(message['content']) + except (json.JSONDecodeError, TypeError): + pass # Keep content as is if it's not JSON - asyncio.run(main()) + messages.append(message) + + # Return messages with pagination info + return { + "messages": messages, + "pagination": { + "total": total_count, + "limit": limit, + "offset": offset, + "has_more": (offset + limit) < total_count + } + } + + except Exception as e: + logging.error(f"Failed to get messages from thread {thread_id}: {e}") + raise + async def thread_exists(self, thread_id: str) -> bool: + """ + Check if a thread exists. + + Args: + thread_id: The ID of the thread to check + + Returns: + bool: True if thread exists, False otherwise + """ + try: + row = await self.db.fetch_one( + "SELECT 1 FROM threads WHERE id = ?", + (thread_id,) + ) + return row is not None + except Exception as e: + logging.error(f"Error checking thread existence: {e}") + return False diff --git a/agentpress/thread_viewer_ui.py b/agentpress/thread_viewer_ui.py index 944cb49..4af1fd9 100644 --- a/agentpress/thread_viewer_ui.py +++ b/agentpress/thread_viewer_ui.py @@ -1,64 +1,111 @@ import streamlit as st from datetime import datetime -from agentpress.thread_manager import ThreadManager from agentpress.db_connection import DBConnection import asyncio +import json def format_message_content(content): - """Format message content handling both string and list formats.""" - if isinstance(content, str): - return content - elif isinstance(content, list): - formatted_content = [] - for item in content: - if item.get('type') == 'text': - formatted_content.append(item['text']) - elif item.get('type') == 'image_url': - formatted_content.append("[Image]") - return "\n".join(formatted_content) - return str(content) + """Format message content handling various formats.""" + try: + if isinstance(content, str): + # Try to parse JSON strings + try: + parsed = json.loads(content) + if isinstance(parsed, (dict, list)): + return json.dumps(parsed, indent=2) + except json.JSONDecodeError: + return content + elif isinstance(content, list): + formatted_content = [] + for item in content: + if item.get('type') == 'text': + formatted_content.append(item['text']) + elif item.get('type') == 'image_url': + formatted_content.append("[Image]") + return "\n".join(formatted_content) + return json.dumps(content, indent=2) + except: + return str(content) async def load_threads(): """Load all thread IDs from the database.""" db = DBConnection() - rows = await db.fetch_all("SELECT thread_id, created_at FROM threads ORDER BY created_at DESC") + rows = await db.fetch_all( + """ + SELECT id, created_at + FROM threads + ORDER BY created_at DESC + """ + ) return rows -async def load_thread_content(thread_id: str): - """Load the content of a specific thread from the database.""" - thread_manager = ThreadManager() - return await thread_manager.get_messages(thread_id) +async def load_thread_content(thread_id: str, filters: dict): + """Load messages from a thread with filters.""" + db = DBConnection() + + query_parts = ["SELECT type, content, include_in_llm_message_history, created_at FROM messages WHERE thread_id = ?"] + params = [thread_id] + + if filters.get('message_types'): + # Convert comma-separated string to list and clean up whitespace + types_list = [t.strip() for t in filters['message_types'].split(',') if t.strip()] + if types_list: + query_parts.append("AND type IN (" + ",".join(["?" for _ in types_list]) + ")") + params.extend(types_list) + + if filters.get('exclude_message_types'): + # Convert comma-separated string to list and clean up whitespace + exclude_types_list = [t.strip() for t in filters['exclude_message_types'].split(',') if t.strip()] + if exclude_types_list: + query_parts.append("AND type NOT IN (" + ",".join(["?" for _ in exclude_types_list]) + ")") + params.extend(exclude_types_list) + + if filters.get('before_timestamp'): + query_parts.append("AND created_at < ?") + params.append(filters['before_timestamp']) + + if filters.get('after_timestamp'): + query_parts.append("AND created_at > ?") + params.append(filters['after_timestamp']) + + if filters.get('include_in_llm_message_history') is not None: + query_parts.append("AND include_in_llm_message_history = ?") + params.append(filters['include_in_llm_message_history']) + + # Add ordering + order_direction = "DESC" if filters.get('order', 'asc').lower() == 'desc' else "ASC" + query_parts.append(f"ORDER BY created_at {order_direction}") + + # Add limit and offset + if filters.get('limit'): + query_parts.append("LIMIT ?") + params.append(filters['limit']) + + if filters.get('offset'): + query_parts.append("OFFSET ?") + params.append(filters['offset']) + + query = " ".join(query_parts) + rows = await db.fetch_all(query, tuple(params)) + return rows -def render_message(role, content, avatar): - """Render a message with a consistent chat-like style.""" - # Create columns for avatar and message - col1, col2 = st.columns([1, 11]) - - # Style based on role - if role == "assistant": - bgcolor = "rgba(25, 25, 25, 0.05)" - elif role == "user": - bgcolor = "rgba(25, 120, 180, 0.05)" - elif role == "system": - bgcolor = "rgba(180, 25, 25, 0.05)" - else: - bgcolor = "rgba(100, 100, 100, 0.05)" - - # Display avatar in first column +def render_message(msg_type: str, content: str, include_in_llm: bool, timestamp: str): + """Render a message using Streamlit components.""" + # Message type and metadata + col1, col2 = st.columns([3, 1]) with col1: - st.markdown(f"
{avatar}
", unsafe_allow_html=True) - - # Display message in second column + st.text(f"Type: {msg_type}") with col2: - st.markdown( - f""" -
- {role.upper()}
- {content} -
- """, - unsafe_allow_html=True - ) + st.text("🟢 LLM" if include_in_llm else "⚫ Non-LLM") + + # Timestamp + st.text(f"Time: {datetime.fromisoformat(timestamp).strftime('%Y-%m-%d %H:%M:%S')}") + + # Message content + st.code(content, language="json") + + # Separator + st.divider() def main(): st.title("Thread Viewer") @@ -86,7 +133,6 @@ def main(): ) if selected_thread_display: - # Get the actual thread ID from the display string selected_thread_id = thread_options[selected_thread_display] # Display thread ID in sidebar @@ -95,46 +141,77 @@ def main(): # Add refresh button if st.sidebar.button("🔄 Refresh Thread"): st.session_state.threads = asyncio.run(load_threads()) - st.experimental_rerun() + st.rerun() + + # Advanced filtering options in sidebar + st.sidebar.title("Filter Options") + + # Message type filters + col1, col2 = st.sidebar.columns(2) + with col1: + message_types = st.text_input( + "Include Types", + help="Enter message types to include, separated by commas" + ) + with col2: + exclude_message_types = st.text_input( + "Exclude Types", + help="Enter message types to exclude, separated by commas" + ) + + # Limit and offset + col1, col2 = st.sidebar.columns(2) + with col1: + limit = st.number_input("Limit", min_value=1, value=50) + with col2: + offset = st.number_input("Offset", min_value=0, value=0) + + # Timestamp filters + st.sidebar.subheader("Time Range") + before_timestamp = st.sidebar.date_input("Before Date", value=None) + after_timestamp = st.sidebar.date_input("After Date", value=None) + + # LLM history filter + include_in_llm = st.sidebar.radio( + "LLM History Filter", + options=["All Messages", "LLM Only", "Non-LLM Only"] + ) + + # Sort order + order = st.sidebar.radio("Sort Order", ["Ascending", "Descending"]) + + # Prepare filters + filters = { + 'message_types': message_types if message_types else None, + 'exclude_message_types': exclude_message_types if exclude_message_types else None, + 'limit': limit, + 'offset': offset, + 'order': 'desc' if order == "Descending" else 'asc' + } + + # Add timestamp filters if selected + if before_timestamp: + filters['before_timestamp'] = before_timestamp.isoformat() + if after_timestamp: + filters['after_timestamp'] = after_timestamp.isoformat() + + # Add LLM history filter + if include_in_llm == "LLM Only": + filters['include_in_llm_message_history'] = True + elif include_in_llm == "Non-LLM Only": + filters['include_in_llm_message_history'] = False + + # Load messages with filters + messages = asyncio.run(load_thread_content(selected_thread_id, filters)) - # Load and display messages - messages = asyncio.run(load_thread_content(selected_thread_id)) + if not messages: + st.info("No messages found with current filters") + return - # Display messages in chat-like interface - for message in messages: - role = message.get("role", "unknown") - content = message.get("content", "") - - # Determine avatar based on role - if role == "assistant": - avatar = "🤖" - elif role == "user": - avatar = "👤" - elif role == "system": - avatar = "⚙️" - elif role == "tool": - avatar = "🔧" - else: - avatar = "❓" - - # Format the content + # Display messages + for msg_type, content, include_in_llm, timestamp in messages: formatted_content = format_message_content(content) - - # Render the message - render_message(role, formatted_content, avatar) - - # Display tool calls if present - if "tool_calls" in message: - with st.expander("🛠️ Tool Calls"): - for tool_call in message["tool_calls"]: - st.code( - f"Function: {tool_call['function']['name']}\n" - f"Arguments: {tool_call['function']['arguments']}", - language="json" - ) - - # Add some spacing between messages - st.markdown("
", unsafe_allow_html=True) + render_message(msg_type, formatted_content, include_in_llm, timestamp) if __name__ == "__main__": main() diff --git a/example/__init__.py b/example/__init__.py new file mode 100644 index 0000000..3ebeb22 --- /dev/null +++ b/example/__init__.py @@ -0,0 +1 @@ +# Empty file to mark as package \ No newline at end of file diff --git a/example/agent.py b/example/agent.py new file mode 100644 index 0000000..2a7be66 --- /dev/null +++ b/example/agent.py @@ -0,0 +1,261 @@ +""" +Interactive web development agent supporting both XML and Standard LLM tool calling. + +This agent can: +- Create and modify web projects +- Execute terminal commands +- Handle file operations +- Use either XML or Standard tool calling patterns +""" + +import asyncio +import json +from agentpress.thread_manager import ThreadManager +from example.tools.files_tool import FilesTool +from agentpress.state_manager import StateManager +from example.tools.terminal_tool import TerminalTool +import logging +from typing import AsyncGenerator, Optional, Dict, Any +import sys + +from agentpress.api.api_factory import register_thread_task_api + +BASE_SYSTEM_MESSAGE = """ +You are a world-class web developer who can create, edit, and delete files, and execute terminal commands. +You write clean, well-structured code. Keep iterating on existing files, continue working on this existing +codebase - do not omit previous progress; instead, keep iterating. +Available tools: +- create_file: Create new files with specified content +- delete_file: Remove existing files +- str_replace: Make precise text replacements in files +- execute_command: Run terminal commands + + +RULES: +- All current file contents are available to you in the section +- Each file in the workspace state includes its full content +- Use str_replace for precise replacements in files +- NEVER include comments in any code you write - the code should be self-documenting +- Always maintain the full context of files when making changes +- When creating new files, write clean code without any comments or documentation + + +[create_file(file_path, file_contents)] - Create new files +[delete_file(file_path)] - Delete existing files +[str_replace(file_path, old_str, new_str)] - Replace specific text in files +[execute_command(command)] - Execute terminal commands + + +ALWAYS RESPOND WITH MULTIPLE SIMULTANEOUS ACTIONS: + +[Provide a concise overview of your planned changes and implementations] + + + +[Include multiple tool calls] + + +EDITING GUIDELINES: +1. Review the current file contents in the workspace state +2. Make targeted changes with str_replace +3. Write clean, self-documenting code without comments +4. Use create_file for new files and str_replace for modifications + +Example workspace state for a file: +{ + "index.html": { + "content": "\\n\\n..." + } +} +Think deeply and step by step. +""" + +XML_FORMAT = """ +RESPONSE FORMAT: +Use XML tags to specify file operations: + + +file contents here + + + +text to replace +replacement text + + + + + + +""" + +@register_thread_task_api("/agent") +async def run_agent( + thread_id: str, + max_iterations: int = 5, + user_input: Optional[str] = None, +) -> Dict[str, Any]: + """Run the development agent with specified configuration. + + Args: + thread_id (str): The ID of the thread. + max_iterations (int, optional): The maximum number of iterations. Defaults to 5. + user_input (Optional[str], optional): The user input. Defaults to None. + """ + thread_manager = ThreadManager() + state_manager = StateManager(thread_id) + + if user_input: + await thread_manager.add_message( + thread_id, + { + "role": "user", + "content": user_input + } + ) + + thread_manager.add_tool(FilesTool, thread_id=thread_id) + thread_manager.add_tool(TerminalTool, thread_id=thread_id) + + system_message = { + "role": "system", + "content": BASE_SYSTEM_MESSAGE + XML_FORMAT + } + + iteration = 0 + while iteration < max_iterations: + iteration += 1 + + files_tool = FilesTool(thread_id=thread_id) + + state = await state_manager.export_store() + + temporary_message_content = f""" + You are tasked to complete the LATEST USER REQUEST! + + {user_input} + + + Current development environment workspace state: + + {json.dumps(state, indent=2) if state else "{}"} + + + CONTINUE WITH THE TASK! USE THE SESSION TOOL TO STOP THE SESSION IF THE TASK IS COMPLETE. + """ + + await thread_manager.add_message( + thread_id=thread_id, + message_data=temporary_message_content, + message_type="temporary_message", + include_in_llm_message_history=False + ) + + temporary_message = { + "role": "user", + "content": temporary_message_content + } + + model_name = "anthropic/claude-3-5-sonnet-latest" + + response = await thread_manager.run_thread( + thread_id=thread_id, + system_message=system_message, + model_name=model_name, + temperature=0.1, + max_tokens=8096, + tool_choice="auto", + temporary_message=temporary_message, + native_tool_calling=False, + xml_tool_calling=True, + stream=True, + execute_tools_on_stream=True, + parallel_tool_execution=True, + ) + + if isinstance(response, AsyncGenerator): + print("\n🤖 Assistant is responding:") + try: + async for chunk in response: + if hasattr(chunk.choices[0], 'delta'): + delta = chunk.choices[0].delta + + if hasattr(delta, 'content') and delta.content is not None: + content = delta.content + print(content, end='', flush=True) + + # Check for open_files_in_editor tag and continue if found + if '' in content: + print("\n📂 Opening files in editor, continuing to next iteration...") + continue + + if hasattr(delta, 'tool_calls') and delta.tool_calls: + for tool_call in delta.tool_calls: + if tool_call.function: + if tool_call.function.name: + print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True) + if tool_call.function.arguments: + print(f" {tool_call.function.arguments}", end='', flush=True) + + print("\n✨ Response completed\n") + + except Exception as e: + print(f"\n❌ Error processing stream: {e}", file=sys.stderr) + logging.error(f"Error processing stream: {e}") + else: + print("\nNon-streaming response received:", response) + + # # Get latest assistant message and check for stop_session + # latest_msg = await thread_manager.get_llm_history_messages( + # thread_id=thread_id, + # only_latest_assistant=True + # ) + # if latest_msg and '' in latest_msg: + # break + + +if __name__ == "__main__": + print("\n🚀 Welcome to AgentPress!") + + project_description = input("What would you like to build? (default: Create a modern, responsive landing page)\n> ") + if not project_description.strip(): + project_description = "Create a modern, responsive landing page" + + print("\nChoose your agent type:") + print("1. XML-based Tool Calling") + print(" - Structured XML format for tool execution") + print(" - Parses tool calls using XML outputs in the LLM response") + + print("\n2. Standard Function Calling") + print(" - Native LLM function calling format") + print(" - JSON-based parameter passing") + + use_xml = input("\nSelect tool calling format [1/2] (default: 1): ").strip() != "2" + + print(f"\n{'XML-based' if use_xml else 'Standard'} agent will help you build: {project_description}") + print("Use Ctrl+C to stop the agent at any time.") + + async def test_agent(): + thread_manager = ThreadManager() + thread_id = await thread_manager.create_thread() + logging.info(f"Created new thread: {thread_id}") + + try: + result = await run_agent( + thread_id=thread_id, + max_iterations=5, + user_input=project_description, + ) + print("\n✅ Test completed successfully!") + + except Exception as e: + print(f"\n❌ Test failed: {str(e)}") + raise + + try: + asyncio.run(test_agent()) + except KeyboardInterrupt: + print("\n⚠️ Test interrupted by user") + except Exception as e: + print(f"\n❌ Test failed with error: {str(e)}") + raise \ No newline at end of file diff --git a/example/tools/files_tool.py b/example/tools/files_tool.py new file mode 100644 index 0000000..0d5948d --- /dev/null +++ b/example/tools/files_tool.py @@ -0,0 +1,297 @@ +import os +import asyncio +from pathlib import Path +from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema +from agentpress.state_manager import StateManager +from typing import Optional + +class FilesTool(Tool): + """File management tool for creating, updating, and deleting files. + + This tool provides file operations within a workspace directory, with built-in + file filtering and state tracking capabilities. + + Attributes: + workspace (str): Path to the workspace directory + EXCLUDED_FILES (set): Files to exclude from operations + EXCLUDED_DIRS (set): Directories to exclude + EXCLUDED_EXT (set): File extensions to exclude + SNIPPET_LINES (int): Context lines for edit previews + """ + + # Excluded files, directories, and extensions + EXCLUDED_FILES = { + ".DS_Store", + ".gitignore", + "package-lock.json", + "postcss.config.js", + "postcss.config.mjs", + "jsconfig.json", + "components.json", + "tsconfig.tsbuildinfo", + "tsconfig.json", + } + + EXCLUDED_DIRS = { + "node_modules", + ".next", + "dist", + "build", + ".git" + } + + EXCLUDED_EXT = { + ".ico", + ".svg", + ".png", + ".jpg", + ".jpeg", + ".gif", + ".bmp", + ".tiff", + ".webp", + ".db", + ".sql" + } + + def __init__(self, thread_id: Optional[str] = None): + super().__init__() + self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace') + os.makedirs(self.workspace, exist_ok=True) + if thread_id: + self.state_manager = StateManager(thread_id) + asyncio.create_task(self._init_workspace_state()) + self.SNIPPET_LINES = 4 + + def _should_exclude_file(self, rel_path: str) -> bool: + """Check if a file should be excluded based on path, name, or extension""" + # Check filename + filename = os.path.basename(rel_path) + if filename in self.EXCLUDED_FILES: + return True + + # Check directory + dir_path = os.path.dirname(rel_path) + if any(excluded in dir_path for excluded in self.EXCLUDED_DIRS): + return True + + # Check extension + _, ext = os.path.splitext(filename) + if ext.lower() in self.EXCLUDED_EXT: + return True + + return False + + async def _init_workspace_state(self): + """Initialize or update the workspace state in JSON""" + files_state = {} + + # Walk through workspace and record all files + for root, _, files in os.walk(self.workspace): + for file in files: + full_path = os.path.join(root, file) + rel_path = os.path.relpath(full_path, self.workspace) + + # Skip excluded files + if self._should_exclude_file(rel_path): + continue + + try: + with open(full_path, 'r') as f: + content = f.read() + files_state[rel_path] = { + "content": content + } + except Exception as e: + print(f"Error reading file {rel_path}: {e}") + except UnicodeDecodeError: + print(f"Skipping binary file: {rel_path}") + + if hasattr(self, 'state_manager'): + await self.state_manager.set("files", files_state) + + async def _update_workspace_state(self): + """Update the workspace state after any file operation""" + await self._init_workspace_state() + + @openapi_schema({ + "type": "function", + "function": { + "name": "create_file", + "description": "Create a new file with the provided contents at a given path in the workspace", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Path to the file to be created" + }, + "file_contents": { + "type": "string", + "description": "The content to write to the file" + } + }, + "required": ["file_path", "file_contents"] + } + } + }) + @xml_schema( + tag_name="create-file", + mappings=[ + {"param_name": "file_path", "node_type": "attribute", "path": "."}, + {"param_name": "file_contents", "node_type": "content", "path": "."} + ], + example=''' + + File contents go here + + ''' + ) + async def create_file(self, file_path: str, file_contents: str) -> ToolResult: + try: + full_path = os.path.join(self.workspace, file_path) + if os.path.exists(full_path): + return self.fail_response(f"File '{file_path}' already exists. Use update_file to modify existing files.") + + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, 'w') as f: + f.write(file_contents) + + await self._update_workspace_state() + return self.success_response(f"File '{file_path}' created successfully.") + except Exception as e: + return self.fail_response(f"Error creating file: {str(e)}") + + @openapi_schema({ + "type": "function", + "function": { + "name": "delete_file", + "description": "Delete a file at the given path", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Path to the file to be deleted" + } + }, + "required": ["file_path"] + } + } + }) + @xml_schema( + tag_name="delete-file", + mappings=[ + {"param_name": "file_path", "node_type": "attribute", "path": "."} + ], + example=''' + + + ''' + ) + async def delete_file(self, file_path: str) -> ToolResult: + try: + full_path = os.path.join(self.workspace, file_path) + os.remove(full_path) + + await self._update_workspace_state() + return self.success_response(f"File '{file_path}' deleted successfully.") + except Exception as e: + return self.fail_response(f"Error deleting file: {str(e)}") + + @openapi_schema({ + "type": "function", + "function": { + "name": "str_replace", + "description": "Replace text in file", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Path to the target file" + }, + "old_str": { + "type": "string", + "description": "Text to be replaced (must appear exactly once)" + }, + "new_str": { + "type": "string", + "description": "Replacement text" + } + }, + "required": ["file_path", "old_str", "new_str"] + } + } + }) + @xml_schema( + tag_name="str-replace", + mappings=[ + {"param_name": "file_path", "node_type": "attribute", "path": "file_path"}, + {"param_name": "old_str", "node_type": "element", "path": "old_str"}, + {"param_name": "new_str", "node_type": "element", "path": "new_str"} + ], + example=''' + + text to replace + replacement text + + ''' + ) + async def str_replace(self, file_path: str, old_str: str, new_str: str) -> ToolResult: + try: + full_path = Path(os.path.join(self.workspace, file_path)) + if not full_path.exists(): + return self.fail_response(f"File '{file_path}' does not exist") + + content = full_path.read_text().expandtabs() + old_str = old_str.expandtabs() + new_str = new_str.expandtabs() + + occurrences = content.count(old_str) + if occurrences == 0: + return self.fail_response(f"String '{old_str}' not found in file") + if occurrences > 1: + lines = [i+1 for i, line in enumerate(content.split('\n')) if old_str in line] + return self.fail_response(f"Multiple occurrences found in lines {lines}. Please ensure string is unique") + + # Perform replacement + new_content = content.replace(old_str, new_str) + full_path.write_text(new_content) + + # Update state after file modification + await self._update_workspace_state() + + # Show snippet around the edit + replacement_line = content.split(old_str)[0].count('\n') + start_line = max(0, replacement_line - self.SNIPPET_LINES) + end_line = replacement_line + self.SNIPPET_LINES + new_str.count('\n') + snippet = '\n'.join(new_content.split('\n')[start_line:end_line + 1]) + + return self.success_response(f"Replacement successful. Snippet of changes:\n{snippet}") + + except Exception as e: + return self.fail_response(f"Error replacing string: {str(e)}") + +if __name__ == "__main__": + async def test_files_tool(): + files_tool = FilesTool() + test_file_path = "test_file.txt" + test_content = "This is a test file." + updated_content = "This is an updated test file." + + print(f"Using workspace directory: {files_tool.workspace}") + + # Test create_file + create_result = await files_tool.create_file(test_file_path, test_content) + print("Create file result:", create_result) + + # Test delete_file + delete_result = await files_tool.delete_file(test_file_path) + print("Delete file result:", delete_result) + + # Test read_file after delete (should fail) + read_deleted_result = await files_tool.read_file(test_file_path) + print("Read deleted file result:", read_deleted_result) + + asyncio.run(test_files_tool()) \ No newline at end of file diff --git a/example/tools/terminal_tool.py b/example/tools/terminal_tool.py new file mode 100644 index 0000000..c9736dc --- /dev/null +++ b/example/tools/terminal_tool.py @@ -0,0 +1,73 @@ +import os +import asyncio +import subprocess +from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema +from typing import Optional + +class TerminalTool(Tool): + """Terminal command execution tool for workspace operations.""" + + def __init__(self, thread_id: Optional[str] = None): + super().__init__() + self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace') + os.makedirs(self.workspace, exist_ok=True) + + @openapi_schema({ + "type": "function", + "function": { + "name": "execute_command", + "description": "Execute a shell command in the workspace directory", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute" + } + }, + "required": ["command"] + } + } + }) + @xml_schema( + tag_name="execute-command", + mappings=[ + {"param_name": "command", "node_type": "content", "path": "."} + ], + example=''' + + npm install package-name + + ''' + ) + async def execute_command(self, command: str) -> ToolResult: + original_dir = os.getcwd() + try: + os.chdir(self.workspace) + + process = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=self.workspace + ) + stdout, stderr = await process.communicate() + + output = stdout.decode() if stdout else "" + error = stderr.decode() if stderr else "" + success = process.returncode == 0 + + if success: + return self.success_response({ + "output": output, + "error": error, + "exit_code": process.returncode, + "cwd": self.workspace + }) + else: + return self.fail_response(f"Command failed with exit code {process.returncode}: {error}") + + except Exception as e: + return self.fail_response(f"Error executing command: {str(e)}") + finally: + os.chdir(original_dir) diff --git a/example/workspace/css/styles.css b/example/workspace/css/styles.css new file mode 100644 index 0000000..f7d8379 --- /dev/null +++ b/example/workspace/css/styles.css @@ -0,0 +1,373 @@ +:root { + --primary-color: #2563eb; + --secondary-color: #1e40af; + --text-color: #1f2937; + --light-text: #6b7280; + --background: #ffffff; + --section-bg: #f3f4f6; +} + +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +html { + scroll-behavior: smooth; +} + +.scroll-top { + position: fixed; + bottom: 2rem; + right: 2rem; + background: var(--primary-color); + color: white; + width: 45px; + height: 45px; + border-radius: 50%; + display: flex; + align-items: center; + justify-content: center; + cursor: pointer; + opacity: 0; + visibility: hidden; + transition: all 0.3s ease; + z-index: 1000; +} + +.scroll-top.visible { + opacity: 1; + visibility: visible; +} + +.scroll-top:hover { + transform: translateY(-3px); + background: var(--secondary-color); +} + +body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; + line-height: 1.6; + color: var(--text-color); +} + +.container { + max-width: 1200px; + margin: 0 auto; + padding: 0 2rem; +} + +.header { + position: fixed; + width: 100%; + background: var(--background); + box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); + z-index: 1000; + transition: transform 0.3s ease; +} + +.header.scrolled { + transform: translateY(-100%); +} + +.header:hover { + transform: translateY(0); +} + +.nav { + display: flex; + justify-content: space-between; + align-items: center; + padding: 1rem 2rem; +} + +.logo { + font-size: 1.5rem; + font-weight: bold; + color: var(--primary-color); +} + +.nav-links { + display: flex; + gap: 2rem; + list-style: none; +} + +.nav-links a { + text-decoration: none; + color: var(--text-color); + font-weight: 500; + transition: color 0.3s ease; +} + +.nav-links a:hover { + color: var(--primary-color); +} + +.mobile-nav-toggle { + display: none; +} + +.hero { + padding: 8rem 0 4rem; + background: linear-gradient(135deg, var(--primary-color), var(--secondary-color)); + background-size: 200% 200%; + color: white; + text-align: center; + animation: gradientMove 10s ease infinite; +} + +@keyframes gradientMove { + 0% { background-position: 0% 50%; } + 50% { background-position: 100% 50%; } + 100% { background-position: 0% 50%; } +} + +.hero h1 { + font-size: 3rem; + margin-bottom: 1rem; +} + +.hero p { + font-size: 1.25rem; + margin-bottom: 2rem; +} + +.cta-button { + padding: 1rem 2rem; + font-size: 1.1rem; + background: white; + color: var(--primary-color); + border: none; + border-radius: 5px; + cursor: pointer; + transition: transform 0.3s ease; +} + +.cta-button:hover { + transform: translateY(-2px); +} + +.features { + padding: 4rem 0; + background: var(--section-bg); +} + +.features h2 { + text-align: center; + margin-bottom: 3rem; +} + +.feature-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); + gap: 2rem; +} + +.feature-card { + background: white; + padding: 2rem; + border-radius: 10px; + text-align: center; + box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); + transition: all 0.3s ease; + opacity: 0; + transform: translateY(20px); + animation: fadeInUp 0.6s ease forwards; +} + +@keyframes fadeInUp { + to { + opacity: 1; + transform: translateY(0); + } +} + +.feature-card:hover { + transform: translateY(-5px); +} + +.feature-icon { + font-size: 2.5rem; + margin-bottom: 1rem; +} + +.contact { + padding: 4rem 0; +} + +.contact h2 { + text-align: center; + margin-bottom: 2rem; +} + +.contact-form { + max-width: 600px; + margin: 0 auto; + display: flex; + flex-direction: column; + gap: 1rem; +} + +.contact-form input, +.contact-form textarea { + padding: 0.8rem; + border: 2px solid #ddd; + border-radius: 5px; + font-size: 1rem; + transition: all 0.3s ease; + outline: none; +} + +.contact-form input:focus, +.contact-form textarea:focus { + border-color: var(--primary-color); + box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1); +} + +.contact-form input:hover, +.contact-form textarea:hover { + border-color: var(--primary-color); +} + +.contact-form textarea { + height: 150px; + resize: vertical; +} + +.submit-button { + padding: 1rem; + background: var(--primary-color); + color: white; + border: none; + border-radius: 5px; + cursor: pointer; + transition: background-color 0.3s ease; +} + +.submit-button:hover { + background: var(--secondary-color); +} + +.testimonials { + padding: 4rem 0; + background: var(--section-bg); +} + +.testimonials h2 { + text-align: center; + margin-bottom: 3rem; +} + +.testimonial-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); + gap: 2rem; + max-width: 1200px; + margin: 0 auto; + padding: 0 2rem; +} + +.testimonial-card { + background: white; + padding: 2rem; + border-radius: 10px; + box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); +} + +.testimonial-text { + font-style: italic; + margin-bottom: 1rem; +} + +.testimonial-author { + font-weight: bold; + color: var(--primary-color); +} + +.footer { + background: var(--text-color); + color: white; + padding: 2rem 0; + text-align: center; +} + +.social-links { + display: flex; + justify-content: center; + gap: 1.5rem; + margin: 1rem 0; +} + +.social-links a { + color: white; + text-decoration: none; + font-size: 1.5rem; + transition: color 0.3s ease; +} + +.social-links a:hover { + color: var(--primary-color); +} + +@media (max-width: 768px) { + .nav-links { + display: none; + position: fixed; + top: 70px; + left: 0; + right: 0; + background: var(--background); + flex-direction: column; + padding: 2rem; + box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); + text-align: center; + transform: translateY(-100%); + opacity: 0; + transition: transform 0.3s ease, opacity 0.3s ease; + } + + .nav-links.active { + display: flex; + transform: translateY(0); + opacity: 1; + } + + .mobile-nav-toggle { + display: block; + background: none; + border: none; + cursor: pointer; + } + + .mobile-nav-toggle span { + display: block; + width: 25px; + height: 3px; + background: var(--text-color); + margin: 5px 0; + transition: 0.3s; + position: relative; + } + + .mobile-nav-toggle.active span:nth-child(1) { + transform: rotate(45deg) translate(5px, 5px); + } + + .mobile-nav-toggle.active span:nth-child(2) { + opacity: 0; + } + + .mobile-nav-toggle.active span:nth-child(3) { + transform: rotate(-45deg) translate(7px, -6px); + } + + .hero h1 { + font-size: 2.5rem; + } + + .feature-grid { + grid-template-columns: 1fr; + } +} \ No newline at end of file diff --git a/example/workspace/index.html b/example/workspace/index.html new file mode 100644 index 0000000..40277ec --- /dev/null +++ b/example/workspace/index.html @@ -0,0 +1,113 @@ + + + + + + Modern Landing Page + + + + +
+
+ +
+ +
+
+
+

Welcome to the Future

+

Experience innovation like never before

+ +
+
+ +
+
+

Our Features

+
+
+ 🚀 +

Fast Performance

+

Lightning-quick loading times

+
+
+ 🎨 +

Beautiful Design

+

Stunning visual experience

+
+
+ 📱 +

Responsive

+

Works on all devices

+
+
+
+
+ +
+
+

Get in Touch

+
+ + + + +
+
+
+
+ + + + + + + \ No newline at end of file diff --git a/example/workspace/js/main.js b/example/workspace/js/main.js new file mode 100644 index 0000000..8cefa43 --- /dev/null +++ b/example/workspace/js/main.js @@ -0,0 +1,85 @@ +const mobileNavToggle = document.querySelector('.mobile-nav-toggle'); +const navLinks = document.querySelector('.nav-links'); + +mobileNavToggle.addEventListener('click', () => { + navLinks.classList.toggle('active'); + mobileNavToggle.classList.toggle('active'); +}); + +document.addEventListener('click', (e) => { + if (!e.target.closest('.nav') && navLinks.classList.contains('active')) { + navLinks.classList.remove('active'); + mobileNavToggle.classList.remove('active'); + } +}); + +document.querySelectorAll('a[href^="#"]').forEach(anchor => { + anchor.addEventListener('click', function (e) { + e.preventDefault(); + document.querySelector(this.getAttribute('href')).scrollIntoView({ + behavior: 'smooth' + }); + }); +}); + +const form = document.querySelector('.contact-form'); +form.addEventListener('submit', (e) => { + e.preventDefault(); + const formData = new FormData(form); + const data = Object.fromEntries(formData); + console.log('Form submitted:', data); + form.reset(); +}); + +let lastScroll = 0; +window.addEventListener('load', () => { + setTimeout(() => { + document.querySelector('.loading').classList.add('loaded'); + }, 500); +}); + +const scrollTopBtn = document.querySelector('.scroll-top'); +scrollTopBtn.addEventListener('click', () => { + window.scrollTo({ + top: 0, + behavior: 'smooth' + }); +}); + +const observeElements = () => { + const observer = new IntersectionObserver((entries) => { + entries.forEach(entry => { + if (entry.isIntersecting) { + entry.target.style.opacity = '1'; + entry.target.style.transform = 'translateY(0)'; + } + }); + }, { threshold: 0.1 }); + + document.querySelectorAll('.feature-card').forEach(card => { + observer.observe(card); + card.style.opacity = '0'; + card.style.transform = 'translateY(20px)'; + card.style.transition = 'all 0.6s ease'; + }); +}; + +document.addEventListener('DOMContentLoaded', observeElements); + +window.addEventListener('scroll', () => { + scrollTopBtn.classList.toggle('visible', window.scrollY > 500); + const header = document.querySelector('.header'); + const currentScroll = window.pageYOffset; + + if (currentScroll <= 0) { + header.classList.remove('scrolled'); + return; + } + + if (currentScroll > lastScroll && !header.classList.contains('scrolled')) { + header.classList.add('scrolled'); + } else if (currentScroll < lastScroll && header.classList.contains('scrolled')) { + header.classList.remove('scrolled'); + } + lastScroll = currentScroll; +}); \ No newline at end of file