Skip to content

Commit c028fac

Browse files
committed
fix: important bug fixes
1 parent f3adf0c commit c028fac

File tree

20 files changed

+4123
-1138
lines changed

20 files changed

+4123
-1138
lines changed

dist/pilott-0.0.0-py3-none-any.whl

31 KB
Binary file not shown.

dist/pilott-0.0.0.tar.gz

24.2 KB
Binary file not shown.
15 KB
Binary file not shown.

pilott/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pilott.core.router import TaskRouter, TaskPriority
66
from pilott.core.role import AgentRole
77
from pilott.core.status import AgentStatus
8+
from pilott.core.task import Task
89

910
__all__ = [
1011
'AgentRole',

pilott/core/agent.py

Lines changed: 379 additions & 73 deletions
Large diffs are not rendered by default.

pilott/core/config.py

Lines changed: 164 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,103 @@
1+
import shutil
12
from typing import Optional, List, Dict, Any
2-
from pydantic import BaseModel, ConfigDict, Field
3+
from pydantic import BaseModel, ConfigDict, Field, SecretStr, field_validator
34
from pilott.core.role import AgentRole
4-
5+
from pathlib import Path
6+
from cryptography.fernet import Fernet
7+
import json
8+
9+
10+
class SecureConfig:
11+
"""Handles secure storage and retrieval of sensitive config values"""
12+
13+
def __init__(self, key_path: Optional[Path] = None):
14+
self._key_path = key_path
15+
if key_path and key_path.exists():
16+
self.key = key_path.read_bytes()
17+
else:
18+
self.key = Fernet.generate_key()
19+
if key_path:
20+
key_path.parent.mkdir(parents=True, exist_ok=True)
21+
key_path.write_bytes(self.key)
22+
self.cipher = Fernet(self.key)
23+
24+
def encrypt(self, value: str) -> bytes:
25+
if not value:
26+
raise ValueError("Cannot encrypt empty value")
27+
return self.cipher.encrypt(value.encode())
28+
29+
def decrypt(self, value: bytes) -> str:
30+
if not value:
31+
raise ValueError("Cannot decrypt empty value")
32+
return self.cipher.decrypt(value).decode()
33+
34+
def cleanup(self):
35+
try:
36+
if self._key_path and self._key_path.exists():
37+
self._key_path.unlink()
38+
except Exception:
39+
pass
540

641
class LLMConfig(BaseModel):
7-
"""Configuration for LLM integration"""
42+
"""Enhanced configuration for LLM integration"""
843
model_config = ConfigDict(
944
arbitrary_types_allowed=True,
1045
use_enum_values=True
1146
)
1247

1348
model_name: str
1449
provider: str
15-
api_key: str
16-
temperature: float = 0.7
17-
max_tokens: int = 2000
50+
api_key: SecretStr
51+
temperature: float = Field(ge=0.0, le=1.0, default=0.7)
52+
max_tokens: int = Field(gt=0, default=2000)
1853
function_calling_model: Optional[str] = None
1954
system_template: Optional[str] = None
2055
prompt_template: Optional[str] = None
56+
retry_attempts: int = Field(ge=0, default=3)
57+
timeout: float = Field(gt=0, default=30.0)
58+
_secure_config: Optional[SecureConfig] = None
59+
60+
def __init__(self, **data):
61+
super().__init__(**data)
62+
self._secure_config = SecureConfig()
63+
64+
@field_validator('api_key')
65+
def encrypt_api_key(cls, v):
66+
if isinstance(v, str):
67+
return SecretStr(v)
68+
return v
2169

2270
def to_dict(self) -> Dict[str, Any]:
23-
"""Convert to a dictionary with only basic Python types"""
2471
return {
25-
"model_name": str(self.model_name),
26-
"provider": str(self.provider),
27-
"api_key": str(self.api_key),
28-
"temperature": float(self.temperature),
29-
"max_tokens": int(self.max_tokens)
72+
"model_name": self.model_name,
73+
"provider": self.provider,
74+
"temperature": self.temperature,
75+
"max_tokens": self.max_tokens,
76+
"function_calling_model": self.function_calling_model
3077
}
3178

3279

3380
class LogConfig(BaseModel):
34-
"""Configuration for logging"""
81+
"""Enhanced logging configuration"""
3582
model_config = ConfigDict(arbitrary_types_allowed=True)
3683

3784
verbose: bool = False
3885
log_to_file: bool = False
39-
log_dir: str = "logs"
86+
log_dir: Path = Field(default=Path("logs"))
4087
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
41-
log_level: str = "INFO"
88+
log_level: str = Field(default="INFO", pattern="^(DEBUG|INFO|WARNING|ERROR|CRITICAL)$")
89+
max_file_size: int = Field(default=10 * 1024 * 1024) # 10MB
90+
backup_count: int = Field(ge=0, default=5)
91+
log_rotation: str = Field(default="midnight")
92+
93+
@field_validator('log_dir')
94+
def create_log_dir(cls, v):
95+
v = Path(v)
96+
try:
97+
v.mkdir(parents=True, exist_ok=True)
98+
except Exception as e:
99+
raise ValueError(f"Failed to create log directory: {str(e)}")
100+
return v
42101

43102

44103
class AgentConfig(BaseModel):
@@ -58,13 +117,14 @@ class AgentConfig(BaseModel):
58117
# Knowledge and Tools
59118
knowledge_sources: List[str] = Field(default_factory=list)
60119
tools: List[str] = Field(default_factory=list)
120+
required_capabilities: List[str] = Field(default_factory=list)
61121

62122
# Execution Settings
63-
max_iterations: int = 20
64-
max_rpm: Optional[int] = None
65-
max_execution_time: Optional[int] = None
66-
retry_limit: int = 2
67-
code_execution_mode: str = "safe"
123+
max_iterations: int = Field(gt=0, default=20)
124+
max_rpm: Optional[int] = Field(gt=0, default=None)
125+
max_execution_time: Optional[int] = Field(gt=0, default=None)
126+
retry_limit: int = Field(ge=0, default=2)
127+
code_execution_mode: str = Field(default="safe", pattern="^(safe|restricted|unrestricted)$")
68128

69129
# Features
70130
memory_enabled: bool = True
@@ -73,23 +133,37 @@ class AgentConfig(BaseModel):
73133
use_cache: bool = True
74134
can_execute_code: bool = False
75135

76-
# Orchestration Settings
77-
max_child_agents: int = 10
78-
max_queue_size: int = 100
79-
max_task_complexity: int = 5
80-
delegation_threshold: float = 0.7
136+
# Resource Limits
137+
max_child_agents: int = Field(gt=0, default=10)
138+
max_queue_size: int = Field(gt=0, default=100)
139+
max_task_complexity: int = Field(ge=1, le=10, default=5)
140+
delegation_threshold: float = Field(ge=0.0, le=1.0, default=0.7)
141+
142+
# Performance Settings
143+
max_concurrent_tasks: int = Field(gt=0, default=5)
144+
task_timeout: int = Field(gt=0, default=300)
145+
resource_limits: Dict[str, float] = Field(
146+
default_factory=lambda: {
147+
"cpu_percent": 80.0,
148+
"memory_percent": 80.0,
149+
"disk_percent": 80.0
150+
}
151+
)
81152

82153
# WebSocket Configuration
83154
websocket_enabled: bool = True
84155
websocket_host: str = "localhost"
85-
websocket_port: int = 8765
156+
websocket_port: int = Field(ge=1024, le=65535, default=8765)
86157

87-
# Async Settings
88-
max_concurrent_tasks: int = 5
89-
task_timeout: int = 300 # seconds
158+
@field_validator('resource_limits')
159+
def validate_resource_limits(cls, v):
160+
for key, value in v.items():
161+
if value <= 0 or value > 100:
162+
raise ValueError(f"Resource limit {key} must be between 0 and 100")
163+
return v
90164

91165
def to_dict(self) -> Dict[str, Any]:
92-
"""Convert to a dictionary with only basic Python types"""
166+
"""Convert to a dictionary with type handling"""
93167
return {
94168
"role": str(self.role),
95169
"role_type": str(self.role_type),
@@ -98,6 +172,7 @@ def to_dict(self) -> Dict[str, Any]:
98172
"backstory": str(self.backstory) if self.backstory else None,
99173
"knowledge_sources": list(self.knowledge_sources),
100174
"tools": list(self.tools),
175+
"required_capabilities": list(self.required_capabilities),
101176
"max_iterations": int(self.max_iterations),
102177
"max_rpm": int(self.max_rpm) if self.max_rpm else None,
103178
"max_execution_time": int(self.max_execution_time) if self.max_execution_time else None,
@@ -113,5 +188,62 @@ def to_dict(self) -> Dict[str, Any]:
113188
"max_task_complexity": int(self.max_task_complexity),
114189
"delegation_threshold": float(self.delegation_threshold),
115190
"max_concurrent_tasks": int(self.max_concurrent_tasks),
116-
"task_timeout": int(self.task_timeout)
117-
}
191+
"task_timeout": int(self.task_timeout),
192+
"resource_limits": dict(self.resource_limits),
193+
"websocket_enabled": bool(self.websocket_enabled),
194+
"websocket_host": str(self.websocket_host),
195+
"websocket_port": int(self.websocket_port)
196+
}
197+
198+
@classmethod
199+
def from_file(cls, path: Path) -> 'AgentConfig':
200+
"""Load configuration from file with proper error handling"""
201+
if not path.exists():
202+
raise FileNotFoundError(f"Config file not found: {path}")
203+
try:
204+
with open(path, 'r') as f:
205+
data = json.load(f)
206+
return cls(**data)
207+
except json.JSONDecodeError as e:
208+
raise ValueError(f"Invalid JSON in config file: {str(e)}")
209+
except Exception as e:
210+
raise ValueError(f"Failed to load config: {str(e)}")
211+
212+
@property
213+
def has_sensitive_data(self) -> bool:
214+
"""Check if config contains sensitive data"""
215+
sensitive_patterns = ['password', 'secret', 'key', 'token', 'auth']
216+
dict_data = self.to_dict()
217+
return any(
218+
pattern in str(value).lower() or pattern in str(key).lower()
219+
for pattern in sensitive_patterns
220+
for key, value in dict_data.items()
221+
)
222+
223+
def save_to_file(self, path: Path):
224+
"""Save configuration to file with backup"""
225+
path = Path(path)
226+
backup_path = None
227+
228+
try:
229+
# Create backup if file exists
230+
if path.exists():
231+
backup_path = path.with_suffix('.bak')
232+
shutil.copy2(path, backup_path)
233+
234+
# Ensure directory exists
235+
path.parent.mkdir(parents=True, exist_ok=True)
236+
237+
# Save new config
238+
with open(path, 'w') as f:
239+
json.dump(self.to_dict(), f, indent=2)
240+
241+
# Remove backup if everything succeeded
242+
if backup_path and backup_path.exists():
243+
backup_path.unlink()
244+
245+
except Exception as e:
246+
# Restore backup if save failed
247+
if backup_path and backup_path.exists():
248+
shutil.copy2(backup_path, path)
249+
raise ValueError(f"Failed to save config: {str(e)}")

0 commit comments

Comments
 (0)