Skip to content

Commit

Permalink
- Enhanced storage capabilities for production readiness:
Browse files Browse the repository at this point in the history
  - Added SQLite as primary storage backend for ThreadManager and StateManager
  - Implemented persistent storage with unique store IDs
  - Added CRUD operations for state management
  - Enabled multiple concurrent stores with referential integrity
  - Improved state persistence and retrieval mechanisms
  • Loading branch information
markokraemer committed Nov 19, 2024
1 parent 13e9867 commit cb9f7b6
Show file tree
Hide file tree
Showing 20 changed files with 613 additions and 428 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
0.1.9
- Enhanced storage capabilities for production readiness:
- Added SQLite as primary storage backend for ThreadManager and StateManager
- Implemented persistent storage with unique store IDs
- Added CRUD operations for state management
- Enabled multiple concurrent stores with referential integrity
- Improved state persistence and retrieval mechanisms

0.1.8
- Added base processor classes for extensible tool handling:
- ToolParserBase: Abstract base class for parsing LLM responses
Expand Down
Binary file added agentpress.db
Binary file not shown.
10 changes: 6 additions & 4 deletions agentpress/agents/simple_web_dev/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,13 @@
async def run_agent(thread_id: str, use_xml: bool = True, max_iterations: int = 5):
"""Run the development agent with specified configuration."""
thread_manager = ThreadManager()
state_manager = StateManager()

store_id = await StateManager.create_store()
state_manager = StateManager(store_id)

thread_manager.add_tool(FilesTool)
thread_manager.add_tool(TerminalTool)
thread_manager.add_tool(FilesTool, store_id=store_id)
thread_manager.add_tool(TerminalTool, store_id=store_id)

