diff --git a/agentpress/__init__.py b/agentpress/__init__.py
new file mode 100644
index 0000000..3ebeb22
--- /dev/null
+++ b/agentpress/__init__.py
@@ -0,0 +1 @@
+# Empty file to mark as package
\ No newline at end of file
diff --git a/agentpress/examples/.env.example b/agentpress/agents/.env.example
similarity index 100%
rename from agentpress/examples/.env.example
rename to agentpress/agents/.env.example
diff --git a/agentpress/examples/simple_web_dev/agent.py b/agentpress/agents/simple_web_dev/agent.py
similarity index 52%
rename from agentpress/examples/simple_web_dev/agent.py
rename to agentpress/agents/simple_web_dev/agent.py
index 16c0f15..2a7be66 100644
--- a/agentpress/examples/simple_web_dev/agent.py
+++ b/agentpress/agents/simple_web_dev/agent.py
@@ -8,23 +8,22 @@
- Use either XML or Standard tool calling patterns
"""
-import os
import asyncio
import json
from agentpress.thread_manager import ThreadManager
-from tools.files_tool import FilesTool
+from example.tools.files_tool import FilesTool
from agentpress.state_manager import StateManager
-from tools.terminal_tool import TerminalTool
-from agentpress.api_factory import register_api_endpoint
+from example.tools.terminal_tool import TerminalTool
import logging
from typing import AsyncGenerator, Optional, Dict, Any
import sys
+from agentpress.api.api_factory import register_thread_task_api
+
BASE_SYSTEM_MESSAGE = """
You are a world-class web developer who can create, edit, and delete files, and execute terminal commands.
You write clean, well-structured code. Keep iterating on existing files, continue working on this existing
codebase - do not omit previous progress; instead, keep iterating.
-
Available tools:
- create_file: Create new files with specified content
- delete_file: Remove existing files
@@ -69,7 +68,6 @@
}
}
Think deeply and step by step.
-
"""
XML_FORMAT = """
@@ -88,87 +86,77 @@
+
"""
-def get_anthropic_api_key():
- """Get Anthropic API key from environment or prompt user."""
- api_key = os.getenv("ANTHROPIC_API_KEY")
- if not api_key:
- api_key = input("\n🔑 Please enter your Anthropic API key: ").strip()
- if not api_key:
- print("❌ No API key provided. Please set ANTHROPIC_API_KEY environment variable or enter a key.")
- sys.exit(1)
- os.environ["ANTHROPIC_API_KEY"] = api_key
- return api_key
-
-@register_api_endpoint("/main_agent")
+@register_thread_task_api("/agent")
async def run_agent(
thread_id: str,
- use_xml: bool = True,
max_iterations: int = 5,
- project_description: Optional[str] = None
+ user_input: Optional[str] = None,
) -> Dict[str, Any]:
- """Run the development agent with specified configuration."""
- # Initialize managers
- thread_manager = ThreadManager()
- await thread_manager.initialize()
+ """Run the development agent with specified configuration.
- state_manager = StateManager(thread_id)
- await state_manager.initialize()
-
- # Register tools
- thread_manager.add_tool(FilesTool, thread_id=thread_id)
- thread_manager.add_tool(TerminalTool, thread_id=thread_id)
+ Args:
+ thread_id (str): The ID of the thread.
+ max_iterations (int, optional): The maximum number of iterations. Defaults to 5.
+ user_input (Optional[str], optional): The user input. Defaults to None.
+ """
+ thread_manager = ThreadManager()
+ state_manager = StateManager(thread_id)
- # Add initial project description if provided
- if project_description:
+ if user_input:
await thread_manager.add_message(
thread_id,
{
"role": "user",
- "content": project_description
+ "content": user_input
}
)
- # Set up system message with appropriate format
+ thread_manager.add_tool(FilesTool, thread_id=thread_id)
+ thread_manager.add_tool(TerminalTool, thread_id=thread_id)
+
system_message = {
"role": "system",
- "content": BASE_SYSTEM_MESSAGE + (XML_FORMAT if use_xml else "")
+ "content": BASE_SYSTEM_MESSAGE + XML_FORMAT
}
- # Create initial event to track agent loop
- await thread_manager.create_event(
- thread_id=thread_id,
- event_type="agent_loop_started",
- content={
- "max_iterations": max_iterations,
- "use_xml": use_xml,
- "project_description": project_description
- },
- include_in_llm_message_history=False
- )
-
- results = []
iteration = 0
while iteration < max_iterations:
iteration += 1
- files_tool = FilesTool(thread_id)
- await files_tool._init_workspace_state()
+ files_tool = FilesTool(thread_id=thread_id)
- state = await state_manager.get_latest_state()
-
- state_message = {
- "role": "user",
- "content": f"""
-Current development environment workspace state:
-
-{json.dumps(state, indent=2)}
-
- """
+ state = await state_manager.export_store()
+
+ temporary_message_content = f"""
+ You are tasked to complete the LATEST USER REQUEST!
+
+ {user_input}
+
+
+ Current development environment workspace state:
+
+ {json.dumps(state, indent=2) if state else "{}"}
+
+
+ CONTINUE WITH THE TASK! USE THE SESSION TOOL TO STOP THE SESSION IF THE TASK IS COMPLETE.
+ """
+
+ await thread_manager.add_message(
+ thread_id=thread_id,
+ message_data=temporary_message_content,
+ message_type="temporary_message",
+ include_in_llm_message_history=False
+ )
+
+ temporary_message = {
+ "role": "user",
+ "content": temporary_message_content
}
- model_name = "anthropic/claude-3-5-sonnet-latest"
+ model_name = "anthropic/claude-3-5-sonnet-latest"
response = await thread_manager.run_thread(
thread_id=thread_id,
@@ -177,51 +165,57 @@ async def run_agent(
temperature=0.1,
max_tokens=8096,
tool_choice="auto",
- temporary_message=state_message,
- native_tool_calling=not use_xml,
- xml_tool_calling=use_xml,
+ temporary_message=temporary_message,
+ native_tool_calling=False,
+ xml_tool_calling=True,
stream=True,
- execute_tools_on_stream=False,
- parallel_tool_execution=True
+ execute_tools_on_stream=True,
+ parallel_tool_execution=True,
)
- # Handle both streaming and regular responses
- if hasattr(response, '__aiter__'):
- chunks = []
+ if isinstance(response, AsyncGenerator):
+ print("\n🤖 Assistant is responding:")
try:
async for chunk in response:
- chunks.append(chunk)
+ if hasattr(chunk.choices[0], 'delta'):
+ delta = chunk.choices[0].delta
+
+ if hasattr(delta, 'content') and delta.content is not None:
+ content = delta.content
+ print(content, end='', flush=True)
+
+ # Check for open_files_in_editor tag and continue if found
+ if '' in content:
+ print("\n📂 Opening files in editor, continuing to next iteration...")
+ continue
+
+ if hasattr(delta, 'tool_calls') and delta.tool_calls:
+ for tool_call in delta.tool_calls:
+ if tool_call.function:
+ if tool_call.function.name:
+ print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True)
+ if tool_call.function.arguments:
+ print(f" {tool_call.function.arguments}", end='', flush=True)
+
+ print("\n✨ Response completed\n")
+
except Exception as e:
+ print(f"\n❌ Error processing stream: {e}", file=sys.stderr)
logging.error(f"Error processing stream: {e}")
- raise
- response = chunks
-
- results.append({
- "iteration": iteration,
- "response": response
- })
+ else:
+ print("\nNon-streaming response received:", response)
- # Create iteration completion event
- await thread_manager.create_event(
- thread_id=thread_id,
- event_type="iteration_complete",
- content={
- "iteration_number": iteration,
- "max_iterations": max_iterations,
- # "state": state
- },
- include_in_llm_message_history=False
- )
+ # # Get latest assistant message and check for stop_session
+ # latest_msg = await thread_manager.get_llm_history_messages(
+ # thread_id=thread_id,
+ # only_latest_assistant=True
+ # )
+ # if latest_msg and '' in latest_msg:
+ # break
- return {
- "thread_id": thread_id,
- "iterations": results,
- }
if __name__ == "__main__":
- print("\n🚀 Welcome to AgentPress Web Developer Example!")
-
- get_anthropic_api_key()
+ print("\n🚀 Welcome to AgentPress!")
project_description = input("What would you like to build? (default: Create a modern, responsive landing page)\n> ")
if not project_description.strip():
@@ -241,10 +235,27 @@ async def run_agent(
print(f"\n{'XML-based' if use_xml else 'Standard'} agent will help you build: {project_description}")
print("Use Ctrl+C to stop the agent at any time.")
- async def async_main():
+ async def test_agent():
thread_manager = ThreadManager()
thread_id = await thread_manager.create_thread()
logging.info(f"Created new thread: {thread_id}")
- await run_agent(thread_id, use_xml, project_description=project_description)
-
- asyncio.run(async_main())
\ No newline at end of file
+
+ try:
+ result = await run_agent(
+ thread_id=thread_id,
+ max_iterations=5,
+ user_input=project_description,
+ )
+ print("\n✅ Test completed successfully!")
+
+ except Exception as e:
+ print(f"\n❌ Test failed: {str(e)}")
+ raise
+
+ try:
+ asyncio.run(test_agent())
+ except KeyboardInterrupt:
+ print("\n⚠️ Test interrupted by user")
+ except Exception as e:
+ print(f"\n❌ Test failed with error: {str(e)}")
+ raise
\ No newline at end of file
diff --git a/agentpress/examples/simple_web_dev/tools/files_tool.py b/agentpress/agents/simple_web_dev/tools/files_tool.py
similarity index 97%
rename from agentpress/examples/simple_web_dev/tools/files_tool.py
rename to agentpress/agents/simple_web_dev/tools/files_tool.py
index 9dca71a..0d5948d 100644
--- a/agentpress/examples/simple_web_dev/tools/files_tool.py
+++ b/agentpress/agents/simple_web_dev/tools/files_tool.py
@@ -60,13 +60,8 @@ def __init__(self, thread_id: Optional[str] = None):
os.makedirs(self.workspace, exist_ok=True)
if thread_id:
self.state_manager = StateManager(thread_id)
- asyncio.create_task(self._init_state())
- self.SNIPPET_LINES = 4 # Number of context lines to show around edits
-
- async def _init_state(self):
- """Initialize state manager and workspace state."""
- await self.state_manager.initialize()
- await self._init_workspace_state()
+ asyncio.create_task(self._init_workspace_state())
+ self.SNIPPET_LINES = 4
def _should_exclude_file(self, rel_path: str) -> bool:
"""Check if a file should be excluded based on path, name, or extension"""
@@ -264,6 +259,9 @@ async def str_replace(self, file_path: str, old_str: str, new_str: str) -> ToolR
new_content = content.replace(old_str, new_str)
full_path.write_text(new_content)
+ # Update state after file modification
+ await self._update_workspace_state()
+
# Show snippet around the edit
replacement_line = content.split(old_str)[0].count('\n')
start_line = max(0, replacement_line - self.SNIPPET_LINES)
diff --git a/agentpress/examples/simple_web_dev/tools/terminal_tool.py b/agentpress/agents/simple_web_dev/tools/terminal_tool.py
similarity index 69%
rename from agentpress/examples/simple_web_dev/tools/terminal_tool.py
rename to agentpress/agents/simple_web_dev/tools/terminal_tool.py
index 3c5b176..c9736dc 100644
--- a/agentpress/examples/simple_web_dev/tools/terminal_tool.py
+++ b/agentpress/agents/simple_web_dev/tools/terminal_tool.py
@@ -2,7 +2,6 @@
import asyncio
import subprocess
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
-from agentpress.state_manager import StateManager
from typing import Optional
class TerminalTool(Tool):
@@ -12,24 +11,6 @@ def __init__(self, thread_id: Optional[str] = None):
super().__init__()
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
os.makedirs(self.workspace, exist_ok=True)
- if thread_id:
- self.state_manager = StateManager(thread_id)
- asyncio.create_task(self._init_state())
-
- async def _init_state(self):
- """Initialize state manager."""
- await self.state_manager.initialize()
-
- async def _update_command_history(self, command: str, output: str, success: bool):
- """Update command history in state"""
- history = await self.state_manager.get("terminal_history") or []
- history.append({
- "command": command,
- "output": output,
- "success": success,
- "cwd": os.path.relpath(os.getcwd(), self.workspace)
- })
- await self.state_manager.set("terminal_history", history)
@openapi_schema({
"type": "function",
@@ -76,12 +57,6 @@ async def execute_command(self, command: str) -> ToolResult:
error = stderr.decode() if stderr else ""
success = process.returncode == 0
- await self._update_command_history(
- command=command,
- output=output + error,
- success=success
- )
-
if success:
return self.success_response({
"output": output,
@@ -93,11 +68,6 @@ async def execute_command(self, command: str) -> ToolResult:
return self.fail_response(f"Command failed with exit code {process.returncode}: {error}")
except Exception as e:
- await self._update_command_history(
- command=command,
- output=str(e),
- success=False
- )
return self.fail_response(f"Error executing command: {str(e)}")
finally:
os.chdir(original_dir)
diff --git a/agentpress/api.py b/agentpress/api.py
deleted file mode 100644
index e52bdb4..0000000
--- a/agentpress/api.py
+++ /dev/null
@@ -1,253 +0,0 @@
-from contextlib import asynccontextmanager
-from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, BackgroundTasks
-from fastapi.middleware.cors import CORSMiddleware
-from typing import Optional, List, Dict, Any
-from pydantic import BaseModel
-from agentpress.thread_manager import ThreadManager
-import asyncio
-import json
-import uvicorn
-import logging
-import importlib
-from agentpress.api_factory import app as api_app, discover_tasks
-
-# Configure logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-# Global managers
-thread_manager: Optional[ThreadManager] = None
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- """Lifespan context manager for FastAPI application."""
- # Startup
- global thread_manager
- thread_manager = ThreadManager()
- await thread_manager.initialize()
-
- yield
-
- # Shutdown
- # Add any cleanup code here if needed
-
-# Create FastAPI app
-app = FastAPI(title="AgentPress API", lifespan=lifespan)
-
-# Enable CORS
-app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
-)
-
-# Import and mount the API Factory app
-try:
- # Run task discovery
- discover_tasks()
- logger.info("Task discovery completed")
-
- # Mount the API Factory app at /tasks instead of root
- app.mount("/tasks", api_app)
- logger.info("Mounted API Factory app at /tasks")
-except Exception as e:
- logger.error(f"Error setting up API Factory: {e}")
- raise
-
-# WebSocket connection manager
-class WebSocketManager:
- """Manages WebSocket connections for real-time thread updates."""
-
- def __init__(self):
- self.active_connections: Dict[str, List[WebSocket]] = {}
-
- async def connect(self, websocket: WebSocket, thread_id: str):
- """Connect a WebSocket to a thread."""
- await websocket.accept()
- if thread_id not in self.active_connections:
- self.active_connections[thread_id] = []
- self.active_connections[thread_id].append(websocket)
-
- def disconnect(self, websocket: WebSocket, thread_id: str):
- """Disconnect a WebSocket from a thread."""
- if thread_id in self.active_connections:
- self.active_connections[thread_id].remove(websocket)
- if not self.active_connections[thread_id]:
- del self.active_connections[thread_id]
-
- async def broadcast_to_thread(self, thread_id: str, message: dict):
- """Broadcast a message to all connections in a thread."""
- if thread_id in self.active_connections:
- for connection in self.active_connections[thread_id]:
- try:
- await connection.send_json(message)
- except WebSocketDisconnect:
- self.disconnect(connection, thread_id)
-
-# Initialize WebSocket manager
-ws_manager = WebSocketManager()
-
-# Pydantic models for request/response validation
-class EventCreate(BaseModel):
- event_type: str
- content: Dict[str, Any]
- include_in_llm_message_history: bool = False
- llm_message: Optional[Dict[str, Any]] = None
-
-class EventUpdate(BaseModel):
- content: Optional[Dict[str, Any]] = None
- include_in_llm_message_history: Optional[bool] = None
- llm_message: Optional[Dict[str, Any]] = None
-
-class ThreadEvents(BaseModel):
- only_llm_messages: bool = False
- event_types: Optional[List[str]] = None
- order_by: str = "created_at"
- order: str = "ASC"
-
-# REST API Endpoints
-@app.post("/threads", response_model=dict, status_code=201)
-async def create_thread():
- """Create a new thread."""
- thread_id = await thread_manager.create_thread()
- return {"thread_id": thread_id}
-
-@app.delete("/threads/{thread_id}", status_code=204)
-async def delete_thread(thread_id: str):
- """Delete a thread and all its events."""
- success = await thread_manager.delete_thread(thread_id)
- if not success:
- raise HTTPException(status_code=404, detail="Thread not found")
- return None
-
-@app.post("/threads/{thread_id}/events", response_model=dict, status_code=201)
-async def create_event(thread_id: str, event: EventCreate, background_tasks: BackgroundTasks):
- """Create a new event in a thread."""
- # First verify thread exists
- if not await thread_manager.thread_exists(thread_id):
- raise HTTPException(status_code=404, detail="Thread not found")
-
- try:
- event_id = await thread_manager.create_event(
- thread_id=thread_id,
- event_type=event.event_type,
- content=event.content,
- include_in_llm_message_history=event.include_in_llm_message_history,
- llm_message=event.llm_message
- )
- # Broadcast to WebSocket connections
- background_tasks.add_task(
- ws_manager.broadcast_to_thread,
- thread_id,
- {"type": "event_created", "event_id": event_id, "event": event.dict()}
- )
- return {"event_id": event_id}
- except Exception as e:
- raise HTTPException(status_code=400, detail=str(e))
-
-@app.delete("/threads/{thread_id}/events/{event_id}", status_code=204)
-async def delete_event(thread_id: str, event_id: str, background_tasks: BackgroundTasks):
- """Delete a specific event."""
- # First verify thread exists
- if not await thread_manager.thread_exists(thread_id):
- raise HTTPException(status_code=404, detail="Thread not found")
-
- # Then verify event exists and belongs to thread
- if not await thread_manager.event_belongs_to_thread(event_id, thread_id):
- raise HTTPException(status_code=404, detail="Event not found in this thread")
-
- success = await thread_manager.delete_event(event_id)
- if not success:
- raise HTTPException(status_code=500, detail="Failed to delete event")
-
- # Broadcast to WebSocket connections
- background_tasks.add_task(
- ws_manager.broadcast_to_thread,
- thread_id,
- {"type": "event_deleted", "event_id": event_id}
- )
- return None
-
-@app.patch("/threads/{thread_id}/events/{event_id}", status_code=200)
-async def update_event(thread_id: str, event_id: str, event: EventUpdate, background_tasks: BackgroundTasks):
- """Update an existing event."""
- # First verify thread exists
- if not await thread_manager.thread_exists(thread_id):
- raise HTTPException(status_code=404, detail="Thread not found")
-
- # Then verify event exists and belongs to thread
- if not await thread_manager.event_belongs_to_thread(event_id, thread_id):
- raise HTTPException(status_code=404, detail="Event not found in this thread")
-
- success = await thread_manager.update_event(
- event_id=event_id,
- thread_id=thread_id,
- content=event.content,
- include_in_llm_message_history=event.include_in_llm_message_history,
- llm_message=event.llm_message
- )
- if not success:
- raise HTTPException(status_code=500, detail="Failed to update event")
-
- # Broadcast to WebSocket connections
- background_tasks.add_task(
- ws_manager.broadcast_to_thread,
- thread_id,
- {"type": "event_updated", "event_id": event_id, "updates": event.dict(exclude_unset=True)}
- )
- return {"status": "success"}
-
-@app.get("/threads/{thread_id}/events")
-async def get_thread_events(
- thread_id: str,
- only_llm_messages: bool = False,
- event_types: Optional[List[str]] = None,
- order_by: str = "created_at",
- order: str = "ASC"
-):
- """Get events from a thread with filtering options."""
- if not await thread_manager.thread_exists(thread_id):
- raise HTTPException(status_code=404, detail="Thread not found")
-
- events = await thread_manager.get_thread_events(
- thread_id=thread_id,
- only_llm_messages=only_llm_messages,
- event_types=event_types,
- order_by=order_by,
- order=order
- )
- return {"events": events}
-
-@app.get("/threads/{thread_id}/messages")
-async def get_thread_messages(thread_id: str):
- """Get LLM-formatted messages from thread events."""
- if not await thread_manager.thread_exists(thread_id):
- raise HTTPException(status_code=404, detail="Thread not found")
-
- messages = await thread_manager.get_thread_llm_messages(thread_id)
- return {"messages": messages}
-
-# WebSocket Endpoint
-@app.websocket("/ws/threads/{thread_id}")
-async def websocket_endpoint(websocket: WebSocket, thread_id: str):
- """WebSocket endpoint for real-time thread updates."""
- # Verify thread exists before accepting connection
- if not await thread_manager.thread_exists(thread_id):
- await websocket.close(code=4004, reason="Thread not found")
- return
-
- await ws_manager.connect(websocket, thread_id)
- try:
- while True:
- await websocket.receive_json()
-
- except WebSocketDisconnect:
- ws_manager.disconnect(websocket, thread_id)
- except Exception as e:
- await websocket.send_json({"type": "error", "detail": str(e)})
- ws_manager.disconnect(websocket, thread_id)
-
-if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=8000)
\ No newline at end of file
diff --git a/agentpress/api/__init__.py b/agentpress/api/__init__.py
new file mode 100644
index 0000000..3ebeb22
--- /dev/null
+++ b/agentpress/api/__init__.py
@@ -0,0 +1 @@
+# Empty file to mark as package
\ No newline at end of file
diff --git a/agentpress/api/api.py b/agentpress/api/api.py
new file mode 100644
index 0000000..480e3d5
--- /dev/null
+++ b/agentpress/api/api.py
@@ -0,0 +1,255 @@
+from contextlib import asynccontextmanager
+from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, BackgroundTasks
+from fastapi.middleware.cors import CORSMiddleware
+from typing import Optional, List, Dict, Any, Union
+from pydantic import BaseModel
+from agentpress.thread_manager import ThreadManager
+import asyncio
+import uvicorn
+import logging
+from agentpress.api.ws import ws_manager
+from agentpress.api.api_factory import (
+ app as thread_task_app,
+ register_thread_task_api,
+ discover_tasks,
+ thread_manager as task_thread_manager
+)
+# from agentpress.api_factory import app as api_app, discover_tasks
+
+# Configure logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# Global managers
+thread_manager: Optional[ThreadManager] = None
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ """Lifespan context manager for FastAPI application."""
+ # Startup
+ global thread_manager
+ thread_manager = ThreadManager()
+
+ # Share thread_manager with task API
+ global task_thread_manager
+ task_thread_manager = thread_manager
+
+ # Wait for DB initialization
+ db = thread_manager.db
+ if db._initialization_task:
+ await db._initialization_task
+
+ # Run task discovery during startup
+ discover_tasks()
+
+ yield
+
+ # Shutdown
+ # Add any cleanup code here if needed
+
+# Create FastAPI app
+app = FastAPI(title="AgentPress API", lifespan=lifespan)
+
+# Enable CORS
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+# # Import and mount the API Factory app
+# try:
+# # Run task discovery
+# # discover_tasks()
+# logger.info("Task discovery completed")
+
+# # Mount the API Factory app at /tasks instead of root
+# app.mount("/tasks", api_app)
+# logger.info("Mounted API Factory app at /tasks")
+# except Exception as e:
+# logger.error(f"Error setting up API Factory: {e}")
+# raise
+
+# Pydantic models for request/response validation
+class MessageCreate(BaseModel):
+ """Model for creating messages in a thread."""
+ message_data: Union[str, Dict[str, Any]]
+ images: Optional[List[Dict[str, Any]]] = None
+ include_in_llm_message_history: bool = True
+ message_type: Optional[str] = None
+
+# REST API Endpoints
+@app.post("/threads", response_model=dict, status_code=201)
+async def create_thread():
+ """Create a new thread."""
+ thread_id = await thread_manager.create_thread()
+ return {"thread_id": thread_id}
+
+@app.post("/threads/{thread_id}/messages", response_model=dict, status_code=201)
+async def create_message(thread_id: str, message: MessageCreate, background_tasks: BackgroundTasks):
+ """Create a new message in a thread."""
+ if not await thread_manager.thread_exists(thread_id):
+ raise HTTPException(status_code=404, detail="Thread not found")
+
+ try:
+ await thread_manager.add_message(
+ thread_id=thread_id,
+ message_data=message.message_data,
+ images=message.images,
+ include_in_llm_message_history=message.include_in_llm_message_history,
+ message_type=message.message_type
+ )
+
+ # Broadcast to WebSocket connections
+ background_tasks.add_task(
+ ws_manager.broadcast_to_thread,
+ thread_id,
+ {"type": "message_created", "message": message.dict()}
+ )
+ return {"status": "success"}
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=str(e))
+
+# TODO: BROKEN FOR SOME REASON – RETURNS [] SHOULD RETURN, LLM MESSAGE STYLE
+@app.get("/threads/{thread_id}/llm_history_messages")
+async def get_thread_llm_messages(
+ thread_id: str,
+ hide_tool_msgs: bool = False,
+ only_latest_assistant: bool = False,
+):
+ """Get messages from a thread with filtering options."""
+ if not await thread_manager.thread_exists(thread_id):
+ raise HTTPException(status_code=404, detail="Thread not found")
+
+ messages = await thread_manager.get_llm_history_messages(
+ thread_id=thread_id,
+ hide_tool_msgs=hide_tool_msgs,
+ only_latest_assistant=only_latest_assistant,
+ )
+ return {"messages": messages}
+
+@app.get("/threads/{thread_id}/messages")
+async def get_thread_messages(
+ thread_id: str,
+ message_types: Optional[List[str]] = None,
+ limit: Optional[int] = 50,
+ offset: Optional[int] = 0,
+ before_timestamp: Optional[str] = None,
+ after_timestamp: Optional[str] = None,
+ include_in_llm_message_history: Optional[bool] = None,
+ order: str = "asc"
+):
+ """
+ Get messages from a thread with comprehensive filtering options.
+
+ Args:
+ thread_id: Thread identifier
+ message_types: Optional list of message types to filter by
+ limit: Maximum number of messages to return (default: 50)
+ offset: Number of messages to skip for pagination
+ before_timestamp: Optional filter for messages before timestamp
+ after_timestamp: Optional filter for messages after timestamp
+ include_in_llm_message_history: Optional filter for LLM history inclusion
+ order: Sort order - "asc" or "desc"
+ """
+ if not await thread_manager.thread_exists(thread_id):
+ raise HTTPException(status_code=404, detail="Thread not found")
+
+ try:
+ messages = await thread_manager.get_messages(
+ thread_id=thread_id,
+ message_types=message_types,
+ limit=limit,
+ offset=offset,
+ before_timestamp=before_timestamp,
+ after_timestamp=after_timestamp,
+ include_in_llm_message_history=include_in_llm_message_history,
+ order=order
+ )
+ return {"messages": messages}
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=str(e))
+
+# TODO ONLY SEND POLLING UPDATES (IN EVEN HIGHER FREQUENCY THEN 1per sec) - IF THEY ARE ANY ACTIVE TASKS FOR THAT THREAD. AS LONG AS THEY ARE ACTIVE TASKS START & STOP THE POLLING BASED ON WHETHER THERE IS AN ACTIVE TASK FOR THE THREAD. IMPLEMENT in API_FACTORY as well to broadcast this ofc & trigger/disable the polling.
+
+# WebSocket Endpoint
+@app.websocket("/threads/{thread_id}")
+async def websocket_endpoint(
+ websocket: WebSocket,
+ thread_id: str,
+ message_types: Optional[List[str]] = None,
+ limit: Optional[int] = 50,
+ offset: Optional[int] = 0,
+ before_timestamp: Optional[str] = None,
+ after_timestamp: Optional[str] = None,
+ include_in_llm_message_history: Optional[bool] = None,
+ order: str = "desc"
+):
+ """
+ WebSocket endpoint for real-time thread updates with filtering and pagination.
+
+ Query Parameters:
+ message_types: Optional list of message types to filter by
+ limit: Maximum number of messages to return (default: 50)
+ offset: Number of messages to skip (for pagination)
+ before_timestamp: Optional timestamp to filter messages before
+ after_timestamp: Optional timestamp to filter messages after
+ include_in_llm_message_history: Optional bool to filter messages by LLM history inclusion
+ order: Sort order - "asc" or "desc" (default: desc)
+ """
+ try:
+ if not await thread_manager.thread_exists(thread_id):
+ await websocket.close(code=4004, reason="Thread not found")
+ return
+
+ await ws_manager.connect(websocket, thread_id)
+
+ while True:
+ try:
+ # Get messages with all filters
+ result = await thread_manager.get_messages(
+ thread_id=thread_id,
+ message_types=message_types,
+ limit=limit,
+ offset=offset,
+ before_timestamp=before_timestamp,
+ after_timestamp=after_timestamp,
+ include_in_llm_message_history=include_in_llm_message_history,
+ order=order
+ )
+
+ # Send messages and pagination info
+ await websocket.send_json({
+ "type": "messages",
+ "data": result
+ })
+
+ # Poll every second
+ await asyncio.sleep(1)
+
+ except WebSocketDisconnect:
+ ws_manager.disconnect(websocket, thread_id)
+ break
+ except Exception as e:
+ logging.error(f"WebSocket error: {e}")
+ await websocket.send_json({
+ "type": "error",
+ "data": str(e)
+ })
+ ws_manager.disconnect(websocket, thread_id)
+ break
+
+ except Exception as e:
+ logging.error(f"WebSocket connection error: {e}")
+ try:
+ await websocket.close(code=1011, reason=str(e))
+ except:
+ pass
+
+# Update the mounting of thread_task_app
+app.mount("/tasks", thread_task_app)
+
+if __name__ == "__main__":
+ uvicorn.run(app, host="0.0.0.0", port=8000)
\ No newline at end of file
diff --git a/agentpress/api/api_factory.py b/agentpress/api/api_factory.py
new file mode 100644
index 0000000..70d37c9
--- /dev/null
+++ b/agentpress/api/api_factory.py
@@ -0,0 +1,349 @@
+"""
+Thread Task API Factory for registering and managing thread-associated long-running tasks.
+"""
+
+import sys
+import os
+import inspect
+import uuid
+import asyncio
+import logging
+import importlib
+from functools import wraps
+from typing import Callable, Dict, Any, Optional, List, ForwardRef
+from fastapi import FastAPI, HTTPException, Request
+from pydantic import create_model
+from agentpress.thread_manager import ThreadManager
+from contextlib import asynccontextmanager
+
+# Configure logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# Initialize managers at module level
+thread_manager: Optional[ThreadManager] = None
+_decorated_functions: Dict[str, Callable] = {}
+_running_tasks: Dict[str, Dict[str, Any]] = {}
+
+def find_project_root():
+ """Find the project root by looking for pyproject.toml"""
+ current = os.path.abspath(os.path.dirname(__file__))
+ while current != '/':
+ if os.path.exists(os.path.join(current, 'pyproject.toml')):
+ return current
+ current = os.path.dirname(current)
+ return None
+
+def discover_tasks():
+ """
+ Discover all decorated functions in the project.
+ Scans from the project root (where pyproject.toml is located).
+ """
+ logger.info("Starting task discovery")
+
+ # Find project root
+ project_root = find_project_root()
+ logger.info(f"Project root found at: {project_root}")
+
+ # Add project root to Python path if not already there
+ if project_root not in sys.path:
+ sys.path.insert(0, project_root)
+
+ # Walk through all Python files in the project
+ for root, _, files in os.walk(project_root):
+ for file in files:
+ if file.endswith('.py'):
+ module_path = os.path.join(root, file)
+ module_name = os.path.relpath(module_path, project_root)
+ module_name = os.path.splitext(module_name)[0].replace(os.path.sep, '.')
+
+ try:
+ logger.info(f"Attempting to import module: {module_name}")
+ module = importlib.import_module(module_name)
+
+ # Inspect all module members
+ for name, obj in inspect.getmembers(module):
+ if inspect.isfunction(obj):
+ # Check if this function has been decorated with register_thread_task_api
+ if hasattr(obj, '__closure__') and obj.__closure__:
+ for cell in obj.__closure__:
+ if cell.cell_contents in _decorated_functions.values():
+ path = next(
+ p for p, f in _decorated_functions.items()
+ if f == cell.cell_contents
+ )
+ _decorated_functions[path] = obj
+ logger.info(f"Registered function: {obj.__name__} at path: {path}")
+
+ except Exception as e:
+ logger.error(f"Error importing {module_name}: {e}")
+
+ logger.info(f"Task discovery complete. Registered paths: {list(_decorated_functions.keys())}")
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ """Lifespan context manager for FastAPI application."""
+ global thread_manager
+
+ # Initialize ThreadManager if not already initialized
+ if thread_manager is None:
+ thread_manager = ThreadManager()
+ # Wait for DB initialization
+ if thread_manager.db._initialization_task:
+ await thread_manager.db._initialization_task
+
+ # Run task discovery during startup
+ discover_tasks()
+
+ yield
+ # Cleanup if needed
+
+# Create FastAPI app
+app = FastAPI(
+ title="Thread Task API",
+ description="API for managing thread-associated long-running tasks",
+ openapi_tags=[{
+ "name": "Generated Thread Tasks",
+ "description": "Dynamically generated endpoints for thread-associated tasks"
+ }],
+ lifespan=lifespan
+)
+
+# Add middleware to ensure thread_manager is available
+@app.middleware("http")
+async def ensure_thread_manager(request: Request, call_next):
+ """Ensure thread_manager is initialized before handling requests."""
+ global thread_manager
+ if thread_manager is None:
+ thread_manager = ThreadManager()
+ if thread_manager.db._initialization_task:
+ await thread_manager.db._initialization_task
+ return await call_next(request)
+
+def register_thread_task_api(path: str):
+ """
+ Decorator to register a function as a thread-associated task API endpoint.
+ The decorated function must have thread_id as its first parameter.
+ """
+ def decorator(func: Callable):
+ logger.info(f"Registering thread task API endpoint: {path} for function {func.__name__}")
+ _decorated_functions[path] = func
+
+ # Validate that thread_id is the first parameter
+ params = inspect.signature(func).parameters
+ if 'thread_id' not in params:
+ raise ValueError(f"Function {func.__name__} must have thread_id as a parameter")
+
+ # Create Pydantic model for function parameters
+ model_fields = {}
+ for name, param in params.items():
+ if name == 'self': # Skip self parameter for methods
+ continue
+
+ annotation = param.annotation
+ if annotation == inspect.Parameter.empty:
+ annotation = Any
+
+ # Convert string annotations to ForwardRef
+ if isinstance(annotation, str):
+ annotation = ForwardRef(annotation)
+
+ default = ... if param.default == inspect.Parameter.empty else param.default
+ model_fields[name] = (annotation, default)
+
+ RequestModel = create_model(f'{func.__name__}Request', **model_fields)
+
+ # Register the start endpoint
+ @app.post(
+ f"{path}/start",
+ response_model=dict,
+ summary=f"Start {func.__name__}",
+ description=f"Start a new {func.__name__} task associated with a thread",
+ tags=["Generated Thread Tasks"]
+ )
+ async def start_task(params: RequestModel):
+ logger.info(f"Starting task at {path}/start")
+
+ # Validate thread exists
+ if not await thread_manager.thread_exists(params.thread_id):
+ raise HTTPException(status_code=404, detail="Thread not found")
+
+ task_id = str(uuid.uuid4())
+ kwargs = params.dict()
+
+ # Create the task
+ task = asyncio.create_task(func(**kwargs))
+
+ # Store task with thread association
+ _running_tasks[task_id] = {
+ "thread_id": params.thread_id,
+ "task": task,
+ "status": "running",
+ "path": path,
+ "started_at": asyncio.get_event_loop().time()
+ }
+
+ # Add task info to thread messages
+ await thread_manager.add_message(
+ thread_id=params.thread_id,
+ message_data={
+ "type": "task_started",
+ "task_id": task_id,
+ "path": path,
+ "status": "running"
+ },
+ message_type="task_status",
+ include_in_llm_message_history=False
+ )
+
+ return {"task_id": task_id}
+
+ # Register the stop endpoint
+ @app.post(
+ f"{path}/stop/{{task_id}}",
+ response_model=dict,
+ summary=f"Stop {func.__name__}",
+ description=f"Stop a running {func.__name__} task",
+ tags=["Generated Thread Tasks"]
+ )
+ async def stop_task(task_id: str):
+ if task_id not in _running_tasks:
+ raise HTTPException(status_code=404, detail="Task not found")
+
+ task_info = _running_tasks[task_id]
+ task_info["task"].cancel()
+ task_info["status"] = "cancelled"
+
+ # Update thread with task cancellation
+ await thread_manager.add_message(
+ thread_id=task_info["thread_id"],
+ message_data={
+ "type": "task_stopped",
+ "task_id": task_id,
+ "path": task_info["path"],
+ "status": "cancelled"
+ },
+ message_type="task_status",
+ include_in_llm_message_history=False
+ )
+
+ return {"status": "stopped"}
+
+ # Register the status endpoint
+ @app.get(
+ f"{path}/status/{{task_id}}",
+ response_model=dict,
+ summary=f"Get {func.__name__} status",
+ description=f"Get the status of a {func.__name__} task",
+ tags=["Generated Thread Tasks"]
+ )
+ async def get_status(task_id: str):
+ if task_id not in _running_tasks:
+ raise HTTPException(status_code=404, detail="Task not found")
+
+ task_info = _running_tasks[task_id]
+ task = task_info["task"]
+
+ if task.done():
+ try:
+ result = task.result()
+ status = "completed"
+ if hasattr(result, '__aiter__'):
+ status = "streaming"
+
+ # Update thread with task completion
+ await thread_manager.add_message(
+ thread_id=task_info["thread_id"],
+ message_data={
+ "type": "task_completed",
+ "task_id": task_id,
+ "path": task_info["path"],
+ "status": status,
+ "result": result if status == "completed" else None
+ },
+ message_type="task_status",
+ include_in_llm_message_history=False
+ )
+
+ return {
+ "status": status,
+ "result": result if status == "completed" else None
+ }
+
+ except asyncio.CancelledError:
+ return {"status": "cancelled"}
+ except Exception as e:
+ error_status = {
+ "status": "failed",
+ "error": str(e)
+ }
+
+ # Update thread with task failure
+ await thread_manager.add_message(
+ thread_id=task_info["thread_id"],
+ message_data={
+ "type": "task_failed",
+ "task_id": task_id,
+ "path": task_info["path"],
+ "status": "failed",
+ "error": str(e)
+ },
+ message_type="task_status",
+ include_in_llm_message_history=False
+ )
+
+ return error_status
+
+ return {"status": "running"}
+
+ @wraps(func)
+ async def wrapper(*args, **kwargs):
+ return await func(*args, **kwargs)
+ return wrapper
+ return decorator
+
+@app.get("/threads/{thread_id}/tasks")
+async def get_thread_tasks(thread_id: str):
+ """Get all tasks associated with a thread."""
+ if not await thread_manager.thread_exists(thread_id):
+ raise HTTPException(status_code=404, detail="Thread not found")
+
+ thread_tasks = {
+ task_id: {
+ "path": info["path"],
+ "status": info["status"],
+ "started_at": info["started_at"]
+ }
+ for task_id, info in _running_tasks.items()
+ if info["thread_id"] == thread_id
+ }
+
+ return {"tasks": thread_tasks}
+
+@app.delete("/threads/{thread_id}/tasks")
+async def stop_thread_tasks(thread_id: str):
+ """Stop all tasks associated with a thread."""
+ if not await thread_manager.thread_exists(thread_id):
+ raise HTTPException(status_code=404, detail="Thread not found")
+
+ stopped_tasks = []
+ for task_id, info in list(_running_tasks.items()):
+ if info["thread_id"] == thread_id:
+ info["task"].cancel()
+ info["status"] = "cancelled"
+ stopped_tasks.append(task_id)
+
+ # Update thread with task cancellations
+ if stopped_tasks:
+ await thread_manager.add_message(
+ thread_id=thread_id,
+ message_data={
+ "type": "tasks_stopped",
+ "task_ids": stopped_tasks,
+ "status": "cancelled"
+ },
+ message_type="task_status",
+ include_in_llm_message_history=False
+ )
+
+ return {"stopped_tasks": stopped_tasks}
\ No newline at end of file
diff --git a/agentpress/api/ws.py b/agentpress/api/ws.py
new file mode 100644
index 0000000..2ee5a28
--- /dev/null
+++ b/agentpress/api/ws.py
@@ -0,0 +1,45 @@
+"""WebSocket management system for real-time updates."""
+
+import logging
+from typing import Dict, List
+from fastapi import WebSocket, WebSocketDisconnect
+
+class WebSocketManager:
+ """Manages WebSocket connections for real-time thread updates."""
+
+ def __init__(self):
+ self.active_connections: Dict[str, List[WebSocket]] = {}
+
+ async def connect(self, websocket: WebSocket, thread_id: str):
+ """Connect a WebSocket to a thread."""
+ await websocket.accept()
+ if thread_id not in self.active_connections:
+ self.active_connections[thread_id] = []
+ self.active_connections[thread_id].append(websocket)
+
+ def disconnect(self, websocket: WebSocket, thread_id: str):
+ """Disconnect a WebSocket from a thread."""
+ if thread_id in self.active_connections:
+ self.active_connections[thread_id].remove(websocket)
+ if not self.active_connections[thread_id]:
+ del self.active_connections[thread_id]
+
+ async def broadcast_to_thread(self, thread_id: str, message: dict):
+ """Broadcast a message to all connections in a thread."""
+ if thread_id in self.active_connections:
+ disconnected = []
+ for connection in self.active_connections[thread_id]:
+ try:
+ await connection.send_json(message)
+ except WebSocketDisconnect:
+ disconnected.append(connection)
+ except Exception as e:
+ logging.warning(f"Failed to send message to websocket: {e}")
+ disconnected.append(connection)
+
+ # Clean up disconnected connections
+ for connection in disconnected:
+ self.disconnect(connection, thread_id)
+
+# Global WebSocket manager instance
+ws_manager = WebSocketManager()
\ No newline at end of file
diff --git a/agentpress/api_factory.py b/agentpress/api_factory.py
deleted file mode 100644
index 69d0257..0000000
--- a/agentpress/api_factory.py
+++ /dev/null
@@ -1,170 +0,0 @@
-"""
-API Factory for registering and managing FastAPI endpoints.
-"""
-
-import sys
-import inspect
-import importlib
-import pkgutil
-import uuid
-import asyncio
-import logging
-import os
-from functools import wraps
-from typing import Callable, Dict, Any, Optional, List
-from fastapi import FastAPI, BackgroundTasks, HTTPException
-from pydantic import create_model, BaseModel
-
-# Configure logging
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
-
-app = FastAPI()
-_decorated_functions: Dict[str, Callable] = {}
-_running_tasks: Dict[str, asyncio.Task] = {}
-
-def register_api_endpoint(path: str):
- """Decorator to register a function as an API endpoint with task management."""
- def decorator(func: Callable):
- logger.info(f"Registering API endpoint: {path} for function {func.__name__}")
- _decorated_functions[path] = func
-
- # Create Pydantic model for function parameters
- params = inspect.signature(func).parameters
- model_fields = {
- name: (param.annotation if param.annotation != inspect.Parameter.empty else Any, ... if param.default == inspect.Parameter.empty else param.default)
- for name, param in params.items()
- if name != 'self' # Skip self parameter for methods
- }
- RequestModel = create_model(f'{func.__name__}Request', **model_fields)
-
- # Register the start endpoint
- @app.post(f"{path}/start", response_model=dict)
- async def start_task(params: Optional[RequestModel] = None):
- logger.info(f"Starting task at {path}/start")
- task_id = str(uuid.uuid4())
- kwargs = params.dict() if params else {}
- _running_tasks[task_id] = asyncio.create_task(func(**kwargs))
- return {"task_id": task_id}
-
- # Register the stop endpoint
- @app.post(f"{path}/stop/{{task_id}}")
- async def stop_task(task_id: str):
- if task_id not in _running_tasks:
- raise HTTPException(status_code=404, detail="Task not found")
- _running_tasks[task_id].cancel()
- return {"status": "stopped"}
-
- # Register the status endpoint
- @app.get(f"{path}/status/{{task_id}}")
- async def get_status(task_id: str):
- if task_id not in _running_tasks:
- raise HTTPException(status_code=404, detail="Task not found")
- task = _running_tasks[task_id]
-
- if task.done():
- try:
- result = task.result()
- # Check if this is a streaming response
- if hasattr(result, '__aiter__') or (
- isinstance(result, dict) and (
- any(hasattr(v, '__aiter__') for v in result.values()) or
- # Check for streaming responses in iterations
- (
- 'iterations' in result and
- result['iterations'] and
- any(hasattr(r.get('response'), '__aiter__') for r in result['iterations'])
- )
- )
- ):
- return {"status": "streaming"}
- return {"status": "completed", "result": result}
- except asyncio.CancelledError:
- return {"status": "cancelled"}
- except Exception as e:
- return {"status": "failed", "error": str(e)}
- return {"status": "running"}
-
- # Also register a direct endpoint for simple calls
- @app.post(path)
- async def direct_call(background_tasks: BackgroundTasks, params: Optional[RequestModel] = None):
- kwargs = params.dict() if params else {}
- task_id = str(uuid.uuid4())
- task = asyncio.create_task(func(**kwargs))
- _running_tasks[task_id] = task
- return {"task_id": task_id}
-
- logger.info(f"Successfully registered all endpoints for {path}")
-
- @wraps(func)
- async def wrapper(*args, **kwargs):
- return await func(*args, **kwargs)
- return wrapper
- return decorator
-
-def find_project_root() -> str:
- """Find the project root by looking for pyproject.toml."""
- current = os.path.abspath(os.path.dirname(__file__))
- while current != '/':
- if os.path.exists(os.path.join(current, 'pyproject.toml')):
- return current
- current = os.path.dirname(current)
- return os.path.dirname(__file__) # Fallback to current directory
-
-def discover_tasks():
- """
- Discover all decorated functions in the project.
- Scans from the project root (where pyproject.toml is located).
- """
- logger.info("Starting task discovery")
-
- # Find project root
- project_root = find_project_root()
- logger.info(f"Project root found at: {project_root}")
-
- # Add project root to Python path if not already there
- if project_root not in sys.path:
- sys.path.insert(0, project_root)
-
- # Walk through all Python files in the project
- for root, _, files in os.walk(project_root):
- for file in files:
- if file.endswith('.py'):
- module_path = os.path.join(root, file)
- module_name = os.path.relpath(module_path, project_root)
- module_name = os.path.splitext(module_name)[0].replace(os.path.sep, '.')
-
- try:
- logger.info(f"Attempting to import module: {module_name}")
- module = importlib.import_module(module_name)
-
- # Inspect all module members
- for name, obj in inspect.getmembers(module):
- if inspect.isfunction(obj):
- # Check if this function has been decorated with @task
- if any(
- path for path, func in _decorated_functions.items()
- if func.__name__ == obj.__name__ and func.__module__ == obj.__module__
- ):
- logger.info(f"Found already registered function: {obj.__name__}")
- continue
-
- # Check for our decorator in the function's closure
- if hasattr(obj, '__closure__') and obj.__closure__:
- for cell in obj.__closure__:
- if cell.cell_contents in _decorated_functions.values():
- # Found a decorated function that wasn't registered
- path = next(
- p for p, f in _decorated_functions.items()
- if f == cell.cell_contents
- )
- _decorated_functions[path] = obj
- logger.info(f"Registered function: {obj.__name__} at path: {path}")
-
- except Exception as e:
- logger.error(f"Error importing {module_name}: {e}")
-
- logger.info(f"Task discovery complete. Registered paths: {list(_decorated_functions.keys())}")
-
-# Auto-discover tasks on import
-discover_tasks()
\ No newline at end of file
diff --git a/agentpress/cli.py b/agentpress/cli.py
index 3300cb5..471fa7f 100644
--- a/agentpress/cli.py
+++ b/agentpress/cli.py
@@ -2,99 +2,44 @@
import shutil
import click
import questionary
-from typing import List, Dict, Optional, Tuple
+from typing import Dict
import time
import pkg_resources
import requests
from packaging import version
-import re
-MODULES = {
- "llm": {
- "required": True,
- "files": ["llm.py"],
- "description": "LLM Interface - Core module for interacting with large language models (OpenAI, Anthropic, 100+ LLMs using the OpenAI Input/Output Format powered by LiteLLM). Handles API calls, response streaming, and model-specific configurations."
- },
- "tool": {
- "required": True,
- "files": [
- "tool.py",
- "tool_registry.py"
- ],
- "description": "Tool System Foundation - Defines the base architecture for creating and managing tools. Includes the tool registry for registering, organizing, and accessing tool functions."
- },
- "processors": {
- "required": True,
- "files": [
- "processor/base_processors.py",
- "processor/llm_response_processor.py",
- "processor/standard/standard_tool_parser.py",
- "processor/standard/standard_tool_executor.py",
- "processor/standard/standard_results_adder.py",
- "processor/xml/xml_tool_parser.py",
- "processor/xml/xml_tool_executor.py",
- "processor/xml/xml_results_adder.py"
- ],
- "description": "Response Processing System - Handles parsing and executing LLM responses, managing tool calls, and processing results. Supports both standard OpenAI-style function calling and XML-based tool execution patterns."
- },
- "thread_management": {
- "required": True,
- "files": [
- "thread_manager.py",
- "thread_viewer_ui.py"
- ],
- "description": "Conversation Management System - Handles message threading, conversation history, and provides a UI for viewing conversation threads. Manages the flow of messages between the user, LLM, and tools."
- },
- "state_management": {
- "required": True,
- "files": ["state_manager.py"],
- "description": "State Persistence System - Provides thread-safe storage and retrieval of conversation state, tool data, and other persistent information. Enables maintaining context across sessions and managing shared state between components."
- },
- "db_connection": {
- "required": True,
- "files": ["db_connection.py"],
- "description": "Database Connection - Provides a connection to a SQLite database for storing and retrieving conversation state, tool data, and other persistent information."
- }
-}
+PACKAGE_NAME = "agentpress"
+PYPI_URL = f"https://pypi.org/pypi/{PACKAGE_NAME}/json"
STARTER_EXAMPLES = {
"simple_web_dev_example_agent": {
"description": "Interactive web development agent with file and terminal manipulation capabilities. Demonstrates both standard and XML-based tool calling patterns.",
"files": {
- "agent.py": "examples/simple_web_dev/agent.py",
- "tools/files_tool.py": "examples/simple_web_dev/tools/files_tool.py",
- "tools/terminal_tool.py": "examples/simple_web_dev/tools/terminal_tool.py",
- ".env.example": "examples/.env.example"
+ "agent.py": "agents/simple_web_dev/agent.py",
+ "tools/files_tool.py": "agents/simple_web_dev/tools/files_tool.py",
+ "tools/terminal_tool.py": "agents/simple_web_dev/tools/terminal_tool.py",
+ ".env.example": "agents/.env.example"
}
}
}
-PACKAGE_NAME = "agentpress"
-PYPI_URL = f"https://pypi.org/pypi/{PACKAGE_NAME}/json"
-
-def check_for_updates() -> Tuple[Optional[str], Optional[str], bool]:
- """
- Check if there's a newer version available on PyPI
- Returns: (current_version, latest_version, update_available)
- """
+def check_for_updates():
+ """Check if there's a newer version available on PyPI"""
try:
current_version = pkg_resources.get_distribution(PACKAGE_NAME).version
response = requests.get(PYPI_URL, timeout=2)
- response.raise_for_status() # Raise exception for bad status codes
+ response.raise_for_status()
latest_version = response.json()["info"]["version"]
- # Compare versions properly using packaging.version
current_ver = version.parse(current_version)
latest_ver = version.parse(latest_version)
return current_version, latest_version, latest_ver > current_ver
except requests.RequestException:
- # Handle network-related errors silently
return None, None, False
except Exception as e:
- # Log other unexpected errors but don't break the CLI
click.echo(f"Warning: Failed to check for updates: {str(e)}", err=True)
return None, None, False
@@ -102,7 +47,6 @@ def show_welcome():
"""Display welcome message with ASCII art"""
click.clear()
- # Check for updates
current_version, latest_version, update_available = check_for_updates()
click.echo("""
@@ -122,16 +66,17 @@ def show_welcome():
time.sleep(1)
-def copy_module_files(src_dir: str, dest_dir: str, files: List[str]):
- """Copy module files from package to destination"""
+def copy_package_files(src_dir: str, dest_dir: str):
+ """Copy all package files except agents folder to destination"""
os.makedirs(dest_dir, exist_ok=True)
- with click.progressbar(files, label='Copying files') as file_list:
- for file in file_list:
- src = os.path.join(src_dir, file)
- dst = os.path.join(dest_dir, file)
- os.makedirs(os.path.dirname(dst), exist_ok=True)
- shutil.copy2(src, dst)
+ def ignore_patterns(path, names):
+ # Ignore the agents directory and any __pycache__ directories
+ return [n for n in names if n == 'agents' or n == '__pycache__']
+
+ with click.progressbar(length=1, label='Copying files') as bar:
+ shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True, ignore=ignore_patterns)
+ bar.update(1)
def copy_example_files(src_dir: str, dest_dir: str, files: Dict[str, str]):
"""Copy example files from package to destination"""
@@ -142,19 +87,6 @@ def copy_example_files(src_dir: str, dest_dir: str, files: Dict[str, str]):
shutil.copy2(src, dst)
click.echo(f" ✓ Created {dest_path}")
-def update_file_paths(file_path: str, replacements: Dict[str, str]):
- """Update file paths in the given file"""
- with open(file_path, 'r') as f:
- content = f.read()
-
- for old, new in replacements.items():
- # Escape special characters in the old string
- escaped_old = re.escape(old)
- content = re.sub(escaped_old, new, content)
-
- with open(file_path, 'w') as f:
- f.write(content)
-
@click.group()
def cli():
"""AgentPress CLI - Initialize your AgentPress modules"""
@@ -165,7 +97,6 @@ def init():
"""Initialize AgentPress modules in your project"""
show_welcome()
- # Set components directory name to 'agentpress'
components_dir = "agentpress"
if os.path.exists(components_dir):
@@ -195,41 +126,20 @@ def init():
# Get package directory
package_dir = os.path.dirname(os.path.abspath(__file__))
- # Show all modules status
- click.echo("\n🔧 AgentPress Modules Configuration\n")
-
- # Show required modules including state_manager
- click.echo("📦 Required Modules (pre-selected):")
- required_modules = {name: module for name, module in MODULES.items()
- if module["required"] or name == "state_management"}
- for name, module in required_modules.items():
- click.echo(f" ✓ {click.style(name, fg='green')} - {module['description']}")
-
- # Create selections dict with required modules pre-selected
- selections = {name: True for name in required_modules.keys()}
-
click.echo("\n🚀 Setting up your AgentPress...")
time.sleep(0.5)
try:
- # Copy selected modules
- selected_modules = [name for name, selected in selections.items() if selected]
- all_files = []
- for module in selected_modules:
- all_files.extend(MODULES[module]["files"])
-
- # Create components directory and copy module files
+ # Create components directory and copy all files except agents folder
components_dir_path = os.path.abspath(components_dir)
- copy_module_files(package_dir, components_dir_path, all_files)
-
+ copy_package_files(package_dir, components_dir_path)
-
- # Copy example only if a valid example (not None) was selected
+ # Copy example if selected
if selected_example and selected_example in STARTER_EXAMPLES:
click.echo(f"\n📝 Creating {selected_example}...")
copy_example_files(
package_dir,
- os.getcwd(), # Use current working directory
+ os.getcwd(),
STARTER_EXAMPLES[selected_example]["files"]
)
@@ -246,7 +156,6 @@ def init():
click.echo(f"\nRun the example agent:")
click.echo(" python agent.py")
-
except Exception as e:
click.echo(f"\n❌ Error during setup: {str(e)}", err=True)
return
diff --git a/agentpress/db_connection.py b/agentpress/db_connection.py
index 8802984..cc52f1c 100644
--- a/agentpress/db_connection.py
+++ b/agentpress/db_connection.py
@@ -7,105 +7,122 @@
from contextlib import asynccontextmanager
import os
import asyncio
-import json
class DBConnection:
"""Singleton database connection manager."""
+
_instance = None
_initialized = False
- _db_path = "ap.db"
+ _db_path = "/Users/markokraemer/Projects/softgen/softgen-core/main.db"
+ _init_lock = asyncio.Lock()
+ _initialization_task = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
+ # Start initialization when instance is first created
+ cls._initialization_task = asyncio.create_task(cls._instance._initialize())
return cls._instance
- async def initialize(self):
- """Initialize the database connection and schema."""
- if self._initialized:
- return
+ def __init__(self):
+ """No initialization needed in __init__ as it's handled in __new__"""
+ pass
- try:
- # Ensure the database directory exists
- os.makedirs(os.path.dirname(os.path.abspath(self._db_path)), exist_ok=True)
+ @classmethod
+ async def _initialize(cls):
+ """Internal initialization method."""
+ if cls._initialized:
+ return
- # Initialize database and create schema
- async with aiosqlite.connect(self._db_path) as db:
- await db.execute("PRAGMA foreign_keys = ON")
+ async with cls._init_lock:
+ if cls._initialized: # Double-check after acquiring lock
+ return
- # Create threads table
- await db.execute("""
- CREATE TABLE IF NOT EXISTS threads (
- id TEXT PRIMARY KEY,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- )
- """)
+ try:
+ async with aiosqlite.connect(cls._db_path) as db:
+ # Threads table
+ await db.execute("""
+ CREATE TABLE IF NOT EXISTS threads (
+ id TEXT PRIMARY KEY,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+
+ # Messages table
+ await db.execute("""
+ CREATE TABLE IF NOT EXISTS messages (
+ id TEXT PRIMARY KEY,
+ thread_id TEXT,
+ type TEXT,
+ content TEXT,
+ include_in_llm_message_history BOOLEAN DEFAULT TRUE,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (thread_id) REFERENCES threads (id)
+ )
+ """)
+
+ await db.commit()
+ cls._initialized = True
+ logging.info("Database schema initialized")
+ except Exception as e:
+ logging.error(f"Database initialization error: {e}")
+ raise
- # Create events table
- await db.execute("""
- CREATE TABLE IF NOT EXISTS events (
- id TEXT PRIMARY KEY,
- thread_id TEXT,
- type TEXT,
- content TEXT,
- include_in_llm_message_history INTEGER DEFAULT 0,
- llm_message TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE
- )
- """)
+ @classmethod
+ def set_db_path(cls, db_path: str):
+ """Set custom database path."""
+ if cls._initialized:
+ raise RuntimeError("Cannot change database path after initialization")
+ cls._db_path = db_path
+ logging.info(f"Updated database path to: {db_path}")
- await db.commit()
- logging.info("Database initialized successfully")
- self._initialized = True
-
- except Exception as e:
- logging.error(f"Failed to initialize database: {e}")
- raise
+ @asynccontextmanager
+ async def connection(self):
+ """Get a database connection."""
+ # Wait for initialization to complete if it hasn't already
+ if self._initialization_task and not self._initialized:
+ await self._initialization_task
+
+ async with aiosqlite.connect(self._db_path) as conn:
+ try:
+ yield conn
+ except Exception as e:
+ logging.error(f"Database error: {e}")
+ raise
@asynccontextmanager
async def transaction(self):
- """Get a database connection with transaction support."""
- if not self._initialized:
- raise Exception("Database not initialized. Call initialize() first.")
-
- async with aiosqlite.connect(self._db_path) as db:
- await db.execute("PRAGMA foreign_keys = ON")
+ """Execute operations in a transaction."""
+ async with self.connection() as db:
try:
yield db
await db.commit()
- logging.debug("Transaction committed successfully")
except Exception as e:
await db.rollback()
- logging.error(f"Transaction failed, rolling back: {e}")
+ logging.error(f"Transaction error: {e}")
raise
async def execute(self, query: str, params: tuple = ()):
- """Execute a query and return the cursor."""
- async with aiosqlite.connect(self._db_path) as db:
- await db.execute("PRAGMA foreign_keys = ON")
- return await db.execute(query, params)
-
- async def fetch_all(self, query: str, params: tuple = ()):
- """Execute a query and fetch all results."""
- async with aiosqlite.connect(self._db_path) as db:
- await db.execute("PRAGMA foreign_keys = ON")
- cursor = await db.execute(query, params)
- return await cursor.fetchall()
+ """Execute a single query."""
+ async with self.connection() as db:
+ try:
+ result = await db.execute(query, params)
+ await db.commit()
+ return result
+ except Exception as e:
+ logging.error(f"Query execution error: {e}")
+ raise
async def fetch_one(self, query: str, params: tuple = ()):
- """Execute a query and fetch one result."""
- async with aiosqlite.connect(self._db_path) as db:
- await db.execute("PRAGMA foreign_keys = ON")
- cursor = await db.execute(query, params)
- return await cursor.fetchone()
+ """Fetch a single row."""
+ async with self.connection() as db:
+ async with db.execute(query, params) as cursor:
+ return await cursor.fetchone()
- def _serialize_json(self, data):
- """Serialize data to JSON string."""
- return json.dumps(data) if data is not None else None
-
- def _deserialize_json(self, data):
- """Deserialize JSON string to data."""
- return json.loads(data) if data is not None else None
\ No newline at end of file
+ async def fetch_all(self, query: str, params: tuple = ()):
+ """Fetch all rows."""
+ async with self.connection() as db:
+ async with db.execute(query, params) as cursor:
+ return await cursor.fetchall()
\ No newline at end of file
diff --git a/agentpress/llm.py b/agentpress/llm.py
index 0ad4432..57194fd 100644
--- a/agentpress/llm.py
+++ b/agentpress/llm.py
@@ -1,4 +1,4 @@
-from typing import Union, Dict, Any
+from typing import Union, Dict, Any, Optional, List
import litellm
import os
import json
@@ -11,6 +11,14 @@
ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY', None)
GROQ_API_KEY = os.environ.get('GROQ_API_KEY', None)
AGENTOPS_API_KEY = os.environ.get('AGENTOPS_API_KEY', None)
+FIREWORKS_API_KEY = os.environ.get('FIREWORKS_AI_API_KEY', None)
+DEEPSEEK_API_KEY = os.environ.get('DEEPSEEK_API_KEY', None)
+OPENROUTER_API_KEY = os.environ.get('OPENROUTER_API_KEY', None)
+GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY', None)
+
+AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID', None)
+AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY', None)
+AWS_REGION_NAME = os.environ.get('AWS_REGION_NAME', None)
if OPENAI_API_KEY:
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
@@ -18,6 +26,22 @@
os.environ['ANTHROPIC_API_KEY'] = ANTHROPIC_API_KEY
if GROQ_API_KEY:
os.environ['GROQ_API_KEY'] = GROQ_API_KEY
+if FIREWORKS_API_KEY:
+ os.environ['FIREWORKS_AI_API_KEY'] = FIREWORKS_API_KEY
+if DEEPSEEK_API_KEY:
+ os.environ['DEEPSEEK_API_KEY'] = DEEPSEEK_API_KEY
+if OPENROUTER_API_KEY:
+ os.environ['OPENROUTER_API_KEY'] = OPENROUTER_API_KEY
+if GEMINI_API_KEY:
+ os.environ['GEMINI_API_KEY'] = GEMINI_API_KEY
+
+# Add AWS environment variables if they exist
+if AWS_ACCESS_KEY_ID:
+ os.environ['AWS_ACCESS_KEY_ID'] = AWS_ACCESS_KEY_ID
+if AWS_SECRET_ACCESS_KEY:
+ os.environ['AWS_SECRET_ACCESS_KEY'] = AWS_SECRET_ACCESS_KEY
+if AWS_REGION_NAME:
+ os.environ['AWS_REGION_NAME'] = AWS_REGION_NAME
async def make_llm_api_call(
messages: list,
@@ -31,7 +55,8 @@ async def make_llm_api_call(
api_base: str = None,
agentops_session: Any = None,
stream: bool = False,
- top_p: float = None
+ top_p: float = None,
+ stop: Optional[Union[str, List[str]]] = None # Add stop parameter
) -> Union[Dict[str, Any], Any]:
"""
Make an API call to a language model using litellm.
@@ -52,6 +77,7 @@ async def make_llm_api_call(
agentops_session (Any, optional): Session for agentops integration
stream (bool, optional): Whether to stream the response. Defaults to False
top_p (float, optional): Top-p sampling parameter
+ stop (Union[str, List[str]], optional): Up to 4 sequences where the API will stop generating tokens
Returns:
Union[Dict[str, Any], Any]: API response, either complete or streaming
@@ -59,7 +85,7 @@ async def make_llm_api_call(
Raises:
Exception: If API call fails after retries
"""
- # litellm.set_verbose = True
+ litellm.set_verbose = False
async def attempt_api_call(api_call_func, max_attempts=3):
"""
@@ -75,10 +101,17 @@ async def attempt_api_call(api_call_func, max_attempts=3):
Raises:
Exception: If all retry attempts fail
"""
+ nonlocal model_name # Add this to access model_name
for attempt in range(max_attempts):
try:
return await api_call_func()
except litellm.exceptions.RateLimitError as e:
+ # Check if it's Bedrock Claude and switch to direct Anthropic
+ if "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0" in model_name:
+ logging.info("Rate limit hit with Bedrock Claude, falling back to direct Anthropic API...")
+ model_name = "anthropic/claude-3-5-sonnet-latest"
+ continue
+
logging.warning(f"Rate limit exceeded. Waiting for 30 seconds before retrying...")
await asyncio.sleep(30)
except OpenAIError as e:
@@ -105,6 +138,10 @@ async def api_call():
"stream": stream,
}
+ # Add stop parameter if provided
+ if stop is not None:
+ api_call_params["stop"] = stop
+
# Add optional parameters if provided
if api_key:
api_call_params["api_key"] = api_key
@@ -129,6 +166,32 @@ async def api_call():
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
}
+ # Add OpenRouter specific parameters
+ if "openrouter" in model_name.lower():
+ if settings.or_site_url:
+ api_call_params["headers"] = {
+ "HTTP-Referer": settings.or_site_url
+ }
+ if settings.or_app_name:
+ api_call_params["headers"] = {
+ "X-Title": settings.or_app_name
+ }
+
+ # Add special handling for Deepseek
+ if "deepseek" in model_name.lower():
+ api_call_params["frequency_penalty"] = 0.5
+ api_call_params["temperature"] = 0.7
+ api_call_params["presence_penalty"] = 0.1
+
+ # Add Bedrock-specific parameters
+ if "bedrock" in model_name.lower():
+ if settings.aws_access_key_id:
+ api_call_params["aws_access_key_id"] = settings.aws_access_key_id
+ if settings.aws_secret_access_key:
+ api_call_params["aws_secret_access_key"] = settings.aws_secret_access_key
+ if settings.aws_region_name:
+ api_call_params["aws_region_name"] = settings.aws_region_name
+
# Log the API request
# logging.info(f"Sending API request: {json.dumps(api_call_params, indent=2)}")
@@ -137,10 +200,36 @@ async def api_call():
response = await agentops_session.patch(litellm.acompletion)(**api_call_params)
else:
response = await litellm.acompletion(**api_call_params)
-
- # Log the API response
+
# logging.info(f"Received API response: {response}")
+ # # For streaming responses, attach cost tracking
+ # if stream:
+ # # Create a wrapper object to track costs across chunks
+ # cost_tracker = {
+ # "prompt_tokens": 0,
+ # "completion_tokens": 0,
+ # "total_tokens": 0,
+ # "cost": 0.0
+ # }
+
+ # # Get the cost per token for the model
+ # model_cost = litellm.model_cost.get(model_name, {})
+ # input_cost = model_cost.get('input_cost_per_token', 0)
+ # output_cost = model_cost.get('output_cost_per_token', 0)
+
+ # # Attach the cost tracker to the response
+ # response.cost_tracker = cost_tracker
+ # response.model_info = {
+ # "input_cost_per_token": input_cost,
+ # "output_cost_per_token": output_cost
+ # }
+ # else:
+ # # For non-streaming, cost is already included in the response
+ # response._hidden_params = {
+ # "response_cost": litellm.completion_cost(completion_response=response)
+ # }
+
return response
return await attempt_api_call(api_call)
@@ -188,4 +277,37 @@ async def test_llm_api_call(stream=True):
print(response.choices[0].message.content)
print()
- asyncio.run(test_llm_api_call())
+ # asyncio.run(test_llm_api_call())
+
+ async def test_bedrock():
+ """
+ Test function for Bedrock API call.
+ """
+ messages = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello from Bedrock!"}
+ ]
+ model_name = "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0"
+
+ response = await make_llm_api_call(messages, model_name, stream=True)
+
+ print("\n🤖 Streaming response from Bedrock:\n")
+ buffer = ""
+ async for chunk in response:
+ if isinstance(chunk, dict) and 'choices' in chunk:
+ content = chunk['choices'][0]['delta'].get('content', '')
+ else:
+ content = chunk.choices[0].delta.content
+
+ if content:
+ buffer += content
+ if content[-1].isspace():
+ print(buffer, end='', flush=True)
+ buffer = ""
+
+ if buffer:
+ print(buffer, flush=True)
+ print("\n✨ Stream completed.\n")
+
+ # Add test_bedrock to the test runs
+ # asyncio.run(test_bedrock())
diff --git a/agentpress/processor/__init__.py b/agentpress/processor/__init__.py
new file mode 100644
index 0000000..3ebeb22
--- /dev/null
+++ b/agentpress/processor/__init__.py
@@ -0,0 +1 @@
+# Empty file to mark as package
\ No newline at end of file
diff --git a/agentpress/processor/base_processors.py b/agentpress/processor/base_processors.py
index ec72418..52e4943 100644
--- a/agentpress/processor/base_processors.py
+++ b/agentpress/processor/base_processors.py
@@ -172,7 +172,7 @@ class ResultsAdderBase(ABC):
Attributes:
add_message: Callback for adding new messages
update_message: Callback for updating existing messages
- get_messages: Callback for retrieving thread messages
+ get_llm_history_messages: Callback for retrieving thread messages
message_added: Flag tracking if initial message has been added
"""
@@ -184,7 +184,7 @@ def __init__(self, thread_manager):
"""
self.add_message = thread_manager.add_message
self.update_message = thread_manager._update_message
- self.get_messages = thread_manager.get_messages
+ self.get_llm_history_messages = thread_manager.get_llm_history_messages
self.message_added = False
@abstractmethod
diff --git a/agentpress/processor/llm_response_processor.py b/agentpress/processor/llm_response_processor.py
index 2457221..5ecb19e 100644
--- a/agentpress/processor/llm_response_processor.py
+++ b/agentpress/processor/llm_response_processor.py
@@ -9,12 +9,9 @@
"""
import asyncio
-from typing import Callable, Dict, Any, AsyncGenerator, Optional
+from typing import Dict, Any, AsyncGenerator
import logging
from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
-from agentpress.processor.standard.standard_tool_parser import StandardToolParser
-from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor
-from agentpress.processor.standard.standard_results_adder import StandardResultsAdder
class LLMResponseProcessor:
"""Handles LLM response processing and tool execution management.
@@ -37,51 +34,30 @@ class LLMResponseProcessor:
def __init__(
self,
thread_id: str,
- available_functions: Dict = None,
- add_message_callback: Callable = None,
- update_message_callback: Callable = None,
- get_messages_callback: Callable = None,
- parallel_tool_execution: bool = True,
- tool_parser: Optional[ToolParserBase] = None,
- tool_executor: Optional[ToolExecutorBase] = None,
- results_adder: Optional[ResultsAdderBase] = None,
- thread_manager = None
+ tool_executor: ToolExecutorBase,
+ tool_parser: ToolParserBase,
+ available_functions: Dict,
+ results_adder: ResultsAdderBase
):
"""Initialize the response processor.
Args:
thread_id: ID of the conversation thread
- available_functions: Dictionary of available tool functions
- add_message_callback: Callback for adding messages
- update_message_callback: Callback for updating messages
- get_messages_callback: Callback for listing messages
- parallel_tool_execution: Whether to execute tools in parallel
- tool_parser: Custom tool parser implementation
tool_executor: Custom tool executor implementation
+ tool_parser: Custom tool parser implementation
+ available_functions: Dictionary of available tool functions
results_adder: Custom results adder implementation
- thread_manager: Optional thread manager instance
"""
self.thread_id = thread_id
- self.tool_executor = tool_executor or StandardToolExecutor(parallel=parallel_tool_execution)
- self.tool_parser = tool_parser or StandardToolParser()
- self.available_functions = available_functions or {}
-
- # Create minimal thread manager if needed
- if thread_manager is None and (add_message_callback and update_message_callback and get_messages_callback):
- class MinimalThreadManager:
- def __init__(self, add_msg, update_msg, list_msg):
- self.add_message = add_msg
- self._update_message = update_msg
- self.get_messages = list_msg
- thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, get_messages_callback)
-
- self.results_adder = results_adder or StandardResultsAdder(thread_manager)
-
- # State tracking for streaming
- self.tool_calls_buffer = {}
- self.processed_tool_calls = set()
+ self.tool_executor = tool_executor
+ self.tool_parser = tool_parser
+ self.available_functions = available_functions
+ self.results_adder = results_adder
self.content_buffer = ""
+ self.tool_calls_buffer = {}
self.tool_calls_accumulated = []
+ self.processed_tool_calls = set()
+ self._executing_tools = set() # Track currently executing tools
async def process_stream(
self,
@@ -92,8 +68,9 @@ async def process_stream(
"""Process streaming LLM response and handle tool execution."""
pending_tool_calls = []
background_tasks = set()
+ stream_completed = False # New flag to track stream completion
- async def handle_message_management(chunk):
+ async def handle_message_management(chunk, is_final=False):
try:
# Accumulate content
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
@@ -106,28 +83,37 @@ async def handle_message_management(chunk):
self.tool_calls_buffer
)
if parsed_message and 'tool_calls' in parsed_message:
- self.tool_calls_accumulated = parsed_message['tool_calls']
+ new_tool_calls = [
+ tool_call for tool_call in parsed_message['tool_calls']
+ if tool_call['id'] not in self.processed_tool_calls
+ ]
+ if new_tool_calls:
+ self.tool_calls_accumulated.extend(new_tool_calls)
# Handle tool execution and results
if execute_tools and self.tool_calls_accumulated:
new_tool_calls = [
tool_call for tool_call in self.tool_calls_accumulated
- if tool_call['id'] not in self.processed_tool_calls
+ if (tool_call['id'] not in self.processed_tool_calls and
+ tool_call['id'] not in self._executing_tools)
]
if new_tool_calls:
if execute_tools_on_stream:
+ for tool_call in new_tool_calls:
+ self._executing_tools.add(tool_call['id'])
+
results = await self.tool_executor.execute_tool_calls(
tool_calls=new_tool_calls,
available_functions=self.available_functions,
thread_id=self.thread_id,
executed_tool_calls=self.processed_tool_calls
)
+
for result in results:
await self.results_adder.add_tool_result(self.thread_id, result)
self.processed_tool_calls.add(result['tool_call_id'])
- else:
- pending_tool_calls.extend(new_tool_calls)
+ self._executing_tools.discard(result['tool_call_id'])
# Add/update assistant message
message = {
@@ -152,7 +138,10 @@ async def handle_message_management(chunk):
)
# Handle stream completion
- if chunk.choices[0].finish_reason:
+ if chunk.choices[0].finish_reason or is_final:
+ nonlocal stream_completed
+ stream_completed = True
+
if not execute_tools_on_stream and pending_tool_calls:
results = await self.tool_executor.execute_tool_calls(
tool_calls=pending_tool_calls,
@@ -165,8 +154,16 @@ async def handle_message_management(chunk):
self.processed_tool_calls.add(result['tool_call_id'])
pending_tool_calls.clear()
+ # Set final state on the chunk instead of returning it
+ chunk._final_state = {
+ "content": self.content_buffer,
+ "tool_calls": self.tool_calls_accumulated,
+ "processed_tool_calls": list(self.processed_tool_calls)
+ }
+
except Exception as e:
logging.error(f"Error in background task: {e}")
+ raise
try:
async for chunk in response_stream:
@@ -175,9 +172,22 @@ async def handle_message_management(chunk):
task.add_done_callback(background_tasks.discard)
yield chunk
+ # Create a final dummy chunk to handle completion
+ final_chunk = type('DummyChunk', (), {
+ 'choices': [type('DummyChoice', (), {
+ 'delta': type('DummyDelta', (), {'content': None}),
+ 'finish_reason': 'stop'
+ })]
+ })()
+
+ # Process final state
+ await handle_message_management(final_chunk, is_final=True)
+ yield final_chunk
+
+ # Wait for all background tasks to complete
if background_tasks:
await asyncio.gather(*background_tasks, return_exceptions=True)
-
+
except Exception as e:
logging.error(f"Error in stream processing: {e}")
for task in background_tasks:
diff --git a/agentpress/processor/standard/__init__.py b/agentpress/processor/standard/__init__.py
new file mode 100644
index 0000000..3ebeb22
--- /dev/null
+++ b/agentpress/processor/standard/__init__.py
@@ -0,0 +1 @@
+# Empty file to mark as package
\ No newline at end of file
diff --git a/agentpress/processor/standard/standard_results_adder.py b/agentpress/processor/standard/standard_results_adder.py
index 49d08b6..810091e 100644
--- a/agentpress/processor/standard/standard_results_adder.py
+++ b/agentpress/processor/standard/standard_results_adder.py
@@ -81,6 +81,6 @@ async def add_tool_result(self, thread_id: str, result: Dict[str, Any]):
- Checks for duplicate tool results before adding
- Adds result only if tool_call_id is unique
"""
- messages = await self.get_messages(thread_id)
+ messages = await self.get_llm_history_messages(thread_id)
if not any(msg.get('tool_call_id') == result['tool_call_id'] for msg in messages):
await self.add_message(thread_id, result)
diff --git a/agentpress/processor/xml/__init__.py b/agentpress/processor/xml/__init__.py
new file mode 100644
index 0000000..3ebeb22
--- /dev/null
+++ b/agentpress/processor/xml/__init__.py
@@ -0,0 +1 @@
+# Empty file to mark as package
\ No newline at end of file
diff --git a/agentpress/processor/xml/xml_results_adder.py b/agentpress/processor/xml/xml_results_adder.py
index 49593da..4b8123d 100644
--- a/agentpress/processor/xml/xml_results_adder.py
+++ b/agentpress/processor/xml/xml_results_adder.py
@@ -79,7 +79,7 @@ async def add_tool_result(self, thread_id: str, result: Dict[str, Any]):
"""
try:
# Get the original tool call to find the root tag
- messages = await self.get_messages(thread_id)
+ messages = await self.get_llm_history_messages(thread_id)
assistant_msg = next((msg for msg in reversed(messages)
if msg['role'] == 'assistant'), None)
@@ -107,10 +107,9 @@ async def add_tool_result(self, thread_id: str, result: Dict[str, Any]):
await self.add_message(thread_id, result_message)
except Exception as e:
- logging.error(f"Error adding tool result: {e}")
- # Ensure the result is still added even if there's an error
+ logging.error(f"Error adding tool result: {e}") # Ensure the result is still added even if there's an error
result_message = {
"role": "user",
"content": f"Result for {result['name']}:\n{result['content']}"
}
- await self.add_message(thread_id, result_message)
\ No newline at end of file
+ await self.add_message(thread_id, result_message)
diff --git a/agentpress/processor/xml/xml_tool_executor.py b/agentpress/processor/xml/xml_tool_executor.py
index 185e8e6..10bcb68 100644
--- a/agentpress/processor/xml/xml_tool_executor.py
+++ b/agentpress/processor/xml/xml_tool_executor.py
@@ -38,6 +38,8 @@ def __init__(self, parallel: bool = True, tool_registry: Optional[ToolRegistry]
"""
self.parallel = parallel
self.tool_registry = tool_registry or ToolRegistry()
+ # Add internal tracking of executed tools
+ self._executed_tools = set()
async def execute_tool_calls(
self,
@@ -65,20 +67,28 @@ async def execute_tool_calls(
if executed_tool_calls is None:
executed_tool_calls = set()
+ # Filter out already executed tool calls
+ new_tool_calls = [
+ tool_call for tool_call in tool_calls
+ if tool_call['id'] not in executed_tool_calls
+ and tool_call['id'] not in self._executed_tools
+ ]
+
+ if not new_tool_calls:
+ return []
+
if self.parallel:
- return await self._execute_parallel(
- tool_calls,
- available_functions,
- thread_id,
- executed_tool_calls
- )
+ results = await self._execute_parallel(new_tool_calls, available_functions, thread_id, executed_tool_calls)
else:
- return await self._execute_sequential(
- tool_calls,
- available_functions,
- thread_id,
- executed_tool_calls
- )
+ results = await self._execute_sequential(new_tool_calls, available_functions, thread_id, executed_tool_calls)
+
+ # Track executed tools internally
+ for tool_call in new_tool_calls:
+ self._executed_tools.add(tool_call['id'])
+ if executed_tool_calls is not None:
+ executed_tool_calls.add(tool_call['id'])
+
+ return results
async def _execute_parallel(
self,
@@ -87,9 +97,10 @@ async def _execute_parallel(
thread_id: str,
executed_tool_calls: Set[str]
) -> List[Dict[str, Any]]:
- async def execute_single_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
- if tool_call['id'] in executed_tool_calls:
- logging.info(f"Tool call {tool_call['id']} already executed")
+ async def execute_single_tool(tool_call: Dict[str, Any]) -> Optional[Dict[str, Any]]:
+ # Double-check the tool hasn't been executed
+ if (tool_call['id'] in executed_tool_calls or
+ tool_call['id'] in self._executed_tools):
return None
try:
diff --git a/agentpress/state_manager.py b/agentpress/state_manager.py
index 01bf1d4..4281e43 100644
--- a/agentpress/state_manager.py
+++ b/agentpress/state_manager.py
@@ -1,207 +1,118 @@
-"""
-Manages persistent state storage for AgentPress components using thread-based events.
-
-The StateManager provides thread-safe access to state data stored as events in threads,
-allowing components to save and retrieve data across sessions. Each state update
-creates a new event containing the complete state.
-"""
-
import json
import logging
from typing import Any, Optional, List, Dict
-from asyncio import Lock
+import uuid
from agentpress.thread_manager import ThreadManager
class StateManager:
"""
- Manages persistent state storage for AgentPress components using thread events.
-
- The StateManager provides thread-safe access to state data stored as events,
- maintaining the complete state in each event for better consistency and tracking.
-
- Attributes:
- lock (Lock): Asyncio lock for thread-safe state access
- thread_id (str): Thread ID for state storage
- thread_manager (ThreadManager): Thread manager instance for event handling
+ Manages state storage using thread messages.
+ Each state message contains a complete snapshot of the state at that point in time.
"""
def __init__(self, thread_id: str):
- """
- Initialize StateManager with thread ID.
-
- Args:
- thread_id (str): Thread ID for state storage
- """
- self.lock = Lock()
- self.thread_id = thread_id
+ """Initialize StateManager with a thread ID."""
self.thread_manager = ThreadManager()
- logging.info(f"StateManager initialized with thread_id: {self.thread_id}")
-
- async def initialize(self):
- """Initialize the thread manager."""
- await self.thread_manager.initialize()
+ self.thread_id = thread_id
+ self._state_cache = None
+ logging.info(f"StateManager initialized for thread: {thread_id}")
- async def _ensure_initialized(self):
- """Ensure thread manager is initialized."""
- if not self.thread_manager.db._initialized:
- await self.initialize()
+ async def _get_state(self) -> Dict[str, Any]:
+ """Get the current complete state."""
+ if self._state_cache is not None:
+ return self._state_cache.copy() # Return copy to prevent cache mutation
- async def _get_current_state(self) -> dict:
- """Get the current state from the latest state event."""
- await self._ensure_initialized()
- events = await self.thread_manager.get_thread_events(
- thread_id=self.thread_id,
- event_types=["state"],
- order_by="created_at",
- order="DESC"
+ # Get the latest state message
+ rows = await self.thread_manager.db.fetch_all(
+ """
+ SELECT content
+ FROM messages
+ WHERE thread_id = ? AND type = 'state_message'
+ ORDER BY created_at DESC LIMIT 1
+ """,
+ (self.thread_id,)
)
- if events:
- return events[0]["content"].get("state", {})
+
+ if rows:
+ try:
+ self._state_cache = json.loads(rows[0][0])
+ return self._state_cache.copy()
+ except json.JSONDecodeError:
+ logging.error("Failed to parse state JSON")
+
return {}
- async def _save_state(self, state: dict):
- """Save the complete state as a new event."""
- await self._ensure_initialized()
- await self.thread_manager.create_event(
+ async def _save_state(self, state: Dict[str, Any]):
+ """Save a new complete state snapshot."""
+ # Format state as a string with proper indentation
+ formatted_state = json.dumps(state, indent=2)
+
+ # Save new state message with complete snapshot
+ await self.thread_manager.add_message(
thread_id=self.thread_id,
- event_type="state",
- content={"state": state}
+ message_data=formatted_state,
+ message_type='state_message',
+ include_in_llm_message_history=False
)
+
+ # Update cache with a copy
+ self._state_cache = state.copy()
async def set(self, key: str, data: Any) -> Any:
- """
- Store data with a key in the state.
-
- Args:
- key (str): Simple string key like "config" or "settings"
- data (Any): Any JSON-serializable data
-
- Returns:
- Any: The stored data
- """
- async with self.lock:
- try:
- current_state = await self._get_current_state()
- current_state[key] = data
- await self._save_state(current_state)
- logging.info(f'Updated state key: {key}')
- return data
- except Exception as e:
- logging.error(f"Error setting state: {e}")
- raise
+ """Store any JSON-serializable data with a key."""
+ state = await self._get_state()
+ state[key] = data
+ await self._save_state(state)
+ logging.info(f'Updated state key: {key}')
+ return data
async def get(self, key: str) -> Optional[Any]:
- """
- Get data for a key from the current state.
-
- Args:
- key (str): Simple string key like "config" or "settings"
-
- Returns:
- Any: The stored data for the key, or None if key not found
- """
- async with self.lock:
- try:
- current_state = await self._get_current_state()
- if key in current_state:
- logging.info(f'Retrieved key: {key}')
- return current_state[key]
- logging.info(f'Key not found: {key}')
- return None
- except Exception as e:
- logging.error(f"Error getting state: {e}")
- raise
+ """Get data for a key."""
+ state = await self._get_state()
+ if key in state:
+ data = state[key]
+ logging.info(f'Retrieved key: {key}')
+ return data
+ logging.info(f'Key not found: {key}')
+ return None
async def delete(self, key: str):
- """
- Delete a key from the state.
-
- Args:
- key (str): Simple string key like "config" or "settings"
- """
- async with self.lock:
- try:
- current_state = await self._get_current_state()
- if key in current_state:
- del current_state[key]
- await self._save_state(current_state)
- logging.info(f"Deleted key: {key}")
- except Exception as e:
- logging.error(f"Error deleting state: {e}")
- raise
+ """Delete data for a key."""
+ state = await self._get_state()
+ if key in state:
+ del state[key]
+ await self._save_state(state)
+ logging.info(f"Deleted key: {key}")
async def update(self, key: str, data: Dict[str, Any]) -> Optional[Any]:
- """
- Update existing dictionary data for a key by merging.
-
- Args:
- key (str): Simple string key like "config" or "settings"
- data (Dict[str, Any]): Dictionary of updates to merge
-
- Returns:
- Optional[Any]: Updated data if successful, None if key not found
- """
- async with self.lock:
- try:
- current_state = await self._get_current_state()
- if key in current_state and isinstance(current_state[key], dict):
- current_state[key].update(data)
- await self._save_state(current_state)
- return current_state[key]
- return None
- except Exception as e:
- logging.error(f"Error updating state: {e}")
- raise
+ """Update existing data for a key by merging dictionaries."""
+ state = await self._get_state()
+ if key in state and isinstance(state[key], dict):
+ state[key].update(data)
+ await self._save_state(state)
+ logging.info(f'Updated state key: {key}')
+ return state[key]
+ return None
async def append(self, key: str, item: Any) -> Optional[List[Any]]:
- """
- Append an item to a list stored at key.
-
- Args:
- key (str): Simple string key like "config" or "settings"
- item (Any): Item to append
-
- Returns:
- Optional[List[Any]]: Updated list if successful, None if key not found
- """
- async with self.lock:
- try:
- current_state = await self._get_current_state()
- if key not in current_state:
- current_state[key] = []
- if isinstance(current_state[key], list):
- current_state[key].append(item)
- await self._save_state(current_state)
- return current_state[key]
- return None
- except Exception as e:
- logging.error(f"Error appending to state: {e}")
- raise
+ """Append an item to a list stored at key."""
+ state = await self._get_state()
+ if key not in state:
+ state[key] = []
+ if isinstance(state[key], list):
+ state[key].append(item)
+ await self._save_state(state)
+ logging.info(f'Appended to key: {key}')
+ return state[key]
+ return None
- async def get_latest_state(self) -> dict:
- """
- Get the latest complete state.
-
- Returns:
- dict: Complete contents of the latest state
- """
- async with self.lock:
- try:
- state = await self._get_current_state()
- logging.info(f"Retrieved latest state with {len(state)} keys")
- return state
- except Exception as e:
- logging.error(f"Error getting latest state: {e}")
- raise
+ async def export_store(self) -> dict:
+ """Export entire state."""
+ state = await self._get_state()
+ return state
- async def clear_state(self):
- """
- Clear the entire state.
- """
- async with self.lock:
- try:
- await self._save_state({})
- logging.info("Cleared state")
- except Exception as e:
- logging.error(f"Error clearing state: {e}")
- raise
+ async def clear_store(self):
+ """Clear entire state."""
+ await self._save_state({})
+ self._state_cache = {}
+ logging.info("Cleared state")
diff --git a/agentpress/thread_manager.py b/agentpress/thread_manager.py
index 8289b34..921e448 100644
--- a/agentpress/thread_manager.py
+++ b/agentpress/thread_manager.py
@@ -2,7 +2,7 @@
Conversation thread management system for AgentPress.
This module provides comprehensive conversation management, including:
-- Thread and Event CRUD operations
+- Thread creation and persistence
- Message handling with support for text and images
- Tool registration and execution
- LLM interaction with streaming support
@@ -13,251 +13,104 @@
import logging
import asyncio
import uuid
-from datetime import datetime
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator
from agentpress.llm import make_llm_api_call
from agentpress.tool import Tool, ToolResult
from agentpress.tool_registry import ToolRegistry
from agentpress.processor.llm_response_processor import LLMResponseProcessor
-from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
from agentpress.db_connection import DBConnection
-from agentpress.processor.xml.xml_tool_parser import XMLToolParser
-from agentpress.processor.xml.xml_tool_executor import XMLToolExecutor
-from agentpress.processor.xml.xml_results_adder import XMLResultsAdder
from agentpress.processor.standard.standard_tool_parser import StandardToolParser
from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor
from agentpress.processor.standard.standard_results_adder import StandardResultsAdder
+from agentpress.processor.xml.xml_tool_parser import XMLToolParser
+from agentpress.processor.xml.xml_tool_executor import XMLToolExecutor
+from agentpress.processor.xml.xml_results_adder import XMLResultsAdder
class ThreadManager:
- """Manages conversation threads with LLM models and tool execution."""
+ """Manages conversation threads with LLM models and tool execution.
+
+ Provides comprehensive conversation management, handling message threading,
+ tool registration, and LLM interactions with support for both standard and
+ XML-based tool execution patterns.
+ """
def __init__(self):
"""Initialize ThreadManager."""
- self.tool_registry = ToolRegistry()
self.db = DBConnection()
-
- async def initialize(self):
- """Initialize async components."""
- await self.db.initialize()
+ self.tool_registry = ToolRegistry()
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
"""Add a tool to the ThreadManager."""
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
- async def thread_exists(self, thread_id: str) -> bool:
- """Check if a thread exists."""
- await self._ensure_initialized()
- result = await self.db.fetch_one(
- "SELECT 1 FROM threads WHERE id = ?",
- (thread_id,)
- )
- return result is not None
-
- async def _ensure_initialized(self):
- """Ensure database is initialized."""
- if not self.db._initialized:
- await self.initialize()
-
- async def event_belongs_to_thread(self, event_id: str, thread_id: str) -> bool:
- """Check if an event exists and belongs to a thread."""
- await self._ensure_initialized()
- result = await self.db.fetch_one(
- "SELECT 1 FROM events WHERE id = ? AND thread_id = ?",
- (event_id, thread_id)
- )
- return result is not None
-
- # Core Thread Operations
async def create_thread(self) -> str:
"""Create a new conversation thread."""
- await self._ensure_initialized()
thread_id = str(uuid.uuid4())
+ await self.db.execute(
+ "INSERT INTO threads (id) VALUES (?)",
+ (thread_id,)
+ )
+ return thread_id
+
+ async def add_message(
+ self,
+ thread_id: str,
+ message_data: Union[str, Dict[str, Any]],
+ images: Optional[List[Dict[str, Any]]] = None,
+ include_in_llm_message_history: bool = True,
+ message_type: Optional[str] = None
+ ):
+ """Add a message to an existing thread."""
+ logging.info(f"Adding message to thread {thread_id}")
+
try:
- async with self.db.transaction() as conn:
- await conn.execute(
- "INSERT INTO threads (id) VALUES (?)",
- (thread_id,)
- )
- logging.info(f"Created thread: {thread_id}")
- return thread_id
- except Exception as e:
- logging.error(f"Failed to create thread: {e}")
- raise
-
- async def delete_thread(self, thread_id: str) -> bool:
- """Delete a thread and all its events (cascade)."""
- try:
- result = await self.db.execute(
- "DELETE FROM threads WHERE id = ?",
- (thread_id,)
- )
- # Check if any rows were affected
- return result.rowcount > 0
- except Exception as e:
- logging.error(f"Failed to delete thread {thread_id}: {e}")
- return False
-
- # Core Event Operations
- async def create_event(
- self,
- thread_id: str,
- event_type: str,
- content: Dict[str, Any],
- include_in_llm_message_history: bool = False,
- llm_message: Optional[Dict[str, Any]] = None
- ) -> str:
- """Create a new event in a thread."""
- await self._ensure_initialized()
- event_id = str(uuid.uuid4())
- try:
- async with self.db.transaction() as conn:
- # First verify thread exists
- cursor = await conn.execute("SELECT 1 FROM threads WHERE id = ?", (thread_id,))
- if not await cursor.fetchone():
- raise Exception(f"Thread {thread_id} does not exist")
-
- # Then create the event
- await conn.execute(
- """
- INSERT INTO events (
- id, thread_id, type, content,
- include_in_llm_message_history, llm_message
- ) VALUES (?, ?, ?, ?, ?, ?)
- """,
- (
- event_id,
- thread_id,
- event_type,
- self.db._serialize_json(content),
- 1 if include_in_llm_message_history else 0,
- self.db._serialize_json(llm_message) if llm_message else None
- )
- )
- logging.info(f"Created event {event_id} in thread {thread_id}")
- return event_id
- except Exception as e:
- logging.error(f"Failed to create event in thread {thread_id}: {e}")
- raise
-
- async def delete_event(self, event_id: str) -> bool:
- """Delete a specific event."""
- try:
- result = await self.db.execute(
- "DELETE FROM events WHERE id = ?",
- (event_id,)
- )
- # Check if any rows were affected
- return result.rowcount > 0
- except Exception as e:
- logging.error(f"Failed to delete event {event_id}: {e}")
- return False
-
- async def update_event(
- self,
- event_id: str,
- thread_id: str,
- content: Optional[Dict[str, Any]] = None,
- include_in_llm_message_history: Optional[bool] = None,
- llm_message: Optional[Dict[str, Any]] = None
- ) -> bool:
- """Update an existing event."""
- try:
- # First verify the event exists and belongs to the thread
- event = await self.db.fetch_one(
- "SELECT 1 FROM events WHERE id = ? AND thread_id = ?",
- (event_id, thread_id)
- )
- if not event:
- return False
-
- updates = []
- params = []
- if content is not None:
- updates.append("content = ?")
- params.append(self.db._serialize_json(content))
- if include_in_llm_message_history is not None:
- updates.append("include_in_llm_message_history = ?")
- params.append(1 if include_in_llm_message_history else 0)
- if llm_message is not None:
- updates.append("llm_message = ?")
- params.append(self.db._serialize_json(llm_message))
-
- if not updates:
- return False
-
- query = f"""
- UPDATE events
- SET {', '.join(updates)}, updated_at = CURRENT_TIMESTAMP
- WHERE id = ? AND thread_id = ?
- """
- params.extend([event_id, thread_id])
-
- result = await self.db.execute(query, tuple(params))
- return result.rowcount > 0
- except Exception as e:
- logging.error(f"Failed to update event {event_id}: {e}")
- return False
-
- async def get_thread_events(
- self,
- thread_id: str,
- only_llm_messages: bool = False,
- event_types: Optional[List[str]] = None,
- order_by: str = "created_at",
- order: str = "ASC"
- ) -> List[Dict[str, Any]]:
- """Get events from a thread with filtering options."""
- try:
- query = ["SELECT * FROM events WHERE thread_id = ?"]
- params = [thread_id]
-
- if only_llm_messages:
- query.append("AND include_in_llm_message_history = 1")
+ message_id = str(uuid.uuid4())
- if event_types:
- placeholders = ','.join(['?' for _ in event_types])
- query.append(f"AND type IN ({placeholders})")
- params.extend(event_types)
-
- query.append(f"ORDER BY {order_by} {order}")
+ # Handle string content
+ if isinstance(message_data, str):
+ content = message_data
+ role = 'unknown'
+
+ # Determine message type only for LLM-related messages if not provided
+ if message_type is None:
+ type_mapping = {
+ 'user': 'user_message',
+ 'assistant': 'assistant_message',
+ 'tool': 'tool_message',
+ 'system': 'system_message'
+ }
+ message_type = type_mapping.get(role, 'unknown_message')
+
+ else:
+ # For dict message_data, check if it's an LLM message format
+ if 'role' in message_data and 'content' in message_data:
+ content = message_data.get('content', '')
+ role = message_data.get('role', 'unknown')
+
+ # Determine message type for LLM messages if not provided
+ if message_type is None:
+ type_mapping = {
+ 'user': 'user_message',
+ 'assistant': 'assistant_message',
+ 'tool': 'tool_message',
+ 'system': 'system_message'
+ }
+ message_type = type_mapping.get(role, 'unknown_message')
+ else:
+ # For non-LLM messages, use the entire message_data as content
+ content = message_data
- rows = await self.db.fetch_all(' '.join(query), tuple(params))
+ # Handle content processing
+ if isinstance(content, ToolResult):
+ content = str(content)
- events = []
- for row in rows:
- event = {
- "id": row[0],
- "thread_id": row[1],
- "type": row[2],
- "content": self.db._deserialize_json(row[3]),
- "include_in_llm_message_history": bool(row[4]),
- "llm_message": self.db._deserialize_json(row[5]) if row[5] else None,
- "created_at": row[6],
- "updated_at": row[7]
- }
- events.append(event)
-
- return events
- except Exception as e:
- logging.error(f"Failed to get events for thread {thread_id}: {e}")
- return []
-
- async def get_thread_llm_messages(self, thread_id: str) -> List[Dict[str, Any]]:
- """Get LLM-formatted messages from thread events."""
- events = await self.get_thread_events(thread_id, only_llm_messages=True)
- return [event["llm_message"] for event in events if event["llm_message"]]
-
- # Message handling methods refactored for event-based system
- async def add_message(self, thread_id: str, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None):
- """Add a message as an event to the thread."""
- try:
# Handle image attachments
if images:
- if isinstance(message_data['content'], str):
- content = [{"type": "text", "text": message_data['content']}]
- else:
- content = message_data['content'] if isinstance(message_data['content'], list) else []
+ if isinstance(content, str):
+ content = [{"type": "text", "text": content}]
+ elif not isinstance(content, list):
+ content = []
for image in images:
image_content = {
@@ -268,75 +121,139 @@ async def add_message(self, thread_id: str, message_data: Dict[str, Any], images
}
}
content.append(image_content)
- else:
- content = message_data['content']
- # Create event
- event_type = f"message_{message_data['role']}"
- await self.create_event(
- thread_id=thread_id,
- event_type=event_type,
- content={"raw_content": content},
- include_in_llm_message_history=True,
- llm_message=message_data
+ # Convert content to JSON string if it's a dict or list
+ if isinstance(content, (dict, list)):
+ content = json.dumps(content)
+
+ # Insert the message
+ await self.db.execute(
+ """
+ INSERT INTO messages (
+ id, thread_id, type, content, include_in_llm_message_history
+ ) VALUES (?, ?, ?, ?, ?)
+ """,
+ (message_id, thread_id, message_type, content, include_in_llm_message_history)
)
+ logging.info(f"Message added to thread {thread_id}")
+
except Exception as e:
logging.error(f"Failed to add message to thread {thread_id}: {e}")
raise
- async def get_messages(
- self,
+ async def get_llm_history_messages(
+ self,
thread_id: str,
hide_tool_msgs: bool = False,
only_latest_assistant: bool = False,
- regular_list: bool = True
) -> List[Dict[str, Any]]:
- """Get messages from thread events with filtering."""
- messages = await self.get_thread_llm_messages(thread_id)
-
+ """
+ Retrieve messages from a thread that are marked for LLM history.
+
+ Args:
+ thread_id: The thread to get messages from
+ hide_tool_msgs: Whether to hide tool messages
+ only_latest_assistant: Whether to only return the latest assistant message
+
+ Returns:
+ List of messages formatted for LLM context
+ """
+
+ # Get only messages marked for LLM history
+ rows = await self.db.fetch_all(
+ """
+ SELECT type, content
+ FROM messages
+ WHERE thread_id = ?
+ AND include_in_llm_message_history = TRUE
+ ORDER BY created_at ASC
+ """,
+ (thread_id,)
+ )
+
+ # Convert DB rows to message format
+ messages = []
+ type_to_role = {
+ 'user_message': 'user',
+ 'assistant_message': 'assistant',
+ 'tool_message': 'tool',
+ 'system_message': 'system'
+ }
+
+ for row in rows:
+ msg_type, content = row
+
+ # Try to parse JSON content
+ try:
+ content = json.loads(content)
+ except (json.JSONDecodeError, TypeError):
+ pass # Keep content as is if it's not JSON
+
+ message = {
+ 'role': type_to_role.get(msg_type, 'unknown'),
+ 'content': content
+ }
+ messages.append(message)
+
+ # Apply filters
if only_latest_assistant:
for msg in reversed(messages):
if msg.get('role') == 'assistant':
return [msg]
return []
-
+
if hide_tool_msgs:
messages = [
{k: v for k, v in msg.items() if k != 'tool_calls'}
for msg in messages
if msg.get('role') != 'tool'
]
-
- if regular_list:
- messages = [
- msg for msg in messages
- if msg.get('role') in ['system', 'assistant', 'tool', 'user']
- ]
-
+
+
return messages
async def _update_message(self, thread_id: str, message: Dict[str, Any]):
- """Update the last assistant message event."""
- events = await self.get_thread_events(
- thread_id=thread_id,
- event_types=["message_assistant"],
- order_by="created_at",
- order="DESC"
- )
-
- if events:
- last_assistant_event = events[0]
- await self.update_event(
- event_id=last_assistant_event["id"],
- thread_id=thread_id,
- content={"raw_content": message.get("content")},
- llm_message=message
+ """Update an existing message in the thread."""
+ try:
+ # Find the latest assistant message for this thread
+ row = await self.db.fetch_one(
+ """
+ SELECT id FROM messages
+ WHERE thread_id = ? AND type = 'assistant_message'
+ ORDER BY created_at DESC LIMIT 1
+ """,
+ (thread_id,)
)
+
+ if not row:
+ return
+
+ message_id = row[0]
+
+ # Convert content to JSON string if needed
+ content = message.get('content', '')
+ if isinstance(content, (dict, list)):
+ content = json.dumps(content)
+
+ # Update the message
+ async with self.db.transaction() as conn:
+ await conn.execute(
+ """
+ UPDATE messages
+ SET content = ?, updated_at = CURRENT_TIMESTAMP
+ WHERE id = ?
+ """,
+ (content, message_id)
+ )
+
+ except Exception as e:
+ logging.error(f"Failed to update message: {e}")
+ raise
async def cleanup_incomplete_tool_calls(self, thread_id: str):
"""Clean up incomplete tool calls in a thread."""
- messages = await self.get_messages(thread_id)
+ messages = await self.get_llm_history_messages(thread_id)
last_assistant_message = next((m for m in reversed(messages)
if m['role'] == 'assistant' and 'tool_calls' in m), None)
@@ -374,21 +291,19 @@ async def cleanup_incomplete_tool_calls(self, thread_id: str):
async def run_thread(
self,
thread_id: str,
- system_message: Dict[str, Any],
+ system_message: Dict[str, str],
model_name: str,
- temperature: float = 0,
- max_tokens: Optional[int] = None,
+ temperature: float = 0.7,
+ max_tokens: int = 4096,
tool_choice: str = "auto",
- temporary_message: Optional[Dict[str, Any]] = None,
- native_tool_calling: bool = False,
+ temporary_message: Optional[Dict[str, str]] = None,
+ native_tool_calling: bool = True,
xml_tool_calling: bool = False,
- execute_tools: bool = True,
stream: bool = False,
- execute_tools_on_stream: bool = False,
- parallel_tool_execution: bool = False,
- tool_parser: Optional[ToolParserBase] = None,
- tool_executor: Optional[ToolExecutorBase] = None,
- results_adder: Optional[ResultsAdderBase] = None
+ execute_tools: bool = True,
+ execute_tools_on_stream: bool = True,
+ parallel_tool_execution: bool = True,
+ stop: Optional[Union[str, List[str]]] = None
) -> Union[Dict[str, Any], AsyncGenerator]:
"""Run a conversation thread with specified parameters.
@@ -406,9 +321,7 @@ async def run_thread(
stream: Whether to stream the response
execute_tools_on_stream: Whether to execute tools during streaming
parallel_tool_execution: Whether to execute tools in parallel
- tool_parser: Custom tool parser implementation
- tool_executor: Custom tool executor implementation
- results_adder: Custom results adder implementation
+ stop (Union[str, List[str]], optional): Up to 4 sequences where the API will stop generating tokens
Returns:
Union[Dict[str, Any], AsyncGenerator]: Response or stream
@@ -421,71 +334,185 @@ async def run_thread(
- Cannot use both native and XML tool calling simultaneously
- Streaming responses include both content and tool results
"""
- # Validate tool calling configuration
- if native_tool_calling and xml_tool_calling:
- raise ValueError("Cannot use both native LLM tool calling and XML tool calling simultaneously")
-
- # Initialize tool components if any tool calling is enabled
- if native_tool_calling or xml_tool_calling:
- if tool_parser is None:
- tool_parser = XMLToolParser(tool_registry=self.tool_registry) if xml_tool_calling else StandardToolParser()
-
- if tool_executor is None:
- tool_executor = XMLToolExecutor(parallel=parallel_tool_execution, tool_registry=self.tool_registry) if xml_tool_calling else StandardToolExecutor(parallel=parallel_tool_execution)
-
- if results_adder is None:
- results_adder = XMLResultsAdder(self) if xml_tool_calling else StandardResultsAdder(self)
-
try:
- messages = await self.get_messages(thread_id)
- prepared_messages = [system_message] + messages
- if temporary_message:
- prepared_messages.append(temporary_message)
-
- openapi_tool_schemas = None
- if native_tool_calling:
- openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
- available_functions = self.tool_registry.get_available_functions()
- elif xml_tool_calling:
- available_functions = self.tool_registry.get_available_functions()
- else:
- available_functions = {}
-
- response_processor = LLMResponseProcessor(
+ # Add thread run start message
+ await self.add_message(
thread_id=thread_id,
- available_functions=available_functions,
- add_message_callback=self.add_message,
- update_message_callback=self._update_message,
- get_messages_callback=self.get_messages,
- parallel_tool_execution=parallel_tool_execution,
- tool_parser=tool_parser,
- tool_executor=tool_executor,
- results_adder=results_adder
+ message_data={
+ "name": "thread_run",
+ "status": "started",
+ "details": {
+ "model": model_name,
+ "temperature": temperature,
+ "native_tool_calling": native_tool_calling,
+ "xml_tool_calling": xml_tool_calling,
+ "execute_tools": execute_tools,
+ "stream": stream
+ }
+ },
+ message_type="agentpress_system",
+ include_in_llm_message_history=False
)
- llm_response = await self._run_thread_completion(
- messages=prepared_messages,
- model_name=model_name,
- temperature=temperature,
- max_tokens=max_tokens,
- tools=openapi_tool_schemas,
- tool_choice=tool_choice if native_tool_calling else None,
- stream=stream
- )
+ try:
+ messages = await self.get_llm_history_messages(thread_id)
+ prepared_messages = [system_message] + messages
+ if temporary_message:
+ prepared_messages.append(temporary_message)
+
+ openapi_tool_schemas = None
+ if native_tool_calling:
+ openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
+ available_functions = self.tool_registry.get_available_functions()
+ elif xml_tool_calling:
+ available_functions = self.tool_registry.get_available_functions()
+ else:
+ available_functions = {}
- if stream:
- return response_processor.process_stream(
- response_stream=llm_response,
- execute_tools=execute_tools,
- execute_tools_on_stream=execute_tools_on_stream
+ # Initialize appropriate tool parser and executor based on calling type
+ if xml_tool_calling:
+ tool_parser = XMLToolParser(tool_registry=self.tool_registry)
+ tool_executor = XMLToolExecutor(parallel=parallel_tool_execution, tool_registry=self.tool_registry)
+ results_adder = XMLResultsAdder(self)
+ else:
+ tool_parser = StandardToolParser()
+ tool_executor = StandardToolExecutor(parallel=parallel_tool_execution)
+ results_adder = StandardResultsAdder(self)
+
+ # Create a SINGLE response processor instance
+ response_processor = LLMResponseProcessor(
+ thread_id=thread_id,
+ tool_executor=tool_executor,
+ tool_parser=tool_parser,
+ available_functions=available_functions,
+ results_adder=results_adder
)
- await response_processor.process_response(
- response=llm_response,
- execute_tools=execute_tools
- )
+ response = await self._run_thread_completion(
+ messages=prepared_messages,
+ model_name=model_name,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ tools=openapi_tool_schemas,
+ tool_choice=tool_choice if native_tool_calling else None,
+ stream=stream,
+ stop=stop
+ )
+
+ if stream:
+ async def stream_with_completion():
+ processor = response_processor.process_stream(
+ response_stream=response,
+ execute_tools=execute_tools,
+ execute_tools_on_stream=execute_tools_on_stream
+ )
+
+ final_state = None
+ async for chunk in processor:
+ yield chunk
+ if hasattr(chunk, '_final_state'):
+ final_state = chunk._final_state
+
+ # Add completion message after stream ends
+ completion_message = {
+ "name": "thread_run",
+ "status": "completed",
+ "details": {
+ "model": model_name,
+ "temperature": temperature,
+ "native_tool_calling": native_tool_calling,
+ "xml_tool_calling": xml_tool_calling,
+ "execute_tools": execute_tools,
+ "stream": stream
+ }
+ }
+
+ # TODO: Add usage, cost information – from final llm response
+
+ # # Add usage information from final state
+ # if final_state and 'usage' in final_state:
+ # completion_message["usage"] = final_state["usage"]
+ # elif hasattr(response, 'cost_tracker'):
+ # completion_message["usage"] = {
+ # "prompt_tokens": response.cost_tracker['prompt_tokens'],
+ # "completion_tokens": response.cost_tracker['completion_tokens'],
+ # "total_tokens": response.cost_tracker['total_tokens'],
+ # "cost_usd": response.cost_tracker['cost']
+ # }
+
+ await self.add_message(
+ thread_id=thread_id,
+ message_data=completion_message,
+ message_type="agentpress_system",
+ include_in_llm_message_history=False
+ )
+
+ return stream_with_completion()
+
+ # For non-streaming, process response once
+ await response_processor.process_response(
+ response=response,
+ execute_tools=execute_tools
+ )
+
+ # Add completion message on success with cost information
+ completion_message = {
+ "name": "thread_run",
+ "status": "completed",
+ "details": {
+ "model": model_name,
+ "temperature": temperature,
+ "native_tool_calling": native_tool_calling,
+ "xml_tool_calling": xml_tool_calling,
+ "execute_tools": execute_tools,
+ "stream": stream
+ }
+ }
+
+ # TODO: Add usage, cost information – from final llm response
+
+ # # Add cost information if available
+ # if hasattr(response, 'cost'):
+ # completion_message["usage"] = {
+ # "cost_usd": response.cost
+ # }
+ # if hasattr(response, 'usage'):
+ # if "usage" not in completion_message:
+ # completion_message["usage"] = {}
+ # completion_message["usage"].update({
+ # "prompt_tokens": response.usage.prompt_tokens,
+ # "completion_tokens": response.usage.completion_tokens,
+ # "total_tokens": response.usage.total_tokens
+ # })
+
+ await self.add_message(
+ thread_id=thread_id,
+ message_data=completion_message,
+ message_type="agentpress_system",
+ include_in_llm_message_history=False
+ )
+
+ return response
- return llm_response
+ except Exception as e:
+ # Add error message if something goes wrong
+
+ # TODO: FIX THAT THREAD RUN CATCHES ERRORS FROM LLM.PY from RUN_THREAD_COMPLETION & CORRECTLY ADD ERROR MESSAGE TO THREAD
+
+ await self.add_message(
+ thread_id=thread_id,
+ message_data={
+ "name": "thread_run",
+ "status": "error",
+ "error": str(e),
+ "details": {
+ "error_type": type(e).__name__
+ }
+ },
+ message_type="agentpress_system",
+ include_in_llm_message_history=False
+ )
+ raise
except Exception as e:
logging.error(f"Error in run_thread: {str(e)}")
@@ -502,104 +529,191 @@ async def _run_thread_completion(
max_tokens: Optional[int],
tools: Optional[List[Dict[str, Any]]],
tool_choice: Optional[str],
- stream: bool
- ) -> Union[Any, AsyncGenerator]:
+ stream: bool,
+ stop: Optional[Union[str, List[str]]] = None
+ ) -> Union[Dict[str, Any], AsyncGenerator]:
"""Get completion from LLM API."""
- return await make_llm_api_call(
+ response = await make_llm_api_call(
messages,
model_name,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice=tool_choice,
- stream=stream
+ stream=stream,
+ stop=stop
)
-if __name__ == "__main__":
- import asyncio
- from agentpress.examples.example_agent.tools.files_tool import FilesTool
+ # For streaming responses, wrap in a cost-tracking generator
+ if stream:
+ async def cost_tracking_stream():
+ try:
+ async for chunk in response:
+ # Update token counts if available
+ if hasattr(chunk, 'usage'):
+ response.cost_tracker['prompt_tokens'] = chunk.usage.prompt_tokens
+ response.cost_tracker['completion_tokens'] = chunk.usage.completion_tokens
+ response.cost_tracker['total_tokens'] = chunk.usage.total_tokens
+
+ # Calculate running cost
+ input_cost = response.model_info['input_cost_per_token']
+ output_cost = response.model_info['output_cost_per_token']
+
+ cost = (response.cost_tracker['prompt_tokens'] * input_cost +
+ response.cost_tracker['completion_tokens'] * output_cost)
+ response.cost_tracker['cost'] = cost
+
+ # Attach cost tracker to the chunk for final state
+ if hasattr(chunk, '_final_state'):
+ chunk._final_state['usage'] = {
+ "prompt_tokens": response.cost_tracker['prompt_tokens'],
+ "completion_tokens": response.cost_tracker['completion_tokens'],
+ "total_tokens": response.cost_tracker['total_tokens'],
+ "cost_usd": response.cost_tracker['cost']
+ }
+ yield chunk
+ except Exception as e:
+ logging.error(f"Error in cost tracking stream: {e}")
+ raise
+
+ return cost_tracking_stream()
+
+ return response
- async def main():
- # Initialize managers
- thread_manager = ThreadManager()
-
- # Register available tools
- thread_manager.add_tool(FilesTool)
-
- # Create a new thread
- thread_id = await thread_manager.create_thread()
+ async def get_messages(
+ self,
+ thread_id: str,
+ message_types: Optional[List[str]] = None,
+ limit: Optional[int] = 50, # Default limit of 50 messages
+ offset: Optional[int] = 0, # Starting offset for pagination
+ before_timestamp: Optional[str] = None,
+ after_timestamp: Optional[str] = None,
+ include_in_llm_message_history: Optional[bool] = None,
+ order: str = "asc"
+ ) -> Dict[str, Any]:
+ """
+ Retrieve messages from a thread with optional filtering and pagination.
- # Add a test message
- await thread_manager.add_message(thread_id, {
- "role": "user",
- "content": "Please create 10x files – Each should be a chapter of a book about an Introduction to Robotics.."
- })
-
- # Define system message
- system_message = {
- "role": "system",
- "content": "You are a helpful assistant that can create, read, update, and delete files."
- }
+ Args:
+ thread_id: The thread to get messages from
+ message_types: Optional list of message types to filter by
+ limit: Maximum number of messages to return (default: 50)
+ offset: Number of messages to skip (for pagination)
+ before_timestamp: Optional timestamp to filter messages before
+ after_timestamp: Optional timestamp to filter messages after
+ include_in_llm_message_history: Optional bool to filter messages by LLM history inclusion
+ order: Sort order - "asc" or "desc"
+
+ Returns:
+ Dict containing messages list and pagination info
+ """
+ try:
+ # Build the base query for total count
+ count_query = """
+ SELECT COUNT(*)
+ FROM messages
+ WHERE thread_id = ?
+ """
+ count_params = [thread_id]
- # Test with streaming response and tool execution
- print("\n🤖 Testing streaming response with tools:")
- response = await thread_manager.run_thread(
- thread_id=thread_id,
- system_message=system_message,
- model_name="anthropic/claude-3-5-haiku-latest",
- temperature=0.7,
- max_tokens=4096,
- stream=True,
- native_tool_calling=True,
- execute_tools=True,
- execute_tools_on_stream=True,
- parallel_tool_execution=True
- )
+ # Build the base query for messages
+ query = """
+ SELECT id, type, content, created_at, updated_at, include_in_llm_message_history
+ FROM messages
+ WHERE thread_id = ?
+ """
+ params = [thread_id]
- # Handle streaming response
- if isinstance(response, AsyncGenerator):
- print("\nAssistant is responding:")
- content_buffer = ""
- try:
- async for chunk in response:
- if hasattr(chunk.choices[0], 'delta'):
- delta = chunk.choices[0].delta
-
- # Handle content streaming
- if hasattr(delta, 'content') and delta.content is not None:
- content_buffer += delta.content
- if delta.content.endswith((' ', '\n')):
- print(content_buffer, end='', flush=True)
- content_buffer = ""
-
- # Handle tool calls
- if hasattr(delta, 'tool_calls') and delta.tool_calls:
- for tool_call in delta.tool_calls:
- # Print tool name when it first appears
- if tool_call.function and tool_call.function.name:
- print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True)
-
- # Print arguments as they stream in
- if tool_call.function and tool_call.function.arguments:
- print(f" {tool_call.function.arguments}", end='', flush=True)
+ # Add filters to both queries
+ if message_types:
+ placeholders = ','.join('?' * len(message_types))
+ filter_sql = f" AND type IN ({placeholders})"
+ query += filter_sql
+ count_query += filter_sql
+ params.extend(message_types)
+ count_params.extend(message_types)
- # Print any remaining content
- if content_buffer:
- print(content_buffer, flush=True)
- print("\n✨ Response completed\n")
+ if before_timestamp:
+ query += " AND created_at < ?"
+ count_query += " AND created_at < ?"
+ params.append(before_timestamp)
+ count_params.append(before_timestamp)
- except Exception as e:
- print(f"\n❌ Error processing stream: {e}")
- else:
- print("\n✨ Response completed\n")
+ if after_timestamp:
+ query += " AND created_at > ?"
+ count_query += " AND created_at > ?"
+ params.append(after_timestamp)
+ count_params.append(after_timestamp)
+
+ if include_in_llm_message_history is not None:
+ query += " AND include_in_llm_message_history = ?"
+ count_query += " AND include_in_llm_message_history = ?"
+ params.append(include_in_llm_message_history)
+ count_params.append(include_in_llm_message_history)
+
+ # Get total count for pagination
+ total_count = await self.db.fetch_one(count_query, tuple(count_params))
+ total_count = total_count[0] if total_count else 0
+
+ # Add ordering and pagination
+ query += f" ORDER BY created_at {'ASC' if order.lower() == 'asc' else 'DESC'}"
+ query += " LIMIT ? OFFSET ?"
+ params.extend([limit, offset])
+
+ # Execute query
+ rows = await self.db.fetch_all(query, tuple(params))
+
+ # Convert rows to dictionaries
+ messages = []
+ for row in rows:
+ message = {
+ 'id': row[0],
+ 'type': row[1],
+ 'content': row[2],
+ 'created_at': row[3],
+ 'updated_at': row[4],
+ 'include_in_llm_message_history': bool(row[5])
+ }
- # Display final thread state
- messages = await thread_manager.get_messages(thread_id)
- print("\n📝 Final Thread State:")
- for msg in messages:
- role = msg.get('role', 'unknown')
- content = msg.get('content', '')
- print(f"\n{role.upper()}: {content[:100]}...")
+ # Try to parse JSON content
+ try:
+ message['content'] = json.loads(message['content'])
+ except (json.JSONDecodeError, TypeError):
+ pass # Keep content as is if it's not JSON
- asyncio.run(main())
+ messages.append(message)
+
+ # Return messages with pagination info
+ return {
+ "messages": messages,
+ "pagination": {
+ "total": total_count,
+ "limit": limit,
+ "offset": offset,
+ "has_more": (offset + limit) < total_count
+ }
+ }
+
+ except Exception as e:
+ logging.error(f"Failed to get messages from thread {thread_id}: {e}")
+ raise
+ async def thread_exists(self, thread_id: str) -> bool:
+ """
+ Check if a thread exists.
+
+ Args:
+ thread_id: The ID of the thread to check
+
+ Returns:
+ bool: True if thread exists, False otherwise
+ """
+ try:
+ row = await self.db.fetch_one(
+ "SELECT 1 FROM threads WHERE id = ?",
+ (thread_id,)
+ )
+ return row is not None
+ except Exception as e:
+ logging.error(f"Error checking thread existence: {e}")
+ return False
diff --git a/agentpress/thread_viewer_ui.py b/agentpress/thread_viewer_ui.py
index 944cb49..4af1fd9 100644
--- a/agentpress/thread_viewer_ui.py
+++ b/agentpress/thread_viewer_ui.py
@@ -1,64 +1,111 @@
import streamlit as st
from datetime import datetime
-from agentpress.thread_manager import ThreadManager
from agentpress.db_connection import DBConnection
import asyncio
+import json
def format_message_content(content):
- """Format message content handling both string and list formats."""
- if isinstance(content, str):
- return content
- elif isinstance(content, list):
- formatted_content = []
- for item in content:
- if item.get('type') == 'text':
- formatted_content.append(item['text'])
- elif item.get('type') == 'image_url':
- formatted_content.append("[Image]")
- return "\n".join(formatted_content)
- return str(content)
+ """Format message content handling various formats."""
+ try:
+ if isinstance(content, str):
+ # Try to parse JSON strings
+ try:
+ parsed = json.loads(content)
+ if isinstance(parsed, (dict, list)):
+ return json.dumps(parsed, indent=2)
+ except json.JSONDecodeError:
+ return content
+ elif isinstance(content, list):
+ formatted_content = []
+ for item in content:
+ if item.get('type') == 'text':
+ formatted_content.append(item['text'])
+ elif item.get('type') == 'image_url':
+ formatted_content.append("[Image]")
+ return "\n".join(formatted_content)
+ return json.dumps(content, indent=2)
+ except:
+ return str(content)
async def load_threads():
"""Load all thread IDs from the database."""
db = DBConnection()
- rows = await db.fetch_all("SELECT thread_id, created_at FROM threads ORDER BY created_at DESC")
+ rows = await db.fetch_all(
+ """
+ SELECT id, created_at
+ FROM threads
+ ORDER BY created_at DESC
+ """
+ )
return rows
-async def load_thread_content(thread_id: str):
- """Load the content of a specific thread from the database."""
- thread_manager = ThreadManager()
- return await thread_manager.get_messages(thread_id)
+async def load_thread_content(thread_id: str, filters: dict):
+ """Load messages from a thread with filters."""
+ db = DBConnection()
+
+ query_parts = ["SELECT type, content, include_in_llm_message_history, created_at FROM messages WHERE thread_id = ?"]
+ params = [thread_id]
+
+ if filters.get('message_types'):
+ # Convert comma-separated string to list and clean up whitespace
+ types_list = [t.strip() for t in filters['message_types'].split(',') if t.strip()]
+ if types_list:
+ query_parts.append("AND type IN (" + ",".join(["?" for _ in types_list]) + ")")
+ params.extend(types_list)
+
+ if filters.get('exclude_message_types'):
+ # Convert comma-separated string to list and clean up whitespace
+ exclude_types_list = [t.strip() for t in filters['exclude_message_types'].split(',') if t.strip()]
+ if exclude_types_list:
+ query_parts.append("AND type NOT IN (" + ",".join(["?" for _ in exclude_types_list]) + ")")
+ params.extend(exclude_types_list)
+
+ if filters.get('before_timestamp'):
+ query_parts.append("AND created_at < ?")
+ params.append(filters['before_timestamp'])
+
+ if filters.get('after_timestamp'):
+ query_parts.append("AND created_at > ?")
+ params.append(filters['after_timestamp'])
+
+ if filters.get('include_in_llm_message_history') is not None:
+ query_parts.append("AND include_in_llm_message_history = ?")
+ params.append(filters['include_in_llm_message_history'])
+
+ # Add ordering
+ order_direction = "DESC" if filters.get('order', 'asc').lower() == 'desc' else "ASC"
+ query_parts.append(f"ORDER BY created_at {order_direction}")
+
+ # Add limit and offset
+ if filters.get('limit'):
+ query_parts.append("LIMIT ?")
+ params.append(filters['limit'])
+
+ if filters.get('offset'):
+ query_parts.append("OFFSET ?")
+ params.append(filters['offset'])
+
+ query = " ".join(query_parts)
+ rows = await db.fetch_all(query, tuple(params))
+ return rows
-def render_message(role, content, avatar):
- """Render a message with a consistent chat-like style."""
- # Create columns for avatar and message
- col1, col2 = st.columns([1, 11])
-
- # Style based on role
- if role == "assistant":
- bgcolor = "rgba(25, 25, 25, 0.05)"
- elif role == "user":
- bgcolor = "rgba(25, 120, 180, 0.05)"
- elif role == "system":
- bgcolor = "rgba(180, 25, 25, 0.05)"
- else:
- bgcolor = "rgba(100, 100, 100, 0.05)"
-
- # Display avatar in first column
+def render_message(msg_type: str, content: str, include_in_llm: bool, timestamp: str):
+ """Render a message using Streamlit components."""
+ # Message type and metadata
+ col1, col2 = st.columns([3, 1])
with col1:
- st.markdown(f"
{avatar}
", unsafe_allow_html=True)
-
- # Display message in second column
+ st.text(f"Type: {msg_type}")
with col2:
- st.markdown(
- f"""
-
- {role.upper()}
- {content}
-
- """,
- unsafe_allow_html=True
- )
+ st.text("🟢 LLM" if include_in_llm else "⚫ Non-LLM")
+
+ # Timestamp
+ st.text(f"Time: {datetime.fromisoformat(timestamp).strftime('%Y-%m-%d %H:%M:%S')}")
+
+ # Message content
+ st.code(content, language="json")
+
+ # Separator
+ st.divider()
def main():
st.title("Thread Viewer")
@@ -86,7 +133,6 @@ def main():
)
if selected_thread_display:
- # Get the actual thread ID from the display string
selected_thread_id = thread_options[selected_thread_display]
# Display thread ID in sidebar
@@ -95,46 +141,77 @@ def main():
# Add refresh button
if st.sidebar.button("🔄 Refresh Thread"):
st.session_state.threads = asyncio.run(load_threads())
- st.experimental_rerun()
+ st.rerun()
+
+ # Advanced filtering options in sidebar
+ st.sidebar.title("Filter Options")
+
+ # Message type filters
+ col1, col2 = st.sidebar.columns(2)
+ with col1:
+ message_types = st.text_input(
+ "Include Types",
+ help="Enter message types to include, separated by commas"
+ )
+ with col2:
+ exclude_message_types = st.text_input(
+ "Exclude Types",
+ help="Enter message types to exclude, separated by commas"
+ )
+
+ # Limit and offset
+ col1, col2 = st.sidebar.columns(2)
+ with col1:
+ limit = st.number_input("Limit", min_value=1, value=50)
+ with col2:
+ offset = st.number_input("Offset", min_value=0, value=0)
+
+ # Timestamp filters
+ st.sidebar.subheader("Time Range")
+ before_timestamp = st.sidebar.date_input("Before Date", value=None)
+ after_timestamp = st.sidebar.date_input("After Date", value=None)
+
+ # LLM history filter
+ include_in_llm = st.sidebar.radio(
+ "LLM History Filter",
+ options=["All Messages", "LLM Only", "Non-LLM Only"]
+ )
+
+ # Sort order
+ order = st.sidebar.radio("Sort Order", ["Ascending", "Descending"])
+
+ # Prepare filters
+ filters = {
+ 'message_types': message_types if message_types else None,
+ 'exclude_message_types': exclude_message_types if exclude_message_types else None,
+ 'limit': limit,
+ 'offset': offset,
+ 'order': 'desc' if order == "Descending" else 'asc'
+ }
+
+ # Add timestamp filters if selected
+ if before_timestamp:
+ filters['before_timestamp'] = before_timestamp.isoformat()
+ if after_timestamp:
+ filters['after_timestamp'] = after_timestamp.isoformat()
+
+ # Add LLM history filter
+ if include_in_llm == "LLM Only":
+ filters['include_in_llm_message_history'] = True
+ elif include_in_llm == "Non-LLM Only":
+ filters['include_in_llm_message_history'] = False
+
+ # Load messages with filters
+ messages = asyncio.run(load_thread_content(selected_thread_id, filters))
- # Load and display messages
- messages = asyncio.run(load_thread_content(selected_thread_id))
+ if not messages:
+ st.info("No messages found with current filters")
+ return
- # Display messages in chat-like interface
- for message in messages:
- role = message.get("role", "unknown")
- content = message.get("content", "")
-
- # Determine avatar based on role
- if role == "assistant":
- avatar = "🤖"
- elif role == "user":
- avatar = "👤"
- elif role == "system":
- avatar = "⚙️"
- elif role == "tool":
- avatar = "🔧"
- else:
- avatar = "❓"
-
- # Format the content
+ # Display messages
+ for msg_type, content, include_in_llm, timestamp in messages:
formatted_content = format_message_content(content)
-
- # Render the message
- render_message(role, formatted_content, avatar)
-
- # Display tool calls if present
- if "tool_calls" in message:
- with st.expander("🛠️ Tool Calls"):
- for tool_call in message["tool_calls"]:
- st.code(
- f"Function: {tool_call['function']['name']}\n"
- f"Arguments: {tool_call['function']['arguments']}",
- language="json"
- )
-
- # Add some spacing between messages
- st.markdown("
", unsafe_allow_html=True)
+ render_message(msg_type, formatted_content, include_in_llm, timestamp)
if __name__ == "__main__":
main()
diff --git a/example/__init__.py b/example/__init__.py
new file mode 100644
index 0000000..3ebeb22
--- /dev/null
+++ b/example/__init__.py
@@ -0,0 +1 @@
+# Empty file to mark as package
\ No newline at end of file
diff --git a/example/agent.py b/example/agent.py
new file mode 100644
index 0000000..2a7be66
--- /dev/null
+++ b/example/agent.py
@@ -0,0 +1,261 @@
+"""
+Interactive web development agent supporting both XML and Standard LLM tool calling.
+
+This agent can:
+- Create and modify web projects
+- Execute terminal commands
+- Handle file operations
+- Use either XML or Standard tool calling patterns
+"""
+
+import asyncio
+import json
+from agentpress.thread_manager import ThreadManager
+from example.tools.files_tool import FilesTool
+from agentpress.state_manager import StateManager
+from example.tools.terminal_tool import TerminalTool
+import logging
+from typing import AsyncGenerator, Optional, Dict, Any
+import sys
+
+from agentpress.api.api_factory import register_thread_task_api
+
+BASE_SYSTEM_MESSAGE = """
+You are a world-class web developer who can create, edit, and delete files, and execute terminal commands.
+You write clean, well-structured code. Keep iterating on existing files, continue working on this existing
+codebase - do not omit previous progress; instead, keep iterating.
+Available tools:
+- create_file: Create new files with specified content
+- delete_file: Remove existing files
+- str_replace: Make precise text replacements in files
+- execute_command: Run terminal commands
+
+
+RULES:
+- All current file contents are available to you in the section
+- Each file in the workspace state includes its full content
+- Use str_replace for precise replacements in files
+- NEVER include comments in any code you write - the code should be self-documenting
+- Always maintain the full context of files when making changes
+- When creating new files, write clean code without any comments or documentation
+
+
+[create_file(file_path, file_contents)] - Create new files
+[delete_file(file_path)] - Delete existing files
+[str_replace(file_path, old_str, new_str)] - Replace specific text in files
+[execute_command(command)] - Execute terminal commands
+
+
+ALWAYS RESPOND WITH MULTIPLE SIMULTANEOUS ACTIONS:
+
+[Provide a concise overview of your planned changes and implementations]
+
+
+
+[Include multiple tool calls]
+
+
+EDITING GUIDELINES:
+1. Review the current file contents in the workspace state
+2. Make targeted changes with str_replace
+3. Write clean, self-documenting code without comments
+4. Use create_file for new files and str_replace for modifications
+
+Example workspace state for a file:
+{
+ "index.html": {
+ "content": "\\n\\n..."
+ }
+}
+Think deeply and step by step.
+"""
+
+XML_FORMAT = """
+RESPONSE FORMAT:
+Use XML tags to specify file operations:
+
+
+file contents here
+
+
+
+text to replace
+replacement text
+
+
+
+
+
+
+"""
+
+@register_thread_task_api("/agent")
+async def run_agent(
+ thread_id: str,
+ max_iterations: int = 5,
+ user_input: Optional[str] = None,
+) -> Dict[str, Any]:
+ """Run the development agent with specified configuration.
+
+ Args:
+ thread_id (str): The ID of the thread.
+ max_iterations (int, optional): The maximum number of iterations. Defaults to 5.
+ user_input (Optional[str], optional): The user input. Defaults to None.
+ """
+ thread_manager = ThreadManager()
+ state_manager = StateManager(thread_id)
+
+ if user_input:
+ await thread_manager.add_message(
+ thread_id,
+ {
+ "role": "user",
+ "content": user_input
+ }
+ )
+
+ thread_manager.add_tool(FilesTool, thread_id=thread_id)
+ thread_manager.add_tool(TerminalTool, thread_id=thread_id)
+
+ system_message = {
+ "role": "system",
+ "content": BASE_SYSTEM_MESSAGE + XML_FORMAT
+ }
+
+ iteration = 0
+ while iteration < max_iterations:
+ iteration += 1
+
+ files_tool = FilesTool(thread_id=thread_id)
+
+ state = await state_manager.export_store()
+
+ temporary_message_content = f"""
+ You are tasked to complete the LATEST USER REQUEST!
+
+ {user_input}
+
+
+ Current development environment workspace state:
+
+ {json.dumps(state, indent=2) if state else "{}"}
+
+
+ CONTINUE WITH THE TASK! USE THE SESSION TOOL TO STOP THE SESSION IF THE TASK IS COMPLETE.
+ """
+
+ await thread_manager.add_message(
+ thread_id=thread_id,
+ message_data=temporary_message_content,
+ message_type="temporary_message",
+ include_in_llm_message_history=False
+ )
+
+ temporary_message = {
+ "role": "user",
+ "content": temporary_message_content
+ }
+
+ model_name = "anthropic/claude-3-5-sonnet-latest"
+
+ response = await thread_manager.run_thread(
+ thread_id=thread_id,
+ system_message=system_message,
+ model_name=model_name,
+ temperature=0.1,
+ max_tokens=8096,
+ tool_choice="auto",
+ temporary_message=temporary_message,
+ native_tool_calling=False,
+ xml_tool_calling=True,
+ stream=True,
+ execute_tools_on_stream=True,
+ parallel_tool_execution=True,
+ )
+
+ if isinstance(response, AsyncGenerator):
+ print("\n🤖 Assistant is responding:")
+ try:
+ async for chunk in response:
+ if hasattr(chunk.choices[0], 'delta'):
+ delta = chunk.choices[0].delta
+
+ if hasattr(delta, 'content') and delta.content is not None:
+ content = delta.content
+ print(content, end='', flush=True)
+
+ # Check for open_files_in_editor tag and continue if found
+ if '' in content:
+ print("\n📂 Opening files in editor, continuing to next iteration...")
+ continue
+
+ if hasattr(delta, 'tool_calls') and delta.tool_calls:
+ for tool_call in delta.tool_calls:
+ if tool_call.function:
+ if tool_call.function.name:
+ print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True)
+ if tool_call.function.arguments:
+ print(f" {tool_call.function.arguments}", end='', flush=True)
+
+ print("\n✨ Response completed\n")
+
+ except Exception as e:
+ print(f"\n❌ Error processing stream: {e}", file=sys.stderr)
+ logging.error(f"Error processing stream: {e}")
+ else:
+ print("\nNon-streaming response received:", response)
+
+ # # Get latest assistant message and check for stop_session
+ # latest_msg = await thread_manager.get_llm_history_messages(
+ # thread_id=thread_id,
+ # only_latest_assistant=True
+ # )
+ # if latest_msg and '' in latest_msg:
+ # break
+
+
+if __name__ == "__main__":
+ print("\n🚀 Welcome to AgentPress!")
+
+ project_description = input("What would you like to build? (default: Create a modern, responsive landing page)\n> ")
+ if not project_description.strip():
+ project_description = "Create a modern, responsive landing page"
+
+ print("\nChoose your agent type:")
+ print("1. XML-based Tool Calling")
+ print(" - Structured XML format for tool execution")
+ print(" - Parses tool calls using XML outputs in the LLM response")
+
+ print("\n2. Standard Function Calling")
+ print(" - Native LLM function calling format")
+ print(" - JSON-based parameter passing")
+
+ use_xml = input("\nSelect tool calling format [1/2] (default: 1): ").strip() != "2"
+
+ print(f"\n{'XML-based' if use_xml else 'Standard'} agent will help you build: {project_description}")
+ print("Use Ctrl+C to stop the agent at any time.")
+
+ async def test_agent():
+ thread_manager = ThreadManager()
+ thread_id = await thread_manager.create_thread()
+ logging.info(f"Created new thread: {thread_id}")
+
+ try:
+ result = await run_agent(
+ thread_id=thread_id,
+ max_iterations=5,
+ user_input=project_description,
+ )
+ print("\n✅ Test completed successfully!")
+
+ except Exception as e:
+ print(f"\n❌ Test failed: {str(e)}")
+ raise
+
+ try:
+ asyncio.run(test_agent())
+ except KeyboardInterrupt:
+ print("\n⚠️ Test interrupted by user")
+ except Exception as e:
+ print(f"\n❌ Test failed with error: {str(e)}")
+ raise
\ No newline at end of file
diff --git a/example/tools/files_tool.py b/example/tools/files_tool.py
new file mode 100644
index 0000000..0d5948d
--- /dev/null
+++ b/example/tools/files_tool.py
@@ -0,0 +1,297 @@
+import os
+import asyncio
+from pathlib import Path
+from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
+from agentpress.state_manager import StateManager
+from typing import Optional
+
+class FilesTool(Tool):
+ """File management tool for creating, updating, and deleting files.
+
+ This tool provides file operations within a workspace directory, with built-in
+ file filtering and state tracking capabilities.
+
+ Attributes:
+ workspace (str): Path to the workspace directory
+ EXCLUDED_FILES (set): Files to exclude from operations
+ EXCLUDED_DIRS (set): Directories to exclude
+ EXCLUDED_EXT (set): File extensions to exclude
+ SNIPPET_LINES (int): Context lines for edit previews
+ """
+
+ # Excluded files, directories, and extensions
+ EXCLUDED_FILES = {
+ ".DS_Store",
+ ".gitignore",
+ "package-lock.json",
+ "postcss.config.js",
+ "postcss.config.mjs",
+ "jsconfig.json",
+ "components.json",
+ "tsconfig.tsbuildinfo",
+ "tsconfig.json",
+ }
+
+ EXCLUDED_DIRS = {
+ "node_modules",
+ ".next",
+ "dist",
+ "build",
+ ".git"
+ }
+
+ EXCLUDED_EXT = {
+ ".ico",
+ ".svg",
+ ".png",
+ ".jpg",
+ ".jpeg",
+ ".gif",
+ ".bmp",
+ ".tiff",
+ ".webp",
+ ".db",
+ ".sql"
+ }
+
+ def __init__(self, thread_id: Optional[str] = None):
+ super().__init__()
+ self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
+ os.makedirs(self.workspace, exist_ok=True)
+ if thread_id:
+ self.state_manager = StateManager(thread_id)
+ asyncio.create_task(self._init_workspace_state())
+ self.SNIPPET_LINES = 4
+
+ def _should_exclude_file(self, rel_path: str) -> bool:
+ """Check if a file should be excluded based on path, name, or extension"""
+ # Check filename
+ filename = os.path.basename(rel_path)
+ if filename in self.EXCLUDED_FILES:
+ return True
+
+ # Check directory
+ dir_path = os.path.dirname(rel_path)
+ if any(excluded in dir_path for excluded in self.EXCLUDED_DIRS):
+ return True
+
+ # Check extension
+ _, ext = os.path.splitext(filename)
+ if ext.lower() in self.EXCLUDED_EXT:
+ return True
+
+ return False
+
+ async def _init_workspace_state(self):
+ """Initialize or update the workspace state in JSON"""
+ files_state = {}
+
+ # Walk through workspace and record all files
+ for root, _, files in os.walk(self.workspace):
+ for file in files:
+ full_path = os.path.join(root, file)
+ rel_path = os.path.relpath(full_path, self.workspace)
+
+ # Skip excluded files
+ if self._should_exclude_file(rel_path):
+ continue
+
+ try:
+ with open(full_path, 'r') as f:
+ content = f.read()
+ files_state[rel_path] = {
+ "content": content
+ }
+ except Exception as e:
+ print(f"Error reading file {rel_path}: {e}")
+ except UnicodeDecodeError:
+ print(f"Skipping binary file: {rel_path}")
+
+ if hasattr(self, 'state_manager'):
+ await self.state_manager.set("files", files_state)
+
+ async def _update_workspace_state(self):
+ """Update the workspace state after any file operation"""
+ await self._init_workspace_state()
+
+ @openapi_schema({
+ "type": "function",
+ "function": {
+ "name": "create_file",
+ "description": "Create a new file with the provided contents at a given path in the workspace",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "file_path": {
+ "type": "string",
+ "description": "Path to the file to be created"
+ },
+ "file_contents": {
+ "type": "string",
+ "description": "The content to write to the file"
+ }
+ },
+ "required": ["file_path", "file_contents"]
+ }
+ }
+ })
+ @xml_schema(
+ tag_name="create-file",
+ mappings=[
+ {"param_name": "file_path", "node_type": "attribute", "path": "."},
+ {"param_name": "file_contents", "node_type": "content", "path": "."}
+ ],
+ example='''
+
+ File contents go here
+
+ '''
+ )
+ async def create_file(self, file_path: str, file_contents: str) -> ToolResult:
+ try:
+ full_path = os.path.join(self.workspace, file_path)
+ if os.path.exists(full_path):
+ return self.fail_response(f"File '{file_path}' already exists. Use update_file to modify existing files.")
+
+ os.makedirs(os.path.dirname(full_path), exist_ok=True)
+ with open(full_path, 'w') as f:
+ f.write(file_contents)
+
+ await self._update_workspace_state()
+ return self.success_response(f"File '{file_path}' created successfully.")
+ except Exception as e:
+ return self.fail_response(f"Error creating file: {str(e)}")
+
+ @openapi_schema({
+ "type": "function",
+ "function": {
+ "name": "delete_file",
+ "description": "Delete a file at the given path",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "file_path": {
+ "type": "string",
+ "description": "Path to the file to be deleted"
+ }
+ },
+ "required": ["file_path"]
+ }
+ }
+ })
+ @xml_schema(
+ tag_name="delete-file",
+ mappings=[
+ {"param_name": "file_path", "node_type": "attribute", "path": "."}
+ ],
+ example='''
+
+
+ '''
+ )
+ async def delete_file(self, file_path: str) -> ToolResult:
+ try:
+ full_path = os.path.join(self.workspace, file_path)
+ os.remove(full_path)
+
+ await self._update_workspace_state()
+ return self.success_response(f"File '{file_path}' deleted successfully.")
+ except Exception as e:
+ return self.fail_response(f"Error deleting file: {str(e)}")
+
+ @openapi_schema({
+ "type": "function",
+ "function": {
+ "name": "str_replace",
+ "description": "Replace text in file",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "file_path": {
+ "type": "string",
+ "description": "Path to the target file"
+ },
+ "old_str": {
+ "type": "string",
+ "description": "Text to be replaced (must appear exactly once)"
+ },
+ "new_str": {
+ "type": "string",
+ "description": "Replacement text"
+ }
+ },
+ "required": ["file_path", "old_str", "new_str"]
+ }
+ }
+ })
+ @xml_schema(
+ tag_name="str-replace",
+ mappings=[
+ {"param_name": "file_path", "node_type": "attribute", "path": "file_path"},
+ {"param_name": "old_str", "node_type": "element", "path": "old_str"},
+ {"param_name": "new_str", "node_type": "element", "path": "new_str"}
+ ],
+ example='''
+
+ text to replace
+ replacement text
+
+ '''
+ )
+ async def str_replace(self, file_path: str, old_str: str, new_str: str) -> ToolResult:
+ try:
+ full_path = Path(os.path.join(self.workspace, file_path))
+ if not full_path.exists():
+ return self.fail_response(f"File '{file_path}' does not exist")
+
+ content = full_path.read_text().expandtabs()
+ old_str = old_str.expandtabs()
+ new_str = new_str.expandtabs()
+
+ occurrences = content.count(old_str)
+ if occurrences == 0:
+ return self.fail_response(f"String '{old_str}' not found in file")
+ if occurrences > 1:
+ lines = [i+1 for i, line in enumerate(content.split('\n')) if old_str in line]
+ return self.fail_response(f"Multiple occurrences found in lines {lines}. Please ensure string is unique")
+
+ # Perform replacement
+ new_content = content.replace(old_str, new_str)
+ full_path.write_text(new_content)
+
+ # Update state after file modification
+ await self._update_workspace_state()
+
+ # Show snippet around the edit
+ replacement_line = content.split(old_str)[0].count('\n')
+ start_line = max(0, replacement_line - self.SNIPPET_LINES)
+ end_line = replacement_line + self.SNIPPET_LINES + new_str.count('\n')
+ snippet = '\n'.join(new_content.split('\n')[start_line:end_line + 1])
+
+ return self.success_response(f"Replacement successful. Snippet of changes:\n{snippet}")
+
+ except Exception as e:
+ return self.fail_response(f"Error replacing string: {str(e)}")
+
+if __name__ == "__main__":
+ async def test_files_tool():
+ files_tool = FilesTool()
+ test_file_path = "test_file.txt"
+ test_content = "This is a test file."
+ updated_content = "This is an updated test file."
+
+ print(f"Using workspace directory: {files_tool.workspace}")
+
+ # Test create_file
+ create_result = await files_tool.create_file(test_file_path, test_content)
+ print("Create file result:", create_result)
+
+ # Test delete_file
+ delete_result = await files_tool.delete_file(test_file_path)
+ print("Delete file result:", delete_result)
+
+ # Test read_file after delete (should fail)
+ read_deleted_result = await files_tool.read_file(test_file_path)
+ print("Read deleted file result:", read_deleted_result)
+
+ asyncio.run(test_files_tool())
\ No newline at end of file
diff --git a/example/tools/terminal_tool.py b/example/tools/terminal_tool.py
new file mode 100644
index 0000000..c9736dc
--- /dev/null
+++ b/example/tools/terminal_tool.py
@@ -0,0 +1,73 @@
+import os
+import asyncio
+import subprocess
+from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
+from typing import Optional
+
+class TerminalTool(Tool):
+ """Terminal command execution tool for workspace operations."""
+
+ def __init__(self, thread_id: Optional[str] = None):
+ super().__init__()
+ self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
+ os.makedirs(self.workspace, exist_ok=True)
+
+ @openapi_schema({
+ "type": "function",
+ "function": {
+ "name": "execute_command",
+ "description": "Execute a shell command in the workspace directory",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "command": {
+ "type": "string",
+ "description": "The shell command to execute"
+ }
+ },
+ "required": ["command"]
+ }
+ }
+ })
+ @xml_schema(
+ tag_name="execute-command",
+ mappings=[
+ {"param_name": "command", "node_type": "content", "path": "."}
+ ],
+ example='''
+
+ npm install package-name
+
+ '''
+ )
+ async def execute_command(self, command: str) -> ToolResult:
+ original_dir = os.getcwd()
+ try:
+ os.chdir(self.workspace)
+
+ process = await asyncio.create_subprocess_shell(
+ command,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ cwd=self.workspace
+ )
+ stdout, stderr = await process.communicate()
+
+ output = stdout.decode() if stdout else ""
+ error = stderr.decode() if stderr else ""
+ success = process.returncode == 0
+
+ if success:
+ return self.success_response({
+ "output": output,
+ "error": error,
+ "exit_code": process.returncode,
+ "cwd": self.workspace
+ })
+ else:
+ return self.fail_response(f"Command failed with exit code {process.returncode}: {error}")
+
+ except Exception as e:
+ return self.fail_response(f"Error executing command: {str(e)}")
+ finally:
+ os.chdir(original_dir)
diff --git a/example/workspace/css/styles.css b/example/workspace/css/styles.css
new file mode 100644
index 0000000..f7d8379
--- /dev/null
+++ b/example/workspace/css/styles.css
@@ -0,0 +1,373 @@
+:root {
+ --primary-color: #2563eb;
+ --secondary-color: #1e40af;
+ --text-color: #1f2937;
+ --light-text: #6b7280;
+ --background: #ffffff;
+ --section-bg: #f3f4f6;
+}
+
+* {
+ margin: 0;
+ padding: 0;
+ box-sizing: border-box;
+}
+
+html {
+ scroll-behavior: smooth;
+}
+
+.scroll-top {
+ position: fixed;
+ bottom: 2rem;
+ right: 2rem;
+ background: var(--primary-color);
+ color: white;
+ width: 45px;
+ height: 45px;
+ border-radius: 50%;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ cursor: pointer;
+ opacity: 0;
+ visibility: hidden;
+ transition: all 0.3s ease;
+ z-index: 1000;
+}
+
+.scroll-top.visible {
+ opacity: 1;
+ visibility: visible;
+}
+
+.scroll-top:hover {
+ transform: translateY(-3px);
+ background: var(--secondary-color);
+}
+
+body {
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
+ line-height: 1.6;
+ color: var(--text-color);
+}
+
+.container {
+ max-width: 1200px;
+ margin: 0 auto;
+ padding: 0 2rem;
+}
+
+.header {
+ position: fixed;
+ width: 100%;
+ background: var(--background);
+ box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
+ z-index: 1000;
+ transition: transform 0.3s ease;
+}
+
+.header.scrolled {
+ transform: translateY(-100%);
+}
+
+.header:hover {
+ transform: translateY(0);
+}
+
+.nav {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+ padding: 1rem 2rem;
+}
+
+.logo {
+ font-size: 1.5rem;
+ font-weight: bold;
+ color: var(--primary-color);
+}
+
+.nav-links {
+ display: flex;
+ gap: 2rem;
+ list-style: none;
+}
+
+.nav-links a {
+ text-decoration: none;
+ color: var(--text-color);
+ font-weight: 500;
+ transition: color 0.3s ease;
+}
+
+.nav-links a:hover {
+ color: var(--primary-color);
+}
+
+.mobile-nav-toggle {
+ display: none;
+}
+
+.hero {
+ padding: 8rem 0 4rem;
+ background: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
+ background-size: 200% 200%;
+ color: white;
+ text-align: center;
+ animation: gradientMove 10s ease infinite;
+}
+
+@keyframes gradientMove {
+ 0% { background-position: 0% 50%; }
+ 50% { background-position: 100% 50%; }
+ 100% { background-position: 0% 50%; }
+}
+
+.hero h1 {
+ font-size: 3rem;
+ margin-bottom: 1rem;
+}
+
+.hero p {
+ font-size: 1.25rem;
+ margin-bottom: 2rem;
+}
+
+.cta-button {
+ padding: 1rem 2rem;
+ font-size: 1.1rem;
+ background: white;
+ color: var(--primary-color);
+ border: none;
+ border-radius: 5px;
+ cursor: pointer;
+ transition: transform 0.3s ease;
+}
+
+.cta-button:hover {
+ transform: translateY(-2px);
+}
+
+.features {
+ padding: 4rem 0;
+ background: var(--section-bg);
+}
+
+.features h2 {
+ text-align: center;
+ margin-bottom: 3rem;
+}
+
+.feature-grid {
+ display: grid;
+ grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
+ gap: 2rem;
+}
+
+.feature-card {
+ background: white;
+ padding: 2rem;
+ border-radius: 10px;
+ text-align: center;
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
+ transition: all 0.3s ease;
+ opacity: 0;
+ transform: translateY(20px);
+ animation: fadeInUp 0.6s ease forwards;
+}
+
+@keyframes fadeInUp {
+ to {
+ opacity: 1;
+ transform: translateY(0);
+ }
+}
+
+.feature-card:hover {
+ transform: translateY(-5px);
+}
+
+.feature-icon {
+ font-size: 2.5rem;
+ margin-bottom: 1rem;
+}
+
+.contact {
+ padding: 4rem 0;
+}
+
+.contact h2 {
+ text-align: center;
+ margin-bottom: 2rem;
+}
+
+.contact-form {
+ max-width: 600px;
+ margin: 0 auto;
+ display: flex;
+ flex-direction: column;
+ gap: 1rem;
+}
+
+.contact-form input,
+.contact-form textarea {
+ padding: 0.8rem;
+ border: 2px solid #ddd;
+ border-radius: 5px;
+ font-size: 1rem;
+ transition: all 0.3s ease;
+ outline: none;
+}
+
+.contact-form input:focus,
+.contact-form textarea:focus {
+ border-color: var(--primary-color);
+ box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
+}
+
+.contact-form input:hover,
+.contact-form textarea:hover {
+ border-color: var(--primary-color);
+}
+
+.contact-form textarea {
+ height: 150px;
+ resize: vertical;
+}
+
+.submit-button {
+ padding: 1rem;
+ background: var(--primary-color);
+ color: white;
+ border: none;
+ border-radius: 5px;
+ cursor: pointer;
+ transition: background-color 0.3s ease;
+}
+
+.submit-button:hover {
+ background: var(--secondary-color);
+}
+
+.testimonials {
+ padding: 4rem 0;
+ background: var(--section-bg);
+}
+
+.testimonials h2 {
+ text-align: center;
+ margin-bottom: 3rem;
+}
+
+.testimonial-grid {
+ display: grid;
+ grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
+ gap: 2rem;
+ max-width: 1200px;
+ margin: 0 auto;
+ padding: 0 2rem;
+}
+
+.testimonial-card {
+ background: white;
+ padding: 2rem;
+ border-radius: 10px;
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
+}
+
+.testimonial-text {
+ font-style: italic;
+ margin-bottom: 1rem;
+}
+
+.testimonial-author {
+ font-weight: bold;
+ color: var(--primary-color);
+}
+
+.footer {
+ background: var(--text-color);
+ color: white;
+ padding: 2rem 0;
+ text-align: center;
+}
+
+.social-links {
+ display: flex;
+ justify-content: center;
+ gap: 1.5rem;
+ margin: 1rem 0;
+}
+
+.social-links a {
+ color: white;
+ text-decoration: none;
+ font-size: 1.5rem;
+ transition: color 0.3s ease;
+}
+
+.social-links a:hover {
+ color: var(--primary-color);
+}
+
+@media (max-width: 768px) {
+ .nav-links {
+ display: none;
+ position: fixed;
+ top: 70px;
+ left: 0;
+ right: 0;
+ background: var(--background);
+ flex-direction: column;
+ padding: 2rem;
+ box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
+ text-align: center;
+ transform: translateY(-100%);
+ opacity: 0;
+ transition: transform 0.3s ease, opacity 0.3s ease;
+ }
+
+ .nav-links.active {
+ display: flex;
+ transform: translateY(0);
+ opacity: 1;
+ }
+
+ .mobile-nav-toggle {
+ display: block;
+ background: none;
+ border: none;
+ cursor: pointer;
+ }
+
+ .mobile-nav-toggle span {
+ display: block;
+ width: 25px;
+ height: 3px;
+ background: var(--text-color);
+ margin: 5px 0;
+ transition: 0.3s;
+ position: relative;
+ }
+
+ .mobile-nav-toggle.active span:nth-child(1) {
+ transform: rotate(45deg) translate(5px, 5px);
+ }
+
+ .mobile-nav-toggle.active span:nth-child(2) {
+ opacity: 0;
+ }
+
+ .mobile-nav-toggle.active span:nth-child(3) {
+ transform: rotate(-45deg) translate(7px, -6px);
+ }
+
+ .hero h1 {
+ font-size: 2.5rem;
+ }
+
+ .feature-grid {
+ grid-template-columns: 1fr;
+ }
+}
\ No newline at end of file
diff --git a/example/workspace/index.html b/example/workspace/index.html
new file mode 100644
index 0000000..40277ec
--- /dev/null
+++ b/example/workspace/index.html
@@ -0,0 +1,113 @@
+
+
+
+
+
+ Modern Landing Page
+
+
+
+
+
+
+
+
+
+
+
Welcome to the Future
+
Experience innovation like never before
+
+
+
+
+
+
+
Our Features
+
+
+
🚀
+
Fast Performance
+
Lightning-quick loading times
+
+
+
🎨
+
Beautiful Design
+
Stunning visual experience
+
+
+
📱
+
Responsive
+
Works on all devices
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/example/workspace/js/main.js b/example/workspace/js/main.js
new file mode 100644
index 0000000..8cefa43
--- /dev/null
+++ b/example/workspace/js/main.js
@@ -0,0 +1,85 @@
+const mobileNavToggle = document.querySelector('.mobile-nav-toggle');
+const navLinks = document.querySelector('.nav-links');
+
+mobileNavToggle.addEventListener('click', () => {
+ navLinks.classList.toggle('active');
+ mobileNavToggle.classList.toggle('active');
+});
+
+document.addEventListener('click', (e) => {
+ if (!e.target.closest('.nav') && navLinks.classList.contains('active')) {
+ navLinks.classList.remove('active');
+ mobileNavToggle.classList.remove('active');
+ }
+});
+
+document.querySelectorAll('a[href^="#"]').forEach(anchor => {
+ anchor.addEventListener('click', function (e) {
+ e.preventDefault();
+ document.querySelector(this.getAttribute('href')).scrollIntoView({
+ behavior: 'smooth'
+ });
+ });
+});
+
+const form = document.querySelector('.contact-form');
+form.addEventListener('submit', (e) => {
+ e.preventDefault();
+ const formData = new FormData(form);
+ const data = Object.fromEntries(formData);
+ console.log('Form submitted:', data);
+ form.reset();
+});
+
+let lastScroll = 0;
+window.addEventListener('load', () => {
+ setTimeout(() => {
+ document.querySelector('.loading').classList.add('loaded');
+ }, 500);
+});
+
+const scrollTopBtn = document.querySelector('.scroll-top');
+scrollTopBtn.addEventListener('click', () => {
+ window.scrollTo({
+ top: 0,
+ behavior: 'smooth'
+ });
+});
+
+const observeElements = () => {
+ const observer = new IntersectionObserver((entries) => {
+ entries.forEach(entry => {
+ if (entry.isIntersecting) {
+ entry.target.style.opacity = '1';
+ entry.target.style.transform = 'translateY(0)';
+ }
+ });
+ }, { threshold: 0.1 });
+
+ document.querySelectorAll('.feature-card').forEach(card => {
+ observer.observe(card);
+ card.style.opacity = '0';
+ card.style.transform = 'translateY(20px)';
+ card.style.transition = 'all 0.6s ease';
+ });
+};
+
+document.addEventListener('DOMContentLoaded', observeElements);
+
+window.addEventListener('scroll', () => {
+ scrollTopBtn.classList.toggle('visible', window.scrollY > 500);
+ const header = document.querySelector('.header');
+ const currentScroll = window.pageYOffset;
+
+ if (currentScroll <= 0) {
+ header.classList.remove('scrolled');
+ return;
+ }
+
+ if (currentScroll > lastScroll && !header.classList.contains('scrolled')) {
+ header.classList.add('scrolled');
+ } else if (currentScroll < lastScroll && header.classList.contains('scrolled')) {
+ header.classList.remove('scrolled');
+ }
+ lastScroll = currentScroll;
+});
\ No newline at end of file