# Combine base message with XML format if needed
system_message = {
"role": "system",
"content": BASE_SYSTEM_MESSAGE + (XML_FORMAT if use_xml else "")
Expand Down Expand Up @@ -199,6 +200,7 @@ def main():

async def async_main():
thread_manager = ThreadManager()

thread_id = await thread_manager.create_thread()
await thread_manager.add_message(
thread_id,
Expand Down
7 changes: 4 additions & 3 deletions agentpress/agents/simple_web_dev/tools/files_tool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import asyncio
from pathlib import Path
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
from agentpress.tools.tool import Tool, ToolResult, openapi_schema, xml_schema
from agentpress.state_manager import StateManager
from typing import Optional

class FilesTool(Tool):
"""File management tool for creating, updating, and deleting files.
Expand Down Expand Up @@ -53,11 +54,11 @@ class FilesTool(Tool):
".sql"
}

def __init__(self):
def __init__(self, store_id: Optional[str] = None):
super().__init__()
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
os.makedirs(self.workspace, exist_ok=True)
self.state_manager = StateManager("state.json")
self.state_manager = StateManager(store_id)
self.SNIPPET_LINES = 4 # Number of context lines to show around edits
asyncio.create_task(self._init_workspace_state())

Expand Down
5 changes: 3 additions & 2 deletions agentpress/agents/simple_web_dev/tools/terminal_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import subprocess
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
from agentpress.state_manager import StateManager
from typing import Optional

class TerminalTool(Tool):
"""Terminal command execution tool for workspace operations."""

def __init__(self):
def __init__(self, store_id: Optional[str] = None):
super().__init__()
self.workspace = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'workspace')
os.makedirs(self.workspace, exist_ok=True)
self.state_manager = StateManager("state.json")
self.state_manager = StateManager(store_id)

async def _update_command_history(self, command: str, output: str, success: bool):
"""Update command history in state"""
Expand Down
16 changes: 8 additions & 8 deletions agentpress/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
"processors": {
"required": True,
"files": [
"base_processors.py",
"llm_response_processor.py",
"standard_tool_parser.py",
"standard_tool_executor.py",
"standard_results_adder.py",
"xml_tool_parser.py",
"xml_tool_executor.py",
"xml_results_adder.py"
"processor/base_processors.py",
"processor/llm_response_processor.py",
"processor/standard/standard_tool_parser.py",
"processor/standard/standard_tool_executor.py",
"processor/standard/standard_results_adder.py",
"processor/xml/xml_tool_parser.py",
"processor/xml/xml_tool_executor.py",
"processor/xml/xml_results_adder.py"
],
"description": "Response Processing System - Handles parsing and executing LLM responses, managing tool calls, and processing results. Supports both standard OpenAI-style function calling and XML-based tool execution patterns."
},
Expand Down
125 changes: 125 additions & 0 deletions agentpress/db_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
Centralized database connection management for AgentPress.
"""

import aiosqlite
import logging
from contextlib import asynccontextmanager
import os
import asyncio

class DBConnection:
"""Singleton database connection manager."""

_instance = None
_initialized = False
_db_path = os.path.join(os.getcwd(), "agentpress.db")
_init_lock = asyncio.Lock()
_initialization_task = None

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
# Start initialization when instance is first created
cls._initialization_task = asyncio.create_task(cls._instance._initialize())
return cls._instance

def __init__(self):
"""No initialization needed in __init__ as it's handled in __new__"""
pass

@classmethod
async def _initialize(cls):
"""Internal initialization method."""
if cls._initialized:
return

async with cls._init_lock:
if cls._initialized: # Double-check after acquiring lock
return

try:
async with aiosqlite.connect(cls._db_path) as db:
# Threads table
await db.execute("""
CREATE TABLE IF NOT EXISTS threads (
thread_id TEXT PRIMARY KEY,
messages TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")

# State stores table
await db.execute("""
CREATE TABLE IF NOT EXISTS state_stores (
store_id TEXT PRIMARY KEY,
store_data TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")

await db.commit()
cls._initialized = True
logging.info("Database schema initialized")
except Exception as e:
logging.error(f"Database initialization error: {e}")
raise

@classmethod
def set_db_path(cls, db_path: str):
"""Set custom database path."""
if cls._initialized:
raise RuntimeError("Cannot change database path after initialization")
cls._db_path = db_path
logging.info(f"Updated database path to: {db_path}")

@asynccontextmanager
async def connection(self):
"""Get a database connection."""
# Wait for initialization to complete if it hasn't already
if self._initialization_task and not self._initialized:
await self._initialization_task

async with aiosqlite.connect(self._db_path) as conn:
try:
yield conn
except Exception as e:
logging.error(f"Database error: {e}")
raise

@asynccontextmanager
async def transaction(self):
"""Execute operations in a transaction."""
async with self.connection() as db:
try:
yield db
await db.commit()
except Exception as e:
await db.rollback()
logging.error(f"Transaction error: {e}")
raise

async def execute(self, query: str, params: tuple = ()):
"""Execute a single query."""
async with self.connection() as db:
try:
result = await db.execute(query, params)
await db.commit()
return result
except Exception as e:
logging.error(f"Query execution error: {e}")
raise

async def fetch_one(self, query: str, params: tuple = ()):
"""Fetch a single row."""
async with self.connection() as db:
async with db.execute(query, params) as cursor:
return await cursor.fetchone()

async def fetch_all(self, query: str, params: tuple = ()):
"""Fetch all rows."""
async with self.connection() as db:
async with db.execute(query, params) as cursor:
return await cursor.fetchall()
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class ResultsAdderBase(ABC):
Attributes:
add_message: Callback for adding new messages
update_message: Callback for updating existing messages
list_messages: Callback for retrieving thread messages
get_messages: Callback for retrieving thread messages
message_added: Flag tracking if initial message has been added
"""

Expand All @@ -184,7 +184,7 @@ def __init__(self, thread_manager):
"""
self.add_message = thread_manager.add_message
self.update_message = thread_manager._update_message
self.list_messages = thread_manager.list_messages
self.get_messages = thread_manager.get_messages
self.message_added = False

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import asyncio
from typing import Callable, Dict, Any, AsyncGenerator, Optional
import logging
from agentpress.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
from agentpress.standard_tool_parser import StandardToolParser
from agentpress.standard_tool_executor import StandardToolExecutor
from agentpress.standard_results_adder import StandardResultsAdder
from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
from agentpress.processor.standard.standard_tool_parser import StandardToolParser
from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor
from agentpress.processor.standard.standard_results_adder import StandardResultsAdder

class LLMResponseProcessor:
"""Handles LLM response processing and tool execution management.
Expand All @@ -40,9 +40,8 @@ def __init__(
available_functions: Dict = None,
add_message_callback: Callable = None,
update_message_callback: Callable = None,
list_messages_callback: Callable = None,
get_messages_callback: Callable = None,
parallel_tool_execution: bool = True,
threads_dir: str = "threads",
tool_parser: Optional[ToolParserBase] = None,
tool_executor: Optional[ToolExecutorBase] = None,
results_adder: Optional[ResultsAdderBase] = None,
Expand All @@ -55,9 +54,8 @@ def __init__(
available_functions: Dictionary of available tool functions
add_message_callback: Callback for adding messages
update_message_callback: Callback for updating messages
list_messages_callback: Callback for listing messages
get_messages_callback: Callback for listing messages
parallel_tool_execution: Whether to execute tools in parallel
threads_dir: Directory for thread storage
tool_parser: Custom tool parser implementation
tool_executor: Custom tool executor implementation
results_adder: Custom results adder implementation
Expand All @@ -67,16 +65,15 @@ def __init__(
self.tool_executor = tool_executor or StandardToolExecutor(parallel=parallel_tool_execution)
self.tool_parser = tool_parser or StandardToolParser()
self.available_functions = available_functions or {}
self.threads_dir = threads_dir

# Create minimal thread manager if needed
if thread_manager is None and (add_message_callback and update_message_callback and list_messages_callback):
if thread_manager is None and (add_message_callback and update_message_callback and get_messages_callback):
class MinimalThreadManager:
def __init__(self, add_msg, update_msg, list_msg):
self.add_message = add_msg
self._update_message = update_msg
self.list_messages = list_msg
thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, list_messages_callback)
self.get_messages = list_msg
thread_manager = MinimalThreadManager(add_message_callback, update_message_callback, get_messages_callback)

self.results_adder = results_adder or StandardResultsAdder(thread_manager)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, Any, List, Optional
from agentpress.base_processors import ResultsAdderBase
from agentpress.processor.base_processors import ResultsAdderBase

# --- Standard Results Adder Implementation ---

Expand Down Expand Up @@ -81,6 +81,6 @@ async def add_tool_result(self, thread_id: str, result: Dict[str, Any]):
- Checks for duplicate tool results before adding
- Adds result only if tool_call_id is unique
"""
messages = await self.list_messages(thread_id)
messages = await self.get_messages(thread_id)
if not any(msg.get('tool_call_id') == result['tool_call_id'] for msg in messages):
await self.add_message(thread_id, result)
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
import logging
from typing import Dict, Any, List, Set, Callable, Optional
from agentpress.base_processors import ToolExecutorBase
from agentpress.processor.base_processors import ToolExecutorBase
from agentpress.tool import ToolResult

# --- Standard Tool Executor Implementation ---
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Dict, Any, Optional
from agentpress.base_processors import ToolParserBase
from agentpress.processor.base_processors import ToolParserBase

# --- Standard Tool Parser Implementation ---

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Dict, Any, List, Optional
from agentpress.base_processors import ResultsAdderBase
from agentpress.processor.base_processors import ResultsAdderBase

class XMLResultsAdder(ResultsAdderBase):
"""XML-specific implementation for handling tool results and message processing.
Expand Down Expand Up @@ -79,7 +79,7 @@ async def add_tool_result(self, thread_id: str, result: Dict[str, Any]):
"""
try:
# Get the original tool call to find the root tag
messages = await self.list_messages(thread_id)
messages = await self.get_messages(thread_id)
assistant_msg = next((msg for msg in reversed(messages)
if msg['role'] == 'assistant'), None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import asyncio
import json
import logging
from agentpress.base_processors import ToolExecutorBase
from agentpress.processor.base_processors import ToolExecutorBase
from agentpress.tool import ToolResult
from agentpress.tool_registry import ToolRegistry

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import logging
from typing import Dict, Any, Optional, List, Tuple
from agentpress.base_processors import ToolParserBase
from agentpress.processor.base_processors import ToolParserBase
import json
import re
from agentpress.tool_registry import ToolRegistry
Expand Down
Loading

0 comments on commit cb9f7b6

Please sign in to comment.